torchgeo.datamodules#

Geospatial DataModules#

AgriFieldNet#

class torchgeo.datamodules.AgriFieldNetDataModule(batch_size=64, patch_size=256, length=None, num_workers=0, **kwargs)[source]#

Bases: GeoDataModule

LightningDataModule implementation for the AgriFieldNet dataset.

Added in version 0.6.

__init__(batch_size=64, patch_size=256, length=None, num_workers=0, **kwargs)[source]#

Initialize a new AgriFieldNetDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (int | tuple[int, int]) – Size of each patch, either size or (height, width).

  • length (int | None) – Length of each training epoch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to AgriFieldNet.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

Chesapeake Land Cover#

class torchgeo.datamodules.ChesapeakeCVPRDataModule(train_splits, val_splits, test_splits, batch_size=64, patch_size=256, length=None, num_workers=0, class_set=7, use_prior_labels=False, prior_smoothing_constant=0.0001, **kwargs)[source]#

Bases: GeoDataModule

LightningDataModule implementation for the Chesapeake CVPR Land Cover dataset.

Uses the random splits defined per state to partition tiles into train, val, and test sets.

__init__(train_splits, val_splits, test_splits, batch_size=64, patch_size=256, length=None, num_workers=0, class_set=7, use_prior_labels=False, prior_smoothing_constant=0.0001, **kwargs)[source]#

Initialize a new ChesapeakeCVPRDataModule instance.

Parameters:
  • train_splits (list[str]) – Splits used to train the model, e.g., [“ny-train”].

  • val_splits (list[str]) – Splits used to validate the model, e.g., [“ny-val”].

  • test_splits (list[str]) – Splits used to test the model, e.g., [“ny-test”].

  • batch_size (int) – Size of each mini-batch.

  • patch_size (int) – Size of each patch, either size or (height, width). Should be a multiple of 32 for most segmentation architectures.

  • length (int | None) – Length of each training epoch.

  • num_workers (int) – Number of workers for parallel data loading.

  • class_set (int) – The high-resolution land cover class set to use (5 or 7).

  • use_prior_labels (bool) – Flag for using a prior over high-resolution classes instead of the high-resolution labels themselves.

  • prior_smoothing_constant (float) – Additive smoothing to add when using prior labels.

  • **kwargs (Any) – Additional keyword arguments passed to ChesapeakeCVPR.

Raises:

AssertionError – If use_prior_labels=True is used with class_set=7.

setup(stage)[source]#

Set up datasets and samplers.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

on_after_batch_transfer(batch, dataloader_idx)[source]#

Apply batch augmentations to the batch after it is transferred to the device.

Parameters:
  • batch (dict[str, Any]) – A batch of data that needs to be altered or augmented.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Returns:

A batch of data.

Return type:

dict[str, Any]

L7 Irish#

class torchgeo.datamodules.L7IrishDataModule(batch_size=1, patch_size=224, length=None, num_workers=0, **kwargs)[source]#

Bases: GeoDataModule

LightningDataModule implementation for the L7 Irish dataset.

Added in version 0.5.

__init__(batch_size=1, patch_size=224, length=None, num_workers=0, **kwargs)[source]#

Initialize a new L7IrishDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (int | tuple[int, int]) – Size of each patch, either size or (height, width).

  • length (int | None) – Length of each training epoch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to L7Irish.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

L8 Biome#

class torchgeo.datamodules.L8BiomeDataModule(batch_size=1, patch_size=224, length=None, num_workers=0, **kwargs)[source]#

Bases: GeoDataModule

LightningDataModule implementation for the L8 Biome dataset.

Added in version 0.5.

__init__(batch_size=1, patch_size=224, length=None, num_workers=0, **kwargs)[source]#

Initialize a new L8BiomeDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (int | tuple[int, int]) – Size of each patch, either size or (height, width).

  • length (int | None) – Length of each training epoch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to L8Biome.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

MMFlood#

class torchgeo.datamodules.MMFloodDataModule(batch_size=32, patch_size=512, length=None, num_workers=0, **kwargs)[source]#

Bases: GeoDataModule

LightningDataModule implementation for the MMFlood dataset.

Added in version 0.7.

__init__(batch_size=32, patch_size=512, length=None, num_workers=0, **kwargs)[source]#

Initialize a new MMFloodDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (int | tuple[int, int]) – Size of each patch, either size or (height, width).

  • length (int | None) – Length of each training epoch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to MMFlood.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, ‘predict’.

NAIP#

class torchgeo.datamodules.NAIPChesapeakeDataModule(batch_size=64, patch_size=256, length=None, num_workers=0, **kwargs)[source]#

Bases: GeoDataModule

LightningDataModule implementation for the NAIP and Chesapeake datasets.

Uses the train/val/test splits from the dataset.

__init__(batch_size=64, patch_size=256, length=None, num_workers=0, **kwargs)[source]#

Initialize a new NAIPChesapeakeDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (int | tuple[int, int]) – Size of each patch, either size or (height, width).

  • length (int | None) – Length of each training epoch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to NAIP (prefix keys with naip_) and Chesapeake (prefix keys with chesapeake_).

setup(stage)[source]#

Set up datasets and samplers.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

plot(*args, **kwargs)[source]#

Run NAIP plot method.

Parameters:
  • *args (Any) – Arguments passed to plot method.

  • **kwargs (Any) – Keyword arguments passed to plot method.

Returns:

A matplotlib Figure with the image, ground truth, and predictions.

Return type:

Figure

Added in version 0.4.

I/O Bench#

class torchgeo.datamodules.IOBenchDataModule(batch_size=32, patch_size=256, length=None, num_workers=0, **kwargs)[source]#

Bases: GeoDataModule

LightningDataModule implementation for the I/O benchmark dataset.

Added in version 0.6.

__init__(batch_size=32, patch_size=256, length=None, num_workers=0, **kwargs)[source]#

Initialize a new IOBenchDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (int | tuple[int, int]) – Size of each patch, either size or (height, width).

  • length (int | None) – Length of each training epoch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to IOBench.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

Sentinel#

class torchgeo.datamodules.Sentinel2CDLDataModule(batch_size=64, patch_size=64, length=None, num_workers=0, **kwargs)[source]#

Bases: GeoDataModule

LightningDataModule implementation for the Sentinel-2 and CDL datasets.

Added in version 0.6.

__init__(batch_size=64, patch_size=64, length=None, num_workers=0, **kwargs)[source]#

Initialize a new Sentinel2CDLDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (int | tuple[int, int]) – Size of each patch, either size or (height, width).

  • length (int | None) – Length of each training epoch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to CDL (prefix keys with cdl_) and Sentinel2 (prefix keys with sentinel2_).

setup(stage)[source]#

Set up datasets and samplers.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

plot(*args, **kwargs)[source]#

Run CDL plot method.

Parameters:
  • *args (Any) – Arguments passed to plot method.

  • **kwargs (Any) – Keyword arguments passed to plot method.

Returns:

A matplotlib Figure with the image, ground truth, and predictions.

Return type:

Figure

class torchgeo.datamodules.Sentinel2EuroCropsDataModule(batch_size=64, patch_size=256, length=None, num_workers=0, **kwargs)[source]#

Bases: GeoDataModule

LightningDataModule implementation for the EuroCrops and Sentinel2 datasets.

Uses the train/val/test splits from the dataset.

Added in version 0.6.

__init__(batch_size=64, patch_size=256, length=None, num_workers=0, **kwargs)[source]#

Initialize a new Sentinel2EuroCropsDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (int | tuple[int, int]) – Size of each patch, either size or (height, width).

  • length (int | None) – Length of each training epoch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to EuroCrops (prefix keys with eurocrops_) and Sentinel2 (prefix keys with sentinel2_).

setup(stage)[source]#

Set up datasets and samplers.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

plot(*args, **kwargs)[source]#

Run EuroCrops plot method.

Parameters:
  • *args (Any) – Arguments passed to plot method.

  • **kwargs (Any) – Keyword arguments passed to plot method.

Returns:

A matplotlib Figure with the image, ground truth, and predictions.

Return type:

Figure

class torchgeo.datamodules.Sentinel2NCCMDataModule(batch_size=64, patch_size=64, length=None, num_workers=0, **kwargs)[source]#

Bases: GeoDataModule

LightningDataModule implementation for the Sentinel-2 and NCCM dataset.

Added in version 0.6.

__init__(batch_size=64, patch_size=64, length=None, num_workers=0, **kwargs)[source]#

Initialize a new Sentinel2NCCMDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (int | tuple[int, int]) – Size of each patch, either size or (height, width).

  • length (int | None) – Length of each training epoch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to NCCM (prefix keys with nccm_) and Sentinel2 (prefix keys with sentinel2_).

setup(stage)[source]#

Set up datasets and samplers.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

plot(*args, **kwargs)[source]#

Run NCCM plot method.

Parameters:
  • *args (Any) – Arguments passed to plot method.

  • **kwargs (Any) – Keyword arguments passed to plot method.

Returns:

A matplotlib Figure with the image, ground truth, and predictions.

Return type:

Figure

class torchgeo.datamodules.Sentinel2SouthAmericaSoybeanDataModule(batch_size=64, patch_size=64, length=None, num_workers=0, **kwargs)[source]#

Bases: GeoDataModule

LightningDataModule for SouthAmericaSoybean and Sentinel2 datasets.

Added in version 0.6.

__init__(batch_size=64, patch_size=64, length=None, num_workers=0, **kwargs)[source]#

Initialize a new Sentinel2SouthAmericaSoybeanDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (int | tuple[int, int]) – Size of each patch, either size or (height, width).

  • length (int | None) – Length of each training epoch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to SouthAmericaSoybean (prefix keys with south_america_soybean_) and Sentinel2 (prefix keys with sentinel2_).

setup(stage)[source]#

Set up datasets and samplers.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

plot(*args, **kwargs)[source]#

Run SouthAmericaSoybean plot method.

Parameters:
  • *args (Any) – Arguments passed to plot method.

  • **kwargs (Any) – Keyword arguments passed to plot method.

Returns:

A matplotlib Figure with the image, ground truth, and predictions.

Return type:

Figure

SouthAfricaCropType#

class torchgeo.datamodules.SouthAfricaCropTypeDataModule(batch_size=64, patch_size=16, length=None, num_workers=0, **kwargs)[source]#

Bases: GeoDataModule

LightningDataModule implementation for the SouthAfricaCropType dataset.

Added in version 0.6.

__init__(batch_size=64, patch_size=16, length=None, num_workers=0, **kwargs)[source]#

Initialize a new SouthAfricaCropTypeDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (int | tuple[int, int]) – Size of each patch, either size or (height, width).

  • length (int | None) – Length of each training epoch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to SouthAfricaCropType.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

Non-geospatial DataModules#

BigEarthNet#

class torchgeo.datamodules.BigEarthNetDataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the BigEarthNet dataset.

Uses the train/val/test splits from the dataset.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new BigEarthNetDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to BigEarthNet.

BRIGHT#

class torchgeo.datamodules.BRIGHTDFC2025DataModule(batch_size=32, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the BRIGHT dataset.

Added in version 0.8.

__init__(batch_size=32, num_workers=0, **kwargs)[source]#

Initialize a new BRIGHTBRIGHTDFC2025DataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to BRIGHTDFC2025.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

CaBuAr#

class torchgeo.datamodules.CaBuArDataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the CaBuAr dataset.

Uses the train/val/test splits from the dataset

Added in version 0.6.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new CaBuArDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to CaBuAr.

CaFFe#

class torchgeo.datamodules.CaFFeDataModule(batch_size=64, num_workers=0, size=512, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the CaFFe dataset.

Implements the default splits that come with the dataset.

Added in version 0.7.

__init__(batch_size=64, num_workers=0, size=512, **kwargs)[source]#

Initialize a new CaFFeDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • size (int) – resize images of input size 512x512 to size x size

  • **kwargs (Any) – Additional keyword arguments passed to CaFFe.

ChaBuD#

class torchgeo.datamodules.ChaBuDDataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the ChaBuD dataset.

Uses the train/val splits from the dataset

Added in version 0.6.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new ChaBuDDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to ChaBuD.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

Cloud Cover Detection#

class torchgeo.datamodules.CloudCoverDetectionDataModule(batch_size=64, num_workers=0, val_split_pct=0.2, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for Cloud Cover Detection.

Splits the training split into train/val subsets using val_split_pct.

Added in version 0.9.

__init__(batch_size=64, num_workers=0, val_split_pct=0.2, **kwargs)[source]#

Initialize a new CloudCoverDetectionDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • val_split_pct (float) – Percentage of the training data to reserve for validation.

  • **kwargs (Any) – Additional keyword arguments passed to CloudCoverDetection.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

COWC#

class torchgeo.datamodules.COWCCountingDataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the COWC Counting dataset.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new COWCCountingDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to COWCCounting.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

Deep Globe Land Cover Challenge#

class torchgeo.datamodules.DeepGlobeLandCoverDataModule(batch_size=64, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the DeepGlobe Land Cover dataset.

Uses the train/test splits from the dataset.

__init__(batch_size=64, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]#

Initialize a new DeepGlobeLandCoverDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (tuple[int, int] | int) – Size of each patch, either size or (height, width). Should be a multiple of 32 for most segmentation architectures.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to DeepGlobeLandCover.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

Digital Typhoon#

class torchgeo.datamodules.DigitalTyphoonDataModule(split_by='time', batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

Digital Typhoon Data Module.

Added in version 0.6.

__init__(split_by='time', batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new DigitalTyphoonDataModule instance.

Parameters:
  • split_by (str) – Either ‘time’ or ‘typhoon_id’, which decides how to split the dataset for train, val, test

  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to DigitalTyphoon.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

ETCI2021 Flood Detection#

class torchgeo.datamodules.ETCI2021DataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the ETCI2021 dataset.

Splits the existing train split from the dataset into train/val with 80/20 proportions, then uses the existing val dataset as the test data.

Added in version 0.2.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new ETCI2021DataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to ETCI2021.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

on_after_batch_transfer(batch, dataloader_idx)[source]#

Apply batch augmentations to the batch after it is transferred to the device.

Parameters:
  • batch (dict[str, Any]) – A batch of data that needs to be altered or augmented.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Returns:

A batch of data.

Return type:

dict[str, Any]

EuroSAT#

class torchgeo.datamodules.EuroSATDataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the EuroSAT dataset.

Uses the train/val/test splits from the dataset.

Added in version 0.2.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new EuroSATDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to EuroSAT.

class torchgeo.datamodules.EuroSATSpatialDataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the EuroSATSpatial dataset.

Uses the spatial train/val/test splits from the dataset.

Added in version 0.6.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new EuroSATSpatialDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to EuroSATSpatial.

class torchgeo.datamodules.EuroSAT100DataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the EuroSAT100 dataset.

Intended for tutorials and demonstrations, not for benchmarking.

Added in version 0.5.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new EuroSAT100DataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to EuroSAT100.

FAIR1M#

class torchgeo.datamodules.FAIR1MDataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the FAIR1M dataset.

Added in version 0.2.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new FAIR1MDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to FAIR1M.

Changed in version 0.5: Removed val_split_pct and test_split_pct parameters.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

Fields Of The World#

class torchgeo.datamodules.FieldsOfTheWorldDataModule(train_countries=['austria'], val_countries=['austria'], test_countries=['austria'], batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the FTW dataset.

Added in version 0.7.

__init__(train_countries=['austria'], val_countries=['austria'], test_countries=['austria'], batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new FTWDataModule instance.

Parameters:
  • train_countries (list[str]) – List of countries to use for training.

  • val_countries (list[str]) – List of countries to use for validation.

  • test_countries (list[str]) – List of countries to use for testing.

  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to FieldsOfTheWorld.

Raises:

AssertionError – If ‘countries’ are specified in kwargs

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, or ‘test’.

FireRisk#

class torchgeo.datamodules.FireRiskDataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the FireRisk dataset.

Added in version 0.5.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new FireRiskDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to FireRisk.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

GeoNRW#

class torchgeo.datamodules.GeoNRWDataModule(batch_size=64, num_workers=0, size=256, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the GeoNRW dataset.

Implements 80/20 train/val splits based on city locations. See setup() for more details.

Added in version 0.6.

__init__(batch_size=64, num_workers=0, size=256, **kwargs)[source]#

Initialize a new GeoNRWDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • size (int) – resize images of input size 1000x1000 to size x size

  • **kwargs (Any) – Additional keyword arguments passed to GeoNRW.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

GID-15#

class torchgeo.datamodules.GID15DataModule(batch_size=64, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the GID-15 dataset.

Uses the train/test splits from the dataset.

Added in version 0.4.

__init__(batch_size=64, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]#

Initialize a new GID15DataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (tuple[int, int] | int) – Size of each patch, either size or (height, width). Should be a multiple of 32 for most segmentation architectures.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to GID15.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

HySpecNet-11k#

class torchgeo.datamodules.HySpecNet11kDataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the HySpecNet11k dataset.

Added in version 0.7.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new HySpecNet11kDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to HySpecNet11k.

Inria Aerial Image Labeling#

class torchgeo.datamodules.InriaAerialImageLabelingDataModule(batch_size=64, patch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the InriaAerialImageLabeling dataset.

Uses the train/test splits from the dataset and further splits the train split into train/val splits.

Added in version 0.3.

__init__(batch_size=64, patch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new InriaAerialImageLabelingDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (tuple[int, int] | int) – Size of each patch, either size or (height, width). Should be a multiple of 32 for most segmentation architectures.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to InriaAerialImageLabeling.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

LandCover.ai#

class torchgeo.datamodules.LandCoverAIDataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the LandCover.ai dataset.

Uses the train/val/test splits from the dataset.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new LandCoverAIDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to LandCoverAI.

class torchgeo.datamodules.LandCoverAI100DataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the LandCoverAI100 dataset.

Uses the train/val/test splits from the dataset.

Added in version 0.7.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new LandCoverAI100DataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to LandCoverAI100.

LEVIR-CD#

class torchgeo.datamodules.LEVIRCDDataModule(batch_size=8, patch_size=256, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the LEVIR-CD dataset.

Added in version 0.6.

__init__(batch_size=8, patch_size=256, num_workers=0, **kwargs)[source]#

Initialize a new LEVIRCDDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (tuple[int, int] | int) – Size of each patch, either size or (height, width). Should be a multiple of 32 for most segmentation architectures.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to LEVIRCD.

LEVIR-CD+#

class torchgeo.datamodules.LEVIRCDPlusDataModule(batch_size=8, patch_size=256, val_split_pct=0.2, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the LEVIR-CD+ dataset.

Uses the train/test splits from the dataset and further splits the train split into train/val splits.

Added in version 0.6.

__init__(batch_size=8, patch_size=256, val_split_pct=0.2, num_workers=0, **kwargs)[source]#

Initialize a new LEVIRCDPlusDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (tuple[int, int] | int) – Size of each patch, either size or (height, width). Should be a multiple of 32 for most segmentation architectures.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to LEVIRCDPlus.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

LoveDA#

class torchgeo.datamodules.LoveDADataModule(batch_size=32, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the LoveDA dataset.

Uses the train/val/test splits from the dataset.

Added in version 0.2.

__init__(batch_size=32, num_workers=0, **kwargs)[source]#

Initialize a new LoveDADataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to LoveDA.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

NASA Marine Debris#

class torchgeo.datamodules.NASAMarineDebrisDataModule(batch_size=64, num_workers=0, val_split_pct=0.2, test_split_pct=0.2, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the NASA Marine Debris dataset.

Added in version 0.2.

__init__(batch_size=64, num_workers=0, val_split_pct=0.2, test_split_pct=0.2, **kwargs)[source]#

Initialize a new NASAMarineDebrisDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • test_split_pct (float) – Percentage of the dataset to use as a test set.

  • **kwargs (Any) – Additional keyword arguments passed to NASAMarineDebris.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

OSCD#

class torchgeo.datamodules.OSCDDataModule(batch_size=32, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the OSCD dataset.

Uses the train/test splits from the dataset and further splits the train split into train/val splits.

Added in version 0.2.

__init__(batch_size=32, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]#

Initialize a new OSCDDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (tuple[int, int] | int) – Size of each patch, either size or (height, width). Should be a multiple of 32 for most segmentation architectures.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to OSCD.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

class torchgeo.datamodules.OSCD100DataModule(batch_size=8, patch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the OSCD100 dataset.

Intended for tutorials and demonstrations, not benchmarking.

Added in version 0.9.

__init__(batch_size=8, patch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new OSCD100DataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (tuple[int, int] | int) – Size of each patch, either size or (height, width). Should be a multiple of 32 for most segmentation architectures.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to OSCD100.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

PASTIS#

class torchgeo.datamodules.PASTISDataModule(batch_size=32, num_workers=0, val_split_pct=0.2, test_split_pct=0.2, padding_length=61, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the PASTIS dataset.

Added in version 0.8.

__init__(batch_size=32, num_workers=0, val_split_pct=0.2, test_split_pct=0.2, padding_length=61, **kwargs)[source]#

Initialize a new PASTISDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • test_split_pct (float) – Percentage of the dataset to use as a test set.

  • padding_length (int) – Padding length of the time series.

  • **kwargs (Any) – Additional keyword arguments passed to PASTIS.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

class torchgeo.datamodules.PASTIS100DataModule(batch_size=32, num_workers=0, val_split_pct=0.2, test_split_pct=0.2, padding_length=61, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the PASTIS-R-100 dataset.

Added in version 0.9.

__init__(batch_size=32, num_workers=0, val_split_pct=0.2, test_split_pct=0.2, padding_length=61, **kwargs)[source]#

Initialize a new PASTIS100DataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • test_split_pct (float) – Percentage of the dataset to use as a test set.

  • padding_length (int) – Padding length of the time series.

  • **kwargs (Any) – Additional keyword arguments passed to PASTIS100.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

PatternNet#

class torchgeo.datamodules.PatternNetDataModule(batch_size=64, num_workers=0, val_split_pct=0.2, test_split_pct=0.2, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the PatternNet dataset.

Uses random train/val/test splits.

Added in version 0.8.

__init__(batch_size=64, num_workers=0, val_split_pct=0.2, test_split_pct=0.2, **kwargs)[source]#

Initialize a new PatternNetDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • val_split_pct (float) – Fraction of dataset to use for validation.

  • test_split_pct (float) – Fraction of dataset to use for testing.

  • **kwargs (Any) – Additional keyword arguments passed to PatternNet.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

Potsdam#

class torchgeo.datamodules.Potsdam2DDataModule(batch_size=64, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the Potsdam2D dataset.

Uses the train/test splits from the dataset.

Added in version 0.2.

__init__(batch_size=64, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]#

Initialize a new Potsdam2DDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (tuple[int, int] | int) – Size of each patch, either size or (height, width). Should be a multiple of 32 for most segmentation architectures.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to Potsdam2D.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

QuakeSet#

class torchgeo.datamodules.QuakeSetDataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the QuakeSet dataset.

Added in version 0.6.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new QuakeSetDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to QuakeSet.

ReforesTree#

class torchgeo.datamodules.ReforesTreeDataModule(batch_size=64, patch_size=64, num_workers=0, val_split_pct=0.2, test_split_pct=0.2, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the ReforesTree dataset.

Added in version 0.7.

__init__(batch_size=64, patch_size=64, num_workers=0, val_split_pct=0.2, test_split_pct=0.2, **kwargs)[source]#

Initialize a new ReforesTreeDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (tuple[int, int] | int) – Size of each patch, either size or (height, width). Should be a multiple of 32 for most segmentation architectures.

  • num_workers (int) – Number of workers for parallel data loading.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • test_split_pct (float) – Percentage of the dataset to use as a test set.

  • **kwargs (Any) – Additional keyword arguments passed to ReforesTree.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

RESISC45#

class torchgeo.datamodules.RESISC45DataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the RESISC45 dataset.

Uses the train/val/test splits from the dataset.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new RESISC45DataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to RESISC45.

Seasonal Contrast#

class torchgeo.datamodules.SeasonalContrastS2DataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the Seasonal Contrast dataset.

Added in version 0.5.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new SeasonalContrastS2DataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to SeasonalContrastS2.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

SEN12MS#

class torchgeo.datamodules.SEN12MSDataModule(batch_size=64, num_workers=0, band_set='all', **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the SEN12MS dataset.

Implements 80/20 geographic train/val splits and uses the test split from the classification dataset definitions.

Uses the Simplified IGBP scheme defined in the 2020 Data Fusion Competition. See https://arxiv.org/abs/2002.08254.

DFC2020_CLASS_MAPPING = tensor([ 0,  1,  1,  1,  1,  1,  2,  2,  3,  3,  4,  5,  6,  7,  6,  8,  9, 10])#

Mapping from the IGBP class definitions to the DFC2020, taken from the dataloader here: lukasliebel/dfc2020_baseline.

__init__(batch_size=64, num_workers=0, band_set='all', **kwargs)[source]#

Initialize a new SEN12MSDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • band_set (str) – Subset of S1/S2 bands to use. Options are: “all”, “s1”, “s2-all”, and “s2-reduced” where the “s2-reduced” set includes: B2, B3, B4, B8, B11, and B12.

  • **kwargs (Any) – Additional keyword arguments passed to SEN12MS.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

on_after_batch_transfer(batch, dataloader_idx)[source]#

Apply batch augmentations to the batch after it is transferred to the device.

Parameters:
  • batch (dict[str, Any]) – A batch of data that needs to be altered or augmented.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Returns:

A batch of data.

Return type:

dict[str, Any]

SKIPP’D#

class torchgeo.datamodules.SKIPPDDataModule(batch_size=64, num_workers=0, val_split_pct=0.2, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the SKIPP’D dataset.

Implements 80/20 train/val splits on train_val set. See setup() for more details.

Added in version 0.5.

__init__(batch_size=64, num_workers=0, val_split_pct=0.2, **kwargs)[source]#

Initialize a new SKIPPDDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • **kwargs (Any) – Additional keyword arguments passed to SKIPPD.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

So2Sat#

class torchgeo.datamodules.So2SatDataModule(batch_size=64, num_workers=0, band_set='all', val_split_pct=0.2, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the So2Sat dataset.

If using the version 2 dataset, we use the train/val/test splits from the dataset. If using the version 3 datasets, we use a random 80/20 train/val split from the “train” set and use the “test” set as the test set.

__init__(batch_size=64, num_workers=0, band_set='all', val_split_pct=0.2, **kwargs)[source]#

Initialize a new So2SatDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • band_set (str) – One of ‘all’, ‘s1’, ‘s2’, or ‘rgb’.

  • val_split_pct (float) – Percentage of training data to use for validation in with the version 3 datasets.

  • **kwargs (Any) – Additional keyword arguments passed to So2Sat.

Added in version 0.5: The val_split_pct parameter, and the ‘rgb’ argument to band_set.

__annotate_func__()[source]#

The type of the None singleton.

setup(stage)[source]#

Set up datasets.

Called at the beginning of fit, validate, test, or predict. During distributed training, this method is called from every process across all the nodes. Setting state here is recommended.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

Solar Plants Brazil#

class torchgeo.datamodules.SolarPlantsBrazilDataModule(batch_size=16, patch_size=256, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule for SolarPlantsBrazil dataset.

This datamodule wraps the SolarPlantsBrazil dataset, which contains predefined train/val/test splits. This design ensures spatial separation between samples by solar plant, preventing data leakage during training.

Added in version 0.8.

__init__(batch_size=16, patch_size=256, num_workers=0, **kwargs)[source]#

Initialize the SolarPlantsBrazilDataModule.

Parameters:
  • batch_size (int) – Number of samples per batch.

  • patch_size (tuple[int, int] | int) – Spatial dimensions (H, W) to crop from images.

  • num_workers (int) – Number of subprocesses used to load the data.

  • **kwargs (Any) – Additional arguments passed to SolarPlantsBrazil.

SpaceNet#

class torchgeo.datamodules.SpaceNetBaseDataModule(spacenet_ds_class, batch_size=64, num_workers=0, val_split_pct=0.1, test_split_pct=0.2, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the SpaceNet datasets.

Randomly splits the train split into train/val/test. The test split does not have labels, and is only used for prediction.

Added in version 0.7.

__init__(spacenet_ds_class, batch_size=64, num_workers=0, val_split_pct=0.1, test_split_pct=0.2, **kwargs)[source]#

Initialize a new SpaceNetBaseDataModule instance.

Parameters:
  • spacenet_ds_class (type[SpaceNet]) – The SpaceNet dataset class to use.

  • batch_size (int) – Size of each mini-batch.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • test_split_pct (float) – Percentage of the dataset to use as a test set.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to the SpaceNet dataset.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

on_after_batch_transfer(batch, dataloader_idx)[source]#

Apply batch augmentations to the batch after it is transferred to the device.

Parameters:
  • batch (dict[str, Any]) – A batch of data that needs to be altered or augmented.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Returns:

A batch of data.

Return type:

dict[str, Any]

class torchgeo.datamodules.SpaceNet1DataModule(batch_size=64, num_workers=0, val_split_pct=0.1, test_split_pct=0.2, **kwargs)[source]#

Bases: SpaceNetBaseDataModule

LightningDataModule implementation for the SpaceNet1 dataset.

Randomly splits the train split into train/val/test. The test split does not have labels, and is only used for prediction.

Added in version 0.4.

__init__(batch_size=64, num_workers=0, val_split_pct=0.1, test_split_pct=0.2, **kwargs)[source]#

Initialize a new SpaceNet1DataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • test_split_pct (float) – Percentage of the dataset to use as a test set.

  • **kwargs (Any) – Additional keyword arguments passed to SpaceNet1.

class torchgeo.datamodules.SpaceNet6DataModule(batch_size=64, num_workers=0, val_split_pct=0.1, test_split_pct=0.2, **kwargs)[source]#

Bases: SpaceNetBaseDataModule

LightningDataModule implementation for the SpaceNet6 dataset.

Randomly splits the train split into train/val/test. The test split does not have labels, and is only used for prediction.

Added in version 0.7.

__init__(batch_size=64, num_workers=0, val_split_pct=0.1, test_split_pct=0.2, **kwargs)[source]#

Initialize a new SpaceNet6DataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • test_split_pct (float) – Percentage of the dataset to use as a test set.

  • **kwargs (Any) – Additional keyword arguments passed to SpaceNet6.

SSL4EO#

class torchgeo.datamodules.SSL4EOLDataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the SSL4EO-L dataset.

Added in version 0.5.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new SSL4EOLDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to SSL4EOL.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

class torchgeo.datamodules.SSL4EOS12DataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the SSL4EO-S12 dataset.

Added in version 0.5.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new SSL4EOS12DataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to SSL4EOS12.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

SSL4EO-L Benchmark#

class torchgeo.datamodules.SSL4EOLBenchmarkDataModule(batch_size=64, patch_size=224, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the SSL4EO-L Benchmark dataset.

Added in version 0.5.

__init__(batch_size=64, patch_size=224, num_workers=0, **kwargs)[source]#

Initialize a new SSL4EOLBenchmarkDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (int | tuple[int, int]) – Size of each patch, either size or (height, width).

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to SSL4EOLBenchmark.

Substation#

class torchgeo.datamodules.SubstationDataModule(batch_size=64, num_workers=0, val_split_pct=0.2, test_split_pct=0.2, size=256, **kwargs)[source]#

Bases: NonGeoDataModule

Substation Data Module with train-test split and transformations.

Added in version 0.7.

__init__(batch_size=64, num_workers=0, val_split_pct=0.2, test_split_pct=0.2, size=256, **kwargs)[source]#

Initialize a new SubstationDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for data loading.

  • val_split_pct (float) – Percentage of data to use for validation.

  • test_split_pct (float) – Percentage of data to use for testing.

  • size (int) – Size of the input images.

  • **kwargs (Any) – Additional keyword arguments passed to Substation.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – One of ‘fit’, ‘validate’, ‘test’, or ‘predict’.

SustainBench Crop Yield#

class torchgeo.datamodules.SustainBenchCropYieldDataModule(batch_size=32, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule for SustainBench Crop Yield dataset.

Added in version 0.5.

__init__(batch_size=32, num_workers=0, **kwargs)[source]#

Initialize a new SustainBenchCropYieldDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to SustainBenchCropYield.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

TreeSatAI#

class torchgeo.datamodules.TreeSatAIDataModule(batch_size=64, patch_size=304, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the TreeSatAI dataset.

Added in version 0.7.

__init__(batch_size=64, patch_size=304, num_workers=0, **kwargs)[source]#

Initialize a new TreeSatAIDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (int | tuple[int, int]) – Size of each patch, either size or (height, width).

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to TreeSatAI.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

on_after_batch_transfer(batch, dataloader_idx)[source]#

Apply batch augmentations to the batch after it is transferred to the device.

Parameters:
  • batch (dict[str, Any]) – A batch of data that needs to be altered or augmented.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Returns:

A batch of data.

Return type:

dict[str, Any]

Tropical Cyclone#

class torchgeo.datamodules.TropicalCycloneDataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the NASA Cyclone dataset.

Implements 80/20 train/val splits based on hurricane storm ids. See setup() for more details.

Changed in version 0.4: Class name changed from CycloneDataModule to TropicalCycloneDataModule to be consistent with TropicalCyclone dataset.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new TropicalCycloneDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to TropicalCyclone.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

UC Merced#

class torchgeo.datamodules.UCMercedDataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the UC Merced dataset.

Uses random train/val/test splits.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new UCMercedDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to UCMerced.

USAVars#

class torchgeo.datamodules.USAVarsDataModule(batch_size=64, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the USAVars dataset.

Uses random train/val/test splits.

Added in version 0.3.

__init__(batch_size=64, num_workers=0, **kwargs)[source]#

Initialize a new USAVarsDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to USAVars.

Vaihingen#

class torchgeo.datamodules.Vaihingen2DDataModule(batch_size=64, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the Vaihingen2D dataset.

Uses the train/test splits from the dataset.

Added in version 0.2.

__init__(batch_size=64, patch_size=64, val_split_pct=0.2, num_workers=0, **kwargs)[source]#

Initialize a new Vaihingen2DDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (tuple[int, int] | int) – Size of each patch, either size or (height, width). Should be a multiple of 32 for most segmentation architectures.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to Vaihingen2D.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

VHR-10#

class torchgeo.datamodules.VHR10DataModule(batch_size=64, patch_size=512, num_workers=0, val_split_pct=0.2, test_split_pct=0.2, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the VHR10 dataset.

Added in version 0.6.

__init__(batch_size=64, patch_size=512, num_workers=0, val_split_pct=0.2, test_split_pct=0.2, **kwargs)[source]#

Initialize a new VHR10DataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • patch_size (tuple[int, int] | int) – Size of each patch, either size or (height, width).

  • num_workers (int) – Number of workers for parallel data loading.

  • val_split_pct (float) – Percentage of the dataset to use as a validation set.

  • test_split_pct (float) – Percentage of the dataset to use as a test set.

  • **kwargs (Any) – Additional keyword arguments passed to VHR10.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

xBD#

class torchgeo.datamodules.xBDDataModule(batch_size=64, num_workers=0, val_split_pct=0.2, **kwargs)[source]#

Bases: NonGeoDataModule

LightningDataModule implementation for the xBD dataset.

Uses the train/val/test splits from the dataset.

Added in version 0.2.

__init__(batch_size=64, num_workers=0, val_split_pct=0.2, **kwargs)[source]#

Initialize a new xBDDataModule instance.

Parameters:
  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • val_split_pct (float) – What percentage of the dataset to use as a validation set

  • **kwargs (Any) – Additional keyword arguments passed to xBD.

setup(stage)[source]#

Set up datasets.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

Base Classes#

BaseDataModule#

class torchgeo.datamodules.BaseDataModule(dataset_class, batch_size=1, num_workers=0, **kwargs)[source]#

Bases: LightningDataModule

Base class for all TorchGeo data modules.

Added in version 0.5.

__init__(dataset_class, batch_size=1, num_workers=0, **kwargs)[source]#

Initialize a new BaseDataModule instance.

Parameters:
  • dataset_class (type[Dataset[dict[str, Any]]]) – Class used to instantiate a new dataset.

  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to dataset_class

prepare_data()[source]#

Download and prepare data.

During distributed training, this method is called only within a single process to avoid corrupted data. This method should not set state since it is not called on every device, use setup instead.

on_after_batch_transfer(batch, dataloader_idx)[source]#

Apply batch augmentations to the batch after it is transferred to the device.

Parameters:
  • batch (dict[str, Any]) – A batch of data that needs to be altered or augmented.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Returns:

A batch of data.

Return type:

dict[str, Any]

plot(*args, **kwargs)[source]#

Run the plot method of the validation dataset if one exists.

Should only be called during ‘fit’ or ‘validate’ stages as val_dataset may not exist during other stages.

Parameters:
  • *args (Any) – Arguments passed to plot method.

  • **kwargs (Any) – Keyword arguments passed to plot method.

Returns:

A matplotlib Figure with the image, ground truth, and predictions.

Return type:

Figure | None

GeoDataModule#

class torchgeo.datamodules.GeoDataModule(dataset_class, batch_size=1, patch_size=64, length=None, num_workers=0, **kwargs)[source]#

Bases: BaseDataModule

Base class for data modules containing geospatial information.

Added in version 0.4.

__init__(dataset_class, batch_size=1, patch_size=64, length=None, num_workers=0, **kwargs)[source]#

Initialize a new GeoDataModule instance.

Parameters:
  • dataset_class (type[GeoDataset]) – Class used to instantiate a new dataset.

  • batch_size (int) – Size of each mini-batch.

  • patch_size (int | tuple[int, int]) – Size of each patch, either size or (height, width).

  • length (int | None) – Length of each training epoch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to dataset_class

setup(stage)[source]#

Set up datasets and samplers.

Called at the beginning of fit, validate, test, or predict. During distributed training, this method is called from every process across all the nodes. Setting state here is recommended.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

train_dataloader()[source]#

Implement one or more PyTorch DataLoaders for training.

Returns:

A collection of data loaders specifying training samples.

Raises:

MisconfigurationException – If setup() does not define a dataset or sampler, or if the dataset or sampler has length 0.

Return type:

DataLoader[dict[str, Any]]

val_dataloader()[source]#

Implement one or more PyTorch DataLoaders for validation.

Returns:

A collection of data loaders specifying validation samples.

Raises:

MisconfigurationException – If setup() does not define a dataset or sampler, or if the dataset or sampler has length 0.

Return type:

DataLoader[dict[str, Any]]

test_dataloader()[source]#

Implement one or more PyTorch DataLoaders for testing.

Returns:

A collection of data loaders specifying testing samples.

Raises:

MisconfigurationException – If setup() does not define a dataset or sampler, or if the dataset or sampler has length 0.

Return type:

DataLoader[dict[str, Any]]

predict_dataloader()[source]#

Implement one or more PyTorch DataLoaders for prediction.

Returns:

A collection of data loaders specifying prediction samples.

Raises:

MisconfigurationException – If setup() does not define a dataset or sampler, or if the dataset or sampler has length 0.

Return type:

DataLoader[dict[str, Any]]

NonGeoDataModule#

class torchgeo.datamodules.NonGeoDataModule(dataset_class, batch_size=1, num_workers=0, **kwargs)[source]#

Bases: BaseDataModule

Base class for data modules lacking geospatial information.

Added in version 0.4.

__init__(dataset_class, batch_size=1, num_workers=0, **kwargs)[source]#

Initialize a new NonGeoDataModule instance.

Parameters:
  • dataset_class (type[NonGeoDataset]) – Class used to instantiate a new dataset.

  • batch_size (int) – Size of each mini-batch.

  • num_workers (int) – Number of workers for parallel data loading.

  • **kwargs (Any) – Additional keyword arguments passed to dataset_class

setup(stage)[source]#

Set up datasets.

Called at the beginning of fit, validate, test, or predict. During distributed training, this method is called from every process across all the nodes. Setting state here is recommended.

Parameters:

stage (str) – Either ‘fit’, ‘validate’, ‘test’, or ‘predict’.

train_dataloader()[source]#

Implement one or more PyTorch DataLoaders for training.

Returns:

A collection of data loaders specifying training samples.

Raises:

MisconfigurationException – If setup() does not define a dataset, or if the dataset has length 0.

Return type:

DataLoader[dict[str, Any]]

val_dataloader()[source]#

Implement one or more PyTorch DataLoaders for validation.

Returns:

A collection of data loaders specifying validation samples.

Raises:

MisconfigurationException – If setup() does not define a dataset, or if the dataset has length 0.

Return type:

DataLoader[dict[str, Any]]

test_dataloader()[source]#

Implement one or more PyTorch DataLoaders for testing.

Returns:

A collection of data loaders specifying testing samples.

Raises:

MisconfigurationException – If setup() does not define a dataset, or if the dataset has length 0.

Return type:

DataLoader[dict[str, Any]]

predict_dataloader()[source]#

Implement one or more PyTorch DataLoaders for prediction.

Returns:

A collection of data loaders specifying prediction samples.

Raises:

MisconfigurationException – If setup() does not define a dataset, or if the dataset has length 0.

Return type:

DataLoader[dict[str, Any]]

Utilities#

class torchgeo.datamodules.MisconfigurationException[source]#

Bases: Exception

Exception used to inform users of misuse with Lightning.

__weakref__#

list of weak references to the object