Source code for fme.ace.train.train_config

import dataclasses
import os
from collections.abc import Callable, Mapping, Sequence
from typing import Any

import torch

from fme.ace.aggregator import (
    InferenceEvaluatorAggregatorConfig,
    OneStepAggregatorConfig,
)
from fme.ace.aggregator.inference.main import InferenceEvaluatorAggregator
from fme.ace.data_loading.config import DataLoaderConfig
from fme.ace.data_loading.getters import get_gridded_data, get_inference_data
from fme.ace.data_loading.gridded_data import (
    ErrorInferenceData,
    GriddedData,
    InferenceGriddedData,
)
from fme.ace.data_loading.inference import InferenceDataLoaderConfig
from fme.ace.requirements import (
    DataRequirements,
    NullDataRequirements,
    PrognosticStateDataRequirements,
)
from fme.ace.stepper import Stepper
from fme.ace.stepper.single_module import StepperConfig
from fme.core.cli import ResumeResultsConfig
from fme.core.dataset.data_typing import VariableMetadata
from fme.core.dataset_info import DatasetInfo
from fme.core.distributed import Distributed
from fme.core.ema import EMAConfig, EMATracker
from fme.core.generics.trainer import EndOfBatchCallback, EndOfEpochCallback
from fme.core.logging_utils import LoggingConfig
from fme.core.optimization import Optimization, OptimizationConfig
from fme.core.rand import set_seed
from fme.core.typing_ import Slice, TensorDict, TensorMapping
from fme.core.weight_ops import CopyWeightsConfig


@dataclasses.dataclass
class WeatherEvaluationConfig:
    """
    Parameters:
        loader: configuration for the data loader used during weather evaluation
        n_forward_steps: number of forward steps to take
        forward_steps_in_memory: number of forward steps to take before
            re-reading data from disk
        epochs: epochs on which to run weather evaluation. By default runs
            weather evaluation every epoch.
        aggregator: configuration of weather evaluation aggregator.
    """

    loader: InferenceDataLoaderConfig
    n_forward_steps: int = 2
    forward_steps_in_memory: int = 2
    epochs: Slice = dataclasses.field(default_factory=lambda: Slice())
    aggregator: InferenceEvaluatorAggregatorConfig = dataclasses.field(
        default_factory=lambda: InferenceEvaluatorAggregatorConfig(
            log_global_mean_time_series=False, log_global_mean_norm_time_series=False
        )
    )

    def __post_init__(self):
        dist = Distributed.get_instance()
        if self.loader.start_indices.n_initial_conditions % dist.world_size != 0:
            raise ValueError(
                "Number of inference initial conditions must be divisible by the "
                "number of parallel workers, got "
                f"{self.loader.start_indices.n_initial_conditions} and "
                f"{dist.world_size}."
            )
        if (
            self.aggregator.log_global_mean_time_series
            or self.aggregator.log_global_mean_norm_time_series
        ):
            # Both of log_global_mean_time_series and
            # log_global_mean_norm_time_series must be False for inline inference.
            self.aggregator.log_global_mean_time_series = False
            self.aggregator.log_global_mean_norm_time_series = False

    def get_inference_data(
        self,
        window_requirements: DataRequirements,
        initial_condition: PrognosticStateDataRequirements,
    ) -> InferenceGriddedData:
        return get_inference_data(
            config=self.loader,
            total_forward_steps=self.n_forward_steps,
            window_requirements=window_requirements,
            initial_condition=initial_condition,
        )


[docs]@dataclasses.dataclass class InlineInferenceConfig: """ Parameters: loader: configuration for the data loader used during inference n_forward_steps: number of forward steps to take forward_steps_in_memory: number of forward steps to take before re-reading data from disk epochs: epochs on which to run inference. By default runs inference every epoch. aggregator: configuration of inline inference aggregator. """ loader: InferenceDataLoaderConfig n_forward_steps: int = 2 forward_steps_in_memory: int = 2 epochs: Slice = dataclasses.field(default_factory=lambda: Slice()) aggregator: InferenceEvaluatorAggregatorConfig = dataclasses.field( default_factory=lambda: InferenceEvaluatorAggregatorConfig( log_global_mean_time_series=False, log_global_mean_norm_time_series=False ) ) def __post_init__(self): dist = Distributed.get_instance() if self.loader.start_indices.n_initial_conditions % dist.world_size != 0: raise ValueError( "Number of inference initial conditions must be divisible by the " "number of parallel workers, got " f"{self.loader.start_indices.n_initial_conditions} and " f"{dist.world_size}." ) if ( self.aggregator.log_global_mean_time_series or self.aggregator.log_global_mean_norm_time_series ): # Both of log_global_mean_time_series and # log_global_mean_norm_time_series must be False for inline inference. self.aggregator.log_global_mean_time_series = False self.aggregator.log_global_mean_norm_time_series = False def get_inference_data( self, window_requirements: DataRequirements, initial_condition: PrognosticStateDataRequirements, ) -> InferenceGriddedData: return get_inference_data( config=self.loader, total_forward_steps=self.n_forward_steps, window_requirements=window_requirements, initial_condition=initial_condition, )
[docs]@dataclasses.dataclass class TrainConfig: """ Configuration for training a model. Arguments: train_loader: Configuration for the training data loader. validation_loader: Configuration for the validation data loader. stepper: Configuration for the stepper. optimization: Configuration for the optimization. logging: Configuration for logging. max_epochs: Total number of epochs to train for. save_checkpoint: Whether to save checkpoints. experiment_dir: Directory where checkpoints and logs are saved. inference: Configuration for inline inference. If None, no inline inference is run, and no "best_inline_inference" checkpoint will be saved. weather_evaluation: Configuration for weather evaluation. If None, no weather evaluation is run. Weather evaluation is not used to select checkpoints, but is used to provide metrics. n_forward_steps: Number of forward steps during training. Cannot be given at the same time as train_n_forward_steps in StepperConfig. seed: Random seed for reproducibility. If set, is used for all types of randomization, including data shuffling and model initialization. If unset, weight initialization is not reproducible but data shuffling is. copy_weights_after_batch: Configuration for copying weights from the base model to the training model after each batch. ema: Configuration for exponential moving average of model weights. validate_using_ema: Whether to validate and perform inference using the EMA model. checkpoint_save_epochs: How often to save epoch-based checkpoints, if save_checkpoint is True. If None, checkpoints are only saved for the most recent epoch (and the best epochs if validate_using_ema == False). ema_checkpoint_save_epochs: How often to save epoch-based EMA checkpoints, if save_checkpoint is True. If None, EMA checkpoints are only saved for the most recent epoch (and the best epochs if validate_using_ema == True). log_train_every_n_batches: How often to log batch_loss during training. checkpoint_every_n_batches: How often to save latest checkpoint during training. If 0 is given, checkpoints will not be saved based on batch progress, only other factors like pre-emption or being at the end of an epoch. segment_epochs: Exit after training for at most this many epochs in current job, without exceeding `max_epochs`. Use this if training must be run in segments, e.g. due to wall clock limit. save_per_epoch_diagnostics: Whether to save per-epoch diagnostics from training, validation and inline inference aggregators. validation_aggregator: Configuration for the validation aggregator. evaluate_before_training: Whether to run validation and inline inference before any training is done. save_best_inference_epoch_checkpoints: Whether to save a separate checkpoint for each epoch where best_inference_error achieves a new minimum. Checkpoints are saved as best_inference_ckpt_XXXX.tar. resume_results: Configuration for resuming a previously stopped or finished training job. When provided and experiment_dir has no training_checkpoints subdirectory, then it is assumed that this is a new run to resume a previously completed run and resume_results.existing_dir is recursively copied to experiment_dir. """ train_loader: DataLoaderConfig validation_loader: DataLoaderConfig stepper: StepperConfig optimization: OptimizationConfig logging: LoggingConfig max_epochs: int save_checkpoint: bool experiment_dir: str inference: InlineInferenceConfig | None n_forward_steps: int | None = None seed: int | None = None copy_weights_after_batch: list[CopyWeightsConfig] = dataclasses.field( default_factory=list ) ema: EMAConfig = dataclasses.field(default_factory=lambda: EMAConfig()) weather_evaluation: WeatherEvaluationConfig | None = None validate_using_ema: bool = False checkpoint_save_epochs: Slice | None = None ema_checkpoint_save_epochs: Slice | None = None log_train_every_n_batches: int = 100 checkpoint_every_n_batches: int = 1000 segment_epochs: int | None = None save_per_epoch_diagnostics: bool = False validation_aggregator: OneStepAggregatorConfig = dataclasses.field( default_factory=lambda: OneStepAggregatorConfig() ) evaluate_before_training: bool = False save_best_inference_epoch_checkpoints: bool = False resume_results: ResumeResultsConfig | None = None def __post_init__(self): if ( isinstance(self.stepper, StepperConfig) and self.stepper.train_n_forward_steps is not None and self.n_forward_steps is not None ): raise ValueError( "stepper.train_n_forward_steps may not be given at the same time as " "n_forward_steps at the top level" ) def set_random_seed(self): if self.seed is not None: set_seed(self.seed) @property def inference_n_forward_steps(self) -> int: if self.inference is None: return 0 return self.inference.n_forward_steps @property def inference_aggregator(self) -> InferenceEvaluatorAggregatorConfig | None: if self.inference is None: return None return self.inference.aggregator @property def checkpoint_dir(self) -> str: """ The directory where checkpoints are saved. """ return os.path.join(self.experiment_dir, "training_checkpoints") @property def output_dir(self) -> str: """ The directory where output files are saved. """ return os.path.join(self.experiment_dir, "output") def get_inference_epochs(self) -> list[int]: if self.inference is None: return [] start_epoch = 0 if self.evaluate_before_training else 1 all_epochs = list(range(start_epoch, self.max_epochs + 1)) return all_epochs[self.inference.epochs.slice]
class TrainBuilders: def __init__(self, config: TrainConfig): self.config = config def _get_train_window_data_requirements(self) -> DataRequirements: return self.config.stepper.get_train_window_data_requirements( default_n_forward_steps=self.config.n_forward_steps ) def _get_evaluation_window_data_requirements(self) -> DataRequirements: if self.config.inference is None: return NullDataRequirements return self.config.stepper.get_evaluation_window_data_requirements( self.config.inference.forward_steps_in_memory ) def _get_initial_condition_data_requirements( self, ) -> PrognosticStateDataRequirements: return self.config.stepper.get_prognostic_state_data_requirements() def get_train_data(self) -> GriddedData: data_requirements = self._get_train_window_data_requirements() return get_gridded_data( self.config.train_loader, requirements=data_requirements, train=True, ) def get_validation_data(self) -> GriddedData: data_requirements = self._get_train_window_data_requirements() return get_gridded_data( self.config.validation_loader, requirements=data_requirements, train=False, ) def get_evaluation_inference_data( self, ) -> InferenceGriddedData: if self.config.inference is None: return ErrorInferenceData() # type: ignore else: return self.config.inference.get_inference_data( window_requirements=self._get_evaluation_window_data_requirements(), initial_condition=self._get_initial_condition_data_requirements(), ) def get_optimization(self, modules: torch.nn.ModuleList) -> Optimization: return self.config.optimization.build(modules, self.config.max_epochs) def get_stepper( self, dataset_info: DatasetInfo, ) -> Stepper: return self.config.stepper.get_stepper( dataset_info=dataset_info, ) def get_ema(self, modules) -> EMATracker: return self.config.ema.build(modules) def get_end_of_batch_ops( self, modules: list[torch.nn.Module], base_weights: list[Mapping[str, Any]] | None, ) -> EndOfBatchCallback: if base_weights is not None: def copy_after_batch(): for module, copy_config in zip( modules, self.config.copy_weights_after_batch ): copy_config.apply(weights=base_weights, modules=[module]) return return copy_after_batch return lambda: None def get_end_of_epoch_callback( self, inference_one_epoch: Callable[ [InferenceGriddedData, InferenceEvaluatorAggregator, str, int], Mapping[str, Any], ], normalize: Callable[[TensorMapping], TensorDict], output_dir: str, variable_metadata: Mapping[str, VariableMetadata], channel_mean_names: Sequence[str] | None, save_diagnostics: bool, n_ic_timesteps: int, ) -> EndOfEpochCallback: if self.config.weather_evaluation is not None: data = self.config.weather_evaluation.get_inference_data( window_requirements=self._get_evaluation_window_data_requirements(), initial_condition=self._get_initial_condition_data_requirements(), ) dataset_info = data.dataset_info.update_variable_metadata(variable_metadata) aggregator = self.config.weather_evaluation.aggregator.build( dataset_info=dataset_info, n_timesteps=self.config.weather_evaluation.n_forward_steps + n_ic_timesteps, initial_time=data.initial_time, normalize=normalize, output_dir=output_dir, record_step_20=self.config.weather_evaluation.n_forward_steps >= 20, channel_mean_names=channel_mean_names, save_diagnostics=save_diagnostics, ) def end_of_epoch_ops(epoch: int) -> Mapping[str, Any]: if self.config.weather_evaluation is not None: if self.config.weather_evaluation.epochs.contains(epoch): return inference_one_epoch( data, aggregator, "weather_eval", epoch, ) return {} return end_of_epoch_ops return lambda epoch: {}