Substation#

class torchgeo.datasets.Substation(root='data', bands=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), mask_2d=True, num_of_timepoints=4, timepoint_aggregation=None, transforms=None, download=False, checksum=False)[source]#

Bases: NonGeoDataset

Substation dataset.

The Substation dataset is curated by TransitionZero and sourced from publicly available data repositories, including OpenSreetMap (OSM) and Copernicus Sentinel data. The dataset consists of Sentinel-2 images from 27k+ locations; the task is to segment power-substations, which appear in the majority of locations in the dataset. Most locations have 4-5 images taken at different timepoints (i.e., revisits).

Dataset Format:

  • .npz file for each datapoint

Dataset Features:

  • 26,522 image-mask pairs stored as numpy files.

  • Data from 5 revisits for most locations.

  • Multi-temporal, multi-spectral images (13 channels) paired with masks, with a spatial resolution of 228x228 pixels. When timepoint_aggregation is None, images are returned as T x C x H x W tensors.

If you use this dataset in your research, please cite the following paper:

__init__(root='data', bands=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12), mask_2d=True, num_of_timepoints=4, timepoint_aggregation=None, transforms=None, download=False, checksum=False)[source]#

Initialize the Substation.

Parameters:
  • root (str | PathLike[str]) – Path to the directory containing the dataset.

  • bands (Sequence[int]) – Channels to use from the image.

  • mask_2d (bool) – Whether to use a 2D mask.

  • num_of_timepoints (int) – Number of timepoints to use for each image.

  • timepoint_aggregation (Literal['concat', 'median', 'first', 'random'] | None) – How to aggregate multiple timepoints. If None, returns time-series as T x C x H x W.

  • transforms (Callable[[dict[str, Any]], dict[str, Any]] | None) – A transform takes input sample and returns a transformed version.

  • download (bool) – Whether to download the dataset if it is not found.

  • checksum (bool) – Whether to verify the dataset after downloading.

__getitem__(index)[source]#

Get an item from the dataset by index.

Parameters:

index (int) – Index of the item to retrieve.

Returns:

A dictionary containing the image and corresponding mask.

Return type:

dict[str, Any]

__len__()[source]#

Returns the number of items in the dataset.

plot(sample, show_titles=True, suptitle=None)[source]#

Plot a sample from the dataset.

When the image is 4D (T x C x H x W), the first two timepoints are plotted.

Parameters:
  • sample (dict[str, Any]) – a sample returned by __getitem__()

  • show_titles (bool) – flag indicating whether to show titles above each panel

  • suptitle (str | None) – optional string to use as a suptitle

Returns:

A matplotlib Figure containing the rendered sample.

Raises:

RGBBandsMissingError – If bands does not include all RGB bands.

Return type:

Figure