Training Config
The following is an example configuration for running training while evaluating against target data. While you can use absolute paths in the config yamls (we encourage it!), the example uses paths relative to the directory you run the command. The example is based on training with our full dataset (containing data from 10 ensemble runs) and assumes you are running in a directory structure like:
.
├── ckpt.tar
└── validation
├── data1.nc # files can have any name, but must sort into time-sequential order
├── data2.nc # can have any number of netCDF files
└── ...
└── traindata
├── ic_0001
│ ├── data1.nc # files can have any name, but must sort into time-sequential order
│ ├── data2.nc # can have any number of netCDF files
│ └── ...
├── ic_0002
│ └── ...
├── ...
└── ic_0010
└── ...
You can modify the example to run on fewer ensemble members by removing entries, or change the data paths as you wish.
The .nc files correspond to data files like 2021010100.nc in the Zenodo repository, while ckpt.tar corresponds to a file like ace_ckpt.tar in that repository.
experiment_dir: train_output
save_checkpoint: true
validate_using_ema: true
max_epochs: 80
n_forward_steps: 1
inference:
n_forward_steps: 7300 # 5 years
forward_steps_in_memory: 50
loader:
start_indices:
first: 0
n_initial_conditions: 4
interval: 1460 # 1 year
dataset:
data_path: validation
num_data_workers: 2
logging:
log_to_screen: true
log_to_wandb: false
log_to_file: true
project: fourcastnet
entity: ai2cm
train_loader:
batch_size: 4
num_data_workers: 8
dataset:
- data_path: traindata/ic_0001
- data_path: traindata/ic_0002
- data_path: traindata/ic_0003
- data_path: traindata/ic_0004
- data_path: traindata/ic_0005
- data_path: traindata/ic_0006
- data_path: traindata/ic_0007
- data_path: traindata/ic_0008
- data_path: traindata/ic_0009
- data_path: traindata/ic_0010
validation_loader:
batch_size: 16
num_data_workers: 2
dataset:
- data_path: validation
subset:
step: 5
optimization:
enable_automatic_mixed_precision: false
lr: 0.0001
optimizer_type: Adam # can switch to FusedAdam if using GPU
stepper:
builder:
type: SphericalFourierNeuralOperatorNet
config:
embed_dim: 384
filter_type: linear
hard_thresholding_fraction: 1.0
use_mlp: true
normalization_layer: instance_norm
num_layers: 8
operator_type: dhconv
scale_factor: 1
separable: false
loss:
type: MSE
normalization:
global_means_path: centering.nc
global_stds_path: scaling.nc
ocean:
surface_temperature_name: surface_temperature
ocean_fraction_name: ocean_fraction
corrector:
conserve_dry_air: true
moisture_budget_correction: advection_and_precipitation
in_names:
- land_fraction
- ocean_fraction
- sea_ice_fraction
- DSWRFtoa
- HGTsfc
- PRESsfc
- surface_temperature
- air_temperature_0 # _0 denotes the top most layer of the atmosphere
- air_temperature_1
- air_temperature_2
- air_temperature_3
- air_temperature_4
- air_temperature_5
- air_temperature_6
- air_temperature_7
- specific_total_water_0
- specific_total_water_1
- specific_total_water_2
- specific_total_water_3
- specific_total_water_4
- specific_total_water_5
- specific_total_water_6
- specific_total_water_7
- eastward_wind_0
- eastward_wind_1
- eastward_wind_2
- eastward_wind_3
- eastward_wind_4
- eastward_wind_5
- eastward_wind_6
- eastward_wind_7
- northward_wind_0
- northward_wind_1
- northward_wind_2
- northward_wind_3
- northward_wind_4
- northward_wind_5
- northward_wind_6
- northward_wind_7
out_names:
- PRESsfc
- surface_temperature
- air_temperature_0
- air_temperature_1
- air_temperature_2
- air_temperature_3
- air_temperature_4
- air_temperature_5
- air_temperature_6
- air_temperature_7
- specific_total_water_0
- specific_total_water_1
- specific_total_water_2
- specific_total_water_3
- specific_total_water_4
- specific_total_water_5
- specific_total_water_6
- specific_total_water_7
- eastward_wind_0
- eastward_wind_1
- eastward_wind_2
- eastward_wind_3
- eastward_wind_4
- eastward_wind_5
- eastward_wind_6
- eastward_wind_7
- northward_wind_0
- northward_wind_1
- northward_wind_2
- northward_wind_3
- northward_wind_4
- northward_wind_5
- northward_wind_6
- northward_wind_7
- LHTFLsfc
- SHTFLsfc
- PRATEsfc
- ULWRFsfc
- ULWRFtoa
- DLWRFsfc
- DSWRFsfc
- USWRFsfc
- USWRFtoa
- tendency_of_total_water_path_due_to_advection
We use the Builder pattern to load this configuration into a multi-level dataclass structure.
The configuration is divided into several sub-configurations, each with its own dataclass.
The top-level configuration is the fme.ace.TrainConfig class.
- class fme.ace.TrainConfig(train_loader: ~fme.core.data_loading.config.DataLoaderConfig, validation_loader: ~fme.core.data_loading.config.DataLoaderConfig, stepper: ~fme.core.stepper.SingleModuleStepperConfig | ~fme.core.stepper.ExistingStepperConfig, optimization: ~fme.core.optimization.OptimizationConfig, logging: ~fme.core.logging_utils.LoggingConfig, max_epochs: int, save_checkpoint: bool, experiment_dir: str, inference: ~fme.ace.train.train_config.InlineInferenceConfig, n_forward_steps: int, copy_weights_after_batch: ~fme.core.weight_ops.CopyWeightsConfig = <factory>, ema: ~fme.core.ema.EMAConfig = <factory>, validate_using_ema: bool = False, checkpoint_save_epochs: ~fme.core.data_loading.config.Slice | None = None, ema_checkpoint_save_epochs: ~fme.core.data_loading.config.Slice | None = None, log_train_every_n_batches: int = 100, segment_epochs: int | None = None)[source]
Bases:
objectConfiguration for training a model.
- train_loader
Configuration for the training data loader.
- validation_loader
Configuration for the validation data loader.
- stepper
Configuration for the stepper.
- optimization
Configuration for the optimization.
- logging
Configuration for logging.
- max_epochs
Total number of epochs to train for.
- Type:
int
- save_checkpoint
Whether to save checkpoints.
- Type:
bool
- experiment_dir
Directory where checkpoints and logs are saved.
- Type:
str
- inference
Configuration for inline inference.
- n_forward_steps
Number of forward steps to take gradient over.
- Type:
int
- copy_weights_after_batch
Configuration for copying weights from the base model to the training model after each batch.
- ema
Configuration for exponential moving average of model weights.
- Type:
- validate_using_ema
Whether to validate and perform inference using the EMA model.
- Type:
bool
- checkpoint_save_epochs
How often to save epoch-based checkpoints, if save_checkpoint is True. If None, checkpoints are only saved for the most recent epoch (and the best epochs if validate_using_ema == False).
- Type:
- ema_checkpoint_save_epochs
How often to save epoch-based EMA checkpoints, if save_checkpoint is True. If None, EMA checkpoints are only saved for the most recent epoch (and the best epochs if validate_using_ema == True).
- Type:
- log_train_every_n_batches
How often to log batch_loss during training.
- Type:
int
- segment_epochs
Exit after training for at most this many epochs in current job, without exceeding max_epochs. Use this if training must be run in segments, e.g. due to wall clock limit.
- Type:
int | None
The top-level sub-configurations are:
- class fme.ace.DataLoaderConfig(dataset: Sequence[XarrayDataConfig], batch_size: int, num_data_workers: int, prefetch_factor: int | None = None, strict_ensemble: bool = True)[source]
Bases:
object- dataset
A sequence of configurations each defining a dataset to be loaded. This sequence of datasets will be concatenated.
- Type:
- batch_size
Number of samples per batch.
- Type:
int
- num_data_workers
Number of parallel workers to use for data loading.
- Type:
int
- prefetch_factor
how many batches a single data worker will attempt to hold in host memory at a given time.
- Type:
int | None
- strict_ensemble
Whether to enforce that the ensemble members have the same dimensions and coordinates.
- Type:
bool
- class fme.ace.SingleModuleStepperConfig(builder: ~fme.core.registry.ModuleSelector, in_names: ~typing.List[str], out_names: ~typing.List[str], normalization: ~fme.core.normalizer.NormalizationConfig | ~fme.core.normalizer.FromStateNormalizer, parameter_init: ~fme.core.parameter_init.ParameterInitializationConfig = <factory>, ocean: ~fme.core.ocean.OceanConfig | None = None, loss: ~fme.core.loss.WeightedMappingLossConfig = <factory>, corrector: ~fme.core.corrector.CorrectorConfig = <factory>, next_step_forcing_names: ~typing.List[str] = <factory>, loss_normalization: ~fme.core.normalizer.NormalizationConfig | ~fme.core.normalizer.FromStateNormalizer | None = None, residual_normalization: ~fme.core.normalizer.NormalizationConfig | ~fme.core.normalizer.FromStateNormalizer | None = None)[source]
Bases:
objectConfiguration for a single module stepper.
- builder
The module builder.
- in_names
Names of input variables.
- Type:
List[str]
- out_names
Names of output variables.
- Type:
List[str]
- normalization
The normalization configuration.
- parameter_init
The parameter initialization configuration.
- ocean
The ocean configuration.
- Type:
fme.core.ocean.OceanConfig | None
- loss
The loss configuration.
- corrector
The corrector configuration.
- next_step_forcing_names
Names of forcing variables for the next timestep.
- Type:
List[str]
- loss_normalization
The normalization configuration for the loss.
- residual_normalization
Optional alternative to configure loss normalization. If provided, it will be used for all prognostic variables in loss scaling.
- class fme.ace.ExistingStepperConfig(checkpoint_path: str)[source]
Bases:
objectConfiguration for an existing stepper. This is only designed to point to a serialized stepper checkpoint for loading, e.g., in the case of training resumption.
- checkpoint_path
The path to the serialized checkpoint.
- Type:
str
- class fme.ace.OptimizationConfig(optimizer_type: ~typing.Literal['Adam', 'FusedAdam'] = 'Adam', lr: float = 0.001, kwargs: ~typing.Mapping[str, ~typing.Any] = <factory>, enable_automatic_mixed_precision: bool = False, scheduler: ~fme.core.scheduler.SchedulerConfig = <factory>)[source]
Bases:
objectConfiguration for optimization.
- optimizer_type
The type of optimizer to use.
- Type:
Literal[‘Adam’, ‘FusedAdam’]
- lr
The learning rate.
- Type:
float
- kwargs
Additional keyword arguments to pass to the optimizer.
- Type:
Mapping[str, Any]
- enable_automatic_mixed_precision
Whether to use automatic mixed precision.
- Type:
bool
- scheduler
The type of scheduler to use. If none is given, no scheduler will be used.
- class fme.ace.LoggingConfig(project: str = 'ace', entity: str = 'ai2cm', log_to_screen: bool = True, log_to_file: bool = True, log_to_wandb: bool = True, log_format: str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s', level: str | int = 20)[source]
Bases:
objectConfiguration for logging.
- project
name of the project in Weights & Biases
- Type:
str
- entity
name of the entity in Weights & Biases
- Type:
str
- log_to_screen
whether to log to the screen
- Type:
bool
- log_to_file
whether to log to a file
- Type:
bool
- log_to_wandb
whether to log to Weights & Biases
- Type:
bool
- log_format
format of the log messages
- Type:
str
- class fme.ace.InlineInferenceConfig(loader: ~fme.core.data_loading.inference.InferenceDataLoaderConfig, n_forward_steps: int = 2, forward_steps_in_memory: int = 2, epochs: ~fme.core.data_loading.config.Slice = Slice(start=0, stop=None, step=1), aggregator: ~fme.core.aggregator.inference.main.InferenceEvaluatorAggregatorConfig = <factory>)[source]
Bases:
object- loader
configuration for the data loader used during inference
- n_forward_steps
number of forward steps to take
- Type:
int
- forward_steps_in_memory
number of forward steps to take before re-reading data from disk
- Type:
int
- epochs
epochs on which to run inference, where the first epoch is defined as epoch 0 (unlike in logs which show epochs as starting from 1). By default runs inference every epoch.
- aggregator
configuration of inline inference aggregator.
- class fme.ace.CopyWeightsConfig(include: ~typing.List[str] = <factory>, exclude: ~typing.List[str] = <factory>)[source]
Bases:
objectConfiguration for copying weights from a base model to a target model.
Used during training to overwrite weights after every batch of data, to have the effect of “freezing” the overwritten weights. When the target parameters have longer dimensions than the base model, only the initial slice is overwritten.
This is used to achieve an effect of freezing model parameters that can freeze a subset of each weight that comes from a smaller base weight. This is less efficient than true parameter freezing, but layer freezing is all-or-nothing for each parameter.
All parameters must be covered by either the include or exclude list, but not both.
- include
list of wildcard patterns to overwrite
- Type:
List[str]
- exclude
list of wildcard patterns to exclude from overwriting
- Type:
List[str]
- class fme.ace.EMAConfig(decay: float = 0.9999)[source]
Bases:
objectConfiguration for exponential moving average of model weights.
- decay
decay rate for the moving average
- Type:
float
- class fme.ace.Slice(start: int | None = None, stop: int | None = None, step: int | None = None)[source]
Bases:
objectConfiguration of a python slice built-in.
Required because slice cannot be initialized directly by dacite.
- start
Start index of the slice.
- Type:
int | None
- stop
Stop index of the slice.
- Type:
int | None
- step
Step of the slice.
- Type:
int | None