Training Config

The following is an example configuration for running training while evaluating against target data. While you can use absolute paths in the config yamls (we encourage it!), the example uses paths relative to the directory you run the command. The example is based on training with our full dataset (containing data from 10 ensemble runs) and assumes you are running in a directory structure like:

.
├── ckpt.tar
└── validation
    ├── data1.nc  # files can have any name, but must sort into time-sequential order
    ├── data2.nc  # can have any number of netCDF files
    └── ...
└── traindata
      ├── ic_0001
      │   ├── data1.nc  # files can have any name, but must sort into time-sequential order
      │   ├── data2.nc  # can have any number of netCDF files
      │   └── ...
      ├── ic_0002
      │   └── ...
      ├── ...
      └── ic_0010
          └── ...

You can modify the example to run on fewer ensemble members by removing entries, or change the data paths as you wish. The .nc files correspond to data files like 2021010100.nc in the Zenodo repository, while ckpt.tar corresponds to a file like ace_ckpt.tar in that repository.

Example YAML Configuration
experiment_dir: train_output
save_checkpoint: true
validate_using_ema: true
max_epochs: 80
n_forward_steps: 1
inference:
  n_forward_steps: 7300  # 5 years
  forward_steps_in_memory: 50
  loader:
    start_indices:
      first: 0
      n_initial_conditions: 4
      interval: 1460  # 1 year
    dataset:
      data_path: validation
    num_data_workers: 2
logging:
  log_to_screen: true
  log_to_wandb: false
  log_to_file: true
  project: fourcastnet
  entity: ai2cm
train_loader:
  batch_size: 4
  num_data_workers: 8
  dataset:
    - data_path: traindata/ic_0001
    - data_path: traindata/ic_0002
    - data_path: traindata/ic_0003
    - data_path: traindata/ic_0004
    - data_path: traindata/ic_0005
    - data_path: traindata/ic_0006
    - data_path: traindata/ic_0007
    - data_path: traindata/ic_0008
    - data_path: traindata/ic_0009
    - data_path: traindata/ic_0010
validation_loader:
  batch_size: 16
  num_data_workers: 2
  dataset:
    - data_path: validation
      subset:
        step: 5
optimization:
  enable_automatic_mixed_precision: false
  lr: 0.0001
  optimizer_type: Adam  # can switch to FusedAdam if using GPU
stepper:
  builder:
    type: SphericalFourierNeuralOperatorNet
    config:
      embed_dim: 384
      filter_type: linear
      hard_thresholding_fraction: 1.0
      use_mlp: true
      normalization_layer: instance_norm
      num_layers: 8
      operator_type: dhconv
      scale_factor: 1
      separable: false
  loss:
    type: MSE
  normalization:
    global_means_path: centering.nc
    global_stds_path: scaling.nc
  ocean:
    surface_temperature_name: surface_temperature
    ocean_fraction_name: ocean_fraction
  corrector:
    conserve_dry_air: true
    moisture_budget_correction: advection_and_precipitation
  in_names:
  - land_fraction
  - ocean_fraction
  - sea_ice_fraction
  - DSWRFtoa
  - HGTsfc
  - PRESsfc
  - surface_temperature
  - air_temperature_0 # _0 denotes the top most layer of the atmosphere
  - air_temperature_1
  - air_temperature_2
  - air_temperature_3
  - air_temperature_4
  - air_temperature_5
  - air_temperature_6
  - air_temperature_7
  - specific_total_water_0
  - specific_total_water_1
  - specific_total_water_2
  - specific_total_water_3
  - specific_total_water_4
  - specific_total_water_5
  - specific_total_water_6
  - specific_total_water_7
  - eastward_wind_0
  - eastward_wind_1
  - eastward_wind_2
  - eastward_wind_3
  - eastward_wind_4
  - eastward_wind_5
  - eastward_wind_6
  - eastward_wind_7
  - northward_wind_0
  - northward_wind_1
  - northward_wind_2
  - northward_wind_3
  - northward_wind_4
  - northward_wind_5
  - northward_wind_6
  - northward_wind_7
  out_names:
  - PRESsfc
  - surface_temperature
  - air_temperature_0
  - air_temperature_1
  - air_temperature_2
  - air_temperature_3
  - air_temperature_4
  - air_temperature_5
  - air_temperature_6
  - air_temperature_7
  - specific_total_water_0
  - specific_total_water_1
  - specific_total_water_2
  - specific_total_water_3
  - specific_total_water_4
  - specific_total_water_5
  - specific_total_water_6
  - specific_total_water_7
  - eastward_wind_0
  - eastward_wind_1
  - eastward_wind_2
  - eastward_wind_3
  - eastward_wind_4
  - eastward_wind_5
  - eastward_wind_6
  - eastward_wind_7
  - northward_wind_0
  - northward_wind_1
  - northward_wind_2
  - northward_wind_3
  - northward_wind_4
  - northward_wind_5
  - northward_wind_6
  - northward_wind_7
  - LHTFLsfc
  - SHTFLsfc
  - PRATEsfc
  - ULWRFsfc
  - ULWRFtoa
  - DLWRFsfc
  - DSWRFsfc
  - USWRFsfc
  - USWRFtoa
  - tendency_of_total_water_path_due_to_advection

We use the Builder pattern to load this configuration into a multi-level dataclass structure. The configuration is divided into several sub-configurations, each with its own dataclass. The top-level configuration is the fme.ace.TrainConfig class.

class fme.ace.TrainConfig(train_loader: ~fme.core.data_loading.config.DataLoaderConfig, validation_loader: ~fme.core.data_loading.config.DataLoaderConfig, stepper: ~fme.core.stepper.SingleModuleStepperConfig | ~fme.core.stepper.ExistingStepperConfig, optimization: ~fme.core.optimization.OptimizationConfig, logging: ~fme.core.logging_utils.LoggingConfig, max_epochs: int, save_checkpoint: bool, experiment_dir: str, inference: ~fme.ace.train.train_config.InlineInferenceConfig, n_forward_steps: int, copy_weights_after_batch: ~fme.core.weight_ops.CopyWeightsConfig = <factory>, ema: ~fme.core.ema.EMAConfig = <factory>, validate_using_ema: bool = False, checkpoint_save_epochs: ~fme.core.data_loading.config.Slice | None = None, ema_checkpoint_save_epochs: ~fme.core.data_loading.config.Slice | None = None, log_train_every_n_batches: int = 100, segment_epochs: int | None = None)[source]

Bases: object

Configuration for training a model.

train_loader

Configuration for the training data loader.

Type:

fme.core.data_loading.config.DataLoaderConfig

validation_loader

Configuration for the validation data loader.

Type:

fme.core.data_loading.config.DataLoaderConfig

stepper

Configuration for the stepper.

Type:

fme.core.stepper.SingleModuleStepperConfig | fme.core.stepper.ExistingStepperConfig

optimization

Configuration for the optimization.

Type:

fme.core.optimization.OptimizationConfig

logging

Configuration for logging.

Type:

fme.core.logging_utils.LoggingConfig

max_epochs

Total number of epochs to train for.

Type:

int

save_checkpoint

Whether to save checkpoints.

Type:

bool

experiment_dir

Directory where checkpoints and logs are saved.

Type:

str

inference

Configuration for inline inference.

Type:

fme.ace.train.train_config.InlineInferenceConfig

n_forward_steps

Number of forward steps to take gradient over.

Type:

int

copy_weights_after_batch

Configuration for copying weights from the base model to the training model after each batch.

Type:

fme.core.weight_ops.CopyWeightsConfig

ema

Configuration for exponential moving average of model weights.

Type:

fme.core.ema.EMAConfig

validate_using_ema

Whether to validate and perform inference using the EMA model.

Type:

bool

checkpoint_save_epochs

How often to save epoch-based checkpoints, if save_checkpoint is True. If None, checkpoints are only saved for the most recent epoch (and the best epochs if validate_using_ema == False).

Type:

fme.core.data_loading.config.Slice | None

ema_checkpoint_save_epochs

How often to save epoch-based EMA checkpoints, if save_checkpoint is True. If None, EMA checkpoints are only saved for the most recent epoch (and the best epochs if validate_using_ema == True).

Type:

fme.core.data_loading.config.Slice | None

log_train_every_n_batches

How often to log batch_loss during training.

Type:

int

segment_epochs

Exit after training for at most this many epochs in current job, without exceeding max_epochs. Use this if training must be run in segments, e.g. due to wall clock limit.

Type:

int | None

The top-level sub-configurations are:

class fme.ace.DataLoaderConfig(dataset: Sequence[XarrayDataConfig], batch_size: int, num_data_workers: int, prefetch_factor: int | None = None, strict_ensemble: bool = True)[source]

Bases: object

dataset

A sequence of configurations each defining a dataset to be loaded. This sequence of datasets will be concatenated.

Type:

Sequence[fme.core.data_loading.config.XarrayDataConfig]

batch_size

Number of samples per batch.

Type:

int

num_data_workers

Number of parallel workers to use for data loading.

Type:

int

prefetch_factor

how many batches a single data worker will attempt to hold in host memory at a given time.

Type:

int | None

strict_ensemble

Whether to enforce that the ensemble members have the same dimensions and coordinates.

Type:

bool

class fme.ace.SingleModuleStepperConfig(builder: ~fme.core.registry.ModuleSelector, in_names: ~typing.List[str], out_names: ~typing.List[str], normalization: ~fme.core.normalizer.NormalizationConfig | ~fme.core.normalizer.FromStateNormalizer, parameter_init: ~fme.core.parameter_init.ParameterInitializationConfig = <factory>, ocean: ~fme.core.ocean.OceanConfig | None = None, loss: ~fme.core.loss.WeightedMappingLossConfig = <factory>, corrector: ~fme.core.corrector.CorrectorConfig = <factory>, next_step_forcing_names: ~typing.List[str] = <factory>, loss_normalization: ~fme.core.normalizer.NormalizationConfig | ~fme.core.normalizer.FromStateNormalizer | None = None, residual_normalization: ~fme.core.normalizer.NormalizationConfig | ~fme.core.normalizer.FromStateNormalizer | None = None)[source]

Bases: object

Configuration for a single module stepper.

builder

The module builder.

Type:

fme.core.registry.ModuleSelector

in_names

Names of input variables.

Type:

List[str]

out_names

Names of output variables.

Type:

List[str]

normalization

The normalization configuration.

Type:

fme.core.normalizer.NormalizationConfig | fme.core.normalizer.FromStateNormalizer

parameter_init

The parameter initialization configuration.

Type:

fme.core.parameter_init.ParameterInitializationConfig

ocean

The ocean configuration.

Type:

fme.core.ocean.OceanConfig | None

loss

The loss configuration.

Type:

fme.core.loss.WeightedMappingLossConfig

corrector

The corrector configuration.

Type:

fme.core.corrector.CorrectorConfig

next_step_forcing_names

Names of forcing variables for the next timestep.

Type:

List[str]

loss_normalization

The normalization configuration for the loss.

Type:

fme.core.normalizer.NormalizationConfig | fme.core.normalizer.FromStateNormalizer | None

residual_normalization

Optional alternative to configure loss normalization. If provided, it will be used for all prognostic variables in loss scaling.

Type:

fme.core.normalizer.NormalizationConfig | fme.core.normalizer.FromStateNormalizer | None

class fme.ace.ExistingStepperConfig(checkpoint_path: str)[source]

Bases: object

Configuration for an existing stepper. This is only designed to point to a serialized stepper checkpoint for loading, e.g., in the case of training resumption.

checkpoint_path

The path to the serialized checkpoint.

Type:

str

class fme.ace.OptimizationConfig(optimizer_type: ~typing.Literal['Adam', 'FusedAdam'] = 'Adam', lr: float = 0.001, kwargs: ~typing.Mapping[str, ~typing.Any] = <factory>, enable_automatic_mixed_precision: bool = False, scheduler: ~fme.core.scheduler.SchedulerConfig = <factory>)[source]

Bases: object

Configuration for optimization.

optimizer_type

The type of optimizer to use.

Type:

Literal[‘Adam’, ‘FusedAdam’]

lr

The learning rate.

Type:

float

kwargs

Additional keyword arguments to pass to the optimizer.

Type:

Mapping[str, Any]

enable_automatic_mixed_precision

Whether to use automatic mixed precision.

Type:

bool

scheduler

The type of scheduler to use. If none is given, no scheduler will be used.

Type:

fme.core.scheduler.SchedulerConfig

class fme.ace.LoggingConfig(project: str = 'ace', entity: str = 'ai2cm', log_to_screen: bool = True, log_to_file: bool = True, log_to_wandb: bool = True, log_format: str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s', level: str | int = 20)[source]

Bases: object

Configuration for logging.

project

name of the project in Weights & Biases

Type:

str

entity

name of the entity in Weights & Biases

Type:

str

log_to_screen

whether to log to the screen

Type:

bool

log_to_file

whether to log to a file

Type:

bool

log_to_wandb

whether to log to Weights & Biases

Type:

bool

log_format

format of the log messages

Type:

str

class fme.ace.InlineInferenceConfig(loader: ~fme.core.data_loading.inference.InferenceDataLoaderConfig, n_forward_steps: int = 2, forward_steps_in_memory: int = 2, epochs: ~fme.core.data_loading.config.Slice = Slice(start=0, stop=None, step=1), aggregator: ~fme.core.aggregator.inference.main.InferenceEvaluatorAggregatorConfig = <factory>)[source]

Bases: object

loader

configuration for the data loader used during inference

Type:

fme.core.data_loading.inference.InferenceDataLoaderConfig

n_forward_steps

number of forward steps to take

Type:

int

forward_steps_in_memory

number of forward steps to take before re-reading data from disk

Type:

int

epochs

epochs on which to run inference, where the first epoch is defined as epoch 0 (unlike in logs which show epochs as starting from 1). By default runs inference every epoch.

Type:

fme.core.data_loading.config.Slice

aggregator

configuration of inline inference aggregator.

Type:

fme.core.aggregator.inference.main.InferenceEvaluatorAggregatorConfig

class fme.ace.CopyWeightsConfig(include: ~typing.List[str] = <factory>, exclude: ~typing.List[str] = <factory>)[source]

Bases: object

Configuration for copying weights from a base model to a target model.

Used during training to overwrite weights after every batch of data, to have the effect of “freezing” the overwritten weights. When the target parameters have longer dimensions than the base model, only the initial slice is overwritten.

This is used to achieve an effect of freezing model parameters that can freeze a subset of each weight that comes from a smaller base weight. This is less efficient than true parameter freezing, but layer freezing is all-or-nothing for each parameter.

All parameters must be covered by either the include or exclude list, but not both.

include

list of wildcard patterns to overwrite

Type:

List[str]

exclude

list of wildcard patterns to exclude from overwriting

Type:

List[str]

class fme.ace.EMAConfig(decay: float = 0.9999)[source]

Bases: object

Configuration for exponential moving average of model weights.

decay

decay rate for the moving average

Type:

float

class fme.ace.Slice(start: int | None = None, stop: int | None = None, step: int | None = None)[source]

Bases: object

Configuration of a python slice built-in.

Required because slice cannot be initialized directly by dacite.

start

Start index of the slice.

Type:

int | None

stop

Stop index of the slice.

Type:

int | None

step

Step of the slice.

Type:

int | None