Source code for fme.ace.train.train_config

import dataclasses
import os
import warnings
from collections.abc import Callable, Mapping, Sequence
from typing import Any

import torch

from fme.ace.aggregator import (
    InferenceEvaluatorAggregatorConfig,
    OneStepAggregatorConfig,
)
from fme.ace.aggregator.inference.main import InferenceEvaluatorAggregator
from fme.ace.data_loading.config import DataLoaderConfig
from fme.ace.data_loading.getters import get_gridded_data, get_inference_data
from fme.ace.data_loading.gridded_data import (
    ErrorInferenceData,
    GriddedData,
    InferenceGriddedData,
)
from fme.ace.data_loading.inference import InferenceDataLoaderConfig
from fme.ace.requirements import (
    DataRequirements,
    NullDataRequirements,
    PrognosticStateDataRequirements,
)
from fme.ace.stepper import ExistingStepperConfig, SingleModuleStepperConfig, Stepper
from fme.ace.stepper.single_module import StepperConfig
from fme.core.dataset.data_typing import VariableMetadata
from fme.core.dataset_info import DatasetInfo
from fme.core.distributed import Distributed
from fme.core.ema import EMAConfig, EMATracker
from fme.core.generics.trainer import EndOfBatchCallback, EndOfEpochCallback
from fme.core.logging_utils import LoggingConfig
from fme.core.optimization import Optimization, OptimizationConfig
from fme.core.typing_ import Slice, TensorDict, TensorMapping
from fme.core.weight_ops import CopyWeightsConfig


@dataclasses.dataclass
class WeatherEvaluationConfig:
    """
    Parameters:
        loader: configuration for the data loader used during weather evaluation
        n_forward_steps: number of forward steps to take
        forward_steps_in_memory: number of forward steps to take before
            re-reading data from disk
        epochs: epochs on which to run weather evaluation. By default runs
            weather evaluation every epoch.
        aggregator: configuration of weather evaluation aggregator.
    """

    loader: InferenceDataLoaderConfig
    n_forward_steps: int = 2
    forward_steps_in_memory: int = 2
    epochs: Slice = dataclasses.field(default_factory=lambda: Slice())
    aggregator: InferenceEvaluatorAggregatorConfig = dataclasses.field(
        default_factory=lambda: InferenceEvaluatorAggregatorConfig(
            log_global_mean_time_series=False, log_global_mean_norm_time_series=False
        )
    )

    def __post_init__(self):
        dist = Distributed.get_instance()
        if self.loader.start_indices.n_initial_conditions % dist.world_size != 0:
            raise ValueError(
                "Number of inference initial conditions must be divisible by the "
                "number of parallel workers, got "
                f"{self.loader.start_indices.n_initial_conditions} and "
                f"{dist.world_size}."
            )
        if (
            self.aggregator.log_global_mean_time_series
            or self.aggregator.log_global_mean_norm_time_series
        ):
            # Both of log_global_mean_time_series and
            # log_global_mean_norm_time_series must be False for inline inference.
            self.aggregator.log_global_mean_time_series = False
            self.aggregator.log_global_mean_norm_time_series = False

    def get_inference_data(
        self,
        window_requirements: DataRequirements,
        initial_condition: PrognosticStateDataRequirements,
    ) -> InferenceGriddedData:
        return get_inference_data(
            config=self.loader,
            total_forward_steps=self.n_forward_steps,
            window_requirements=window_requirements,
            initial_condition=initial_condition,
        )


[docs]@dataclasses.dataclass class InlineInferenceConfig: """ Parameters: loader: configuration for the data loader used during inference n_forward_steps: number of forward steps to take forward_steps_in_memory: number of forward steps to take before re-reading data from disk epochs: epochs on which to run inference. By default runs inference every epoch. aggregator: configuration of inline inference aggregator. """ loader: InferenceDataLoaderConfig n_forward_steps: int = 2 forward_steps_in_memory: int = 2 epochs: Slice = dataclasses.field(default_factory=lambda: Slice()) aggregator: InferenceEvaluatorAggregatorConfig = dataclasses.field( default_factory=lambda: InferenceEvaluatorAggregatorConfig( log_global_mean_time_series=False, log_global_mean_norm_time_series=False ) ) def __post_init__(self): dist = Distributed.get_instance() if self.loader.start_indices.n_initial_conditions % dist.world_size != 0: raise ValueError( "Number of inference initial conditions must be divisible by the " "number of parallel workers, got " f"{self.loader.start_indices.n_initial_conditions} and " f"{dist.world_size}." ) if ( self.aggregator.log_global_mean_time_series or self.aggregator.log_global_mean_norm_time_series ): # Both of log_global_mean_time_series and # log_global_mean_norm_time_series must be False for inline inference. self.aggregator.log_global_mean_time_series = False self.aggregator.log_global_mean_norm_time_series = False def get_inference_data( self, window_requirements: DataRequirements, initial_condition: PrognosticStateDataRequirements, ) -> InferenceGriddedData: return get_inference_data( config=self.loader, total_forward_steps=self.n_forward_steps, window_requirements=window_requirements, initial_condition=initial_condition, )
[docs]@dataclasses.dataclass class TrainConfig: """ Configuration for training a model. Arguments: train_loader: Configuration for the training data loader. validation_loader: Configuration for the validation data loader. stepper: Configuration for the stepper. SingleModuleStepperConfig is deprecated and will be removed in a future version. Use StepperConfig instead. optimization: Configuration for the optimization. logging: Configuration for logging. max_epochs: Total number of epochs to train for. save_checkpoint: Whether to save checkpoints. experiment_dir: Directory where checkpoints and logs are saved. inference: Configuration for inline inference. If None, no inline inference is run, and no "best_inline_inference" checkpoint will be saved. weather_evaluation: Configuration for weather evaluation. If None, no weather evaluation is run. Weather evaluation is not used to select checkpoints, but is used to provide metrics. n_forward_steps: Number of forward steps to take gradient over. copy_weights_after_batch: Configuration for copying weights from the base model to the training model after each batch. ema: Configuration for exponential moving average of model weights. validate_using_ema: Whether to validate and perform inference using the EMA model. checkpoint_save_epochs: How often to save epoch-based checkpoints, if save_checkpoint is True. If None, checkpoints are only saved for the most recent epoch (and the best epochs if validate_using_ema == False). ema_checkpoint_save_epochs: How often to save epoch-based EMA checkpoints, if save_checkpoint is True. If None, EMA checkpoints are only saved for the most recent epoch (and the best epochs if validate_using_ema == True). log_train_every_n_batches: How often to log batch_loss during training. segment_epochs: Exit after training for at most this many epochs in current job, without exceeding `max_epochs`. Use this if training must be run in segments, e.g. due to wall clock limit. save_per_epoch_diagnostics: Whether to save per-epoch diagnostics from training, validation and inline inference aggregators. validation_aggregator: Configuration for the validation aggregator. evaluate_before_training: Whether to run validation and inline inference before any training is done. """ train_loader: DataLoaderConfig validation_loader: DataLoaderConfig stepper: SingleModuleStepperConfig | ExistingStepperConfig | StepperConfig optimization: OptimizationConfig logging: LoggingConfig max_epochs: int save_checkpoint: bool experiment_dir: str inference: InlineInferenceConfig | None n_forward_steps: int copy_weights_after_batch: list[CopyWeightsConfig] = dataclasses.field( default_factory=list ) ema: EMAConfig = dataclasses.field(default_factory=lambda: EMAConfig()) weather_evaluation: WeatherEvaluationConfig | None = None validate_using_ema: bool = False checkpoint_save_epochs: Slice | None = None ema_checkpoint_save_epochs: Slice | None = None log_train_every_n_batches: int = 100 segment_epochs: int | None = None save_per_epoch_diagnostics: bool = False validation_aggregator: OneStepAggregatorConfig = dataclasses.field( default_factory=lambda: OneStepAggregatorConfig() ) evaluate_before_training: bool = False def __post_init__(self): if isinstance(self.stepper, SingleModuleStepperConfig): warnings.warn( "SingleModuleStepperConfig is deprecated. Use StepperConfig instead.", DeprecationWarning, ) @property def inference_n_forward_steps(self) -> int: if self.inference is None: return 0 return self.inference.n_forward_steps @property def inference_aggregator(self) -> InferenceEvaluatorAggregatorConfig | None: if self.inference is None: return None return self.inference.aggregator @property def checkpoint_dir(self) -> str: """ The directory where checkpoints are saved. """ return os.path.join(self.experiment_dir, "training_checkpoints") @property def output_dir(self) -> str: """ The directory where output files are saved. """ return os.path.join(self.experiment_dir, "output") def get_inference_epochs(self) -> list[int]: if self.inference is None: return [] start_epoch = 0 if self.evaluate_before_training else 1 all_epochs = list(range(start_epoch, self.max_epochs + 1)) return all_epochs[self.inference.epochs.slice]
class TrainBuilders: def __init__(self, config: TrainConfig): self.config = config def _get_train_window_data_requirements(self) -> DataRequirements: return self.config.stepper.get_evaluation_window_data_requirements( self.config.n_forward_steps ) def _get_evaluation_window_data_requirements(self) -> DataRequirements: if self.config.inference is None: return NullDataRequirements return self.config.stepper.get_evaluation_window_data_requirements( self.config.inference.forward_steps_in_memory ) def _get_initial_condition_data_requirements( self, ) -> PrognosticStateDataRequirements: return self.config.stepper.get_prognostic_state_data_requirements() def get_train_data(self) -> GriddedData: data_requirements = self._get_train_window_data_requirements() return get_gridded_data( self.config.train_loader, requirements=data_requirements, train=True, ) def get_validation_data(self) -> GriddedData: data_requirements = self._get_train_window_data_requirements() return get_gridded_data( self.config.validation_loader, requirements=data_requirements, train=False, ) def get_evaluation_inference_data( self, ) -> InferenceGriddedData: if self.config.inference is None: return ErrorInferenceData() # type: ignore else: return self.config.inference.get_inference_data( window_requirements=self._get_evaluation_window_data_requirements(), initial_condition=self._get_initial_condition_data_requirements(), ) def get_optimization(self, modules: torch.nn.ModuleList) -> Optimization: return self.config.optimization.build(modules, self.config.max_epochs) def get_stepper( self, dataset_info: DatasetInfo, ) -> Stepper: return self.config.stepper.get_stepper( dataset_info=dataset_info, ) def get_ema(self, modules) -> EMATracker: return self.config.ema.build(modules) def get_end_of_batch_ops( self, modules: list[torch.nn.Module], base_weights: list[Mapping[str, Any]] | None, ) -> EndOfBatchCallback: if base_weights is not None: def copy_after_batch(): for module, copy_config in zip( modules, self.config.copy_weights_after_batch ): copy_config.apply(weights=base_weights, modules=[module]) return return copy_after_batch return lambda: None def get_end_of_epoch_callback( self, inference_one_epoch: Callable[ [InferenceGriddedData, InferenceEvaluatorAggregator, str, int], Mapping[str, Any], ], normalize: Callable[[TensorMapping], TensorDict], output_dir: str, variable_metadata: Mapping[str, VariableMetadata], channel_mean_names: Sequence[str], save_diagnostics: bool, n_ic_timesteps: int, ) -> EndOfEpochCallback: if self.config.weather_evaluation is not None: data = self.config.weather_evaluation.get_inference_data( window_requirements=self._get_evaluation_window_data_requirements(), initial_condition=self._get_initial_condition_data_requirements(), ) dataset_info = data.dataset_info.update_variable_metadata(variable_metadata) aggregator = self.config.weather_evaluation.aggregator.build( dataset_info=dataset_info, n_timesteps=self.config.weather_evaluation.n_forward_steps + n_ic_timesteps, initial_time=data.initial_time, normalize=normalize, output_dir=output_dir, record_step_20=self.config.weather_evaluation.n_forward_steps >= 20, channel_mean_names=channel_mean_names, save_diagnostics=save_diagnostics, ) def end_of_epoch_ops(epoch: int) -> Mapping[str, Any]: if self.config.weather_evaluation is not None: if self.config.weather_evaluation.epochs.contains(epoch): return inference_one_epoch( data, aggregator, "weather_eval", epoch, ) return {} return end_of_epoch_ops return lambda epoch: {}