Source code for fme.ace.stepper.single_module

import dataclasses
import datetime
import logging
import pathlib
import warnings
from collections.abc import Callable, Generator, Mapping
from typing import Any, Literal, cast

import dacite
import dacite.exceptions
import torch
import xarray as xr
from torch import nn

from fme.ace.data_loading.batch_data import BatchData, PairedData, PrognosticState
from fme.ace.requirements import DataRequirements, PrognosticStateDataRequirements
from fme.ace.stepper.parameter_init import (
    ParameterInitializationConfig,
    ParameterInitializer,
)
from fme.core.coordinates import (
    NullPostProcessFn,
    SerializableVerticalCoordinate,
    VerticalCoordinate,
)
from fme.core.corrector.atmosphere import AtmosphereCorrectorConfig
from fme.core.dataset.data_typing import VariableMetadata
from fme.core.dataset.utils import encode_timestep
from fme.core.dataset_info import DatasetInfo, MissingDatasetInfo
from fme.core.device import get_device
from fme.core.generics.inference import PredictFunction
from fme.core.generics.optimization import OptimizationABC
from fme.core.generics.train_stepper import TrainOutputABC, TrainStepperABC
from fme.core.loss import WeightedMappingLoss, WeightedMappingLossConfig
from fme.core.masking import NullMasking, StaticMaskingConfig
from fme.core.multi_call import MultiCallConfig
from fme.core.normalizer import (
    NetworkAndLossNormalizationConfig,
    NormalizationConfig,
    StandardNormalizer,
)
from fme.core.ocean import OceanConfig
from fme.core.optimization import NullOptimization
from fme.core.registry import CorrectorSelector, ModuleSelector
from fme.core.step.multi_call import MultiCallStepConfig, replace_multi_call
from fme.core.step.single_module import SingleModuleStepConfig
from fme.core.step.step import StepABC, StepSelector
from fme.core.tensors import (
    add_ensemble_dim,
    fold_ensemble_dim,
    fold_sized_ensemble_dim,
    repeat_interleave_batch_dim,
    unfold_ensemble_dim,
)
from fme.core.timing import GlobalTimer
from fme.core.training_history import TrainingHistory, TrainingJob
from fme.core.typing_ import EnsembleTensorDict, TensorDict, TensorMapping

DEFAULT_TIMESTEP = datetime.timedelta(hours=6)
DEFAULT_ENCODED_TIMESTEP = encode_timestep(DEFAULT_TIMESTEP)

Weights = list[Mapping[str, Any]]
StepperWeightsAndHistory = tuple[Weights, TrainingHistory]


def _load_weights_and_history(path: str) -> StepperWeightsAndHistory:
    stepper = load_stepper(path)
    return_weights: Weights = []
    for module in stepper.modules:
        return_weights.append(module.state_dict())
    return return_weights, stepper.training_history


[docs]@dataclasses.dataclass class SingleModuleStepperConfig: """ Configuration for a single module stepper. Parameters: builder: The module builder. in_names: Names of input variables. out_names: Names of output variables. normalization: The normalization configuration. parameter_init: The parameter initialization configuration. ocean: The ocean configuration. loss: The loss configuration. corrector: The corrector configuration. next_step_forcing_names: Names of forcing variables for the next timestep. loss_normalization: The normalization configuration for the loss. residual_normalization: Optional alternative to configure loss normalization. If provided, it will be used for all *prognostic* variables in loss scaling. multi_call: The configuration of multi-called diagnostics. include_multi_call_in_loss: Whether to include multi-call diagnostics in the loss. The same loss configuration as specified in 'loss' is used. crps_training: Whether to use CRPS training for stochastic models. residual_prediction: Whether to have ML module predict tendencies for prognostic variables. """ builder: ModuleSelector in_names: list[str] out_names: list[str] normalization: NormalizationConfig parameter_init: ParameterInitializationConfig = dataclasses.field( default_factory=lambda: ParameterInitializationConfig() ) ocean: OceanConfig | None = None loss: WeightedMappingLossConfig = dataclasses.field( default_factory=lambda: WeightedMappingLossConfig() ) corrector: AtmosphereCorrectorConfig | CorrectorSelector = dataclasses.field( default_factory=lambda: AtmosphereCorrectorConfig() ) next_step_forcing_names: list[str] = dataclasses.field(default_factory=list) loss_normalization: NormalizationConfig | None = None residual_normalization: NormalizationConfig | None = None multi_call: MultiCallConfig | None = None include_multi_call_in_loss: bool = False crps_training: bool = False residual_prediction: bool = False def __post_init__(self): for name in self.next_step_forcing_names: if name not in self.in_names: raise ValueError( f"next_step_forcing_name '{name}' not in in_names: {self.in_names}" ) if name in self.out_names: raise ValueError( f"next_step_forcing_name is an output variable: '{name}'" ) if ( self.residual_normalization is not None and self.loss_normalization is not None ): raise ValueError( "Only one of residual_normalization, loss_normalization can " "be provided." "If residual_normalization is provided, it will be used for all " "*prognostic* variables in loss scalng. " "If loss_normalization is provided, it will be used for all variables " "in loss scaling." ) if self.multi_call is not None: self.multi_call.validate(self.in_names, self.out_names) if self.include_multi_call_in_loss: if self.multi_call is None: raise ValueError( "include_multi_calls_in_loss is True but no multi_call config " "was provided." ) def load(self): self.normalization.load() if self.loss_normalization is not None: self.loss_normalization.load() if self.residual_normalization is not None: self.residual_normalization.load() @property def n_ic_timesteps(self) -> int: return 1 def get_evaluation_window_data_requirements( self, n_forward_steps: int ) -> DataRequirements: return DataRequirements( names=self.all_names, n_timesteps=self._window_steps_required(n_forward_steps), ) def get_prognostic_state_data_requirements(self) -> PrognosticStateDataRequirements: return PrognosticStateDataRequirements( names=self.prognostic_names, n_timesteps=self.n_ic_timesteps, ) def _window_steps_required(self, n_forward_steps: int) -> int: return n_forward_steps + self.n_ic_timesteps def get_state(self): self.load() return dataclasses.asdict(self)
[docs] def get_parameter_initializer(self) -> ParameterInitializer: """ Get the parameter initializer for this stepper configuration. """ return self.parameter_init.build( load_weights_and_history=_load_weights_and_history )
[docs] def get_stepper( self, dataset_info: DatasetInfo, apply_parameter_init: bool = True, ) -> "Stepper": """ Args: dataset_info: Information about the training dataset. apply_parameter_init: Whether to apply parameter initialization. """ logging.info("Initializing stepper from provided legacy config") normalizer = self.normalization.build(self.normalize_names) combined_normalization_config = NetworkAndLossNormalizationConfig( network=self.normalization, loss=self.loss_normalization, residual=self.residual_normalization, ) loss_normalizer = combined_normalization_config.get_loss_normalizer( self.normalize_names, residual_scaled_names=self.prognostic_names ) new_config = self.to_stepper_config( normalizer=normalizer, loss_normalizer=loss_normalizer ) return new_config.get_stepper( dataset_info=dataset_info, apply_parameter_init=apply_parameter_init, )
def get_ocean(self) -> OceanConfig | None: return self.ocean @classmethod def from_state(cls, state) -> "SingleModuleStepperConfig": state = cls.remove_deprecated_keys(state) return dacite.from_dict( data_class=cls, data=state, config=dacite.Config(strict=True) ) @property def input_names(self) -> list[str]: return self.in_names @property def output_names(self) -> list[str]: return self.out_names @property def all_names(self): """Names of all variables required, including auxiliary ones.""" extra_names = [] if self.ocean is not None: extra_names.extend(self.ocean.forcing_names) if self.multi_call is not None: extra_names.extend(self.multi_call.names) all_names = list(set(self.in_names).union(self.out_names).union(extra_names)) return all_names @property def normalize_names(self): """Names of variables which require normalization. I.e. inputs/outputs.""" extra_names = [] if self.multi_call is not None: extra_names.extend(self.multi_call.names) return list(set(self.in_names).union(self.out_names).union(extra_names)) @property def input_only_names(self) -> list[str]: """Names of variables which are inputs only.""" return list(set(self.all_names) - set(self.out_names)) @property def prognostic_names(self) -> list[str]: """Names of variables which both inputs and outputs.""" return list(set(self.out_names).intersection(self.in_names)) @property def loss_names(self) -> list[str]: extra_names = [] if self.multi_call is not None: extra_names.extend(self.multi_call.names) return list(set(self.out_names).union(extra_names)) @property def diagnostic_names(self) -> list[str]: """Names of variables which are outputs only.""" extra_names = [] if self.multi_call is not None: extra_names = self.multi_call.names out_names = list(set(self.out_names).union(extra_names)) return list(set(out_names).difference(self.in_names)) @classmethod def remove_deprecated_keys(cls, state: dict[str, Any]) -> dict[str, Any]: _unsupported_key_defaults = { "conserve_dry_air": False, "optimization": None, "conservation_loss": {"dry_air_penalty": None}, } state_copy = state.copy() for key, default in _unsupported_key_defaults.items(): if key in state_copy: if state_copy[key] == default or state_copy[key] is None: del state_copy[key] else: raise ValueError( f"The stepper config option {key} is deprecated and the setting" f" provided, {state_copy[key]}, is no longer implemented. The " "SingleModuleStepper being loaded from state cannot be run by " "this version of the code." ) for normalization_key in [ "normalization", "loss_normalization", "residual_normalization", ]: if state_copy.get(normalization_key) is not None: if "exclude_names" in state_copy[normalization_key]: if state_copy[normalization_key]["exclude_names"] is not None: raise ValueError( "The exclude_names option in normalization config is no " "longer supported, but excluded names were found in " f"{normalization_key}." ) else: del state_copy[normalization_key]["exclude_names"] if "prescriber" in state_copy: # want to maintain backwards compatibility for this particular feature if state_copy["prescriber"] is not None: if state_copy.get("ocean") is not None: raise ValueError("Cannot specify both prescriber and ocean.") state_copy["ocean"] = { "surface_temperature_name": state_copy["prescriber"][ "prescribed_name" ], "ocean_fraction_name": state_copy["prescriber"]["mask_name"], "interpolate": state_copy["prescriber"]["interpolate"], } del state_copy["prescriber"] if "activation_checkpointing" in state_copy: del state_copy["activation_checkpointing"] return state_copy
[docs] def to_stepper_config( self, normalizer: StandardNormalizer, loss_normalizer: StandardNormalizer, ) -> "StepperConfig": """ Convert the current config to a stepper config. Overwriting normalization configuration is needed to avoid a checkpoint trying to load normalization data from netCDF files which are no longer present when running inference. Args: normalizer: overwrite the normalization config with data from this normalizer loss_normalizer: overwrite the loss normalization config with data from this normalizer Returns: A stepper config. """ return StepperConfig( step=self._to_step_config(normalizer, loss_normalizer), loss=self.loss, crps_training=self.crps_training, parameter_init=self.parameter_init, )
def _to_step_config( self, normalizer: StandardNormalizer | None = None, loss_normalizer: StandardNormalizer | None = None, ) -> StepSelector: return StepSelector( type="multi_call", config=dataclasses.asdict( MultiCallStepConfig( wrapped_step=StepSelector( type="single_module", config=dataclasses.asdict( self._to_single_module_step_config( normalizer=normalizer, loss_normalizer=loss_normalizer, ) ), ), config=self.multi_call, include_multi_call_in_loss=self.include_multi_call_in_loss, ) ), ) def _to_single_module_step_config( self, normalizer: StandardNormalizer | None = None, loss_normalizer: StandardNormalizer | None = None, ) -> "SingleModuleStepConfig": if normalizer is not None: normalization = normalizer.get_normalization_config() else: normalization = self.normalization if loss_normalizer is not None: loss_normalization: NormalizationConfig | None = ( loss_normalizer.get_normalization_config() ) residual_normalization: NormalizationConfig | None = None else: loss_normalization = self.loss_normalization residual_normalization = self.residual_normalization return SingleModuleStepConfig( builder=self.builder, in_names=self.in_names, out_names=self.out_names, normalization=NetworkAndLossNormalizationConfig( network=normalization, loss=loss_normalization, residual=residual_normalization, ), ocean=self.ocean, corrector=self.corrector, next_step_forcing_names=self.next_step_forcing_names, crps_training=self.crps_training, residual_prediction=self.residual_prediction, ) def replace_multi_call(self, multi_call: MultiCallConfig | None): self.multi_call = multi_call def replace_ocean(self, ocean: OceanConfig | None): self.ocean = ocean
[docs]@dataclasses.dataclass class ExistingStepperConfig: """ Configuration for an existing stepper. This allows loading a serialized stepper from a checkpoint without loading its configuration of the training and optimization schedule, i.e., this allows for specifying a new schedule in fine-tuning. Not used for training resumption. Parameters: checkpoint_path: The path to the serialized checkpoint; should be different than the experiment output directory. """ checkpoint_path: str def __post_init__(self): self._stepper_config = StepperConfig.from_stepper_state( self._load_checkpoint()["stepper"] ) def _load_checkpoint(self) -> Mapping[str, Any]: return torch.load( self.checkpoint_path, map_location=get_device(), weights_only=False ) def get_evaluation_window_data_requirements( self, n_forward_steps: int ) -> DataRequirements: return self._stepper_config.get_evaluation_window_data_requirements( n_forward_steps ) def get_prognostic_state_data_requirements(self) -> PrognosticStateDataRequirements: return self._stepper_config.get_prognostic_state_data_requirements() def get_forcing_window_data_requirements( self, n_forward_steps: int ) -> DataRequirements: return self._stepper_config.get_forcing_window_data_requirements( n_forward_steps )
[docs] def get_parameter_initializer(self) -> ParameterInitializer: """Get a parameter initializer for this stepper configuration.""" return self._stepper_config.get_parameter_initializer()
def get_stepper( self, dataset_info: DatasetInfo, apply_parameter_init: bool = True, ): logging.info(f"Initializing stepper from {self.checkpoint_path}") return Stepper.from_state(self._load_checkpoint()["stepper"])
def _prepend_timesteps( data: EnsembleTensorDict, timesteps: TensorMapping, time_dim: int = 2 ) -> EnsembleTensorDict: for v in data.values(): n_ensemble = v.shape[1] break else: return data # data is length zero timesteps = add_ensemble_dim(timesteps, repeats=n_ensemble) return EnsembleTensorDict( {k: torch.cat([timesteps[k], v], dim=time_dim) for k, v in data.items()} ) @dataclasses.dataclass class TrainOutput(TrainOutputABC): metrics: TensorDict gen_data: EnsembleTensorDict target_data: EnsembleTensorDict time: xr.DataArray normalize: Callable[[TensorDict], TensorDict] derive_func: Callable[[TensorMapping, TensorMapping], TensorDict] = ( lambda x, _: dict(x) ) def __post_init__(self): for v in self.target_data.values(): if v.shape[1] != 1: raise ValueError( f"target_data can only have one ensemble member, got {v.shape[1]}" ) def ensemble_derive_func( self, data: EnsembleTensorDict, forcing_data: TensorMapping ) -> EnsembleTensorDict: flattened_data, n_ensemble = fold_ensemble_dim(data) if n_ensemble > 1: ensemble_forcing_data = add_ensemble_dim(forcing_data, repeats=n_ensemble) flattened_forcing_data = fold_sized_ensemble_dim( ensemble_forcing_data, n_ensemble ) else: flattened_forcing_data = dict(forcing_data) derived_data = self.derive_func(flattened_data, flattened_forcing_data) return unfold_ensemble_dim(derived_data, n_ensemble) def remove_initial_condition(self, n_ic_timesteps: int) -> "TrainOutput": return TrainOutput( metrics=self.metrics, gen_data=EnsembleTensorDict( {k: v[:, :, n_ic_timesteps:] for k, v in self.gen_data.items()} ), target_data=EnsembleTensorDict( {k: v[:, :, n_ic_timesteps:] for k, v in self.target_data.items()} ), time=self.time[:, n_ic_timesteps:], normalize=self.normalize, derive_func=self.derive_func, ) def copy(self) -> "TrainOutput": """Creates new dictionaries for the data but with the same tensors.""" return TrainOutput( metrics=self.metrics, gen_data=EnsembleTensorDict({k: v for k, v in self.gen_data.items()}), target_data=EnsembleTensorDict({k: v for k, v in self.target_data.items()}), time=self.time, normalize=self.normalize, derive_func=self.derive_func, ) def prepend_initial_condition( self, initial_condition: PrognosticState, ) -> "TrainOutput": """ Prepends an initial condition to the existing stepped data. Assumes data are on the same device. For data windows > 0, the target IC is different from the generated IC and may be provided for correct calculation of tendencies. Args: initial_condition: Initial condition data. """ batch_data = initial_condition.as_batch_data() return TrainOutput( metrics=self.metrics, gen_data=_prepend_timesteps(self.gen_data, batch_data.data), target_data=_prepend_timesteps( self.target_data, batch_data.data, ), time=xr.concat([batch_data.time, self.time], dim="time"), normalize=self.normalize, derive_func=self.derive_func, ) def compute_derived_variables( self, ) -> "TrainOutput": gen_data = self.ensemble_derive_func( self.gen_data, fold_sized_ensemble_dim(self.target_data, 1) ) target_data = self.ensemble_derive_func( self.target_data, fold_sized_ensemble_dim(self.target_data, 1) ) return TrainOutput( metrics=self.metrics, gen_data=gen_data, target_data=target_data, time=self.time, normalize=self.normalize, derive_func=self.derive_func, ) def get_metrics(self) -> TensorDict: return self.metrics def stack_list_of_tensor_dicts( dict_list: list[TensorDict], time_dim: int, ) -> TensorDict: keys = next(iter(dict_list)).keys() stack_dict = {} for k in keys: stack_dict[k] = torch.stack([d[k] for d in dict_list], dim=time_dim) return stack_dict def process_ensemble_prediction_generator_list( output_list: list[EnsembleTensorDict], ) -> EnsembleTensorDict: output_timeseries = stack_list_of_tensor_dicts( cast(list[TensorDict], output_list), time_dim=2 ) return EnsembleTensorDict( {k: v for k, v in output_timeseries.items()}, ) def process_prediction_generator_list( output_list: list[TensorDict], time: xr.DataArray, horizontal_dims: list[str] | None = None, ) -> BatchData: output_timeseries = stack_list_of_tensor_dicts(output_list, time_dim=1) return BatchData.new_on_device( data=output_timeseries, time=time, horizontal_dims=horizontal_dims, )
[docs]@dataclasses.dataclass class StepperConfig: """ Configuration for a stepper. Parameters: step: The step configuration. loss: The loss configuration. n_ensemble: The number of ensemble members evaluated for each training batch member. Default is 2 if the loss type is EnsembleLoss, otherwise the default is 1. Must be 2 for EnsembleLoss to be valid. crps_training: Deprecated, kept for backwards compatibility. Use n_ensemble=2 with a CRPS loss instead. parameter_init: The parameter initialization configuration. input_masking: Config for masking step inputs. """ step: StepSelector loss: WeightedMappingLossConfig = dataclasses.field( default_factory=lambda: WeightedMappingLossConfig() ) n_ensemble: int = -1 # sentinel value to avoid None typing of attribute crps_training: bool = False parameter_init: ParameterInitializationConfig = dataclasses.field( default_factory=lambda: ParameterInitializationConfig() ) input_masking: StaticMaskingConfig | None = None def __post_init__(self): if self.crps_training: warnings.warn( "crps_training is deprecated, use n_ensemble=2 " "with a CRPS loss instead", DeprecationWarning, ) self.n_ensemble = 2 self.loss = WeightedMappingLossConfig( type="EnsembleLoss", kwargs={"crps_weight": 1.0}, ) if self.n_ensemble == -1: if self.loss.type == "EnsembleLoss": self.n_ensemble = 2 else: self.n_ensemble = 1 @property def n_ic_timesteps(self) -> int: return self.step.n_ic_timesteps def get_evaluation_window_data_requirements( self, n_forward_steps: int ) -> DataRequirements: return DataRequirements( names=self.all_names, n_timesteps=self._window_steps_required(n_forward_steps), ) def get_prognostic_state_data_requirements(self) -> PrognosticStateDataRequirements: return PrognosticStateDataRequirements( names=self.prognostic_names, n_timesteps=self.n_ic_timesteps, ) @property def input_only_names(self) -> list[str]: return list(set(self.input_names) - set(self.output_names)) def get_forcing_window_data_requirements( self, n_forward_steps: int ) -> DataRequirements: return DataRequirements( names=list( set(self.input_only_names).union(self.step.next_step_input_names) ), n_timesteps=self._window_steps_required(n_forward_steps), ) def _window_steps_required(self, n_forward_steps: int) -> int: return n_forward_steps + self.n_ic_timesteps def as_loaded_dict(self): self.step.load() return dataclasses.asdict(self)
[docs] def get_stepper( self, dataset_info: DatasetInfo, apply_parameter_init: bool = True, training_history: TrainingHistory | None = None, ): """ Args: dataset_info: Information about the training dataset. apply_parameter_init: Whether to apply parameter initialization. training_history: History of the stepper's training jobs. """ logging.info("Initializing stepper from provided config") if apply_parameter_init: parameter_initializer = self.get_parameter_initializer() else: parameter_initializer = ParameterInitializer() step = self.step.get_step( dataset_info, init_weights=parameter_initializer.freeze_weights ) derive_func = dataset_info.vertical_coordinate.build_derive_function( dataset_info.timestep ) if self.input_masking is None: input_masking = NullMasking() else: input_masking = self.input_masking.build( mask=dataset_info.mask_provider, means=step.normalizer.means, ) try: output_process_func = dataset_info.mask_provider.build_output_masker() except MissingDatasetInfo: output_process_func = NullPostProcessFn() return Stepper( config=self, step=step, dataset_info=dataset_info, input_process_func=input_masking, output_process_func=output_process_func, derive_func=derive_func, parameter_initializer=parameter_initializer, training_history=training_history, )
[docs] @classmethod def from_stepper_state(cls, state) -> "StepperConfig": """ Initialize a StepperConfig from a stepper state. This is required for backwards compatibility with older steppers, whose configuration did not provide normalization constants, but rather pointed to files on disk. Newer stepper configurations load these constants into the configuration before checkpoints are saved. Args: state: The state of the stepper. Returns: The stepper config. """ try: legacy_config = SingleModuleStepperConfig.from_state(state["config"]) normalizer = StandardNormalizer.from_state( state.get("normalizer", state.get("normalization")) ) if normalizer is None: raise KeyError( "No normalization found in state, available keys: " + ", ".join(state.keys()) ) loss_normalizer_config = state.get( "loss_normalizer", state.get("loss_normalization") ) if loss_normalizer_config is None: loss_normalizer = normalizer else: loss_normalizer = StandardNormalizer.from_state(loss_normalizer_config) return legacy_config.to_stepper_config( normalizer=normalizer, loss_normalizer=loss_normalizer ) except (dacite.exceptions.DaciteError, KeyError): state = cls.remove_deprecated_keys(state["config"]) return dacite.from_dict( data_class=cls, data=state, config=dacite.Config(strict=True) )
@property def loss_names(self): """Names of variables to include in loss.""" return self.step.loss_names @property def input_names(self) -> list[str]: """Names of variables which are required as inputs.""" return self.step.input_names @property def all_names(self) -> list[str]: """Names of all variables.""" return list(set(self.input_names + self.output_names)) @property def next_step_forcing_names(self) -> list[str]: """ Names of variables which are given as inputs but taken from the output timestep. An example might be solar insolation taken during the output window period. """ return self.step.get_next_step_forcing_names() @property def prognostic_names(self) -> list[str]: """Names of variables which both inputs and outputs.""" return self.step.prognostic_names @property def output_names(self) -> list[str]: """Names of variables which are outputs only.""" return self.step.output_names @classmethod def remove_deprecated_keys(cls, state: dict[str, Any]) -> dict[str, Any]: state_copy = state.copy() return state_copy def replace_ocean(self, ocean: OceanConfig | None): self.step.replace_ocean(ocean) def get_ocean(self) -> OceanConfig | None: return self.step.get_ocean()
[docs] def replace_multi_call( self, multi_call: MultiCallConfig | None, state: dict[str, Any] ) -> dict[str, Any]: """Replace the multi-call configuration of self.step and ensure the associated state can be loaded as a multi-call step. A value of `None` for `multi_call` will remove the multi-call configuration. If the selected type supports it, the multi-call configuration will be updated in place. Otherwise, it will be wrapped in the multi_call step configuration with the given multi_call config or None. Note this updates self.step in place, but returns a new state dictionary. Args: multi_call: MultiCallConfig for the resulting self.step. state: state dictionary associated with the loaded step. Returns: The state dictionary updated to ensure consistency with that of a serialized multi-call step. """ self.step, new_state = replace_multi_call(self.step, multi_call, state) return new_state
[docs] def get_parameter_initializer(self) -> ParameterInitializer: """ Get the parameter initializer for this stepper configuration. """ return self.parameter_init.build( load_weights_and_history=_load_weights_and_history )
class Stepper( TrainStepperABC[ PrognosticState, BatchData, BatchData, PairedData, TrainOutput, ] ): """ Stepper class for selectable step configurations. """ TIME_DIM = 1 CHANNEL_DIM = -3 def __init__( self, config: StepperConfig, step: StepABC, dataset_info: DatasetInfo, input_process_func: Callable[[TensorMapping], TensorDict], output_process_func: Callable[[TensorMapping], TensorDict], derive_func: Callable[[TensorMapping, TensorMapping], TensorDict], parameter_initializer: ParameterInitializer, training_history: TrainingHistory | None = None, ): """ Args: config: The configuration. step: The step object. dataset_info: Information about dataset used for training. output_process_func: Function to post-process the output of the step function. derive_func: Function to compute derived variables. input_process_func: Optional function for processing inputs and next-step inputs before passing them to the step object, e.g., by masking specific regions. parameter_initializer: The parameter initializer to use for loading weights from an external source. training_history: History of the stepper's training jobs. """ self._config = config self._step_obj = step self._dataset_info = dataset_info self._derive_func = derive_func self._output_process_func = output_process_func self._input_process_func = input_process_func self._no_optimization = NullOptimization() self._parameter_initializer = parameter_initializer def get_loss_obj(): loss_normalizer = step.get_loss_normalizer() if config.loss is None: raise ValueError("Loss is not configured") return config.loss.build( dataset_info.gridded_operations, out_names=config.loss_names, channel_dim=self.CHANNEL_DIM, normalizer=loss_normalizer, ) self._loss_normalizer: StandardNormalizer | None = None self._get_loss_obj = get_loss_obj self._loss_obj: WeightedMappingLoss | None = None self._parameter_initializer.apply_weights( step.modules, ) self._l2_sp_tuning_regularizer = ( self._parameter_initializer.get_l2_sp_tuning_regularizer( step.modules, ) ) self._training_history = ( training_history if training_history is not None else TrainingHistory() ) self._append_training_history_from( base_training_history=self._parameter_initializer.training_history ) _1: PredictFunction[ # for type checking PrognosticState, BatchData, BatchData, ] = self.predict _2: PredictFunction[ # for type checking PrognosticState, BatchData, PairedData, ] = self.predict_paired self._dataset_info = dataset_info @property def _loaded_loss_normalizer(self) -> StandardNormalizer: if self._loss_normalizer is None: loss_normalizer = self._step_obj.get_loss_normalizer() self._loss_normalizer = loss_normalizer return self._loss_normalizer @property def loss_obj(self) -> WeightedMappingLoss: if self._loss_obj is None: self._loss_obj = self._get_loss_obj() return self._loss_obj @property def config(self) -> StepperConfig: return self._config @property def derive_func(self) -> Callable[[TensorMapping, TensorMapping], TensorDict]: return self._derive_func @property def surface_temperature_name(self) -> str | None: return self._step_obj.surface_temperature_name @property def ocean_fraction_name(self) -> str | None: return self._step_obj.ocean_fraction_name @property def training_dataset_info(self) -> DatasetInfo: return self._dataset_info @property def training_variable_metadata(self) -> Mapping[str, VariableMetadata]: return self._dataset_info.variable_metadata @property def training_history(self) -> TrainingHistory: return self._training_history def _append_training_history_from( self, base_training_history: TrainingHistory | None ): """ When the stepper receives weights from a base stepper via parameter initialization, this helper is used to extend its training history to include the training history of the base stepper. Args: base_training_history: The training history from a base stepper to append. """ if base_training_history is not None: self._training_history.extend(base_training_history) @property def effective_loss_scaling(self) -> TensorDict: """ Effective loss scalings used to normalize outputs before computing loss. y_loss_normalized_i = (y_i - y_mean_i) / loss_scaling_i where loss_scaling_i = loss_normalizer_std_i / weight_i. """ return self.loss_obj.effective_loss_scaling def replace_multi_call(self, multi_call: MultiCallConfig | None): """ Replace the MultiCall object with a new one. Note this is only meant to be used at inference time and may result in the loss function being unusable. Args: multi_call: The new multi_call configuration or None. """ state = self._step_obj.get_state() new_state = self._config.replace_multi_call(multi_call, state) new_stepper: Stepper = self._config.get_stepper( dataset_info=self._dataset_info, apply_parameter_init=False ) new_stepper._step_obj.load_state(new_state) self._step_obj = new_stepper._step_obj def replace_ocean(self, ocean: OceanConfig | None): """ Replace the ocean model with a new one. Args: ocean: The new ocean model configuration or None. """ self._config.replace_ocean(ocean) new_stepper: Stepper = self._config.get_stepper( dataset_info=self._dataset_info, apply_parameter_init=False, ) new_stepper._step_obj.load_state(self._step_obj.get_state()) self._step_obj = new_stepper._step_obj def get_base_weights(self) -> Weights | None: """ Get the base weights of the stepper. Returns: A list of weight dictionaries for each module in the stepper. """ return self._parameter_initializer.base_weights @property def prognostic_names(self) -> list[str]: return self._step_obj.prognostic_names @property def out_names(self) -> list[str]: return self._step_obj.output_names @property def loss_names(self) -> list[str]: return self._step_obj.loss_names @property def n_ic_timesteps(self) -> int: return self._step_obj.n_ic_timesteps @property def modules(self) -> nn.ModuleList: """ Returns: A list of modules being trained. """ return self._step_obj.modules @property def normalizer(self) -> StandardNormalizer: return self._step_obj.normalizer def step( self, input: TensorMapping, next_step_input_data: TensorMapping, wrapper: Callable[[nn.Module], nn.Module] = lambda x: x, ) -> TensorDict: """ Step the model forward one timestep given input data. Args: input: Mapping from variable name to tensor of shape [n_batch, n_lat, n_lon] containing denormalized data from the initial timestep. next_step_input_data: Mapping from variable name to tensor of shape [n_batch, n_lat, n_lon] containing denormalized data from the output timestep. wrapper: Wrapper to apply over each nn.Module before calling. Returns: The denormalized output data at the next time step. """ input = self._input_process_func(input) next_step_input_data = self._input_process_func(next_step_input_data) output = self._step_obj.step(input, next_step_input_data, wrapper=wrapper) return self._output_process_func(output) def get_prediction_generator( self, initial_condition: PrognosticState, forcing_data: BatchData, n_forward_steps: int, optimizer: OptimizationABC, ) -> Generator[TensorDict, None, None]: """ Predict multiple steps forward given initial condition and forcing data. Uses low-level inputs and does not compute derived variables, to separate concerns from the public `predict` method. Args: initial_condition: The initial condition, containing tensors of shape [n_batch, self.n_ic_timesteps, <horizontal_dims>]. forcing_data: The forcing data, containing tensors of shape [n_batch, n_forward_steps + self.n_ic_timesteps, <horizontal_dims>]. n_forward_steps: The number of forward steps to predict, corresponding to the data shapes of forcing_data. optimizer: The optimizer to use for updating the module. Returns: Generator yielding the output data at each timestep. """ ic_dict = initial_condition.as_batch_data().data forcing_dict = forcing_data.data return self._predict_generator( ic_dict, forcing_dict, n_forward_steps, optimizer ) @property def _input_only_names(self) -> list[str]: return list( set(self._step_obj.input_names).difference(set(self._step_obj.output_names)) ) def _predict_generator( self, ic_dict: TensorMapping, forcing_dict: TensorMapping, n_forward_steps: int, optimizer: OptimizationABC, ) -> Generator[TensorDict, None, None]: state = {k: ic_dict[k].squeeze(self.TIME_DIM) for k in ic_dict} for step in range(n_forward_steps): input_forcing = { k: ( forcing_dict[k][:, step] if k not in self._step_obj.next_step_forcing_names else forcing_dict[k][:, step + 1] ) for k in self._input_only_names } next_step_input_dict = { k: forcing_dict[k][:, step + 1] for k in self._step_obj.next_step_input_names } input_data = {**state, **input_forcing} def checkpoint(module): return optimizer.checkpoint(module, step=step) state = self.step( input_data, next_step_input_dict, wrapper=checkpoint, ) yield state state = optimizer.detach_if_using_gradient_accumulation(state) def predict( self, initial_condition: PrognosticState, forcing: BatchData, compute_derived_variables: bool = False, ) -> tuple[BatchData, PrognosticState]: """ Predict multiple steps forward given initial condition and reference data. Args: initial_condition: Prognostic state data with tensors of shape [n_batch, self.n_ic_timesteps, <horizontal_dims>]. This data is assumed to contain all prognostic variables and be denormalized. forcing: Contains tensors of shape [n_batch, self.n_ic_timesteps + n_forward_steps, n_lat, n_lon]. This contains the forcing and ocean data for the initial condition and all subsequent timesteps. compute_derived_variables: Whether to compute derived variables for the prediction. Returns: A batch data containing the prediction and the prediction's final state which can be used as a new initial condition. """ timer = GlobalTimer.get_instance() forcing_names = set(self._input_only_names).union( self._step_obj.next_step_input_names ) with timer.context("forward_prediction"): forcing_data = forcing.subset_names(forcing_names) if initial_condition.as_batch_data().n_timesteps != self.n_ic_timesteps: raise ValueError( f"Initial condition must have {self.n_ic_timesteps} timesteps, got " f"{initial_condition.as_batch_data().n_timesteps}." ) n_forward_steps = forcing_data.n_timesteps - self.n_ic_timesteps output_list = list( self.get_prediction_generator( initial_condition, forcing_data, n_forward_steps, NullOptimization(), ) ) data = process_prediction_generator_list( output_list, time=forcing_data.time[:, self.n_ic_timesteps :], horizontal_dims=forcing_data.horizontal_dims, ) if compute_derived_variables: with timer.context("compute_derived_variables"): data = ( data.prepend(initial_condition) .compute_derived_variables( derive_func=self.derive_func, forcing_data=forcing_data, ) .remove_initial_condition(self.n_ic_timesteps) ) prognostic_state = data.get_end(self.prognostic_names, self.n_ic_timesteps) data = BatchData.new_on_device( data=data.data, time=data.time, horizontal_dims=data.horizontal_dims, ) return data, prognostic_state def predict_paired( self, initial_condition: PrognosticState, forcing: BatchData, compute_derived_variables: bool = False, ) -> tuple[PairedData, PrognosticState]: """ Predict multiple steps forward given initial condition and reference data. Args: initial_condition: Prognostic state data with tensors of shape [n_batch, self.n_ic_timesteps, <horizontal_dims>]. This data is assumed to contain all prognostic variables and be denormalized. forcing: Contains tensors of shape [n_batch, self.n_ic_timesteps + n_forward_steps, n_lat, n_lon]. This contains the forcing and ocean data for the initial condition and all subsequent timesteps. compute_derived_variables: Whether to compute derived variables for the prediction. Returns: A tuple of 1) a paired data object, containing the prediction paired with all target/forcing data at the same timesteps, and 2) the prediction's final state, which can be used as a new initial condition. """ prediction, new_initial_condition = self.predict( initial_condition, forcing, compute_derived_variables ) forward_data = self.get_forward_data( forcing, compute_derived_variables=compute_derived_variables ) return ( PairedData.from_batch_data( prediction=prediction, reference=BatchData.new_on_device( data=forward_data.data, time=forward_data.time, horizontal_dims=forward_data.horizontal_dims, ), ), new_initial_condition, ) def get_forward_data( self, data: BatchData, compute_derived_variables: bool = False ) -> BatchData: if compute_derived_variables: timer = GlobalTimer.get_instance() with timer.context("compute_derived_variables"): data = data.compute_derived_variables( derive_func=self.derive_func, forcing_data=data, ) return data.remove_initial_condition(self.n_ic_timesteps) def _get_regularizer_loss(self) -> torch.Tensor: return self._l2_sp_tuning_regularizer() + self._step_obj.get_regularizer_loss() def train_on_batch( self, data: BatchData, optimization: OptimizationABC, compute_derived_variables: bool = False, ) -> TrainOutput: """ Train the model on a batch of data with one or more forward steps. If gradient accumulation is used by the optimization, the computational graph is detached between steps to reduce memory consumption. This means the model learns how to deal with inputs on step N but does not try to improve the behavior at step N by modifying the behavior for step N-1. Args: data: The batch data where each tensor in data.data has shape [n_sample, n_forward_steps + self.n_ic_timesteps, <horizontal_dims>]. optimization: The optimization class to use for updating the module. Use `NullOptimization` to disable training. compute_derived_variables: Whether to compute derived variables for the prediction and target data. Returns: The loss metrics, the generated data, the normalized generated data, and the normalized batch data. """ metrics: dict[str, float] = {} input_data = data.get_start(self.prognostic_names, self.n_ic_timesteps) target_data = self.get_forward_data(data, compute_derived_variables=False) optimization.set_mode(self._step_obj.modules) output_list = self._accumulate_loss( input_data, data, target_data, optimization, metrics, ) regularizer_loss = self._get_regularizer_loss() if torch.any(regularizer_loss > 0): optimization.accumulate_loss(regularizer_loss) metrics["loss"] = optimization.get_accumulated_loss().detach() optimization.step_weights() gen_data = process_ensemble_prediction_generator_list(output_list) stepped = TrainOutput( metrics=metrics, gen_data=gen_data, target_data=add_ensemble_dim(target_data.data), time=target_data.time, normalize=self.normalizer.normalize, derive_func=self.derive_func, ) ic = data.get_start( set(data.data.keys()), self.n_ic_timesteps ) # full data and not just prognostic get prepended stepped = stepped.prepend_initial_condition(ic) if compute_derived_variables: stepped = stepped.compute_derived_variables() # apply post-processing and return return stepped def _accumulate_loss( self, input_data: PrognosticState, data: BatchData, target_data: BatchData, optimization: OptimizationABC, metrics: dict[str, float], ) -> list[EnsembleTensorDict]: input_data = data.get_start(self.prognostic_names, self.n_ic_timesteps) # output from self.predict_paired does not include initial condition n_forward_steps = data.time.shape[1] - self.n_ic_timesteps n_ensemble = self._config.n_ensemble input_ensemble_data: TensorMapping = repeat_interleave_batch_dim( input_data.as_batch_data().data, repeats=n_ensemble ) forcing_ensemble_data: TensorMapping = repeat_interleave_batch_dim( data.data, repeats=n_ensemble ) output_generator = self._predict_generator( input_ensemble_data, forcing_ensemble_data, n_forward_steps, optimization, ) output_list: list[EnsembleTensorDict] = [] for step, gen_step in enumerate(output_generator): gen_step = unfold_ensemble_dim(gen_step, n_ensemble=n_ensemble) output_list.append(gen_step) # Note: here we examine the loss for a single timestep, # not a single model call (which may contain multiple timesteps). target_step = add_ensemble_dim( {k: v.select(self.TIME_DIM, step) for k, v in target_data.data.items()} ) step_loss = self.loss_obj(gen_step, target_step) metrics[f"loss_step_{step}"] = step_loss.detach() optimization.accumulate_loss(step_loss) return output_list def update_training_history(self, training_job: TrainingJob) -> None: """ Update the stepper's history of training jobs. Args: training_job: The training job to add to the history. """ self._training_history.append(training_job) def get_state(self): """ Returns: The state of the stepper. """ return { "config": self._config.as_loaded_dict(), "dataset_info": self._dataset_info.to_state(), "step": self._step_obj.get_state(), "training_history": self._training_history.get_state(), } def load_state(self, state: dict[str, Any]) -> None: """ Load the state of the stepper. Args: state: The state to load. """ self._step_obj.load_state(state["step"]) @classmethod def from_state(cls, state) -> "Stepper": """ Load the state of the stepper. Args: state: The state to load. Returns: The stepper. """ try: legacy_config = SingleModuleStepperConfig.from_state(state["config"]) dataset_state = {} dataset_state["timestep"] = state.get( "encoded_timestep", DEFAULT_ENCODED_TIMESTEP ) if "sigma_coordinates" in state: # for backwards compatibility with old checkpoints dataset_state["vertical_coordinate"] = state["sigma_coordinates"] else: dataset_state["vertical_coordinate"] = state["vertical_coordinate"] if "area" in state: # backwards-compatibility, these older checkpoints are always lat-lon dataset_state["gridded_operations"] = { "type": "LatLonOperations", "state": {"area_weights": state["area"]}, } else: dataset_state["gridded_operations"] = state["gridded_operations"] if "img_shape" in state: dataset_state["img_shape"] = state["img_shape"] normalizer = StandardNormalizer.from_state( state.get("normalizer", state.get("normalization")) ) if normalizer is None: raise ValueError( f"No normalizer state found, keys include {state.keys()}" ) loss_normalizer = StandardNormalizer.from_state( state.get("loss_normalizer", state.get("loss_normalization")) ) if loss_normalizer is None: loss_normalizer = normalizer config = legacy_config.to_stepper_config( normalizer=normalizer, loss_normalizer=loss_normalizer ) dataset_info = DatasetInfo.from_state(dataset_state) state["step"] = { # SingleModuleStep inside MultiCallStep "wrapped_step": {"module": state["module"]} } except dacite.exceptions.DaciteError: config = StepperConfig.from_stepper_state(state) dataset_info = DatasetInfo.from_state(state["dataset_info"]) training_history = TrainingHistory.from_state(state.get("training_history", [])) stepper = config.get_stepper( dataset_info=dataset_info, training_history=training_history, # don't need to initialize weights, we're about to load_state apply_parameter_init=False, ) stepper.load_state(state) return stepper def get_serialized_stepper_vertical_coordinate( state: dict[str, Any], ) -> VerticalCoordinate: if "vertical_coordinate" in state: return dacite.from_dict( data_class=SerializableVerticalCoordinate, data={"vertical_coordinate": state["vertical_coordinate"]}, config=dacite.Config(strict=True), ).vertical_coordinate elif "sigma_coordinates" in state: return dacite.from_dict( data_class=SerializableVerticalCoordinate, data={"vertical_coordinate": state["sigma_coordinates"]}, config=dacite.Config(strict=True), ).vertical_coordinate else: dataset_info = DatasetInfo.from_state(state["dataset_info"]) return dataset_info.vertical_coordinate
[docs]@dataclasses.dataclass class StepperOverrideConfig: """ Configuration for overriding stepper configuration options. The default value for each parameter is ``"keep"``, which denotes that the serialized stepper's configuration will not be modified when loaded. Passing other values will override the configuration of the loaded stepper. Parameters: ocean: Ocean configuration to override that used in producing a serialized stepper. multi_call: MultiCall configuration to override that used in producing a serialized stepper. """ ocean: Literal["keep"] | OceanConfig | None = "keep" multi_call: Literal["keep"] | MultiCallConfig | None = "keep"
def load_stepper_config( checkpoint_path: str | pathlib.Path, override_config: StepperOverrideConfig | None = None, ) -> StepperConfig: """Load a stepper configuration, optionally overriding certain aspects. Args: checkpoint_path: The path to the serialized checkpoint. override_config: Configuration options to override (optional). Returns: The configuration of the stepper serialized in the checkpoint, with appropriate options overridden. """ stepper = load_stepper(checkpoint_path, override_config) return stepper._config def load_stepper( checkpoint_path: str | pathlib.Path, override_config: StepperOverrideConfig | None = None, ) -> Stepper: """Load a stepper, optionally overriding certain aspects. Args: checkpoint_path: The path to the serialized checkpoint. override_config: Configuration options to override (optional). Returns: The stepper serialized in the checkpoint, with appropriate options overridden. """ if override_config is None: override_config = StepperOverrideConfig() checkpoint = torch.load( checkpoint_path, map_location=get_device(), weights_only=False ) stepper = Stepper.from_state(checkpoint["stepper"]) if override_config.ocean != "keep": logging.info( "Overriding training ocean configuration with a new ocean configuration." ) stepper.replace_ocean(override_config.ocean) if override_config.multi_call != "keep": logging.info( "Overriding training multi_call configuration with a new " "multi_call configuration." ) stepper.replace_multi_call(override_config.multi_call) return stepper