Steps¶
ACE’s code uses a “step” registry system to allow various emulation configuration step objects to be specified. (A step object consists of a
specific configuration of NN module calls and other operations such as normalization, denormalization, correction, etc.).
In ACE’s hierarchy, a stepper contains the step object, which in turn may contain one or more NN modules.
Step registry is managed by the fme.ace.StepSelector configuration class, which is used to select a step type and version.
- class fme.ace.StepSelector(type, config)[source]
Bases:
StepConfigABC- classmethod get_available_types()[source]
This class method is used to expose all available types of Steps.
- get_loss_normalizer(extra_names=None, extra_residual_scaled_names=None)[source]
- Parameters:
- Return type:
- Returns:
The loss normalizer.
- get_ocean()[source]
- Return type:
- get_step(dataset_info, init_weights=<function StepSelector.<lambda>>)[source]
- Parameters:
dataset_info (
DatasetInfo) – Information about the training dataset.init_weights (
Callable[[list[Module]],None], default:<function StepSelector.<lambda> at 0x7f60b2c94d60>) – Function to initialize the weights of the step before wrapping in DistributedDataParallel. This is particularly useful when freezing parameters, as the DistributedDataParallel will otherwise expect frozen weights to have gradients, and will raise an exception.
- Return type:
StepABC- Returns:
The state of the stepper.
- load()[source]
Update configuration in-place so it does not depend on external files.
- property n_ic_timesteps: int
- property next_step_input_names: list[str]
Names of variables required in next_step_input_data for .step.
- classmethod register(name)[source]
Register a virtual subclass of an ABC.
Returns the subclass, to allow usage as a class decorator.
- Parameters:
name (str) –
-
registry:
ClassVar[Registry[StepConfigABC]] = <fme.core.registry.registry.Registry object>
- replace_ocean(ocean)[source]
- Parameters:
ocean (OceanConfig | None) –
- replace_prescribed_prognostic_names(names)[source]
Replace prescribed prognostic names (e.g. when loading from checkpoint).
-
type:
str
The following step types are available:
{'multi_call', 'separate_radiation', 'single_module', 'FCN3', 'secondary_module', 'default'}
- fme.core.step.StepSelector.get_available_types()¶
This class method is used to expose all available types of Steps.
The following step builders are available:
- class fme.core.step.SingleModuleStepConfig(builder, in_names, out_names, normalization, secondary_decoder=None, ocean=None, corrector=<factory>, next_step_forcing_names=<factory>, prescribed_prognostic_names=<factory>, residual_prediction=False)[source]
Bases:
StepConfigABCConfiguration for a single module stepper.
- Parameters:
builder (
ModuleSelector) – The module builder.normalization (
NetworkAndLossNormalizationConfig) – The normalization configuration.secondary_decoder (
Optional[SecondaryDecoderConfig], default:None) – Configuration for the secondary decoder that computes additional diagnostic variables from outputs.ocean (
Optional[OceanConfig], default:None) – The ocean configuration.corrector (
AtmosphereCorrectorConfig|CorrectorSelector, default:<factory>) – The corrector configuration.next_step_forcing_names (
list[str], default:<factory>) – Names of forcing variables for the next timestep.prescribed_prognostic_names (
list[str], default:<factory>) – Prognostic variable names to overwrite from forcing data at each step (e.g. for inference with observed values).residual_prediction (
bool, default:False) – Whether to use residual prediction.
-
builder:
ModuleSelector
-
corrector:
AtmosphereCorrectorConfig|CorrectorSelector
- classmethod from_state(state)[source]
- Return type:
- get_loss_normalizer(extra_names=None, extra_residual_scaled_names=None)[source]
- Parameters:
- Return type:
- Returns:
The loss normalizer.
- get_next_step_forcing_names()[source]
Names of input-only variables which come from the output timestep.
- get_ocean()[source]
- Return type:
- get_state()[source]
- get_step(dataset_info, init_weights)[source]
- Parameters:
dataset_info (
DatasetInfo) – Information about the training dataset.init_weights (
Callable[[list[Module]],None]) – Function to initialize the weights of the step before wrapping in DistributedDataParallel. This is particularly useful when freezing parameters, as the DistributedDataParallel will otherwise expect frozen weights to have gradients, and will raise an exception.
- Return type:
SingleModuleStep- Returns:
The state of the stepper.
- property input_names: list[str]
Names of variables required as inputs to step, either in input or next_step_input_data.
- load()[source]
Update configuration in-place so it does not depend on external files.
- property n_ic_timesteps: int
-
normalization:
NetworkAndLossNormalizationConfig
-
ocean:
Optional[OceanConfig] = None
- replace_ocean(ocean)[source]
Replace the ocean model with a new one.
- Parameters:
ocean (
Optional[OceanConfig]) – The new ocean model configuration or None.
- replace_prescribed_prognostic_names(names)[source]
Replace prescribed prognostic names (e.g. when loading from checkpoint).
-
residual_prediction:
bool= False
-
secondary_decoder:
Optional[SecondaryDecoderConfig] = None
- class fme.core.step.MultiCallStepConfig(wrapped_step, config=None, include_multi_call_in_loss=True)[source]
Bases:
StepConfigABCConfiguration for a multi-call step.
- Parameters:
wrapped_step (
StepSelector) – The step to wrap.config (
Optional[MultiCallConfig], default:None) – The multi-call configuration.include_multi_call_in_loss (
bool, default:True) – Whether to include multi-call diagnostics in the loss.
- build(step_method)[source]
-
config:
Optional[MultiCallConfig] = None
- extend_normalizer_with_multi_call_outputs(normalizer)[source]
Extend the normalizer by setting multi-call output names to use the same normalization as their base counterparts.
- Return type:
- Parameters:
normalizer (StandardNormalizer) –
- get_loss_normalizer(extra_names=None, extra_residual_scaled_names=None)[source]
Get the loss normalizer for the multi-call step.
Normalizer will use statistics from multi-call variables in the stats dataset, meaning the normalization for multi-call output versions will be different from the normalization for the base variables.
- get_ocean()[source]
- Return type:
- get_step(dataset_info, init_weights)[source]
- Parameters:
dataset_info (
DatasetInfo) – Information about the training dataset.init_weights (
Callable[[list[Module]],None]) – Function to initialize the weights of the step before wrapping in DistributedDataParallel. This is particularly useful when freezing parameters, as the DistributedDataParallel will otherwise expect frozen weights to have gradients, and will raise an exception.
- Return type:
MultiCallStep- Returns:
The state of the stepper.
-
include_multi_call_in_loss:
bool= True
- load()[source]
Update configuration in-place so it does not depend on external files.
- property n_ic_timesteps: int
- property next_step_input_names: list[str]
Names of variables required in next_step_input_data for .step.
- replace_multi_call(multi_call)[source]
- Parameters:
multi_call (MultiCallConfig | None) –
- replace_ocean(ocean)[source]
- Parameters:
ocean (OceanConfig | None) –
- replace_prescribed_prognostic_names(names)[source]
Replace prescribed prognostic names (e.g. when loading from checkpoint).
-
wrapped_step:
StepSelector
- class fme.core.step.SeparateRadiationStepConfig(builder, radiation_builder, main_prognostic_names, shared_forcing_names, radiation_only_forcing_names, radiation_diagnostic_names, main_diagnostic_names, normalization, next_step_forcing_names=<factory>, ocean=None, corrector=<factory>, detach_radiation=False, residual_prediction=False)[source]
Bases:
StepConfigABCConfiguration for a separate radiation stepper.
- Parameters:
builder (
ModuleSelector) – The module builder.radiation_builder (
ModuleSelector) – The radiation module builder.main_prognostic_names (
list[str]) – Names of prognostic variables. These are provided as input to both the main and radiation models, and output by the main model.shared_forcing_names (
list[str]) – Names of forcing variables.radiation_only_forcing_names (
list[str]) – Names of forcing variables for the radiation model, in addition to the ones specified in shared_forcing_names.radiation_diagnostic_names (
list[str]) – Names of diagnostic variables for the radiation model.main_diagnostic_names (
list[str]) – Names of diagnostic variables for the main model.normalization (
NetworkAndLossNormalizationConfig) – The normalization configuration.next_step_forcing_names (
list[str], default:<factory>) – Names of forcing variables which come from the output timestep.ocean (
Optional[OceanConfig], default:None) – The ocean configuration.corrector (
AtmosphereCorrectorConfig|CorrectorSelector, default:<factory>) – The corrector configuration.detach_radiation (
bool, default:False) – Whether to detach the output of the radiation model before passing it to the main model. The radiation outputs returned by .step() will not be detached.residual_prediction (
bool, default:False) – Whether to use residual prediction.
-
builder:
ModuleSelector
-
corrector:
AtmosphereCorrectorConfig|CorrectorSelector
-
detach_radiation:
bool= False
- classmethod from_state(state)[source]
- Return type:
- get_loss_normalizer(extra_names=None, extra_residual_scaled_names=None)[source]
- Parameters:
- Return type:
- Returns:
The loss normalizer.
- get_ocean()[source]
- Return type:
- get_state()[source]
- get_step(dataset_info, init_weights)[source]
- Parameters:
dataset_info (
DatasetInfo) – Information about the training dataset.init_weights (
Callable[[list[Module]],None]) – Function to initialize the weights of the step before wrapping in DistributedDataParallel. This is particularly useful when freezing parameters, as the DistributedDataParallel will otherwise expect frozen weights to have gradients, and will raise an exception.
- Return type:
SeparateRadiationStep- Returns:
The state of the stepper.
- load()[source]
Update configuration in-place so it does not depend on external files.
- property n_ic_timesteps: int
-
normalization:
NetworkAndLossNormalizationConfig
-
ocean:
Optional[OceanConfig] = None
-
radiation_builder:
ModuleSelector
- replace_ocean(ocean)[source]
- Parameters:
ocean (OceanConfig | None) –
-
residual_prediction:
bool= False