Training Config¶
The following is an example configuration for running training while evaluating against target data. While you can use absolute paths in the config yamls (we encourage it!), the example uses paths relative to the directory you run the command. The example is based on training with separate training and validation datasets, e.g., a directory structure like:
.
├── ckpt.tar
└── validation
├── data1.nc # files can have any name, but must sort into time-sequential order
├── data2.nc # can have any number of netCDF files
└── ...
└── traindata
├── ic_0001
│ ├── data1.nc # files can have any name, but must sort into time-sequential order
│ ├── data2.nc # can have any number of netCDF files
│ └── ...
├── ic_0002
│ └── ...
├── ...
└── ic_0010
└── ...
You can flexibly modify the example to run on data organized in a different manner, and change the data paths as you wish.
A 1-year (1940) single-ensemble member data subsample available on the in the ACE2-ERA5 Hugging Face page.
In that example dataset, the .nc files would correspond to files like training_validation_data/training_validation/1940010100.nc, and ckpt.tar corresponds to ace2_era5_ckpt.tar.
experiment_dir: train_output
save_checkpoint: true
validate_using_ema: true
max_epochs: 80
inference:
n_forward_steps: 7300 # 5 years
forward_steps_in_memory: 50
loader:
start_indices:
first: 0
n_initial_conditions: 4
interval: 1460 # 1 year
dataset:
data_path: validation
num_data_workers: 4
logging:
log_to_screen: true
log_to_wandb: false
log_to_file: true
project: ace
entity: your_wandb_entity
train_loader:
batch_size: 4
num_data_workers: 4
prefetch_factor: 4
dataset:
concat:
- data_path: traindata/ic_0001
- data_path: traindata/ic_0002
- data_path: traindata/ic_0003
- data_path: traindata/ic_0004
- data_path: traindata/ic_0005
- data_path: traindata/ic_0006
- data_path: traindata/ic_0007
- data_path: traindata/ic_0008
- data_path: traindata/ic_0009
- data_path: traindata/ic_0010
validation_loader:
batch_size: 16
num_data_workers: 4
prefetch_factor: 4
dataset:
data_path: validation
subset:
step: 5
optimization:
enable_automatic_mixed_precision: false
lr: 0.0001
optimizer_type: AdamW
# can also set kwargs: fused: true for performance if using GPU
stepper_training:
n_forward_steps: 1
loss:
type: MSE
stepper:
step:
type: single_module
config:
builder:
type: SphericalFourierNeuralOperatorNet
config:
embed_dim: 384
filter_type: linear
hard_thresholding_fraction: 1.0
use_mlp: true
normalization_layer: instance_norm
num_layers: 8
operator_type: dhconv
scale_factor: 1
separable: false
normalization:
network:
global_means_path: centering.nc
global_stds_path: scaling-full-field.nc
loss:
global_means_path: centering.nc
global_stds_path: scaling-residual.nc
ocean:
surface_temperature_name: surface_temperature
ocean_fraction_name: ocean_fraction
corrector:
conserve_dry_air: true
moisture_budget_correction: advection_and_precipitation
in_names:
- land_fraction
- ocean_fraction
- sea_ice_fraction
- DSWRFtoa
- HGTsfc
- PRESsfc
- surface_temperature
- air_temperature_0 # _0 denotes the top most layer of the atmosphere
- air_temperature_1
- air_temperature_2
- air_temperature_3
- air_temperature_4
- air_temperature_5
- air_temperature_6
- air_temperature_7
- specific_total_water_0
- specific_total_water_1
- specific_total_water_2
- specific_total_water_3
- specific_total_water_4
- specific_total_water_5
- specific_total_water_6
- specific_total_water_7
- eastward_wind_0
- eastward_wind_1
- eastward_wind_2
- eastward_wind_3
- eastward_wind_4
- eastward_wind_5
- eastward_wind_6
- eastward_wind_7
- northward_wind_0
- northward_wind_1
- northward_wind_2
- northward_wind_3
- northward_wind_4
- northward_wind_5
- northward_wind_6
- northward_wind_7
out_names:
- PRESsfc
- surface_temperature
- air_temperature_0
- air_temperature_1
- air_temperature_2
- air_temperature_3
- air_temperature_4
- air_temperature_5
- air_temperature_6
- air_temperature_7
- specific_total_water_0
- specific_total_water_1
- specific_total_water_2
- specific_total_water_3
- specific_total_water_4
- specific_total_water_5
- specific_total_water_6
- specific_total_water_7
- eastward_wind_0
- eastward_wind_1
- eastward_wind_2
- eastward_wind_3
- eastward_wind_4
- eastward_wind_5
- eastward_wind_6
- eastward_wind_7
- northward_wind_0
- northward_wind_1
- northward_wind_2
- northward_wind_3
- northward_wind_4
- northward_wind_5
- northward_wind_6
- northward_wind_7
- LHTFLsfc
- SHTFLsfc
- PRATEsfc
- ULWRFsfc
- ULWRFtoa
- DLWRFsfc
- DSWRFsfc
- USWRFsfc
- USWRFtoa
- tendency_of_total_water_path_due_to_advection
We use the Builder pattern to load this configuration into a multi-level dataclass structure.
The configuration is divided into several sub-configurations, each with its own dataclass.
The top-level configuration is the fme.ace.TrainConfig class.
- class fme.ace.TrainConfig(train_loader, validation_loader, stepper, optimization, logging, max_epochs, save_checkpoint, experiment_dir, inference, stepper_training=<factory>, train_aggregator=<factory>, seed=None, copy_weights_after_batch=<factory>, ema=<factory>, additional_inference=<factory>, validate_using_ema=False, checkpoint_save_epochs=None, ema_checkpoint_save_epochs=None, log_train_every_n_batches=100, train_evaluation_samples=1000, checkpoint_every_n_batches=1000, segment_epochs=None, save_per_epoch_diagnostics=False, validation_aggregator=<factory>, evaluate_before_training=False, save_best_inference_epoch_checkpoints=False, lr_tuning=None, resume_results=None)[source]
Bases:
objectConfiguration for training a model.
- Parameters:
train_loader (
DataLoaderConfig) – Configuration for the training data loader.validation_loader (
DataLoaderConfig) – Configuration for the validation data loader.stepper (
StepperConfig|CheckpointStepperConfig) – Configuration for the stepper.optimization (
OptimizationConfig) – Configuration for the optimization.logging (
LoggingConfig) – Configuration for logging.max_epochs (
int) – Total number of epochs to train for.save_checkpoint (
bool) – Whether to save checkpoints. If false, no checkpoints are saved regardless of other checkpoint configuration settings. If true, checkpoints are saved at the end of the training loop, after evaluation, and on catching a termination signal.experiment_dir (
str) – Directory where checkpoints and logs are saved. For the time being, this must be a local directory.inference (
Optional[InlineInferenceConfig]) – Configuration for inline inference. If None, no inline inference is run, and no “best_inline_inference” checkpoint will be saved.additional_inference (
list[AdditionalInferenceConfig], default:<factory>) – Configurations for additional inference runs. Each entry has a name (used as wandb log prefix) and config. Not used to select checkpoints, but used to provide metrics.stepper_training (
TrainStepperConfig, default:<factory>) – Training-specific configuration including loss, ensemble settings, parameter initialization, and forward step scheduling.train_aggregator (
TrainAggregatorConfig, default:<factory>) – Configuration for the train aggregator.seed (
Optional[int], default:None) – Random seed for reproducibility. If set, is used for all types of randomization, including data shuffling and model initialization. If unset, weight initialization is not reproducible but data shuffling is.copy_weights_after_batch (
list[CopyWeightsConfig], default:<factory>) – Configuration for copying weights from the base model to the training model after each batch.ema (
EMAConfig, default:<factory>) – Configuration for exponential moving average of model weights.validate_using_ema (
bool, default:False) – Whether to validate and perform inference using the EMA model.checkpoint_save_epochs (
Optional[Slice], default:None) – How often to save epoch-based checkpoints, if save_checkpoint is True. If None, checkpoints are only saved for the most recent epoch (and the best epochs if validate_using_ema == False).ema_checkpoint_save_epochs (
Optional[Slice], default:None) – How often to save epoch-based EMA checkpoints, if save_checkpoint is True. If None, EMA checkpoints are only saved for the most recent epoch (and the best epochs if validate_using_ema == True).log_train_every_n_batches (
int, default:100) – How often to log batch_loss during training.train_evaluation_samples (
int, default:1000) – Number of samples to evaluate on after training on each epoch. The remainder samples after dividing by the batch size are discarded.checkpoint_every_n_batches (
int, default:1000) – How often to save latest checkpoint during training. If 0 is given, checkpoints will not be saved based on batch progress, only other factors like pre-emption or being at the end of an epoch.segment_epochs (
Optional[int], default:None) – Exit after training for at most this many epochs in current job, without exceeding max_epochs. Use this if training must be run in segments, e.g. due to wall clock limit.save_per_epoch_diagnostics (
bool, default:False) – Whether to save per-epoch diagnostics from training, validation and inline inference aggregators.validation_aggregator (
OneStepAggregatorConfig, default:<factory>) – Configuration for the validation aggregator.evaluate_before_training (
bool, default:False) – Whether to run validation and inline inference before any training is done.save_best_inference_epoch_checkpoints (
bool, default:False) – Whether to save a separate checkpoint for each epoch where best_inference_error achieves a new minimum. Checkpoints are saved as best_inference_ckpt_XXXX.tar.resume_results (
Optional[ResumeResultsConfig], default:None) – Configuration for resuming a previously stopped or finished training job. When provided and experiment_dir has no training_checkpoints subdirectory, then it is assumed that this is a new run to resume a previously completed run and resume_results.existing_dir is recursively copied to experiment_dir.lr_tuning (LRTuningConfig | None) –
The top-level sub-configurations are:
- class fme.ace.DataLoaderConfig(dataset, batch_size, num_data_workers=0, prefetch_factor=None, augmentation=<factory>, sample_with_replacement=None, time_buffer=0)[source]
Bases:
objectConfiguration for a data loader for training/validation.
- Parameters:
dataset (
ConcatDatasetConfig|MergeDatasetConfig|XarrayDataConfig) – Could be a single dataset configuration, or a sequence of datasets to be concatenated using the keyword concat, or datasets from different sources to be merged using the keyword merge.batch_size (
int) – Number of samples per batch.num_data_workers (
int, default:0) – Number of parallel workers to use for data loading.prefetch_factor (
Optional[int], default:None) – how many batches a single data worker will attempt to hold in host memory at a given time.augmentation (
AugmentationConfig, default:<factory>) – Configuration for data augmentation.sample_with_replacement (
Optional[int], default:None) – If provided, the dataset will be sampled randomly with replacement to the given size each period, instead of retrieving each sample once (either shuffled or not).time_buffer (
int, default:0) – How many more continuous timesteps to load in memory than the required number of timesteps for a single batch. Setting this to greater than 0 should improve data loading performance, however, it also decreases the independence of subsequent batches if shuffled batches are desired.
Note
Setting time_buffer to a value greater than 0 results in pre-loading samples of length time_buffer + n_timesteps_required, where n_timesteps_required is the number of timesteps required for training the model (initial condition(s) plus forward step(s)). These pre-loaded samples become a window from which samples of the required length are drawn without replacement. The windows will overlap by an amount such that no samples are skipped, with exception of the last window, which is dropped if incomplete. This is useful for improving data loading throughput and reducing the number of reads. There must be enough pre-loaded samples in the dataset to produce at least one batch at the configured batch size. Independent data will be seen every time_buffer + 1 batches, i.e., this is the number of samples in each pre-loaded window.
- class fme.ace.StepperConfig(step, input_masking=None, derived_forcings=<factory>)[source]
Bases:
objectConfiguration for a stepper.
- Parameters:
step (
StepSelector) – The step configuration.input_masking (
Optional[StaticMaskingConfig], default:None) – Config for masking step inputs.derived_forcings (
DerivedForcingsConfig, default:<factory>) – Configuration for deriving forcing variables.
- class fme.ace.TrainStepperConfig(loss=<factory>, optimize_last_step_only=False, n_ensemble=-1, n_forward_steps=None, parameter_init=<factory>)[source]
Bases:
objectConfiguration for training-specific aspects of a stepper.
- Parameters:
loss (
StepLossConfig, default:<factory>) – The loss configuration.optimize_last_step_only (
bool, default:False) – Whether to optimize only the last step.n_ensemble (
int, default:-1) – The number of ensemble members evaluated for each training batch member. Default is 2 if the loss type is EnsembleLoss, otherwise the default is 1. Must be 2 for EnsembleLoss to be valid.n_forward_steps (
UnionType[TimeLengthProbabilities,int,TimeLengthSchedule,None], default:None) – The number of timesteps to train on and associated sampling probabilities. By default, the stepper will train on the full number of timesteps present in the training dataset samples. Values must be less than or equal to the number of timesteps present in the training dataset samples.parameter_init (
ParameterInitializationConfig, default:<factory>) – The parameter initialization configuration for fine-tuning.
- class fme.ace.StepSelector(type, config)[source]
Bases:
StepConfigABC
- class fme.ace.OptimizationConfig(optimizer_type='Adam', lr=0.001, kwargs=<factory>, enable_automatic_mixed_precision=False, scheduler=<factory>, use_gradient_accumulation=False, checkpoint=<factory>, resume_optimizer_ckpt_path=None)[source]
Bases:
objectConfiguration for optimization.
- Parameters:
optimizer_type (
Literal['Adam','AdamW','FusedAdam'], default:'Adam') – The type of optimizer to use.lr (
float, default:0.001) – The learning rate.kwargs (
Mapping[str,Any], default:<factory>) – Additional keyword arguments to pass to the optimizer.enable_automatic_mixed_precision (
bool, default:False) – Whether to use automatic mixed precision.scheduler (
SchedulerConfig|SequentialSchedulerConfig, default:<factory>) – The type of scheduler to use. If none is given, no scheduler will be used.use_gradient_accumulation (
bool, default:False) – Whether to use gradient accumulation. This must be supported by the stepper being optimized, which may accumulate gradients from separate losses to reduce memory consumption. The stepper may choose to accumulate gradients differently when this is enabled, such as by detaching the computational graph between steps. See the documentation of your stepper (e.g. Stepper) for more details.resume_optimizer_ckpt_path (
Optional[str], default:None) – Optional path to a training checkpoint (ckpt.tar) whose per-parameter optimizer running state (e.g. Adam moment estimates) and grad scaler state should be loaded into the freshly-builtOptimizationfor fine-tuning. The current config’s per-group hyperparameters (lr,weight_decay,betas, …) and scheduler are kept; only the running state is transferred. Intended for non-resuming jobs; preemption resume in the Trainer overrides this state viaOptimization.load_state.checkpoint (CheckpointConfig) –
- class fme.ace.LoggingConfig(project='ace', entity='ai2cm', log_to_screen=True, log_to_file=True, log_to_wandb=True, metrics_log_dir=None, log_format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=20, wandb_dir_in_experiment_dir=False)[source]
Bases:
objectConfiguration for logging.
- Parameters:
project (
str, default:'ace') – Name of the project in Weights & Biases.entity (
str, default:'ai2cm') – Name of the entity in Weights & Biases.log_to_screen (
bool, default:True) – Whether to log to the screen.log_to_file (
bool, default:True) – Whether to log to a file.log_to_wandb (
bool, default:True) – Whether to log to Weights & Biases.metrics_log_dir (
Optional[str], default:None) – Directory to write scalar metrics to disk as JSONL. If None, disk metric logging is disabled.log_format (
str, default:'%(asctime)s - %(name)s - %(levelname)s - %(message)s') – Format of the log messages.wandb_dir_in_experiment_dir (
bool, default:False) – Whether to create the wandb_dir in the experiment_dir or in local /tmp (default False).
- class fme.ace.InlineInferenceConfig(loader, n_forward_steps, forward_steps_in_memory, n_ensemble_per_ic=1, epochs=<factory>, aggregator=<factory>)[source]
Bases:
object- Parameters:
loader (
InferenceDataLoaderConfig) – configuration for the data loader used during inferencen_forward_steps (
int) – number of forward steps to takeforward_steps_in_memory (
int) – number of forward steps to take before re-reading data from diskn_ensemble_per_ic (
int, default:1) – number of initial condition based ensemblesepochs (
Slice, default:<factory>) – epochs on which to run inference. By default runs inference every epoch.aggregator (
InferenceEvaluatorAggregatorConfig, default:<factory>) – configuration of inline inference aggregator.
- class fme.ace.CopyWeightsConfig(include=<factory>, exclude=None)[source]
Bases:
objectConfiguration for copying weights from a base model to a target model.
Used during training to overwrite weights after every batch of data, to have the effect of “freezing” the overwritten weights. When the target parameters have longer dimensions than the base model, only the initial slice is overwritten.
This is used to achieve an effect of freezing model parameters that can freeze a subset of each weight that comes from a smaller base weight. This is less efficient than true parameter freezing, but layer freezing is all-or-nothing for each parameter.
- Parameters:
include (
list[str], default:<factory>) – list of wildcard patterns to overwrite, if given then only these parameters are overwrittenexclude (
Optional[list[str]], default:None) – list of wildcard patterns to exclude from overwriting, if given then all parameters except these are overwritten. Cannot be given together with include.
- class fme.ace.EMAConfig(decay=0.9999, resume_ema_ckpt_path=None)[source]
Bases:
objectConfiguration for exponential moving average of model weights.
- Parameters:
decay (
float, default:0.9999) – decay rate for the moving averageresume_ema_ckpt_path (
Optional[str], default:None) – Optional path to a training checkpoint (ckpt.tar) whose EMA running state (averaged weights and update counter) should be loaded into the freshly-builtEMATrackerfor fine-tuning. The current config’sdecayis kept; only the running state is transferred. Intended for non-resuming jobs; preemption resume in the Trainer overrides this state viaEMATracker.from_state.