Training Config

The following is an example configuration for running training while evaluating against target data. While you can use absolute paths in the config yamls (we encourage it!), the example uses paths relative to the directory you run the command. The example is based on training with our full dataset (containing data from 10 ensemble runs) and assumes you are running in a directory structure like:

.
├── ckpt.tar
└── validation
    ├── data1.nc  # files can have any name, but must sort into time-sequential order
    ├── data2.nc  # can have any number of netCDF files
    └── ...
└── traindata
      ├── ic_0001
      │   ├── data1.nc  # files can have any name, but must sort into time-sequential order
      │   ├── data2.nc  # can have any number of netCDF files
      │   └── ...
      ├── ic_0002
      │   └── ...
      ├── ...
      └── ic_0010
          └── ...

You can modify the example to run on fewer ensemble members by removing entries, or change the data paths as you wish. The .nc files correspond to data files like 2021010100.nc in the Zenodo repository, while ckpt.tar corresponds to a file like ace_ckpt.tar in that repository.

Example YAML Configuration
experiment_dir: train_output
save_checkpoint: true
validate_using_ema: true
max_epochs: 80
n_forward_steps: 1
inference:
  n_forward_steps: 7300  # 5 years
  forward_steps_in_memory: 50
  loader:
    start_indices:
      first: 0
      n_initial_conditions: 4
      interval: 1460  # 1 year
    dataset:
      data_path: validation
    num_data_workers: 2
logging:
  log_to_screen: true
  log_to_wandb: false
  log_to_file: true
  project: fourcastnet
  entity: ai2cm
train_loader:
  batch_size: 4
  num_data_workers: 8
  dataset:
    - data_path: traindata/ic_0001
    - data_path: traindata/ic_0002
    - data_path: traindata/ic_0003
    - data_path: traindata/ic_0004
    - data_path: traindata/ic_0005
    - data_path: traindata/ic_0006
    - data_path: traindata/ic_0007
    - data_path: traindata/ic_0008
    - data_path: traindata/ic_0009
    - data_path: traindata/ic_0010
validation_loader:
  batch_size: 16
  num_data_workers: 2
  dataset:
    - data_path: validation
      subset:
        step: 5
optimization:
  enable_automatic_mixed_precision: false
  lr: 0.0001
  optimizer_type: Adam  # can switch to FusedAdam if using GPU
stepper:
  builder:
    type: SphericalFourierNeuralOperatorNet
    config:
      embed_dim: 384
      filter_type: linear
      hard_thresholding_fraction: 1.0
      use_mlp: true
      normalization_layer: instance_norm
      num_layers: 8
      operator_type: dhconv
      scale_factor: 1
      separable: false
  loss:
    type: MSE
  normalization:
    global_means_path: centering.nc
    global_stds_path: scaling.nc
  ocean:
    surface_temperature_name: surface_temperature
    ocean_fraction_name: ocean_fraction
  corrector:
    conserve_dry_air: true
    moisture_budget_correction: advection_and_precipitation
  in_names:
  - land_fraction
  - ocean_fraction
  - sea_ice_fraction
  - DSWRFtoa
  - HGTsfc
  - PRESsfc
  - surface_temperature
  - air_temperature_0 # _0 denotes the top most layer of the atmosphere
  - air_temperature_1
  - air_temperature_2
  - air_temperature_3
  - air_temperature_4
  - air_temperature_5
  - air_temperature_6
  - air_temperature_7
  - specific_total_water_0
  - specific_total_water_1
  - specific_total_water_2
  - specific_total_water_3
  - specific_total_water_4
  - specific_total_water_5
  - specific_total_water_6
  - specific_total_water_7
  - eastward_wind_0
  - eastward_wind_1
  - eastward_wind_2
  - eastward_wind_3
  - eastward_wind_4
  - eastward_wind_5
  - eastward_wind_6
  - eastward_wind_7
  - northward_wind_0
  - northward_wind_1
  - northward_wind_2
  - northward_wind_3
  - northward_wind_4
  - northward_wind_5
  - northward_wind_6
  - northward_wind_7
  out_names:
  - PRESsfc
  - surface_temperature
  - air_temperature_0
  - air_temperature_1
  - air_temperature_2
  - air_temperature_3
  - air_temperature_4
  - air_temperature_5
  - air_temperature_6
  - air_temperature_7
  - specific_total_water_0
  - specific_total_water_1
  - specific_total_water_2
  - specific_total_water_3
  - specific_total_water_4
  - specific_total_water_5
  - specific_total_water_6
  - specific_total_water_7
  - eastward_wind_0
  - eastward_wind_1
  - eastward_wind_2
  - eastward_wind_3
  - eastward_wind_4
  - eastward_wind_5
  - eastward_wind_6
  - eastward_wind_7
  - northward_wind_0
  - northward_wind_1
  - northward_wind_2
  - northward_wind_3
  - northward_wind_4
  - northward_wind_5
  - northward_wind_6
  - northward_wind_7
  - LHTFLsfc
  - SHTFLsfc
  - PRATEsfc
  - ULWRFsfc
  - ULWRFtoa
  - DLWRFsfc
  - DSWRFsfc
  - USWRFsfc
  - USWRFtoa
  - tendency_of_total_water_path_due_to_advection

We use the Builder pattern to load this configuration into a multi-level dataclass structure. The configuration is divided into several sub-configurations, each with its own dataclass. The top-level configuration is the fme.ace.TrainConfig class.

class fme.ace.TrainConfig(train_loader, validation_loader, stepper, optimization, logging, max_epochs, save_checkpoint, experiment_dir, inference, n_forward_steps, copy_weights_after_batch=<factory>, ema=<factory>, validate_using_ema=False, checkpoint_save_epochs=None, ema_checkpoint_save_epochs=None, log_train_every_n_batches=100, segment_epochs=None)[source]

Bases: object

Configuration for training a model.

Parameters:
  • train_loader (DataLoaderConfig) – Configuration for the training data loader.

  • validation_loader (DataLoaderConfig) – Configuration for the validation data loader.

  • stepper (Union[SingleModuleStepperConfig, ExistingStepperConfig]) – Configuration for the stepper.

  • optimization (OptimizationConfig) – Configuration for the optimization.

  • logging (LoggingConfig) – Configuration for logging.

  • max_epochs (int) – Total number of epochs to train for.

  • save_checkpoint (bool) – Whether to save checkpoints.

  • experiment_dir (str) – Directory where checkpoints and logs are saved.

  • inference (InlineInferenceConfig) – Configuration for inline inference.

  • n_forward_steps (int) – Number of forward steps to take gradient over.

  • copy_weights_after_batch (CopyWeightsConfig, default: <factory>) – Configuration for copying weights from the base model to the training model after each batch.

  • ema (EMAConfig, default: <factory>) – Configuration for exponential moving average of model weights.

  • validate_using_ema (bool, default: False) – Whether to validate and perform inference using the EMA model.

  • checkpoint_save_epochs (Optional[Slice], default: None) – How often to save epoch-based checkpoints, if save_checkpoint is True. If None, checkpoints are only saved for the most recent epoch (and the best epochs if validate_using_ema == False).

  • ema_checkpoint_save_epochs (Optional[Slice], default: None) – How often to save epoch-based EMA checkpoints, if save_checkpoint is True. If None, EMA checkpoints are only saved for the most recent epoch (and the best epochs if validate_using_ema == True).

  • log_train_every_n_batches (int, default: 100) – How often to log batch_loss during training.

  • segment_epochs (Optional[int], default: None) – Exit after training for at most this many epochs in current job, without exceeding max_epochs. Use this if training must be run in segments, e.g. due to wall clock limit.

The top-level sub-configurations are:

class fme.ace.DataLoaderConfig(dataset, batch_size, num_data_workers=0, prefetch_factor=None, strict_ensemble=True)[source]

Bases: object

Parameters:
  • dataset (Sequence[XarrayDataConfig]) – A sequence of configurations each defining a dataset to be loaded. This sequence of datasets will be concatenated.

  • batch_size (int) – Number of samples per batch.

  • num_data_workers (int, default: 0) – Number of parallel workers to use for data loading.

  • prefetch_factor (Optional[int], default: None) – how many batches a single data worker will attempt to hold in host memory at a given time.

  • strict_ensemble (bool, default: True) – Whether to enforce that the datasets to be concatened have the same dimensions and coordinates.

class fme.ace.SingleModuleStepperConfig(builder, in_names, out_names, normalization, parameter_init=<factory>, ocean=None, loss=<factory>, corrector=<factory>, next_step_forcing_names=<factory>, loss_normalization=None, residual_normalization=None)[source]

Bases: object

Configuration for a single module stepper.

Parameters:
class fme.ace.ExistingStepperConfig(checkpoint_path)[source]

Bases: object

Configuration for an existing stepper. This is only designed to point to a serialized stepper checkpoint for loading, e.g., in the case of training resumption.

Parameters:

checkpoint_path (str) – The path to the serialized checkpoint.

class fme.ace.OptimizationConfig(optimizer_type='Adam', lr=0.001, kwargs=<factory>, enable_automatic_mixed_precision=False, scheduler=<factory>)[source]

Bases: object

Configuration for optimization.

Parameters:
  • optimizer_type (Literal['Adam', 'FusedAdam'], default: 'Adam') – The type of optimizer to use.

  • lr (float, default: 0.001) – The learning rate.

  • kwargs (Mapping[str, Any], default: <factory>) – Additional keyword arguments to pass to the optimizer.

  • enable_automatic_mixed_precision (bool, default: False) – Whether to use automatic mixed precision.

  • scheduler (SchedulerConfig, default: <factory>) – The type of scheduler to use. If none is given, no scheduler will be used.

class fme.ace.LoggingConfig(project='ace', entity='ai2cm', log_to_screen=True, log_to_file=True, log_to_wandb=True, log_format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=20)[source]

Bases: object

Configuration for logging.

Parameters:
  • project (str, default: 'ace') – name of the project in Weights & Biases

  • entity (str, default: 'ai2cm') – name of the entity in Weights & Biases

  • log_to_screen (bool, default: True) – whether to log to the screen

  • log_to_file (bool, default: True) – whether to log to a file

  • log_to_wandb (bool, default: True) – whether to log to Weights & Biases

  • log_format (str, default: '%(asctime)s - %(name)s - %(levelname)s - %(message)s') – format of the log messages

  • level (str | int) –

class fme.ace.InlineInferenceConfig(loader, n_forward_steps=2, forward_steps_in_memory=2, epochs=Slice(start=0, stop=None, step=1), aggregator=<factory>)[source]

Bases: object

Parameters:
  • loader (InferenceDataLoaderConfig) – configuration for the data loader used during inference

  • n_forward_steps (int, default: 2) – number of forward steps to take

  • forward_steps_in_memory (int, default: 2) – number of forward steps to take before re-reading data from disk

  • epochs (Slice, default: Slice(start=0, stop=None, step=1)) – epochs on which to run inference, where the first epoch is defined as epoch 0 (unlike in logs which show epochs as starting from 1). By default runs inference every epoch.

  • aggregator (InferenceEvaluatorAggregatorConfig, default: <factory>) – configuration of inline inference aggregator.

class fme.ace.CopyWeightsConfig(include=<factory>, exclude=<factory>)[source]

Bases: object

Configuration for copying weights from a base model to a target model.

Used during training to overwrite weights after every batch of data, to have the effect of “freezing” the overwritten weights. When the target parameters have longer dimensions than the base model, only the initial slice is overwritten.

This is used to achieve an effect of freezing model parameters that can freeze a subset of each weight that comes from a smaller base weight. This is less efficient than true parameter freezing, but layer freezing is all-or-nothing for each parameter.

All parameters must be covered by either the include or exclude list, but not both.

Parameters:
  • include (List[str], default: <factory>) – list of wildcard patterns to overwrite

  • exclude (List[str], default: <factory>) – list of wildcard patterns to exclude from overwriting

class fme.ace.EMAConfig(decay=0.9999)[source]

Bases: object

Configuration for exponential moving average of model weights.

Parameters:

decay (float, default: 0.9999) – decay rate for the moving average

class fme.ace.Slice(start=None, stop=None, step=None)[source]

Bases: object

Configuration of a python slice built-in.

Required because slice cannot be initialized directly by dacite.

Parameters:
  • start (Optional[int], default: None) – Start index of the slice.

  • stop (Optional[int], default: None) – Stop index of the slice.

  • step (Optional[int], default: None) – Step of the slice.