import abc
import dataclasses
from collections.abc import Callable, Mapping
from typing import Any, ClassVar, Self, TypeVar, final
import dacite
import torch
from torch import nn
from fme.core.dataset_info import DatasetInfo
from fme.core.normalizer import StandardNormalizer
from fme.core.ocean import OceanConfig
from fme.core.registry.registry import Registry
from fme.core.step.args import StepArgs
from fme.core.typing_ import TensorDict, TensorMapping
# Children still need to decorate with @dataclass, otherwise
# they will be a dataclass with no dataclass fields.
@dataclasses.dataclass
class StepConfigABC(abc.ABC):
@abc.abstractmethod
def get_step(
self,
dataset_info: DatasetInfo,
init_weights: Callable[[list[nn.Module]], None],
) -> "StepABC":
"""
Args:
dataset_info: Information about the training dataset.
init_weights: Function to initialize the weights of the step before
wrapping in DistributedDataParallel. This is particularly useful
when freezing parameters, as the DistributedDataParallel will
otherwise expect frozen weights to have gradients, and will
raise an exception.
Returns:
The state of the stepper.
"""
@property
@abc.abstractmethod
def n_ic_timesteps(self) -> int:
pass
@property
@abc.abstractmethod
def input_names(self) -> list[str]:
pass
@property
@abc.abstractmethod
def output_names(self) -> list[str]:
"""
Names of variables output by the step.
"""
pass
@property
@abc.abstractmethod
def next_step_input_names(self) -> list[str]:
"""
Names of variables required in next_step_input_data for .step.
"""
pass
@property
@final
def prognostic_names(self) -> list[str]:
return list(set(self.input_names).intersection(self.output_names))
@property
@abc.abstractmethod
def loss_names(self) -> list[str]:
"""
Names of variables to be included in the loss function.
"""
pass
@abc.abstractmethod
def get_next_step_forcing_names(self) -> list[str]:
pass
@abc.abstractmethod
def get_loss_normalizer(
self,
extra_names: list[str] | None = None,
extra_residual_scaled_names: list[str] | None = None,
) -> StandardNormalizer:
"""
Args:
extra_names: Names of additional variables to include in the
loss normalizer.
extra_residual_scaled_names: extra_names which use residual scale factors,
if enabled.
Returns:
The loss normalizer.
"""
@abc.abstractmethod
def replace_ocean(self, ocean: OceanConfig | None):
pass
@abc.abstractmethod
def get_ocean(self) -> OceanConfig | None:
pass
def replace_prescribed_prognostic_names(self, names: list[str]) -> None:
"""Replace prescribed prognostic names (e.g. when loading from checkpoint)."""
pass
@abc.abstractmethod
def load(self):
"""
Update configuration in-place so it does not depend on external files.
"""
pass
@classmethod
def from_state(cls, state: Mapping[str, Any]) -> Self:
return dacite.from_dict(cls, state, config=dacite.Config(strict=True))
[docs]@dataclasses.dataclass
class StepSelector(StepConfigABC):
type: str
config: dict[str, Any]
registry: ClassVar[Registry[StepConfigABC]] = Registry[StepConfigABC]()
def __post_init__(self):
self._step_config_instance = self.registry.get(self.type, self.config)
@property
def n_ic_timesteps(self) -> int:
return self._step_config_instance.n_ic_timesteps
[docs] @classmethod
def register(cls, name: str):
return cls.registry.register(name)
[docs] def get_step(
self,
dataset_info: DatasetInfo,
init_weights: Callable[[list[nn.Module]], None] = lambda x: None,
) -> "StepABC":
"""
Args:
dataset_info: Information about the training dataset.
init_weights: Function to initialize the weights of the step before
wrapping in DistributedDataParallel. This is particularly useful
when freezing parameters, as the DistributedDataParallel will
otherwise expect frozen weights to have gradients, and will
raise an exception.
Returns:
The state of the stepper.
"""
return self._step_config_instance.get_step(dataset_info, init_weights)
[docs] @classmethod
def get_available_types(cls) -> set[str]:
"""This class method is used to expose all available types of Steps."""
return set(cls.registry._types.keys())
[docs] def get_next_step_forcing_names(self) -> list[str]:
return self._step_config_instance.get_next_step_forcing_names()
@property
def input_names(self) -> list[str]:
return self._step_config_instance.input_names
@property
def output_names(self) -> list[str]:
"""
Names of variables output by the step.
"""
return self._step_config_instance.output_names
@property
def next_step_input_names(self) -> list[str]:
"""
Names of variables required in next_step_input_data for .step.
"""
return self._step_config_instance.next_step_input_names
@property
def loss_names(self) -> list[str]:
"""
Names of variables to be included in the loss function.
"""
return self._step_config_instance.loss_names
[docs] def get_loss_normalizer(
self,
extra_names: list[str] | None = None,
extra_residual_scaled_names: list[str] | None = None,
) -> StandardNormalizer:
return self._step_config_instance.get_loss_normalizer(
extra_names=extra_names,
extra_residual_scaled_names=extra_residual_scaled_names,
)
[docs] def replace_ocean(self, ocean: OceanConfig | None):
self._step_config_instance.replace_ocean(ocean)
self.config = dataclasses.asdict(self._step_config_instance)
[docs] def get_ocean(self) -> OceanConfig | None:
return self._step_config_instance.get_ocean()
[docs] def replace_prescribed_prognostic_names(self, names: list[str]) -> None:
self._step_config_instance.replace_prescribed_prognostic_names(names)
self.config = dataclasses.asdict(self._step_config_instance)
[docs] def load(self):
self._step_config_instance.load()
self.config = dataclasses.asdict(self._step_config_instance)
class StepABC(abc.ABC):
SelfType = TypeVar("SelfType", bound="StepABC")
@property
@abc.abstractmethod
def config(self) -> StepConfigABC:
pass
@final
def get_loss_normalizer(
self,
extra_names: list[str] | None = None,
extra_residual_scaled_names: list[str] | None = None,
) -> StandardNormalizer:
return self.config.get_loss_normalizer(
extra_names=extra_names,
extra_residual_scaled_names=extra_residual_scaled_names,
)
@property
@final
def n_ic_timesteps(self) -> int:
return self.config.n_ic_timesteps
@property
@final
def input_names(self) -> list[str]:
return self.config.input_names
@property
@final
def output_names(self) -> list[str]:
return self.config.output_names
@property
@final
def prognostic_names(self) -> list[str]:
return self.config.prognostic_names
@property
@final
def loss_names(self) -> list[str]:
return self.config.loss_names
@property
@abc.abstractmethod
def modules(self) -> nn.ModuleList:
pass
@property
@abc.abstractmethod
def normalizer(self) -> StandardNormalizer:
pass
@property
@final
def next_step_input_names(self) -> list[str]:
"""
Names of variables required in next_step_input_data for .step.
"""
return self.config.next_step_input_names
@property
@final
def next_step_forcing_names(self) -> list[str]:
"""Names of input variables which come from the output timestep."""
return self.config.get_next_step_forcing_names()
@property
@abc.abstractmethod
def surface_temperature_name(self) -> str | None:
"""
Name of the surface temperature variable, if one is available.
"""
pass
@property
@abc.abstractmethod
def ocean_fraction_name(self) -> str | None:
"""
Name of the ocean fraction variable, if one is available.
"""
pass
@abc.abstractmethod
def prescribe_sst(
self,
mask_data: TensorMapping,
gen_data: TensorMapping,
target_data: TensorMapping,
) -> TensorDict:
"""
Prescribe target_data SST onto gen_data according to mask_data.
"""
pass
@abc.abstractmethod
def get_regularizer_loss(self) -> torch.Tensor:
"""
Get the regularizer loss.
"""
pass
@abc.abstractmethod
def step(
self: SelfType,
args: StepArgs,
wrapper: Callable[[nn.Module], nn.Module] = lambda x: x,
) -> TensorDict:
"""
Step the model forward one timestep given input data.
Args:
args: The arguments to the step function.
wrapper: Wrapper to apply over each nn.Module before calling.
Returns:
The denormalized output data at the next time step.
"""
pass
@abc.abstractmethod
def get_state(self) -> dict[str, Any]:
"""
Returns:
The state of the step object as expected by load_state,
may or may not include initialization parameters.
"""
pass
@abc.abstractmethod
def load_state(self, state: dict[str, Any]):
"""
Load the state of the step object.
"""
pass