Training Config¶
The following is an example configuration for running training while evaluating against target data. While you can use absolute paths in the config yamls (we encourage it!), the example uses paths relative to the directory you run the command. The example is based on training with separate training and validation datasets, e.g., a directory structure like:
.
├── ckpt.tar
└── validation
├── data1.nc # files can have any name, but must sort into time-sequential order
├── data2.nc # can have any number of netCDF files
└── ...
└── traindata
├── ic_0001
│ ├── data1.nc # files can have any name, but must sort into time-sequential order
│ ├── data2.nc # can have any number of netCDF files
│ └── ...
├── ic_0002
│ └── ...
├── ...
└── ic_0010
└── ...
You can flexibly modify the example to run on data organized in a different manner, and change the data paths as you wish.
A 1-year (1940) single-ensemble member data subsample available on the in the ACE2-ERA5 Hugging Face page.
In that example dataset, the .nc files would correspond to files like training_validation_data/training_validation/1940010100.nc, and ckpt.tar corresponds to ace2_era5_ckpt.tar.
experiment_dir: train_output
save_checkpoint: true
validate_using_ema: true
max_epochs: 80
n_forward_steps: 1
inference:
n_forward_steps: 7300 # 5 years
forward_steps_in_memory: 50
loader:
start_indices:
first: 0
n_initial_conditions: 4
interval: 1460 # 1 year
dataset:
data_path: validation
num_data_workers: 4
logging:
log_to_screen: true
log_to_wandb: false
log_to_file: true
project: ace
entity: your_wandb_entity
train_loader:
batch_size: 4
num_data_workers: 4
prefetch_factor: 4
dataset:
concat:
- data_path: traindata/ic_0001
- data_path: traindata/ic_0002
- data_path: traindata/ic_0003
- data_path: traindata/ic_0004
- data_path: traindata/ic_0005
- data_path: traindata/ic_0006
- data_path: traindata/ic_0007
- data_path: traindata/ic_0008
- data_path: traindata/ic_0009
- data_path: traindata/ic_0010
validation_loader:
batch_size: 16
num_data_workers: 4
prefetch_factor: 4
dataset:
data_path: validation
subset:
step: 5
optimization:
enable_automatic_mixed_precision: false
lr: 0.0001
optimizer_type: AdamW
# can also set kwargs: fused: true for performance if using GPU
stepper:
loss:
type: MSE
step:
type: single_module
config:
builder:
type: SphericalFourierNeuralOperatorNet
config:
embed_dim: 384
filter_type: linear
hard_thresholding_fraction: 1.0
use_mlp: true
normalization_layer: instance_norm
num_layers: 8
operator_type: dhconv
scale_factor: 1
separable: false
normalization:
network:
global_means_path: centering.nc
global_stds_path: scaling-full-field.nc
loss:
global_means_path: centering.nc
global_stds_path: scaling-residual.nc
ocean:
surface_temperature_name: surface_temperature
ocean_fraction_name: ocean_fraction
corrector:
conserve_dry_air: true
moisture_budget_correction: advection_and_precipitation
in_names:
- land_fraction
- ocean_fraction
- sea_ice_fraction
- DSWRFtoa
- HGTsfc
- PRESsfc
- surface_temperature
- air_temperature_0 # _0 denotes the top most layer of the atmosphere
- air_temperature_1
- air_temperature_2
- air_temperature_3
- air_temperature_4
- air_temperature_5
- air_temperature_6
- air_temperature_7
- specific_total_water_0
- specific_total_water_1
- specific_total_water_2
- specific_total_water_3
- specific_total_water_4
- specific_total_water_5
- specific_total_water_6
- specific_total_water_7
- eastward_wind_0
- eastward_wind_1
- eastward_wind_2
- eastward_wind_3
- eastward_wind_4
- eastward_wind_5
- eastward_wind_6
- eastward_wind_7
- northward_wind_0
- northward_wind_1
- northward_wind_2
- northward_wind_3
- northward_wind_4
- northward_wind_5
- northward_wind_6
- northward_wind_7
out_names:
- PRESsfc
- surface_temperature
- air_temperature_0
- air_temperature_1
- air_temperature_2
- air_temperature_3
- air_temperature_4
- air_temperature_5
- air_temperature_6
- air_temperature_7
- specific_total_water_0
- specific_total_water_1
- specific_total_water_2
- specific_total_water_3
- specific_total_water_4
- specific_total_water_5
- specific_total_water_6
- specific_total_water_7
- eastward_wind_0
- eastward_wind_1
- eastward_wind_2
- eastward_wind_3
- eastward_wind_4
- eastward_wind_5
- eastward_wind_6
- eastward_wind_7
- northward_wind_0
- northward_wind_1
- northward_wind_2
- northward_wind_3
- northward_wind_4
- northward_wind_5
- northward_wind_6
- northward_wind_7
- LHTFLsfc
- SHTFLsfc
- PRATEsfc
- ULWRFsfc
- ULWRFtoa
- DLWRFsfc
- DSWRFsfc
- USWRFsfc
- USWRFtoa
- tendency_of_total_water_path_due_to_advection
We use the Builder pattern to load this configuration into a multi-level dataclass structure.
The configuration is divided into several sub-configurations, each with its own dataclass.
The top-level configuration is the fme.ace.TrainConfig class.
- class fme.ace.TrainConfig(train_loader, validation_loader, stepper, optimization, logging, max_epochs, save_checkpoint, experiment_dir, inference, n_forward_steps=None, seed=None, copy_weights_after_batch=<factory>, ema=<factory>, weather_evaluation=None, validate_using_ema=False, checkpoint_save_epochs=None, ema_checkpoint_save_epochs=None, log_train_every_n_batches=100, checkpoint_every_n_batches=1000, segment_epochs=None, save_per_epoch_diagnostics=False, validation_aggregator=<factory>, evaluate_before_training=False, save_best_inference_epoch_checkpoints=False, resume_results=None)[source]
Bases:
objectConfiguration for training a model.
- Parameters:
train_loader (
DataLoaderConfig) – Configuration for the training data loader.validation_loader (
DataLoaderConfig) – Configuration for the validation data loader.stepper (
StepperConfig) – Configuration for the stepper.optimization (
OptimizationConfig) – Configuration for the optimization.logging (
LoggingConfig) – Configuration for logging.max_epochs (
int) – Total number of epochs to train for.save_checkpoint (
bool) – Whether to save checkpoints.experiment_dir (
str) – Directory where checkpoints and logs are saved.inference (
Optional[InlineInferenceConfig]) – Configuration for inline inference. If None, no inline inference is run, and no “best_inline_inference” checkpoint will be saved.weather_evaluation (
Optional[WeatherEvaluationConfig], default:None) – Configuration for weather evaluation. If None, no weather evaluation is run. Weather evaluation is not used to select checkpoints, but is used to provide metrics.n_forward_steps (
Optional[int], default:None) – Number of forward steps during training. Cannot be given at the same time as train_n_forward_steps in StepperConfig.seed (
Optional[int], default:None) – Random seed for reproducibility. If set, is used for all types of randomization, including data shuffling and model initialization. If unset, weight initialization is not reproducible but data shuffling is.copy_weights_after_batch (
list[CopyWeightsConfig], default:<factory>) – Configuration for copying weights from the base model to the training model after each batch.ema (
EMAConfig, default:<factory>) – Configuration for exponential moving average of model weights.validate_using_ema (
bool, default:False) – Whether to validate and perform inference using the EMA model.checkpoint_save_epochs (
Optional[Slice], default:None) – How often to save epoch-based checkpoints, if save_checkpoint is True. If None, checkpoints are only saved for the most recent epoch (and the best epochs if validate_using_ema == False).ema_checkpoint_save_epochs (
Optional[Slice], default:None) – How often to save epoch-based EMA checkpoints, if save_checkpoint is True. If None, EMA checkpoints are only saved for the most recent epoch (and the best epochs if validate_using_ema == True).log_train_every_n_batches (
int, default:100) – How often to log batch_loss during training.checkpoint_every_n_batches (
int, default:1000) – How often to save latest checkpoint during training. If 0 is given, checkpoints will not be saved based on batch progress, only other factors like pre-emption or being at the end of an epoch.segment_epochs (
Optional[int], default:None) – Exit after training for at most this many epochs in current job, without exceeding max_epochs. Use this if training must be run in segments, e.g. due to wall clock limit.save_per_epoch_diagnostics (
bool, default:False) – Whether to save per-epoch diagnostics from training, validation and inline inference aggregators.validation_aggregator (
OneStepAggregatorConfig, default:<factory>) – Configuration for the validation aggregator.evaluate_before_training (
bool, default:False) – Whether to run validation and inline inference before any training is done.save_best_inference_epoch_checkpoints (
bool, default:False) – Whether to save a separate checkpoint for each epoch where best_inference_error achieves a new minimum. Checkpoints are saved as best_inference_ckpt_XXXX.tar.resume_results (
Optional[ResumeResultsConfig], default:None) – Configuration for resuming a previously stopped or finished training job. When provided and experiment_dir has no training_checkpoints subdirectory, then it is assumed that this is a new run to resume a previously completed run and resume_results.existing_dir is recursively copied to experiment_dir.
The top-level sub-configurations are:
- class fme.ace.DataLoaderConfig(dataset, batch_size, num_data_workers=0, prefetch_factor=None, augmentation=<factory>, sample_with_replacement=None, time_buffer=0)[source]
Bases:
objectConfiguration for a data loader for training/validation.
- Parameters:
dataset (
ConcatDatasetConfig|MergeDatasetConfig|XarrayDataConfig|Sequence[XarrayDataConfig]) – Could be a single dataset configuration, or a sequence of datasets to be concatenated using the keyword concat, or datasets from different sources to be merged using the keyword merge. For backwards compatibility, it can also be a sequence of datasets, which will be concatenated. During merge, if multiple datasets contain the same data variable, the version from the first source is loaded and other sources are ignored.batch_size (
int) – Number of samples per batch.num_data_workers (
int, default:0) – Number of parallel workers to use for data loading.prefetch_factor (
Optional[int], default:None) – how many batches a single data worker will attempt to hold in host memory at a given time.augmentation (
AugmentationConfig, default:<factory>) – Configuration for data augmentation.sample_with_replacement (
Optional[int], default:None) – If provided, the dataset will be sampled randomly with replacement to the given size each period, instead of retrieving each sample once (either shuffled or not).time_buffer (
int, default:0) – How many more continuous timesteps to load in memory than the required number of timesteps for a single batch. Setting this to greater than 0 should improve data loading performance, however, it also decreases the independence of subsequent batches if shuffled batches are desired.
Note
Setting time_buffer to a value greater than 0 results in pre-loading samples of length time_buffer + n_timesteps_required, where n_timesteps_required is the number of timesteps required for training the model (initial condition(s) plus forward step(s)). These pre-loaded samples become a window from which samples of the required length are drawn without replacement. The windows will overlap by an amount such that no samples are skipped, with exception of the last window, which is dropped if incomplete. This is useful for improving data loading throughput and reducing the number of reads. There must be enough pre-loaded samples in the dataset to produce at least one batch at the configured batch size. Independent data will be seen every time_buffer + 1 batches, i.e., this is the number of samples in each pre-loaded window.
- class fme.ace.StepperConfig(step, loss=<factory>, optimize_last_step_only=False, n_ensemble=-1, crps_training=False, parameter_init=<factory>, input_masking=None, train_n_forward_steps=None)[source]
Bases:
objectConfiguration for a stepper.
- Parameters:
step (
StepSelector) – The step configuration.loss (
StepLossConfig, default:<factory>) – The loss configuration.optimize_last_step_only (
bool, default:False) – Whether to optimize only the last step.n_ensemble (
int, default:-1) – The number of ensemble members evaluated for each training batch member. Default is 2 if the loss type is EnsembleLoss, otherwise the default is 1. Must be 2 for EnsembleLoss to be valid.crps_training (
bool, default:False) – Deprecated, kept for backwards compatibility. Use n_ensemble=2 with a CRPS loss instead.parameter_init (
ParameterInitializationConfig, default:<factory>) – The parameter initialization configuration.input_masking (
Optional[StaticMaskingConfig], default:None) – Config for masking step inputs.train_n_forward_steps (
UnionType[TimeLengthProbabilities,int,None], default:None) – The number of timesteps to train on and associated sampling probabilities. By default, the stepper will train on the full number of timesteps present in the training dataset samples. Values must be less than or equal to the number of timesteps present in the training dataset samples.
- class fme.ace.StepSelector(type, config)[source]
Bases:
StepConfigABC
- class fme.ace.OptimizationConfig(optimizer_type='Adam', lr=0.001, kwargs=<factory>, enable_automatic_mixed_precision=False, scheduler=<factory>, use_gradient_accumulation=False, checkpoint=<factory>)[source]
Bases:
objectConfiguration for optimization.
- Parameters:
optimizer_type (
Literal['Adam','AdamW','FusedAdam'], default:'Adam') – The type of optimizer to use.lr (
float, default:0.001) – The learning rate.kwargs (
Mapping[str,Any], default:<factory>) – Additional keyword arguments to pass to the optimizer.enable_automatic_mixed_precision (
bool, default:False) – Whether to use automatic mixed precision.scheduler (
SchedulerConfig|SequentialSchedulerConfig, default:<factory>) – The type of scheduler to use. If none is given, no scheduler will be used.use_gradient_accumulation (
bool, default:False) – Whether to use gradient accumulation. This must be supported by the stepper being optimized, which may accumulate gradients from separate losses to reduce memory consumption. The stepper may choose to accumulate gradients differently when this is enabled, such as by detaching the computational graph between steps. See the documentation of your stepper (e.g. Stepper) for more details.checkpoint (CheckpointConfig) –
- class fme.ace.LoggingConfig(project='ace', entity='ai2cm', log_to_screen=True, log_to_file=True, log_to_wandb=True, log_format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=20, wandb_dir_in_experiment_dir=False)[source]
Bases:
objectConfiguration for logging.
- Parameters:
project (
str, default:'ace') – Name of the project in Weights & Biases.entity (
str, default:'ai2cm') – Name of the entity in Weights & Biases.log_to_screen (
bool, default:True) – Whether to log to the screen.log_to_file (
bool, default:True) – Whether to log to a file.log_to_wandb (
bool, default:True) – Whether to log to Weights & Biases.log_format (
str, default:'%(asctime)s - %(name)s - %(levelname)s - %(message)s') – Format of the log messages.wandb_dir_in_experiment_dir (
bool, default:False) – Whether to create the wandb_dir in the experiment_dir or in local /tmp (default False).
- class fme.ace.InlineInferenceConfig(loader, n_forward_steps=2, forward_steps_in_memory=2, epochs=<factory>, aggregator=<factory>)[source]
Bases:
object- Parameters:
loader (
InferenceDataLoaderConfig) – configuration for the data loader used during inferencen_forward_steps (
int, default:2) – number of forward steps to takeforward_steps_in_memory (
int, default:2) – number of forward steps to take before re-reading data from diskepochs (
Slice, default:<factory>) – epochs on which to run inference. By default runs inference every epoch.aggregator (
InferenceEvaluatorAggregatorConfig, default:<factory>) – configuration of inline inference aggregator.
- class fme.ace.CopyWeightsConfig(include=<factory>, exclude=<factory>)[source]
Bases:
objectConfiguration for copying weights from a base model to a target model.
Used during training to overwrite weights after every batch of data, to have the effect of “freezing” the overwritten weights. When the target parameters have longer dimensions than the base model, only the initial slice is overwritten.
This is used to achieve an effect of freezing model parameters that can freeze a subset of each weight that comes from a smaller base weight. This is less efficient than true parameter freezing, but layer freezing is all-or-nothing for each parameter.
All parameters must be covered by either the include or exclude list, but not both.
- class fme.ace.EMAConfig(decay=0.9999)[source]
Bases:
objectConfiguration for exponential moving average of model weights.
- Parameters:
decay (
float, default:0.9999) – decay rate for the moving average