TreeSatAI#

class torchgeo.datasets.TreeSatAI(root='data', split='train', sensors=('aerial', 's1', 's2'), transforms=None, download=False, checksum=False)[source]#

Bases: NonGeoDataset

TreeSatAI Benchmark Archive.

TreeSatAI Benchmark Archive is a multi-sensor, multi-label dataset for tree species classification in remote sensing. It was created by combining labels from the federal forest inventory of Lower Saxony, Germany with 20 cm Color-Infrared (CIR) and 10 m Sentinel imagery.

The TreeSatAI Benchmark Archive contains:

  • 50,381 image triplets (aerial, Sentinel-1, Sentinel-2)

  • synchronized time steps and locations

  • all original spectral bands/polarizations from the sensors

  • 20 species classes (single labels)

  • 12 age classes (single labels)

  • 15 genus classes (multi labels)

  • 60 m and 200 m patches

  • fixed split for train (90%) and test (10%) data

  • additional single labels such as English species name, genus, forest stand type, foliage type, land cover

If you use this dataset in your research, please cite the following paper:

Added in version 0.7.

__init__(root='data', split='train', sensors=('aerial', 's1', 's2'), transforms=None, download=False, checksum=False)[source]#

Initialize a new TreeSatAI instance.

Parameters:
  • root (str | PathLike[str]) – Root directory where dataset can be found.

  • split (str) – Either ‘train’ or ‘test’.

  • sensors (Sequence[str]) – One or more of ‘aerial’, ‘s1’, and/or ‘s2’.

  • transforms (Callable[[dict[str, Any]], dict[str, Any]] | None) – A function/transform that takes input sample and its target as entry and returns a transformed version.

  • download (bool) – If True, download dataset and store it in the root directory.

  • checksum (bool) – If True, check the MD5 of the downloaded files (may be slow).

Raises:
__len__()[source]#

Return the number of data points in the dataset.

Returns:

Length of the dataset.

Return type:

int

__getitem__(index)[source]#

Return an index within the dataset.

Parameters:

index (int) – Index to return.

Returns:

Data and label at that index.

Return type:

dict[str, Any]

plot(sample, show_titles=True)[source]#

Plot a sample from the dataset.

Parameters:
  • sample (dict[str, Any]) – A sample returned by __getitem__().

  • show_titles (bool) – Flag indicating whether to show titles above each panel.

Returns:

A matplotlib Figure with the rendered sample.

Return type:

Figure

__annotate_func__()#

The type of the None singleton.