Source code for fme.ace.train.train_config

import dataclasses
import datetime
import os
from typing import List, Optional, Tuple, Union

import torch

from fme.ace.aggregator import InferenceEvaluatorAggregatorConfig
from fme.ace.data_loading.batch_data import GriddedData, InferenceGriddedData
from fme.ace.data_loading.config import DataLoaderConfig
from fme.ace.data_loading.getters import get_data_loader, get_inference_data
from fme.ace.data_loading.inference import InferenceDataLoaderConfig
from fme.ace.requirements import PrognosticStateDataRequirements
from fme.ace.stepper import (
    ExistingStepperConfig,
    SingleModuleStepper,
    SingleModuleStepperConfig,
)
from fme.core.coordinates import HybridSigmaPressureCoordinate
from fme.core.dataset.requirements import DataRequirements
from fme.core.distributed import Distributed
from fme.core.ema import EMAConfig, EMATracker
from fme.core.generics.trainer import EndOfBatchCallback
from fme.core.gridded_ops import GriddedOperations
from fme.core.logging_utils import LoggingConfig
from fme.core.optimization import Optimization, OptimizationConfig
from fme.core.typing_ import Slice
from fme.core.weight_ops import CopyWeightsConfig


[docs]@dataclasses.dataclass class InlineInferenceConfig: """ Parameters: loader: configuration for the data loader used during inference n_forward_steps: number of forward steps to take forward_steps_in_memory: number of forward steps to take before re-reading data from disk epochs: epochs on which to run inference, where the first epoch is defined as epoch 0 (unlike in logs which show epochs as starting from 1). By default runs inference every epoch. aggregator: configuration of inline inference aggregator. """ loader: InferenceDataLoaderConfig n_forward_steps: int = 2 forward_steps_in_memory: int = 2 epochs: Slice = Slice(start=0, stop=None, step=1) aggregator: InferenceEvaluatorAggregatorConfig = dataclasses.field( default_factory=lambda: InferenceEvaluatorAggregatorConfig( log_global_mean_time_series=False, log_global_mean_norm_time_series=False ) ) def __post_init__(self): dist = Distributed.get_instance() if self.loader.start_indices.n_initial_conditions % dist.world_size != 0: raise ValueError( "Number of inference initial conditions must be divisible by the " "number of parallel workers, got " f"{self.loader.start_indices.n_initial_conditions} and " f"{dist.world_size}." ) if ( self.aggregator.log_global_mean_time_series or self.aggregator.log_global_mean_norm_time_series ): # Both of log_global_mean_time_series and # log_global_mean_norm_time_series must be False for inline inference. self.aggregator.log_global_mean_time_series = False self.aggregator.log_global_mean_norm_time_series = False
[docs]@dataclasses.dataclass class TrainConfig: """ Configuration for training a model. Arguments: train_loader: Configuration for the training data loader. validation_loader: Configuration for the validation data loader. stepper: Configuration for the stepper. optimization: Configuration for the optimization. logging: Configuration for logging. max_epochs: Total number of epochs to train for. save_checkpoint: Whether to save checkpoints. experiment_dir: Directory where checkpoints and logs are saved. inference: Configuration for inline inference. n_forward_steps: Number of forward steps to take gradient over. copy_weights_after_batch: Configuration for copying weights from the base model to the training model after each batch. ema: Configuration for exponential moving average of model weights. validate_using_ema: Whether to validate and perform inference using the EMA model. checkpoint_save_epochs: How often to save epoch-based checkpoints, if save_checkpoint is True. If None, checkpoints are only saved for the most recent epoch (and the best epochs if validate_using_ema == False). ema_checkpoint_save_epochs: How often to save epoch-based EMA checkpoints, if save_checkpoint is True. If None, EMA checkpoints are only saved for the most recent epoch (and the best epochs if validate_using_ema == True). log_train_every_n_batches: How often to log batch_loss during training. segment_epochs: Exit after training for at most this many epochs in current job, without exceeding `max_epochs`. Use this if training must be run in segments, e.g. due to wall clock limit. """ train_loader: DataLoaderConfig validation_loader: DataLoaderConfig stepper: Union[SingleModuleStepperConfig, ExistingStepperConfig] optimization: OptimizationConfig logging: LoggingConfig max_epochs: int save_checkpoint: bool experiment_dir: str inference: InlineInferenceConfig n_forward_steps: int copy_weights_after_batch: CopyWeightsConfig = dataclasses.field( default_factory=lambda: CopyWeightsConfig(exclude=["*"]) ) ema: EMAConfig = dataclasses.field(default_factory=lambda: EMAConfig()) validate_using_ema: bool = False checkpoint_save_epochs: Optional[Slice] = None ema_checkpoint_save_epochs: Optional[Slice] = None log_train_every_n_batches: int = 100 segment_epochs: Optional[int] = None @property def inference_n_forward_steps(self) -> int: return self.inference.n_forward_steps @property def inference_aggregator(self) -> InferenceEvaluatorAggregatorConfig: return self.inference.aggregator @property def checkpoint_dir(self) -> str: """ The directory where checkpoints are saved. """ return os.path.join(self.experiment_dir, "training_checkpoints") def clean_wandb(self, experiment_dir: str) -> None: self.logging.clean_wandb(experiment_dir=experiment_dir) def get_inference_epochs(self) -> List[int]: return list(range(0, self.max_epochs))[self.inference.epochs.slice]
class TrainBuilders: def __init__(self, config: TrainConfig): self.config = config def _get_train_window_data_requirements(self) -> DataRequirements: return self.config.stepper.get_evaluation_window_data_requirements( self.config.n_forward_steps ) def _get_evaluation_window_data_requirements(self) -> DataRequirements: return self.config.stepper.get_evaluation_window_data_requirements( self.config.inference.forward_steps_in_memory ) def _get_initial_condition_data_requirements( self, ) -> PrognosticStateDataRequirements: return self.config.stepper.get_prognostic_state_data_requirements() def get_train_data(self) -> GriddedData: data_requirements = self._get_train_window_data_requirements() return get_data_loader( self.config.train_loader, requirements=data_requirements, train=True, ) def get_validation_data(self) -> GriddedData: data_requirements = self._get_train_window_data_requirements() return get_data_loader( self.config.validation_loader, requirements=data_requirements, train=False, ) def get_evaluation_inference_data( self, ) -> InferenceGriddedData: return get_inference_data( config=self.config.inference.loader, total_forward_steps=self.config.inference_n_forward_steps, window_requirements=self._get_evaluation_window_data_requirements(), initial_condition=self._get_initial_condition_data_requirements(), ) def get_optimization(self, modules: torch.nn.ModuleList) -> Optimization: return self.config.optimization.build(modules, self.config.max_epochs) def get_stepper( self, img_shape: Tuple[int, int], gridded_operations: GriddedOperations, vertical_coordinate: HybridSigmaPressureCoordinate, timestep: datetime.timedelta, ) -> SingleModuleStepper: return self.config.stepper.get_stepper( img_shape=img_shape, gridded_operations=gridded_operations, vertical_coordinate=vertical_coordinate, timestep=timestep, ) def get_ema(self, modules) -> EMATracker: return self.config.ema.build(modules) def get_end_of_batch_ops( self, modules: List[torch.nn.Module] ) -> EndOfBatchCallback: base_weights = self.config.stepper.get_base_weights() if base_weights is not None: copy_after_batch = self.config.copy_weights_after_batch return lambda: copy_after_batch.apply(weights=base_weights, modules=modules) return lambda: None