import dataclasses
import datetime
import logging
import os
from collections.abc import Callable, Mapping, Sequence
import cftime
import dacite
import numpy as np
import numpy.typing as npt
import torch
import fme
from fme.ace.aggregator import OneStepAggregatorConfig
from fme.ace.aggregator.inference import InferenceEvaluatorAggregatorConfig
from fme.ace.data_loading.batch_data import BatchData, PrognosticState
from fme.ace.data_loading.config import DataLoaderConfig
from fme.ace.data_loading.getters import get_gridded_data, get_inference_data
from fme.ace.data_loading.inference import InferenceDataLoaderConfig
from fme.ace.inference.data_writer import DataWriterConfig, PairedDataWriter
from fme.ace.inference.data_writer.dataset_metadata import DatasetMetadata
from fme.ace.inference.default_metadata import get_default_variable_metadata
from fme.ace.inference.loop import DeriverABC, run_dataset_comparison
from fme.ace.stepper import (
Stepper,
StepperOverrideConfig,
load_stepper,
load_stepper_config_with_override,
)
from fme.ace.stepper.single_module import (
StepperConfig,
TrainStepper,
TrainStepperConfig,
)
from fme.ace.stepper.time_length_probabilities import (
TimeLengthProbabilities,
TimeLengthSchedule,
)
from fme.core.cli import prepare_config, prepare_directory
from fme.core.cloud import makedirs
from fme.core.dataset.data_typing import VariableMetadata
from fme.core.dataset_info import IncompatibleDatasetInfo
from fme.core.derived_variables import get_derived_variable_metadata
from fme.core.generics.inference import get_record_to_wandb, run_inference
from fme.core.generics.validation import run_validation
from fme.core.logging_utils import LoggingConfig
from fme.core.timing import GlobalTimer
from fme.core.typing_ import TensorDict, TensorMapping
def resolve_variable_metadata(
dataset_metadata: Mapping[str, VariableMetadata],
stepper_metadata: Mapping[str, VariableMetadata],
stepper_all_names: Sequence[str],
) -> dict[str, VariableMetadata]:
"""
Resolve variable metadata by merging from the following sources: derived variables,
the dataset, the stepper, and finally a set of defaults. If there are conflicts on
variable metadata values, preference is given first to values from the stepper,
then from the dataset, and finally from default values.
Note that if not saved with the stepper, the variable metadata is not guaranteed to
be the same as that in the dataset used for training the stepper.
Args:
dataset_metadata: Metadata from the dataset.
stepper_metadata: Metadata from the stepper.
stepper_all_names: Variable names associated with the stepper.
Returns:
A mappping of variable names to metadata.
"""
default_metadata = get_default_variable_metadata(version="era5_v1")
names_from_default = (
set(stepper_all_names) - (dataset_metadata.keys() | stepper_metadata.keys())
) & default_metadata.keys()
if names_from_default:
logging.warning(
"Variable metadata for the following stepper variables were not found in "
"the variable metadata of the forcing dataset or stepper: "
f"{names_from_default}. Using default values for these variables instead. "
"Users should ensure that the default values are consistent with the "
"training dataset of the stepper."
)
resolved_metadata = (
default_metadata | dict(dataset_metadata) | dict(stepper_metadata)
)
resolved_metadata = {
name: resolved_metadata[name]
for name in stepper_all_names
if name in resolved_metadata
}
return get_derived_variable_metadata() | resolved_metadata
[docs]@dataclasses.dataclass
class ValidationConfig:
"""
Configuration for running "validation" within an inference evaluator job.
This mirrors the validation loop performed at the end of each training
epoch, producing metrics like ``val/mean/weighted_rmse`` and
``val/mean/loss``. A possible use case is to configure ``loader`` so that it
matches the validation data loader used during training, but other periods or
datasets that are compatible with the checkpoint may also be used.
Parameters:
loader: Data loader configuration for validation data. Uses the same
:class:`~fme.ace.data_loading.config.DataLoaderConfig` as training
data loaders.
aggregator: Configuration for the one-step validation aggregator.
stepper_training: Training-specific configuration including loss, ensemble
settings, and forward step scheduling. Set this to match the training
configuration if you want ``val/mean/loss`` to be directly comparable.
The number of forward steps is derived from
``stepper_training.n_forward_steps`` (defaults to 1 if unset).
"""
loader: DataLoaderConfig
aggregator: OneStepAggregatorConfig = dataclasses.field(
default_factory=lambda: OneStepAggregatorConfig()
)
stepper_training: TrainStepperConfig = dataclasses.field(
default_factory=lambda: TrainStepperConfig()
)
def __post_init__(self):
if self.stepper_training.parameter_init.weights_path is not None:
raise ValueError(
"stepper_training.parameter_init is not used for validation within "
"inference evaluator jobs."
)
if isinstance(self.stepper_training.n_forward_steps, TimeLengthSchedule):
raise ValueError(
"stepper_training.n_forward_steps may not be a "
"TimeLengthSchedule for validation within inference evaluator jobs. "
"Use TimeLengthProbabilities or an int instead."
)
[docs] def get_n_forward_steps(self) -> int:
"""Resolve the effective number of forward steps for validation.
Derives the value from ``stepper_training.n_forward_steps``.
Defaults to 1 for standard single-step validation if unset.
"""
train_n = self.stepper_training.n_forward_steps
if train_n is None:
logging.info(
"stepper_training.n_forward_steps was not configured for "
"validation within the inference evaluator job, defaulting to "
"n_forward_steps=1."
)
return 1
if isinstance(train_n, int):
return train_n
assert isinstance(train_n, TimeLengthProbabilities) # already validated
return train_n.max_n_forward_steps
[docs]@dataclasses.dataclass
class InferenceEvaluatorConfig:
"""
Configuration for running inference including comparison to reference data.
Parameters:
experiment_dir: Directory to save results to. This can be a local
directory, like ``/results``, or a remote directory prefixed with a
protocol recognized by ``fsspec``, like ``gs://bucket/results``.
.. note::
While most types of output can be written to a remote
``experiment_dir``, there are some limitations:
- To write raw or time-coarsened data, the zarr writer must be
used. See the ``files`` parameter of the
:class:`fme.ace.DataWriterConfig` for more details on how this
can be configured. Note that monthly coarsened data cannot
currently be written to zarr, and hence a remote directory,
since it uses a different code path than uniformly coarsened
data.
- Piping logging output to a file in the ``experiment_dir``
is not supported. To silence the warning related to this, set
``log_to_file`` to ``False`` in the
:class:`fme.ace.LoggingConfig`.
There are no restrictions on the types of output that can be
written to a local ``experiment_dir``.
n_forward_steps: Number of steps to run the model forward for.
checkpoint_path: Path to stepper checkpoint to load.
logging: configuration for logging.
loader: Configuration for data to be used as initial conditions, forcing, and
target in inference.
prediction_loader: Configuration for prediction data to evaluate. If given,
model evaluation will not run, and instead predictions will be evaluated.
Model checkpoint will still be used to determine inputs and outputs.
forward_steps_in_memory: Number of forward steps to complete in memory
at a time, will load one more step for initial condition.
data_writer: Configuration for data writers.
aggregator: Configuration for inference evaluator aggregator.
stepper_override: Configuration for overriding select stepper configuration
options at inference time (optional).
allow_incompatible_dataset: If True, allow the forcing dataset used
for inference to be incompatible with the dataset used for stepper training.
This should be used with caution, as it may allow the stepper to make
scientifically invalid predictions, but it can allow running inference with
incorrectly formatted or missing grid information.
validation: Optional configuration for running a one-step validation loop
before inference. When provided, validation runs first and produces
metrics prefixed with ``val/`` (e.g. ``val/mean/weighted_rmse``),
mirroring the validation done at the end of each training epoch.
n_ensemble_per_ic: Number of ensemble members per initial condition. Useful for
stochastic model weather inference. n_ensemble_per_ic = 1 is default
inference behavior.
"""
experiment_dir: str
n_forward_steps: int
checkpoint_path: str
logging: LoggingConfig
loader: InferenceDataLoaderConfig
forward_steps_in_memory: int
prediction_loader: InferenceDataLoaderConfig | None = None
data_writer: DataWriterConfig = dataclasses.field(
default_factory=lambda: DataWriterConfig()
)
aggregator: InferenceEvaluatorAggregatorConfig = dataclasses.field(
default_factory=lambda: InferenceEvaluatorAggregatorConfig()
)
stepper_override: StepperOverrideConfig | None = None
allow_incompatible_dataset: bool = False
validation: ValidationConfig | None = None
n_ensemble_per_ic: int = 1
def __post_init__(self):
if self.data_writer.time_coarsen is not None:
self.data_writer.time_coarsen.validate(
self.forward_steps_in_memory,
self.n_forward_steps,
)
if self.data_writer.files is not None:
for file_config in self.data_writer.files:
if file_config.time_coarsen is not None:
file_config.time_coarsen.validate(
self.forward_steps_in_memory,
self.n_forward_steps,
)
for log_step_mean in self.aggregator.log_step_means:
log_step_mean.validate(self.n_forward_steps)
def configure_logging(self, log_filename: str):
config = dataclasses.asdict(self)
self.logging.configure_logging(
self.experiment_dir, log_filename, config=config, resumable=False
)
def load_stepper(self) -> Stepper:
logging.info(f"Loading trained model checkpoint from {self.checkpoint_path}")
return load_stepper(self.checkpoint_path, self.stepper_override)
def load_stepper_config(self) -> StepperConfig:
logging.info(f"Loading trained model checkpoint from {self.checkpoint_path}")
return load_stepper_config_with_override(
self.checkpoint_path, self.stepper_override
)
def get_data_writer(
self,
initial_condition_times: npt.NDArray[cftime.datetime],
timestep: datetime.timedelta,
variable_metadata: Mapping[str, VariableMetadata],
coords: Mapping[str, np.ndarray],
) -> PairedDataWriter:
# initial_condition_times from data.initial_time already has one entry per
# sample (n_ic * n_ensemble_per_ic); do not repeat by n_ensemble_per_ic again.
return self.data_writer.build_paired(
experiment_dir=self.experiment_dir,
initial_condition_times=initial_condition_times,
n_timesteps=self.n_forward_steps,
timestep=timestep,
variable_metadata=variable_metadata,
coords=coords,
dataset_metadata=DatasetMetadata.from_env(),
)
def main(yaml_config: str, override_dotlist: Sequence[str] | None = None):
config_data = prepare_config(yaml_config, override=override_dotlist)
config = dacite.from_dict(
data_class=InferenceEvaluatorConfig,
data=config_data,
config=dacite.Config(strict=True),
)
prepare_directory(config.experiment_dir, config_data)
with GlobalTimer(), torch.no_grad():
return run_evaluator_from_config(config)
class _Deriver(DeriverABC):
"""
DeriverABC implementation for dataset comparison.
"""
def __init__(
self,
n_ic_timesteps: int,
derive_func: Callable[[TensorMapping, TensorMapping], TensorDict],
):
self._n_ic_timesteps = n_ic_timesteps
self._derive_func = derive_func
@property
def n_ic_timesteps(self) -> int:
return self._n_ic_timesteps
def get_forward_data(
self, data: BatchData, compute_derived_variables: bool = False
) -> BatchData:
if compute_derived_variables:
timer = GlobalTimer.get_instance()
with timer.context("compute_derived_variables"):
data = data.compute_derived_variables(
derive_func=self._derive_func,
forcing_data=data,
)
return data.remove_initial_condition(self._n_ic_timesteps)
def run_evaluator_from_config(config: InferenceEvaluatorConfig):
timer = GlobalTimer.get_instance()
timer.start_outer("inference")
with timer.context("initialization"):
makedirs(config.experiment_dir, exist_ok=True)
config.configure_logging(log_filename="inference_out.log")
if fme.using_gpu():
torch.backends.cudnn.benchmark = True
stepper_config = config.load_stepper_config()
logging.info("Initializing data loader")
window_requirements = stepper_config.get_evaluation_window_data_requirements(
n_forward_steps=config.forward_steps_in_memory
)
initial_condition_requirements = (
stepper_config.get_prognostic_state_data_requirements()
)
data = get_inference_data(
config=config.loader,
total_forward_steps=config.n_forward_steps,
window_requirements=window_requirements,
initial_condition=initial_condition_requirements,
)
if config.n_ensemble_per_ic > 1:
ic = data.initial_condition.as_batch_data()
data._initial_condition = PrognosticState(
ic.broadcast_ensemble(config.n_ensemble_per_ic)
)
stepper = config.load_stepper()
stepper.set_eval()
if not config.allow_incompatible_dataset:
try:
stepper.training_dataset_info.assert_compatible_with(data.dataset_info)
except IncompatibleDatasetInfo as err:
raise IncompatibleDatasetInfo(
"Inference dataset is not compatible with dataset used for stepper "
"training. Set allow_incompatible_dataset to True to ignore this "
f"error. The incompatiblity found was: {str(err)}"
) from err
aggregator_config: InferenceEvaluatorAggregatorConfig = config.aggregator
for batch in data.loader:
initial_time = batch.time.isel(time=0)
break
variable_metadata = resolve_variable_metadata(
dataset_metadata=data.variable_metadata,
stepper_metadata=stepper.training_variable_metadata,
stepper_all_names=stepper_config.all_names,
)
dataset_info = data.dataset_info.update_variable_metadata(variable_metadata)
if config.validation is not None:
timer.stop_outer("inference")
timer.start_outer("validation")
with timer.context("initialization"):
logging.info("Initializing validation data loader")
data_requirements = stepper_config.get_evaluation_window_data_requirements(
n_forward_steps=config.validation.get_n_forward_steps()
)
valid_data = get_gridded_data(
config.validation.loader,
requirements=data_requirements,
train=False,
)
logging.info("Building validation stepper and aggregator")
train_stepper_config = config.validation.stepper_training
train_stepper = TrainStepper(stepper=stepper, config=train_stepper_config)
aggregator = config.validation.aggregator.build(
dataset_info=dataset_info,
loss_scaling=train_stepper.effective_loss_scaling,
save_diagnostics=True,
output_dir=os.path.join(config.experiment_dir, "validation"),
channel_mean_names=stepper.loss_names,
)
run_validation(
train_stepper=train_stepper,
validation_data=valid_data,
aggregator=aggregator,
label="val",
log_progress=True,
)
timer.stop_outer("validation")
timer.start_outer("inference")
with timer.context("initialization"):
aggregator = aggregator_config.build(
dataset_info=dataset_info,
n_ic_steps=stepper_config.n_ic_timesteps,
n_forward_steps=config.n_forward_steps,
initial_time=initial_time,
channel_mean_names=stepper.loss_names,
normalize=stepper.normalizer.normalize,
output_dir=config.experiment_dir,
n_ensemble_per_ic=config.n_ensemble_per_ic,
)
writer = config.get_data_writer(
initial_condition_times=data.initial_time.to_numpy(),
timestep=data.timestep,
variable_metadata=variable_metadata,
coords=data.coords,
)
logging.info("Starting inference")
logger = get_record_to_wandb(label="inference")
if config.prediction_loader is not None:
prediction_data = get_inference_data(
config.prediction_loader,
total_forward_steps=config.n_forward_steps,
window_requirements=window_requirements,
initial_condition=initial_condition_requirements,
)
if config.n_ensemble_per_ic > 1:
ic = prediction_data.initial_condition.as_batch_data()
prediction_data._initial_condition = PrognosticState(
ic.broadcast_ensemble(config.n_ensemble_per_ic)
)
deriver = _Deriver(
n_ic_timesteps=stepper_config.n_ic_timesteps,
derive_func=stepper.derive_func,
)
run_dataset_comparison(
aggregator=aggregator,
prediction_data=prediction_data,
target_data=data,
deriver=deriver,
writer=writer,
record_logs=logger.log,
)
else:
run_inference(
predict=stepper.predict_paired,
data=data,
aggregator=aggregator,
writer=writer,
record_logs=logger.log,
)
with timer.context("final_writer_flush"):
logging.info("Starting final flush of data writer")
writer.finalize()
logging.info("Writing reduced metrics to disk in netcdf format.")
aggregator.flush_diagnostics()
timer.stop_outer("inference")
total_steps = config.n_forward_steps * config.loader.n_initial_conditions
inference_duration = timer.get_duration("inference")
wandb_logging_duration = timer.get_duration("inference/wandb_logging")
total_steps_per_second = total_steps / (inference_duration - wandb_logging_duration)
timer.log_durations()
logging.info(
"Total steps per second (ignoring wandb logging): "
f"{total_steps_per_second:.2f} steps/second"
)
summary_logs = {
"total_steps_per_second": total_steps_per_second,
**aggregator.get_summary_logs(),
}
logger.log_to_current_step(summary_logs) # prefix "inference/"
logger.log_to_current_step(
timer.get_durations(), label=""
) # durations already prefixed