Shortcuts

Source code for torchgeo.models.api

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""APIs for querying and loading pre-trained model weights.

See the following references for design details:

* https://pytorch.org/blog/easily-list-and-initialize-models-with-new-apis-in-torchvision/
* https://pytorch.org/vision/stable/models.html
* https://github.com/pytorch/vision/blob/main/torchvision/models/_api.py
"""  # noqa: E501

from typing import Any, Callable, Union

import torch.nn as nn
from torchvision.models._api import WeightsEnum

from .resnet import ResNet18_Weights, ResNet50_Weights, resnet18, resnet50
from .vit import ViTSmall16_Weights, vit_small_patch16_224

_model = {
    "resnet18": resnet18,
    "resnet50": resnet50,
    "vit_small_patch16_224": vit_small_patch16_224,
}

_model_weights = {
    resnet18: ResNet18_Weights,
    resnet50: ResNet50_Weights,
    vit_small_patch16_224: ViTSmall16_Weights,
    "resnet18": ResNet18_Weights,
    "resnet50": ResNet50_Weights,
    "vit_small_patch16_224": ViTSmall16_Weights,
}


[docs]def get_model(name: str, *args: Any, **kwargs: Any) -> nn.Module: """Get an instantiated model from its name. .. versionadded:: 0.4 Args: name: Name of the model. *args: Additional arguments passed to the model builder method. **kwargs: Additional keyword arguments passed to the model builder method. Returns: An instantiated model. """ model: nn.Module = _model[name](*args, **kwargs) return model
[docs]def get_model_weights(name: Union[Callable[..., nn.Module], str]) -> WeightsEnum: """Get the weights enum class associated with a given model. .. versionadded:: 0.4 Args: name: Model builder function or the name under which it is registered. Returns: The weights enum class associated with the model. """ return _model_weights[name]
[docs]def get_weight(name: str) -> WeightsEnum: """Get the weights enum value by its full name. .. versionadded:: 0.4 Args: name: Name of the weight enum entry. Returns: The requested weight enum. """ return eval(name)
[docs]def list_models() -> list[str]: """List the registered models. .. versionadded:: 0.4 Returns: A list of registered models. """ return list(_model.keys())

© Copyright 2021, Microsoft Corporation. Revision fe546bf1.

Built with Sphinx using a theme provided by Read the Docs.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources