Source code for fme.core.registry.module

import abc
import dataclasses
from collections.abc import Callable, Mapping

# we use Type to distinguish from type attr of ModuleSelector
from typing import Any, ClassVar, Type  # noqa: UP035

import dacite
import torch
from torch import nn

from fme.core.dataset_info import DatasetInfo
from fme.core.labels import BatchLabels, LabelEncoding

from .registry import Registry


@dataclasses.dataclass
class ModuleConfig(abc.ABC):
    """
    Builds a nn.Module given information about the input and output channels
    and dataset information.

    This is a "Config" as in practice it is a dataclass loaded directly from yaml,
    allowing us to specify details of the network architecture in a config file.
    """

    @abc.abstractmethod
    def build(
        self,
        n_in_channels: int,
        n_out_channels: int,
        dataset_info: DatasetInfo,
    ) -> nn.Module:
        """
        Build a nn.Module given information about the input and output channels
        and the dataset.

        Args:
            n_in_channels: number of input channels
            n_out_channels: number of output channels
            dataset_info: Information about the dataset, including img_shape,
                horizontal coordinates, vertical coordinate, etc.

        Returns:
            a nn.Module
        """
        ...

    @classmethod
    def from_state(cls, state: Mapping[str, Any]) -> "ModuleConfig":
        """
        Create a ModuleSelector from a dictionary containing all the information
        needed to build a ModuleConfig.
        """
        return dacite.from_dict(
            data_class=cls, data=state, config=dacite.Config(strict=True)
        )


CONDITIONAL_BUILDERS = [
    "NoiseConditionedSFNO",
]


class Module:
    def __init__(self, module: nn.Module, label_encoding: LabelEncoding | None):
        self._module = module
        self._label_encoding = label_encoding

    def __call__(
        self, input: torch.Tensor, labels: BatchLabels | None = None
    ) -> torch.Tensor:
        if labels is not None and self._label_encoding is None:
            raise TypeError("Labels are not allowed for unconditional models")

        if self._label_encoding is not None:
            if labels is None:
                raise TypeError("Labels are required for conditional models")
            encoded_labels = labels.conform_to_encoding(self._label_encoding)
            return self._module(input, labels=encoded_labels.tensor)
        else:
            return self._module(input)

    @property
    def torch_module(self) -> nn.Module:
        return self._module

    def get_state(self) -> dict[str, Any]:
        if self._label_encoding is not None:
            label_encoder_state = self._label_encoding.get_state()
        else:
            label_encoder_state = None
        return {
            **self._module.state_dict(),
            "label_encoding": label_encoder_state,
        }

    def load_state(self, state: dict[str, Any]) -> None:
        state = state.copy()
        if state.get("label_encoding") is not None:
            if self._label_encoding is None:
                self._label_encoding = LabelEncoding.from_state(
                    state.pop("label_encoding")
                )
            else:
                self._label_encoding.conform_to_state(state.pop("label_encoding"))
        state.pop("label_encoding", None)
        self._module.load_state_dict(state)

    def wrap_module(self, callable: Callable[[nn.Module], nn.Module]) -> "Module":
        return Module(callable(self._module), self._label_encoding)

    def to(self, device: torch.device) -> "Module":
        return Module(self._module.to(device), self._label_encoding)


[docs]@dataclasses.dataclass class ModuleSelector: """ A dataclass containing all the information needed to build a ModuleConfig, including the type of the ModuleConfig and the data needed to build it. This is helpful as ModuleSelector can be serialized and deserialized without any additional information, whereas to load a ModuleConfig you would need to know the type of the ModuleConfig being loaded. It is also convenient because ModuleSelector is a single class that can be used to represent any ModuleConfig, whereas ModuleConfig is a protocol that can be implemented by many different classes. Parameters: type: the type of the ModuleConfig config: data for a ModuleConfig instance of the indicated type conditional: whether to condition the predictions on batch labels. """ type: str config: Mapping[str, Any] conditional: bool = False registry: ClassVar[Registry[ModuleConfig]] = Registry[ModuleConfig]() def __post_init__(self): if not isinstance(self.registry, Registry): raise ValueError("ModuleSelector.registry should not be set manually") if self.conditional and self.type not in CONDITIONAL_BUILDERS: raise ValueError( "Conditional predictions require a conditional builder, " f"got {self.type} (available: {CONDITIONAL_BUILDERS})" ) self._instance = self.registry.get(self.type, self.config)
[docs] @classmethod def register( cls, type_name: str ) -> Callable[[Type[ModuleConfig]], Type[ModuleConfig]]: # noqa: UP006 return cls.registry.register(type_name)
[docs] def build( self, n_in_channels: int, n_out_channels: int, dataset_info: DatasetInfo, ) -> nn.Module: """ Build a nn.Module given information about the input and output channels and the dataset. Args: n_in_channels: number of input channels n_out_channels: number of output channels dataset_info: Information about the dataset, including img_shape (shape of last two dimensions of data, e.g. latitude and longitude), horizontal coordinates, vertical coordinate, etc. Returns: a Module object """ if self.conditional and len(dataset_info.all_labels) == 0: raise ValueError("Conditional predictions require labels") if self.conditional: label_encoding = LabelEncoding(sorted(list(dataset_info.all_labels))) else: label_encoding = None module = self._instance.build( n_in_channels=n_in_channels, n_out_channels=n_out_channels, dataset_info=dataset_info, ) return Module(module, label_encoding)
[docs] @classmethod def get_available_types(cls): """This class method is used to expose all available types of Modules.""" return cls.registry._types.keys()