Source code for fme.ace.train.train_config

import dataclasses
import logging
import os
from typing import Any, Dict, Optional, Union

from fme.core.aggregator import InferenceEvaluatorAggregatorConfig
from fme.core.data_loading.config import DataLoaderConfig, Slice
from fme.core.data_loading.inference import InferenceDataLoaderConfig
from fme.core.dicts import to_flat_dict
from fme.core.distributed import Distributed
from fme.core.ema import EMAConfig
from fme.core.logging_utils import LoggingConfig
from fme.core.optimization import OptimizationConfig
from fme.core.stepper import ExistingStepperConfig, SingleModuleStepperConfig
from fme.core.weight_ops import CopyWeightsConfig


[docs]@dataclasses.dataclass class InlineInferenceConfig: """ Attributes: loader: configuration for the data loader used during inference n_forward_steps: number of forward steps to take forward_steps_in_memory: number of forward steps to take before re-reading data from disk epochs: epochs on which to run inference, where the first epoch is defined as epoch 0 (unlike in logs which show epochs as starting from 1). By default runs inference every epoch. aggregator: configuration of inline inference aggregator. """ loader: InferenceDataLoaderConfig n_forward_steps: int = 2 forward_steps_in_memory: int = 2 epochs: Slice = Slice(start=0, stop=None, step=1) aggregator: InferenceEvaluatorAggregatorConfig = dataclasses.field( default_factory=lambda: InferenceEvaluatorAggregatorConfig() ) def __post_init__(self): dist = Distributed.get_instance() if self.loader.start_indices.n_initial_conditions % dist.world_size != 0: raise ValueError( "Number of inference initial conditions must be divisible by the " "number of parallel workers, got " f"{self.loader.start_indices.n_initial_conditions} and " f"{dist.world_size}." )
[docs]@dataclasses.dataclass class TrainConfig: """ Configuration for training a model. Attributes: train_loader: Configuration for the training data loader. validation_loader: Configuration for the validation data loader. stepper: Configuration for the stepper. optimization: Configuration for the optimization. logging: Configuration for logging. max_epochs: Total number of epochs to train for. save_checkpoint: Whether to save checkpoints. experiment_dir: Directory where checkpoints and logs are saved. inference: Configuration for inline inference. n_forward_steps: Number of forward steps to take gradient over. copy_weights_after_batch: Configuration for copying weights from the base model to the training model after each batch. ema: Configuration for exponential moving average of model weights. validate_using_ema: Whether to validate and perform inference using the EMA model. checkpoint_save_epochs: How often to save epoch-based checkpoints, if save_checkpoint is True. If None, checkpoints are only saved for the most recent epoch (and the best epochs if validate_using_ema == False). ema_checkpoint_save_epochs: How often to save epoch-based EMA checkpoints, if save_checkpoint is True. If None, EMA checkpoints are only saved for the most recent epoch (and the best epochs if validate_using_ema == True). log_train_every_n_batches: How often to log batch_loss during training. segment_epochs: Exit after training for at most this many epochs in current job, without exceeding `max_epochs`. Use this if training must be run in segments, e.g. due to wall clock limit. """ train_loader: DataLoaderConfig validation_loader: DataLoaderConfig stepper: Union[SingleModuleStepperConfig, ExistingStepperConfig] optimization: OptimizationConfig logging: LoggingConfig max_epochs: int save_checkpoint: bool experiment_dir: str inference: InlineInferenceConfig n_forward_steps: int copy_weights_after_batch: CopyWeightsConfig = dataclasses.field( default_factory=lambda: CopyWeightsConfig(exclude=["*"]) ) ema: EMAConfig = dataclasses.field(default_factory=lambda: EMAConfig()) validate_using_ema: bool = False checkpoint_save_epochs: Optional[Slice] = None ema_checkpoint_save_epochs: Optional[Slice] = None log_train_every_n_batches: int = 100 segment_epochs: Optional[int] = None @property def checkpoint_dir(self) -> str: """ The directory where checkpoints are saved. """ return os.path.join(self.experiment_dir, "training_checkpoints") @property def latest_checkpoint_path(self) -> str: return os.path.join(self.checkpoint_dir, "ckpt.tar") @property def best_checkpoint_path(self) -> str: return os.path.join(self.checkpoint_dir, "best_ckpt.tar") @property def best_inference_checkpoint_path(self) -> str: return os.path.join(self.checkpoint_dir, "best_inference_ckpt.tar") @property def ema_checkpoint_path(self) -> str: return os.path.join(self.checkpoint_dir, "ema_ckpt.tar") def epoch_checkpoint_path(self, epoch: int) -> str: return os.path.join(self.checkpoint_dir, f"ckpt_{epoch:04d}.tar") def ema_epoch_checkpoint_path(self, epoch: int) -> str: return os.path.join(self.checkpoint_dir, f"ema_ckpt_{epoch:04d}.tar") def epoch_checkpoint_enabled(self, epoch: int) -> bool: return epoch_checkpoint_enabled( epoch, self.max_epochs, self.checkpoint_save_epochs ) def ema_epoch_checkpoint_enabled(self, epoch: int) -> bool: return epoch_checkpoint_enabled( epoch, self.max_epochs, self.ema_checkpoint_save_epochs ) @property def resuming(self) -> bool: checkpoint_file_exists = os.path.isfile(self.latest_checkpoint_path) resuming = True if checkpoint_file_exists else False return resuming def configure_logging(self, log_filename: str): self.logging.configure_logging(self.experiment_dir, log_filename) def configure_wandb(self, env_vars: Optional[Dict[str, Any]] = None, **kwargs): config = to_flat_dict(dataclasses.asdict(self)) self.logging.configure_wandb(config=config, env_vars=env_vars, **kwargs) def log(self): logging.info("------------------ Configuration ------------------") logging.info(str(self)) logging.info("---------------------------------------------------") def clean_wandb(self): self.logging.clean_wandb(experiment_dir=self.experiment_dir)
def epoch_checkpoint_enabled( epoch: int, max_epochs: int, save_epochs: Optional[Slice] ) -> bool: if save_epochs is None: return False return epoch in range(max_epochs)[save_epochs.slice]