import dataclasses
import datetime
import logging
from copy import copy
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
import dacite
import torch
import xarray as xr
from torch import nn
from fme.ace.data_loading.batch_data import BatchData, PairedData, PrognosticState
from fme.ace.inference.derived_variables import compute_derived_quantities
from fme.ace.requirements import PrognosticStateDataRequirements
from fme.core.coordinates import HybridSigmaPressureCoordinate
from fme.core.corrector.corrector import CorrectorConfig
from fme.core.dataset.requirements import DataRequirements
from fme.core.dataset.utils import decode_timestep, encode_timestep
from fme.core.device import get_device
from fme.core.distributed import Distributed
from fme.core.generics.inference import PredictFunction
from fme.core.generics.optimization import OptimizationABC
from fme.core.generics.train_stepper import TrainOutputABC, TrainStepperABC
from fme.core.gridded_ops import GriddedOperations, LatLonOperations
from fme.core.loss import WeightedMappingLossConfig
from fme.core.normalizer import NormalizationConfig, StandardNormalizer
from fme.core.ocean import Ocean, OceanConfig
from fme.core.optimization import NullOptimization
from fme.core.packer import Packer
from fme.core.parameter_init import ParameterInitializationConfig
from fme.core.registry import CorrectorSelector, ModuleSelector
from fme.core.timing import GlobalTimer
from fme.core.typing_ import TensorDict, TensorMapping
DEFAULT_TIMESTEP = datetime.timedelta(hours=6)
DEFAULT_ENCODED_TIMESTEP = encode_timestep(DEFAULT_TIMESTEP)
class AtmosphericDeriveFn:
def __init__(
self,
vertical_coordinate: HybridSigmaPressureCoordinate,
timestep: datetime.timedelta,
):
self.vertical_coordinate = vertical_coordinate.to(
"cpu"
) # must be on cpu for multiprocessing fork context
self.timestep = timestep
def __call__(self, data: TensorMapping, forcing_data: TensorMapping) -> TensorDict:
return compute_derived_quantities(
dict(data),
vertical_coordinate=self.vertical_coordinate.to(get_device()),
timestep=self.timestep,
forcing_data=dict(forcing_data),
)
[docs]@dataclasses.dataclass
class SingleModuleStepperConfig:
"""
Configuration for a single module stepper.
Parameters:
builder: The module builder.
in_names: Names of input variables.
out_names: Names of output variables.
normalization: The normalization configuration.
parameter_init: The parameter initialization configuration.
ocean: The ocean configuration.
loss: The loss configuration.
corrector: The corrector configuration.
next_step_forcing_names: Names of forcing variables for the next timestep.
loss_normalization: The normalization configuration for the loss.
residual_normalization: Optional alternative to configure loss normalization.
If provided, it will be used for all *prognostic* variables in loss scaling.
"""
builder: ModuleSelector
in_names: List[str]
out_names: List[str]
normalization: NormalizationConfig
parameter_init: ParameterInitializationConfig = dataclasses.field(
default_factory=lambda: ParameterInitializationConfig()
)
ocean: Optional[OceanConfig] = None
loss: WeightedMappingLossConfig = dataclasses.field(
default_factory=lambda: WeightedMappingLossConfig()
)
corrector: Union[CorrectorConfig, CorrectorSelector] = dataclasses.field(
default_factory=lambda: CorrectorConfig()
)
next_step_forcing_names: List[str] = dataclasses.field(default_factory=list)
loss_normalization: Optional[NormalizationConfig] = None
residual_normalization: Optional[NormalizationConfig] = None
def __post_init__(self):
for name in self.next_step_forcing_names:
if name not in self.in_names:
raise ValueError(
f"next_step_forcing_name '{name}' not in in_names: {self.in_names}"
)
if name in self.out_names:
raise ValueError(
f"next_step_forcing_name is an output variable: '{name}'"
)
if (
self.residual_normalization is not None
and self.loss_normalization is not None
):
raise ValueError(
"Only one of residual_normalization, loss_normalization can "
"be provided."
"If residual_normalization is provided, it will be used for all "
"*prognostic* variables in loss scalng. "
"If loss_normalization is provided, it will be used for all variables "
"in loss scaling."
)
@property
def n_ic_timesteps(self) -> int:
return 1
def get_evaluation_window_data_requirements(
self, n_forward_steps: int
) -> DataRequirements:
return DataRequirements(
names=self.all_names,
n_timesteps=self._window_steps_required(n_forward_steps),
)
def get_prognostic_state_data_requirements(self) -> PrognosticStateDataRequirements:
return PrognosticStateDataRequirements(
names=self.prognostic_names,
n_timesteps=self.n_ic_timesteps,
)
def get_forcing_window_data_requirements(
self, n_forward_steps: int
) -> DataRequirements:
if self.ocean is None:
names = self.forcing_names
else:
names = list(set(self.forcing_names).union(self.ocean.forcing_names))
return DataRequirements(
names=names,
n_timesteps=self._window_steps_required(n_forward_steps),
)
def _window_steps_required(self, n_forward_steps: int) -> int:
return n_forward_steps + self.n_ic_timesteps
def get_state(self):
return dataclasses.asdict(self)
[docs] def get_base_weights(self) -> Optional[List[Mapping[str, Any]]]:
"""
If the model is being initialized from another model's weights for fine-tuning,
returns those weights. Otherwise, returns None.
The list mirrors the order of `modules` in the `SingleModuleStepper` class.
"""
base_weights = self.parameter_init.get_base_weights()
if base_weights is not None:
return [base_weights]
else:
return None
def get_stepper(
self,
img_shape: Tuple[int, int],
gridded_operations: GriddedOperations,
vertical_coordinate: HybridSigmaPressureCoordinate,
timestep: datetime.timedelta,
):
logging.info("Initializing stepper from provided config")
derive_func = AtmosphericDeriveFn(vertical_coordinate, timestep)
return SingleModuleStepper(
config=self,
img_shape=img_shape,
gridded_operations=gridded_operations,
vertical_coordinate=vertical_coordinate,
timestep=timestep,
derive_func=derive_func,
)
@classmethod
def from_state(cls, state) -> "SingleModuleStepperConfig":
state = cls.remove_deprecated_keys(state)
return dacite.from_dict(
data_class=cls, data=state, config=dacite.Config(strict=True)
)
@property
def all_names(self):
"""Names of all variables required, including auxiliary ones."""
extra_names = []
if self.ocean is not None:
extra_names.extend(self.ocean.forcing_names)
all_names = list(set(self.in_names).union(self.out_names).union(extra_names))
return all_names
@property
def normalize_names(self):
"""Names of variables which require normalization. I.e. inputs/outputs."""
return list(set(self.in_names).union(self.out_names))
@property
def forcing_names(self) -> List[str]:
"""Names of variables which are inputs only."""
return list(set(self.in_names) - set(self.out_names))
@property
def prognostic_names(self) -> List[str]:
"""Names of variables which both inputs and outputs."""
return list(set(self.out_names).intersection(self.in_names))
@property
def diagnostic_names(self) -> List[str]:
"""Names of variables which both inputs and outputs."""
return list(set(self.out_names).difference(self.in_names))
@classmethod
def remove_deprecated_keys(cls, state: Dict[str, Any]) -> Dict[str, Any]:
_unsupported_key_defaults = {
"conserve_dry_air": False,
"optimization": None,
"conservation_loss": {"dry_air_penalty": None},
}
state_copy = state.copy()
for key, default in _unsupported_key_defaults.items():
if key in state_copy:
if state_copy[key] == default or state_copy[key] is None:
del state_copy[key]
else:
raise ValueError(
f"The stepper config option {key} is deprecated and the setting"
f" provided, {state_copy[key]}, is no longer implemented. The "
"SingleModuleStepper being loaded from state cannot be run by "
"this version of the code."
)
for normalization_key in [
"normalization",
"loss_normalization",
"residual_normalization",
]:
if state_copy.get(normalization_key) is not None:
if "exclude_names" in state_copy[normalization_key]:
if state_copy[normalization_key]["exclude_names"] is not None:
raise ValueError(
"The exclude_names option in normalization config is no "
"longer supported, but excluded names were found in "
f"{normalization_key}."
)
else:
del state_copy[normalization_key]["exclude_names"]
if "prescriber" in state_copy:
# want to maintain backwards compatibility for this particular feature
if state_copy["prescriber"] is not None:
if state_copy.get("ocean") is not None:
raise ValueError("Cannot specify both prescriber and ocean.")
state_copy["ocean"] = {
"surface_temperature_name": state_copy["prescriber"][
"prescribed_name"
],
"ocean_fraction_name": state_copy["prescriber"]["mask_name"],
"interpolate": state_copy["prescriber"]["interpolate"],
}
del state_copy["prescriber"]
return state_copy
[docs]@dataclasses.dataclass
class ExistingStepperConfig:
"""
Configuration for an existing stepper. This is only designed to point to
a serialized stepper checkpoint for loading, e.g., in the case of training
resumption.
Parameters:
checkpoint_path: The path to the serialized checkpoint.
"""
checkpoint_path: str
def __post_init__(self):
self._stepper_config = SingleModuleStepperConfig.from_state(
self._load_checkpoint()["stepper"]["config"]
)
def _load_checkpoint(self) -> Mapping[str, Any]:
return torch.load(self.checkpoint_path, map_location=get_device())
def get_evaluation_window_data_requirements(
self, n_forward_steps: int
) -> DataRequirements:
return self._stepper_config.get_evaluation_window_data_requirements(
n_forward_steps
)
def get_prognostic_state_data_requirements(self) -> PrognosticStateDataRequirements:
return self._stepper_config.get_prognostic_state_data_requirements()
def get_forcing_window_data_requirements(
self, n_forward_steps: int
) -> DataRequirements:
return self._stepper_config.get_forcing_window_data_requirements(
n_forward_steps
)
def get_base_weights(self) -> Optional[List[Mapping[str, Any]]]:
return self._stepper_config.get_base_weights()
def get_stepper(self, img_shape, gridded_operations, vertical_coordinate, timestep):
del img_shape # unused
logging.info(f"Initializing stepper from {self.checkpoint_path}")
return SingleModuleStepper.from_state(self._load_checkpoint()["stepper"])
def _combine_normalizers(
residual_normalizer: StandardNormalizer,
model_normalizer: StandardNormalizer,
) -> StandardNormalizer:
# Combine residual and model normalizers by overwriting the model normalizer
# values that are present in residual normalizer. The residual normalizer
# is assumed to have a subset of prognostic keys only.
means, stds = copy(model_normalizer.means), copy(model_normalizer.stds)
means.update(residual_normalizer.means)
stds.update(residual_normalizer.stds)
return StandardNormalizer(means=means, stds=stds)
def _prepend_timesteps(
data: TensorMapping, timesteps: TensorMapping, time_dim: int = 1
) -> TensorDict:
return {k: torch.cat([timesteps[k], v], dim=time_dim) for k, v in data.items()}
@dataclasses.dataclass
class TrainOutput(TrainOutputABC):
metrics: TensorDict
gen_data: TensorDict
target_data: TensorDict
time: xr.DataArray
normalize: Callable[[TensorDict], TensorDict]
derive_func: Callable[[TensorMapping, TensorMapping], TensorDict] = (
lambda x, _: dict(x)
)
def remove_initial_condition(self, n_ic_timesteps: int) -> "TrainOutput":
return TrainOutput(
metrics=self.metrics,
gen_data={k: v[:, n_ic_timesteps:] for k, v in self.gen_data.items()},
target_data={k: v[:, n_ic_timesteps:] for k, v in self.target_data.items()},
time=self.time[:, n_ic_timesteps:],
normalize=self.normalize,
derive_func=self.derive_func,
)
def copy(self) -> "TrainOutput":
"""Creates new dictionaries for the data but with the same tensors."""
return TrainOutput(
metrics=self.metrics,
gen_data={k: v for k, v in self.gen_data.items()},
target_data={k: v for k, v in self.target_data.items()},
time=self.time,
normalize=self.normalize,
derive_func=self.derive_func,
)
def prepend_initial_condition(
self,
initial_condition: PrognosticState,
) -> "TrainOutput":
"""
Prepends an initial condition to the existing stepped data.
Assumes data are on the same device.
For data windows > 0, the target IC is different from the generated IC
and may be provided for correct calculation of tendencies.
Args:
initial_condition: Initial condition data.
"""
batch_data = initial_condition.as_batch_data()
return TrainOutput(
metrics=self.metrics,
gen_data=_prepend_timesteps(self.gen_data, batch_data.data),
target_data=_prepend_timesteps(
self.target_data,
batch_data.data,
),
time=xr.concat([batch_data.time, self.time], dim="time"),
normalize=self.normalize,
derive_func=self.derive_func,
)
def compute_derived_variables(
self,
) -> "TrainOutput":
gen_data = self.derive_func(self.gen_data, self.target_data)
target_data = self.derive_func(self.target_data, self.target_data)
return TrainOutput(
metrics=self.metrics,
gen_data=gen_data,
target_data=target_data,
time=self.time,
normalize=self.normalize,
derive_func=self.derive_func,
)
def get_metrics(self) -> TensorDict:
return self.metrics
class SingleModuleStepper(
TrainStepperABC[
PrognosticState,
BatchData,
BatchData,
PairedData,
TrainOutput,
],
):
"""
Stepper class for a single pytorch module.
"""
TIME_DIM = 1
CHANNEL_DIM = -3
def __init__(
self,
config: SingleModuleStepperConfig,
img_shape: Tuple[int, int],
gridded_operations: GriddedOperations,
vertical_coordinate: HybridSigmaPressureCoordinate,
derive_func: Callable[[TensorMapping, TensorMapping], TensorDict],
timestep: datetime.timedelta,
init_weights: bool = True,
):
"""
Args:
config: The configuration.
img_shape: Shape of domain as (n_lat, n_lon).
gridded_operations: The gridded operations, e.g. for area weighting.
vertical_coordinate: The vertical coordinate.
derive_func: Function to compute derived variables.
timestep: Timestep of the model.
init_weights: Whether to initialize the weights. Should pass False if
the weights are about to be overwritten by a checkpoint.
"""
self._gridded_operations = gridded_operations # stored for serializing
n_in_channels = len(config.in_names)
n_out_channels = len(config.out_names)
self.in_packer = Packer(config.in_names)
self.out_packer = Packer(config.out_names)
self.normalizer = config.normalization.build(config.normalize_names)
if config.ocean is not None:
self.ocean: Optional[Ocean] = config.ocean.build(
config.in_names, config.out_names, timestep
)
else:
self.ocean = None
self.module = config.builder.build(
n_in_channels=n_in_channels,
n_out_channels=n_out_channels,
img_shape=img_shape,
)
module, self._l2_sp_tuning_regularizer = config.parameter_init.apply(
self.module, init_weights=init_weights
)
self.module = module.to(get_device())
self.derive_func = derive_func
self._img_shape = img_shape
self._config = config
self._no_optimization = NullOptimization()
dist = Distributed.get_instance()
self._is_distributed = dist.is_distributed()
self.module = dist.wrap_module(self.module)
self._vertical_coordinates = vertical_coordinate.to(get_device())
self._timestep = timestep
self.loss_obj = config.loss.build(
gridded_operations.area_weighted_mean, config.out_names, self.CHANNEL_DIM
)
self._corrector = config.corrector.build(
gridded_operations=gridded_operations,
vertical_coordinate=self.vertical_coordinate,
timestep=timestep,
)
if config.loss_normalization is not None:
self.loss_normalizer = config.loss_normalization.build(
names=config.normalize_names
)
elif config.residual_normalization is not None:
# Use residual norm for prognostic variables and input/output
# normalizer for diagnostic variables in loss
self.loss_normalizer = _combine_normalizers(
residual_normalizer=config.residual_normalization.build(
config.prognostic_names
),
model_normalizer=self.normalizer,
)
else:
self.loss_normalizer = self.normalizer
self.in_names = config.in_names
self.out_names = config.out_names
_1: PredictFunction[ # for type checking
PrognosticState,
BatchData,
BatchData,
] = self.predict
_2: PredictFunction[ # for type checking
PrognosticState,
BatchData,
PairedData,
] = self.predict_paired
@property
def vertical_coordinate(self) -> HybridSigmaPressureCoordinate:
return self._vertical_coordinates
@property
def timestep(self) -> datetime.timedelta:
return self._timestep
@property
def surface_temperature_name(self) -> Optional[str]:
if self._config.ocean is not None:
return self._config.ocean.surface_temperature_name
return None
@property
def ocean_fraction_name(self) -> Optional[str]:
if self._config.ocean is not None:
return self._config.ocean.ocean_fraction_name
return None
@property
def effective_loss_scaling(self) -> TensorDict:
"""
Effective loss scalings used to normalize outputs before computing loss.
y_loss_normalized_i = (y_i - y_mean_i) / loss_scaling_i
where loss_scaling_i = loss_normalizer_std_i / weight_i.
"""
custom_weights = self._config.loss.weights
loss_normalizer_stds = self.loss_normalizer.stds
return {
k: loss_normalizer_stds[k] / custom_weights.get(k, 1.0)
for k in self._config.out_names
}
def replace_ocean(self, ocean: Ocean):
"""
Replace the ocean model with a new one.
Args:
ocean: The new ocean model.
"""
self.ocean = ocean
@property
def forcing_names(self) -> List[str]:
"""Names of variables which are inputs only."""
return self._config.forcing_names
@property
def prognostic_names(self) -> List[str]:
return sorted(
list(set(self.out_packer.names).intersection(self.in_packer.names))
)
@property
def diagnostic_names(self) -> List[str]:
return sorted(list(set(self.out_packer.names).difference(self.in_packer.names)))
@property
def n_ic_timesteps(self) -> int:
return 1
@property
def modules(self) -> nn.ModuleList:
"""
Returns:
A list of modules being trained.
"""
return nn.ModuleList([self.module])
def step(
self,
input: TensorMapping,
next_step_forcing_data: TensorMapping,
) -> TensorDict:
"""
Step the model forward one timestep given input data.
Args:
input: Mapping from variable name to tensor of shape
[n_batch, n_lat, n_lon]. This data is used as input for `self.module`
and is assumed to contain all input variables and be denormalized.
next_step_forcing_data: Mapping from variable name to tensor of shape
[n_batch, n_lat, n_lon]. This must contain the necessary forcing
data at the output timestep for the ocean model and corrector.
Returns:
The denormalized output data at the next time step.
"""
input_norm = self.normalizer.normalize(input)
input_tensor = self.in_packer.pack(input_norm, axis=self.CHANNEL_DIM)
output_tensor = self.module(input_tensor)
output_norm = self.out_packer.unpack(output_tensor, axis=self.CHANNEL_DIM)
output = self.normalizer.denormalize(output_norm)
if self._corrector is not None:
output = self._corrector(input, output, next_step_forcing_data)
if self.ocean is not None:
output = self.ocean(input, output, next_step_forcing_data)
return output
def _predict(
self,
initial_condition: TensorMapping,
forcing_data: TensorMapping,
n_forward_steps: int,
) -> TensorDict:
"""
Predict multiple steps forward given initial condition and forcing data.
Uses low-level inputs and does not compute derived variables, to separate
concerns from the public `predict` method.
Args:
initial_condition: The initial condition, containing tensors of shape
[n_batch, self.n_ic_timesteps, <horizontal_dims>].
forcing_data: The forcing data, containing tensors of shape
[n_batch, n_forward_steps + self.n_ic_timesteps, <horizontal_dims>].
n_forward_steps: The number of forward steps to predict, corresponding
to the data shapes of forcing_data.
Returns:
The output data at each timestep.
"""
state = {
k: initial_condition[k].squeeze(self.TIME_DIM) for k in initial_condition
}
ml_forcing_names = self._config.forcing_names
output_list = []
for step in range(n_forward_steps):
ml_input_forcing = {
k: (
forcing_data[k][:, step]
if k not in self._config.next_step_forcing_names
else forcing_data[k][:, step + 1]
)
for k in ml_forcing_names
}
next_step_forcing_data = {
k: forcing_data[k][:, step + 1] for k in self._forcing_names()
}
input_data = {**state, **ml_input_forcing}
state = self.step(input_data, next_step_forcing_data)
output_list.append(state)
output_timeseries = {}
for name in state:
output_timeseries[name] = torch.stack(
[x[name] for x in output_list], dim=self.TIME_DIM
)
return output_timeseries
def predict(
self,
initial_condition: PrognosticState,
forcing: BatchData,
compute_derived_variables: bool = False,
) -> Tuple[BatchData, PrognosticState]:
"""
Predict multiple steps forward given initial condition and reference data.
Args:
initial_condition: Prognostic state data with tensors of shape
[n_batch, self.n_ic_timesteps, <horizontal_dims>]. This data is assumed
to contain all prognostic variables and be denormalized.
forcing: Contains tensors of shape
[n_batch, self.n_ic_timesteps + n_forward_steps, n_lat, n_lon]. This
contains the forcing and ocean data for the initial condition and all
subsequent timesteps.
compute_derived_variables: Whether to compute derived variables for the
prediction.
Returns:
A batch data containing the prediction and the prediction's final state
which can be used as a new initial condition.
"""
timer = GlobalTimer.get_instance()
with timer.context("forward_prediction"):
forcing_data = forcing.subset_names(self._forcing_names())
initial_condition_state = initial_condition.as_batch_data()
if initial_condition_state.time.shape[1] != self.n_ic_timesteps:
raise ValueError(
f"Initial condition must have {self.n_ic_timesteps} timesteps, got "
f"{initial_condition_state.time.shape[1]}."
)
n_forward_steps = forcing_data.time.shape[1] - self.n_ic_timesteps
output_timeseries = self._predict(
initial_condition_state.data, forcing_data.data, n_forward_steps
)
data = BatchData.new_on_device(
output_timeseries,
forcing_data.time[:, self.n_ic_timesteps :],
horizontal_dims=forcing_data.horizontal_dims,
)
if compute_derived_variables:
with timer.context("compute_derived_variables"):
data = (
data.prepend(initial_condition)
.compute_derived_variables(
derive_func=self.derive_func,
forcing_data=forcing_data,
)
.remove_initial_condition(self.n_ic_timesteps)
)
return data, data.get_end(self.prognostic_names, self.n_ic_timesteps)
def predict_paired(
self,
initial_condition: PrognosticState,
forcing: BatchData,
compute_derived_variables: bool = False,
) -> Tuple[PairedData, PrognosticState]:
"""
Predict multiple steps forward given initial condition and reference data.
Args:
initial_condition: Prognostic state data with tensors of shape
[n_batch, self.n_ic_timesteps, <horizontal_dims>]. This data is assumed
to contain all prognostic variables and be denormalized.
forcing: Contains tensors of shape
[n_batch, self.n_ic_timesteps + n_forward_steps, n_lat, n_lon]. This
contains the forcing and ocean data for the initial condition and all
subsequent timesteps.
compute_derived_variables: Whether to compute derived variables for the
prediction.
Returns:
A paired data containing the prediction paired with all forcing data at the
same timesteps and the prediction's final state which can be used as a
new initial condition.
"""
prediction, new_initial_condition = self.predict(
initial_condition, forcing, compute_derived_variables
)
return (
PairedData.from_batch_data(
prediction=prediction,
target=self.get_forward_data(
forcing, compute_derived_variables=compute_derived_variables
),
),
new_initial_condition,
)
def get_forward_data(
self, data: BatchData, compute_derived_variables: bool = False
) -> BatchData:
if compute_derived_variables:
timer = GlobalTimer.get_instance()
with timer.context("compute_derived_variables"):
data = data.compute_derived_variables(
derive_func=self.derive_func,
forcing_data=data,
)
return data.remove_initial_condition(self.n_ic_timesteps)
def _forcing_names(self) -> List[str]:
if self.ocean is None:
return self._config.forcing_names
return list(set(self._config.forcing_names).union(self.ocean.forcing_names))
def train_on_batch(
self,
data: BatchData,
optimization: OptimizationABC,
compute_derived_variables: bool = False,
) -> TrainOutput:
"""
Step the model forward multiple steps on a batch of data.
Args:
data: The batch data where each tensor in data.data has shape
[n_sample, n_forward_steps + self.n_ic_timesteps, <horizontal_dims>].
optimization: The optimization class to use for updating the module.
Use `NullOptimization` to disable training.
compute_derived_variables: Whether to compute derived variables for the
prediction and target data.
Returns:
The loss metrics, the generated data, the normalized generated data,
and the normalized batch data.
"""
time_dim = self.TIME_DIM
loss = torch.tensor(0.0, device=get_device())
metrics: Dict[str, float] = {}
input_data = data.get_start(self.prognostic_names, self.n_ic_timesteps)
optimization.set_mode(self.module)
with optimization.autocast():
# output from self.predict does not include initial condition
output, _ = self.predict_paired(
input_data,
forcing=data,
)
gen_data = output.prediction
target_data = output.target
n_forward_steps = output.time.shape[1]
# compute loss for each timestep
for step in range(n_forward_steps):
# Note: here we examine the loss for a single timestep,
# not a single model call (which may contain multiple timesteps).
gen_step = {k: v.select(time_dim, step) for k, v in gen_data.items()}
target_step = {
k: v.select(time_dim, step) for k, v in target_data.items()
}
gen_norm_step = self.loss_normalizer.normalize(gen_step)
target_norm_step = self.loss_normalizer.normalize(target_step)
step_loss = self.loss_obj(gen_norm_step, target_norm_step)
loss += step_loss
metrics[f"loss_step_{step}"] = step_loss.detach()
loss += self._l2_sp_tuning_regularizer()
metrics["loss"] = loss.detach()
optimization.step_weights(loss)
stepped = TrainOutput(
metrics=metrics,
gen_data=dict(gen_data),
target_data=dict(target_data),
time=output.time,
normalize=self.normalizer.normalize,
derive_func=self.derive_func,
)
ic = data.get_start(
set(data.data.keys()), self.n_ic_timesteps
) # full data and not just prognostic get prepended
stepped = stepped.prepend_initial_condition(ic)
if compute_derived_variables:
stepped = stepped.compute_derived_variables()
return stepped
def get_state(self):
"""
Returns:
The state of the stepper.
"""
return {
"module": self.module.state_dict(),
"normalizer": self.normalizer.get_state(),
"img_shape": self._img_shape,
"config": self._config.get_state(),
"gridded_operations": self._gridded_operations.to_state(),
"vertical_coordinate": self.vertical_coordinate.as_dict(),
"encoded_timestep": encode_timestep(self.timestep),
"loss_normalizer": self.loss_normalizer.get_state(),
}
def load_state(self, state: Dict[str, Any]) -> None:
"""
Load the state of the stepper.
Args:
state: The state to load.
"""
if "module" in state:
module = state["module"]
if "module.device_buffer" in module:
# for backwards compatibility with old checkpoints
del module["module.device_buffer"]
self.module.load_state_dict(module)
@classmethod
def from_state(cls, state) -> "SingleModuleStepper":
"""
Load the state of the stepper.
Args:
state: The state to load.
Returns:
The stepper.
"""
config = {**state["config"]} # make a copy to avoid mutating input
config["normalization"] = state["normalizer"]
# for backwards compatibility with previous steppers created w/o
# loss_normalization or residual_normalization
loss_normalizer_state = state.get("loss_normalizer", state["normalizer"])
config["loss_normalization"] = loss_normalizer_state
# Overwrite the residual_normalization key if it exists, since the combined
# loss scalings are saved in initial training as the loss_normalization
config["residual_normalization"] = None
if "area" in state:
# backwards-compatibility, these older checkpoints are always lat-lon
gridded_operations: GriddedOperations = LatLonOperations(state["area"])
else:
gridded_operations = GriddedOperations.from_state(
state["gridded_operations"]
)
if "sigma_coordinates" in state:
# for backwards compatibility with old checkpoints
state["vertical_coordinate"] = state["sigma_coordinates"]
vertical_coordinate = dacite.from_dict(
data_class=HybridSigmaPressureCoordinate,
data=state["vertical_coordinate"],
config=dacite.Config(strict=True),
)
# for backwards compatibility with original ACE checkpoint which
# serialized vertical coordinates as float64
if vertical_coordinate.ak.dtype == torch.float64:
vertical_coordinate.ak = vertical_coordinate.ak.to(dtype=torch.float32)
if vertical_coordinate.bk.dtype == torch.float64:
vertical_coordinate.bk = vertical_coordinate.bk.to(dtype=torch.float32)
encoded_timestep = state.get("encoded_timestep", DEFAULT_ENCODED_TIMESTEP)
timestep = decode_timestep(encoded_timestep)
if "img_shape" in state:
img_shape = state["img_shape"]
else:
# this is for backwards compatibility with old checkpoints
for v in state["data_shapes"].values():
img_shape = v[-2:]
break
derive_func = AtmosphericDeriveFn(vertical_coordinate, timestep)
stepper = cls(
config=SingleModuleStepperConfig.from_state(config),
img_shape=img_shape,
gridded_operations=gridded_operations,
vertical_coordinate=vertical_coordinate,
timestep=timestep,
derive_func=derive_func,
# don't need to initialize weights, we're about to load_state
init_weights=False,
)
stepper.load_state(state)
return stepper