Evaluator Config

The following is an example configuration for running inference while evaluating against target data. While you can use absolute paths in the config yamls (we encourage it!), the example uses paths relative to the directory you run the command. The example assumes you are running in a directory structure like:

.
├── ckpt.tar
└── validation
    ├── data1.nc  # files can have any name, but must sort into time-sequential order
    ├── data2.nc  # can have any number of netCDF files
    └── ...

The .nc files correspond to data files like training_validation_data/training_validation/1940010100.nc in the ACE2-ERA5 Hugging Face page, while ckpt.tar corresponds to a file like ace2_era5_ckpt.tar in that repository.

Example YAML Configuration
experiment_dir: evaluator_output
n_forward_steps: 400  # 100 days
forward_steps_in_memory: 50
checkpoint_path: ckpt.tar
logging:
  log_to_screen: true
  log_to_wandb: false
  log_to_file: true
  project: ace
  entity: your_wandb_entity
loader:
  dataset:
    data_path: validation
  start_indices:
    first: 0
    n_initial_conditions: 1
  num_data_workers: 4
data_writer:
  save_prediction_files: false
  save_monthly_files: false

We use the Builder pattern to load this configuration into a multi-level dataclass structure. The configuration is divided into several sub-configurations, each with its own dataclass. The top-level configuration is the fme.ace.InferenceEvaluatorConfig class.

class fme.ace.InferenceEvaluatorConfig(experiment_dir, n_forward_steps, checkpoint_path, logging, loader, forward_steps_in_memory, prediction_loader=None, data_writer=<factory>, aggregator=<factory>, stepper_override=None, allow_incompatible_dataset=False, validation=None, n_ensemble_per_ic=1)[source]

Bases: object

Configuration for running inference including comparison to reference data.

Parameters:
  • experiment_dir (str) –

    Directory to save results to. This can be a local directory, like /results, or a remote directory prefixed with a protocol recognized by fsspec, like gs://bucket/results.

    Note

    While most types of output can be written to a remote experiment_dir, there are some limitations:

    • To write raw or time-coarsened data, the zarr writer must be used. See the files parameter of the fme.ace.DataWriterConfig for more details on how this can be configured. Note that monthly coarsened data cannot currently be written to zarr, and hence a remote directory, since it uses a different code path than uniformly coarsened data.

    • Piping logging output to a file in the experiment_dir is not supported. To silence the warning related to this, set log_to_file to False in the fme.ace.LoggingConfig.

    There are no restrictions on the types of output that can be written to a local experiment_dir.

  • n_forward_steps (int) – Number of steps to run the model forward for.

  • checkpoint_path (str) – Path to stepper checkpoint to load.

  • logging (LoggingConfig) – configuration for logging.

  • loader (InferenceDataLoaderConfig) – Configuration for data to be used as initial conditions, forcing, and target in inference.

  • prediction_loader (Optional[InferenceDataLoaderConfig], default: None) – Configuration for prediction data to evaluate. If given, model evaluation will not run, and instead predictions will be evaluated. Model checkpoint will still be used to determine inputs and outputs.

  • forward_steps_in_memory (int) – Number of forward steps to complete in memory at a time, will load one more step for initial condition.

  • data_writer (DataWriterConfig, default: <factory>) – Configuration for data writers.

  • aggregator (InferenceEvaluatorAggregatorConfig | LegacyFlagInferenceEvaluatorAggregatorConfig, default: <factory>) – Configuration for inference evaluator aggregator.

  • stepper_override (Optional[StepperOverrideConfig], default: None) – Configuration for overriding select stepper configuration options at inference time (optional).

  • allow_incompatible_dataset (bool, default: False) – If True, allow the forcing dataset used for inference to be incompatible with the dataset used for stepper training. This should be used with caution, as it may allow the stepper to make scientifically invalid predictions, but it can allow running inference with incorrectly formatted or missing grid information.

  • validation (Optional[ValidationConfig], default: None) – Optional configuration for running a one-step validation loop before inference. When provided, validation runs first and produces metrics prefixed with val/ (e.g. val/mean/weighted_rmse), mirroring the validation done at the end of each training epoch.

  • n_ensemble_per_ic (int, default: 1) – Number of ensemble members per initial condition. Useful for stochastic model weather inference. n_ensemble_per_ic = 1 is default inference behavior.

The sub-configurations are:

class fme.ace.LoggingConfig(project='ace', entity='ai2cm', log_to_screen=True, log_to_file=True, log_to_wandb=True, metrics_log_dir=None, log_format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=20, wandb_dir_in_experiment_dir=False)[source]

Bases: object

Configuration for logging.

Parameters:
  • project (str, default: 'ace') – Name of the project in Weights & Biases.

  • entity (str, default: 'ai2cm') – Name of the entity in Weights & Biases.

  • log_to_screen (bool, default: True) – Whether to log to the screen.

  • log_to_file (bool, default: True) – Whether to log to a file.

  • log_to_wandb (bool, default: True) – Whether to log to Weights & Biases.

  • metrics_log_dir (Optional[str], default: None) – Directory to write scalar metrics to disk as JSONL. If None, disk metric logging is disabled.

  • log_format (str, default: '%(asctime)s - %(name)s - %(levelname)s - %(message)s') – Format of the log messages.

  • level (str | int, default: 20) – Sets the logging level.

  • wandb_dir_in_experiment_dir (bool, default: False) – Whether to create the wandb_dir in the experiment_dir or in local /tmp (default False).

class fme.ace.InferenceDataLoaderConfig(dataset, start_indices, num_data_workers=0, perturbations=None, persistence_names=None)[source]

Bases: object

Configuration for inference data.

This is like the DataLoaderConfig class, but with some additional constraints. During inference, we have only one batch, so the number of samples directly determines the size of that batch.

Parameters:
  • dataset (XarrayDataConfig | MergeNoConcatDatasetConfig) – Configuration to define the dataset.

  • start_indices (InferenceInitialConditionIndices | ExplicitIndices | TimestampList) – Configuration of the indices for initial conditions during inference. This can be a list of timestamps, a list of integer indices, or a slice configuration of the integer indices. Values following the initial condition will still come from the full dataset.

  • num_data_workers (int, default: 0) – Number of parallel workers to use for data loading.

  • perturbations (Optional[SSTPerturbation], default: None) – Configuration for SST perturbations.

  • persistence_names (Optional[Sequence[str]], default: None) – Names of variables for which all returned values will be the same as the initial condition. When evaluating initial condition predictability, set this to forcing variables that should not be updated during inference (e.g. surface temperature).

class fme.ace.InferenceInitialConditionIndices(n_initial_conditions, first=0, interval=1)[source]

Bases: object

Configuration of the indices for initial conditions during inference.

Parameters:
  • n_initial_conditions (int) – Number of initial conditions to use.

  • first (int, default: 0) – Index of the first initial condition.

  • interval (int, default: 1) – Interval between initial conditions.

class fme.ace.ExplicitIndices(list)[source]

Bases: object

Configure indices providing them explicitly.

Parameters:

list (Sequence[int]) – List of integer indices.

class fme.ace.TimestampList(times, timestamp_format='%Y-%m-%dT%H:%M:%S')[source]

Bases: object

Configuration for a list of timestamps.

Parameters:
  • times (Sequence[str]) – List of timestamps.

  • timestamp_format (str, default: '%Y-%m-%dT%H:%M:%S') – Format of the timestamps.

class fme.ace.XarrayDataConfig(data_path, file_pattern='*.nc', n_repeats=1, engine='netcdf4', spatial_dimensions='latlon', subset=<factory>, infer_timestep=True, dtype='float32', overwrite=<factory>, fill_nans=None, isel=<factory>, labels=None)[source]

Bases: DatasetConfigABC

Parameters:
  • data_path (str) – Path to the data.

  • file_pattern (str, default: '*.nc') – Glob pattern to match files in the data_path.

  • n_repeats (int, default: 1) – Number of times to repeat the dataset (in time). It is up to the user to ensure that the input dataset to repeat results in data that is reasonably continuous across repetitions.

  • engine (Literal['netcdf4', 'h5netcdf', 'zarr'], default: 'netcdf4') – Backend used in xarray.open_dataset call.

  • spatial_dimensions (Literal['healpix', 'latlon'], default: 'latlon') – Specifies the spatial dimensions for the grid, default is lat/lon. If ‘latlon’, it is assumed that the last two dimensions are latitude and longitude, respectively. If ‘healpix’, it is assumed that the last three dimensions are face, height, and width, respectively.

  • subset (Slice | TimeSlice | RepeatedInterval, default: <factory>) – Slice defining a subset of the XarrayDataset to load. This can either be a Slice of integer indices or a TimeSlice of timestamps. This feature is applied directly to the dataset samples. For example, if the file(s) have the time coordinate (t0, t1, t2, t3) and requirements.n_timesteps=2, then subset=Slice(stop=2) will provide two samples: (t0, t1), (t1, t2).

  • infer_timestep (bool, default: True) – Whether to infer the timestep from the provided data. This should be set to True (the default) for ACE training. It may be useful to toggle this to False for applications like downscaling, which do not depend on the timestep of the data and therefore lack the additional requirement that the data be ordered and evenly spaced in time. It must be set to True if n_repeats > 1 in order to be able to infer the full time coordinate.

  • dtype (Optional[str], default: 'float32') – Data type to cast the data to. If None, no casting is done. It is required that ‘torch.{dtype}’ is a valid dtype.

  • overwrite (OverwriteConfig, default: <factory>) – Optional OverwriteConfig to overwrite loaded field values.

  • fill_nans (Optional[FillNaNsConfig], default: None) – Optional FillNaNsConfig to fill NaNs with a constant value.

  • isel (Mapping[str, Slice | int], default: <factory>) – Optional xarray isel arguments to be passed to the dataset. Will raise ValueError if time is included here, since the subset argument is used specifically for selecting times. Horizontal dimensions are also not currently supported.

  • labels (Optional[list[str]], default: None) – Optional list of labels to be returned with the data.

Examples

If data is stored in a directory with multiple netCDF files which can be concatenated along the time dimension, use:

>>> fme.ace.XarrayDataConfig(data_path="/some/directory", file_pattern="*.nc") 

If data is stored in a single zarr store at /some/directory/dataset.zarr, use:

>>> fme.ace.XarrayDataConfig(
...     data_path="/some/directory",
...     file_pattern="dataset.zarr",
...     engine="zarr"
... ) 
class fme.ace.DataWriterConfig(save_prediction_files=True, save_monthly_files=True, names=None, time_coarsen=None, files=None)[source]

Bases: object

Configuration for inference data writers.

Parameters:
  • save_prediction_files (bool, default: True) – Whether to enable writing of netCDF files containing the predictions and target values.

  • save_monthly_files (bool, default: True) – Whether to enable writing of netCDF files containing the monthly predictions and target values.

  • names (Optional[Sequence[str]], default: None) – Names of variables to save in the prediction and monthly netCDF files.

  • time_coarsen (Optional[TimeCoarsenConfig], default: None) – Configuration for time coarsening of written outputs to the raw data writer.

  • files (Optional[list[FileWriterConfig]], default: None) – Configuration for a sequence of individual data writers. Each data writer must have a unique label to avoid filename collisions.

class fme.ace.InferenceEvaluatorAggregatorConfig(mean_denorm=<factory>, mean_norm=<factory>, step_means=<factory>, ensembles=<factory>, power_spectrum=<factory>, zonal_mean=<factory>, time_mean_denorm=<factory>, time_mean_norm=<factory>, video=<factory>, histogram=<factory>, seasonal=<factory>, annual=<factory>, enso_index=<factory>, enso_coefficient=<factory>, ipo_index=<factory>, monthly_reference_data=None, time_mean_reference_data=None)[source]

Bases: object

Configuration for inference evaluator aggregator.

Each metric is a named field with its own typed configuration and an enabled flag. Defaults match the standard metric set: metrics that are always desired are enabled, while optional ones (histogram, video, seasonal) are disabled.

Metrics whose runtime requirements are not met (e.g. enso_index on a non-lat/lon grid) are skipped with a warning when strict is False (the default for built-in metrics), or raise an error when strict is True (the default for user-enabled metrics like histogram, video, seasonal).

Parameters:
  • mean_denorm (MeanMetricConfig, default: <factory>) – Global-mean time-series metrics on denormalized data.

  • mean_norm (MeanMetricConfig, default: <factory>) – Global-mean time-series metrics on normalized data.

  • step_means (list[StepMeanMetricConfig], default: <factory>) – Per-step snapshot metrics. Defaults to step-20 denorm and norm.

  • ensembles (list[EnsembleMetricConfig], default: <factory>) – Ensemble spread metrics. Defaults to step-20. Silently skipped when n_ensemble <= 1.

  • power_spectrum (PowerSpectrumMetricConfig, default: <factory>) – Spherical power spectrum metrics.

  • zonal_mean (ZonalMeanMetricConfig, default: <factory>) – Zonal-mean image metrics.

  • time_mean_denorm (TimeMeanMetricConfig, default: <factory>) – Time-mean metrics on denormalized data.

  • time_mean_norm (TimeMeanMetricConfig, default: <factory>) – Time-mean metrics on normalized data.

  • video (VideoMetricConfig, default: <factory>) – Video (animated map) metrics. Disabled by default.

  • histogram (HistogramMetricConfig, default: <factory>) – Distribution histogram metrics. Disabled by default.

  • seasonal (SeasonalMetricConfig, default: <factory>) – Seasonal-mean metrics. Disabled by default.

  • annual (AnnualMetricConfig, default: <factory>) – Annual-mean metrics.

  • enso_index (EnsoIndexMetricConfig, default: <factory>) – ENSO index metrics.

  • enso_coefficient (EnsoCoefficientMetricConfig, default: <factory>) – ENSO regression coefficient metrics.

  • ipo_index (IpoIndexMetricConfig, default: <factory>) – Interdecadal Pacific Oscillation index metrics.

  • monthly_reference_data (Optional[str], default: None) – Path to monthly reference data to compare against.

  • time_mean_reference_data (Optional[str], default: None) – Path to reference time means to compare against.

Default aggregator configuration

The default aggregator configuration (InferenceEvaluatorAggregatorConfig with no arguments) produces the following YAML. You only need to include fields you want to override.

Default Aggregator Configuration
aggregator:
  mean_denorm:
    variables: null
    name: mean
    target: denorm
    enabled: true
    strict: false
  mean_norm:
    variables: null
    name: mean_norm
    target: norm
    enabled: true
    strict: false
  step_means:
  - step: 20
    variables: null
    name: mean_step_20
    target: denorm
    channel_mean_names: null
    enabled: true
    strict: false
  - step: 20
    variables: null
    name: mean_step_20_norm
    target: norm
    channel_mean_names: null
    enabled: true
    strict: false
  ensembles:
  - step: 20
    name: ensemble_step_20
    log_mean_maps: false
    enabled: true
    strict: false
    target: denorm
    channel_mean_names: null
  power_spectrum:
    variables: null
    name: power_spectrum
    enabled: true
    strict: false
    report_directional_bias: true
    plot_variables: null
  zonal_mean:
    variables: null
    name: zonal_mean
    zonal_mean_max_size: 4096
    enabled: true
    strict: false
  time_mean_denorm:
    variables: null
    name: time_mean
    target: denorm
    reference_data: null
    channel_mean_names: null
    enabled: true
    strict: false
  time_mean_norm:
    variables: null
    name: time_mean_norm
    target: norm
    reference_data: null
    channel_mean_names: null
    enabled: true
    strict: false
  video:
    variables: null
    name: video
    enable_extended_videos: false
    enabled: false
    strict: true
  histogram:
    variables: null
    name: histogram
    enabled: false
    strict: true
    percentile_variables: null
  seasonal:
    variables: null
    name: seasonal
    enabled: false
    strict: true
  annual:
    variables: null
    name: annual
    reference_data: null
    enabled: true
    strict: false
  enso_index:
    name: enso_index
    enabled: true
    strict: false
  enso_coefficient:
    name: enso_coefficient
    enabled: true
    strict: false
  ipo_index:
    name: ipo_index
    enabled: true
    strict: false
  monthly_reference_data: null
  time_mean_reference_data: null

Metric configurations

The named fields on InferenceEvaluatorAggregatorConfig accept the following typed entries. Each entry has an enabled field that can be set to true or false to enable or disable a metric. Most metrics default to enabled: true; however, video, histogram, and seasonal default to enabled: false as shown in the default aggregator configuration above.

class fme.ace.MeanMetricConfig(variables=None, name=None, target='denorm', enabled=True, strict=False)[source]
Parameters:
  • variables (list[str] | None) –

  • name (str | None) –

  • target (Literal['denorm', 'norm']) –

  • enabled (bool) –

  • strict (bool) –

class fme.ace.StepMeanMetricConfig(step, variables=None, name=None, target='denorm', channel_mean_names=None, enabled=True, strict=False)[source]
Parameters:
  • step (int) –

  • variables (list[str] | None) –

  • name (str | None) –

  • target (Literal['denorm', 'norm']) –

  • channel_mean_names (list[str] | None) –

  • enabled (bool) –

  • strict (bool) –

class fme.ace.PowerSpectrumMetricConfig(variables=None, name='power_spectrum', enabled=True, strict=False, report_directional_bias=True, plot_variables=None)[source]
Parameters:
  • variables (Optional[list[str]], default: None) – when set, filter the aggregator to these variables only (affects every output — scalar metrics and the per-variable spectrum-pair plots).

  • name (str, default: 'power_spectrum') – log prefix and wandb key prefix.

  • enabled (bool, default: True) – master toggle for the metric.

  • strict (bool, default: False) – raise if the metric can’t be built (e.g. wrong grid).

  • report_directional_bias (bool, default: True) – when False, drop the positive_norm_bias and negative_norm_bias scalar metrics. mean_abs_norm_bias is unaffected (and is the directional pair’s redundant summary). Defaults to True for backwards compatibility.

  • plot_variables (Optional[list[str]], default: None) – when set, restrict the per-variable spectrum-pair plot to these variable names — scalar metrics are still emitted for every variable that passed variables. Use to keep the cheap scalar comparisons cohort-wide while limiting the expensive per-variable plot output to a small reference list. Defaults to None (plot everything that passed variables, current behaviour).

class fme.ace.ZonalMeanMetricConfig(variables=None, name='zonal_mean', zonal_mean_max_size=4096, enabled=True, strict=False)[source]
Parameters:
  • variables (list[str] | None) –

  • name (str) –

  • zonal_mean_max_size (int) –

  • enabled (bool) –

  • strict (bool) –

class fme.ace.TimeMeanMetricConfig(variables=None, name=None, target='denorm', reference_data=None, channel_mean_names=None, enabled=True, strict=False)[source]
Parameters:
  • variables (list[str] | None) –

  • name (str | None) –

  • target (Literal['denorm', 'norm']) –

  • reference_data (str | None) –

  • channel_mean_names (list[str] | None) –

  • enabled (bool) –

  • strict (bool) –

class fme.ace.HistogramMetricConfig(variables=None, name='histogram', enabled=False, strict=True, percentile_variables=None)[source]
Parameters:
  • variables (Optional[list[str]], default: None) – when set, filter the aggregator to these variables only — no histogram (plot or percentile) is emitted for variables not in this list.

  • name (str, default: 'histogram') – log prefix and wandb key prefix.

  • enabled (bool, default: False) – master toggle for the metric.

  • strict (bool, default: True) – raise if the metric can’t be built.

  • percentile_variables (Optional[list[str]], default: None) – when set, only these variables get the 99.9999th-percentile (and any other configured percentile) scalar metrics emitted. The histogram plot is still emitted for every variable that passed variables. Defaults to None (emit percentile keys for every variable that passed variables — current behaviour). Use to restrict the noisy tail-percentile keys to a small list (e.g. precipitation only) while keeping the histogram plot cohort-wide.

class fme.ace.VideoMetricConfig(variables=None, name='video', enable_extended_videos=False, enabled=False, strict=True)[source]
Parameters:
  • variables (list[str] | None) –

  • name (str) –

  • enable_extended_videos (bool) –

  • enabled (bool) –

  • strict (bool) –

class fme.ace.SeasonalMetricConfig(variables=None, name='seasonal', enabled=False, strict=True)[source]
Parameters:
class fme.ace.AnnualMetricConfig(variables=None, name='annual', reference_data=None, enabled=True, strict=False)[source]
Parameters:
  • variables (list[str] | None) –

  • name (str) –

  • reference_data (str | None) –

  • enabled (bool) –

  • strict (bool) –

class fme.ace.EnsoIndexMetricConfig(name='enso_index', enabled=True, strict=False)[source]
Parameters:
class fme.ace.EnsoCoefficientMetricConfig(name='enso_coefficient', enabled=True, strict=False)[source]
Parameters:
class fme.ace.IpoIndexMetricConfig(name='ipo_index', enabled=True, strict=False)[source]
Parameters:
class fme.ace.EnsembleMetricConfig(step=20, name=None, log_mean_maps=False, enabled=True, strict=False, target='denorm', channel_mean_names=None)[source]

Configuration for an ensemble metric (CRPS, SSR bias, ensemble-mean RMSE) at a specific forward step.

Parameters:
  • step (int) –

  • name (str | None) –

  • log_mean_maps (bool) –

  • enabled (bool) –

  • strict (bool) –

  • target (Literal['norm', 'denorm']) –

  • channel_mean_names (list[str] | None) –

step

Forward step at which to compute the metric.

name

Name to use for the logged metric. Defaults to ensemble_step_{step} for target="denorm" and ensemble_step_{step}_norm for target="norm".

log_mean_maps

Whether to log per-variable mean maps.

enabled

Whether the metric is enabled.

strict

Whether to raise if the metric cannot be built.

target

Whether to compute metrics on normalized (“norm”) or denormalized (“denorm”) data. channel_mean is only logged when target="norm", since averaging metrics across variables with different physical units is not meaningful.

channel_mean_names

Names of variables to include in the channel-mean metric. If None, falls back to the aggregator-level value passed via the build context, and finally to all variables present in the data when that is also None. Names not present in the data raise KeyError. Ignored when target="denorm".

class fme.ace.StepperOverrideConfig(ocean='keep', multi_call='keep', derived_forcings='keep', prescribed_prognostic_names='keep')[source]

Bases: object

Configuration for overriding stepper configuration options.

The default value for each parameter is "keep", which denotes that the serialized stepper’s configuration will not be modified when loaded. Passing other values will override the configuration of the loaded stepper.

Parameters:
  • ocean (Union[Literal['keep'], OceanConfig, None], default: 'keep') – Ocean configuration to override that used in producing a serialized stepper.

  • multi_call (Union[Literal['keep'], MultiCallConfig, None], default: 'keep') – MultiCall configuration to override that used in producing a serialized stepper.

  • derived_forcings (Union[Literal['keep'], DerivedForcingsConfig], default: 'keep') – Derived forcings configuration to override that used in producing a serialized stepper.

  • prescribed_prognostic_names (Union[Literal['keep'], list[str]], default: 'keep') – List of prognostic variable names to overwrite from forcing at each step during inference.