Source code for fme.core.registry

import abc
import dataclasses
from typing import Any, Callable, Dict, Mapping, Tuple, Type

import dacite
from torch import nn


@dataclasses.dataclass
class ModuleConfig(abc.ABC):
    """
    Builds a nn.Module given information about the input
    and output channels and the image shape.

    This is a "Config" as in practice it is a dataclass loaded directly from yaml,
    allowing us to specify details of the network architecture in a config file.
    """

    @abc.abstractmethod
    def build(
        self,
        n_in_channels: int,
        n_out_channels: int,
        img_shape: Tuple[int, int],
    ) -> nn.Module:
        """
        Build a nn.Module given information about the input and output channels
        and the image shape.

        Args:
            n_in_channels: number of input channels
            n_out_channels: number of output channels
            img_shape: last two dimensions of data, corresponding to lat and
                lon when using FourCastNet conventions

        Returns:
            a nn.Module
        """
        ...

    @classmethod
    def from_state(cls, state: Mapping[str, Any]) -> "ModuleConfig":
        """
        Create a ModuleSelector from a dictionary containing all the information
        needed to build a ModuleConfig.
        """
        return dacite.from_dict(
            data_class=cls, data=state, config=dacite.Config(strict=True)
        )


NET_REGISTRY: Dict[str, Type[ModuleConfig]] = {}


[docs]def get_available_module_types(): return NET_REGISTRY.keys()
[docs]def register(name: str) -> Callable[[Type[ModuleConfig]], Type[ModuleConfig]]: """ Register a new ModuleConfig type with the NET_REGISTRY. This is useful for adding new ModuleConfig types to the registry from other modules. Args: name: name of the ModuleConfig type to register Returns: a decorator which registers the decorated class with the NET_REGISTRY """ if not isinstance(name, str): raise TypeError( f"name must be a string, got {name}, " "make sure to use as @register('module_name')" ) def decorator(cls: Type[ModuleConfig]) -> Type[ModuleConfig]: NET_REGISTRY[name] = cls return cls return decorator
def get_from_registry(name) -> Type[ModuleConfig]: return NET_REGISTRY[name]
[docs]@dataclasses.dataclass class ModuleSelector: """ A dataclass containing all the information needed to build a ModuleConfig, including the type of the ModuleConfig and the data needed to build it. This is helpful as ModuleSelector can be serialized and deserialized without any additional information, whereas to load a ModuleConfig you would need to know the type of the ModuleConfig being loaded. It is also convenient because ModuleSelector is a single class that can be used to represent any ModuleConfig, whereas ModuleConfig is a protocol that can be implemented by many different classes. Attributes: type: the type of the ModuleConfig config: data for a ModuleConfig instance of the indicated type """ type: str config: Mapping[str, Any] def __post_init__(self): try: self._config = get_from_registry(self.type).from_state(self.config) except KeyError: raise ValueError( f"unknown module type {self.type}, " f"known module types are {list(NET_REGISTRY.keys())}" )
[docs] def build( self, n_in_channels: int, n_out_channels: int, img_shape: Tuple[int, int], ) -> nn.Module: """ Build a nn.Module given information about the input and output channels and the image shape. Args: n_in_channels: number of input channels n_out_channels: number of output channels img_shape: last two dimensions of data, corresponding to lat and lon when using FourCastNet conventions Returns: a nn.Module """ return self._config.build( n_in_channels=n_in_channels, n_out_channels=n_out_channels, img_shape=img_shape, )
[docs] def get_state(self) -> Mapping[str, Any]: """ Get a dictionary containing all the information needed to build a ModuleConfig. """ return {"type": self.type, "config": self.config}
[docs] @classmethod def from_state(cls, state: Mapping[str, Any]) -> "ModuleSelector": """ Create a ModuleSelector from a dictionary containing all the information needed to build a ModuleConfig. """ return dacite.from_dict( data_class=ModuleSelector, data=state, config=dacite.Config(strict=True) )