import contextlib
import dataclasses
import itertools
import warnings
from collections.abc import Callable, Iterable, Mapping
from typing import Any, Literal, TypeAlias
import numpy as np
import torch
from torch import nn
from fme.core.device import get_device
from fme.core.generics.optimization import OptimizationABC
from fme.core.scheduler import SchedulerConfig, SequentialSchedulerConfig
from fme.core.typing_ import TensorDict, TensorMapping
class Checkpoint:
def __init__(self, kwargs: Mapping[str, Any]):
self._kwargs = kwargs
def __call__(self, module: nn.Module):
def wrapped(*args):
return torch.utils.checkpoint.checkpoint(
module,
*args,
use_reentrant=False,
**self._kwargs,
)
return wrapped
class NoCheckpoint:
def __call__(self, module: nn.Module):
return module
[docs]@dataclasses.dataclass
class CheckpointConfig:
"""
Configuration for activation checkpointing.
Trades increased computation in exchange for lowered memory consumption during
training by recomputing activations in the backward pass.
Parameters:
after_n_forward_steps: Number of forward steps to generate before activation
checkpointing is applied. Activation checkpointing is not used unless this
number is less than the number of forward steps in the optimization.
kwargs: Keyword arguments to pass to torch.utils.checkpoint.checkpoint.
Note that use_reentrant=False is always explicitly passed
as is recommended by the docs.
"""
after_n_forward_steps: float = np.inf
kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict)
[docs] def build(self, step: int) -> Checkpoint | NoCheckpoint:
"""
Builds a checkpoint function.
Args:
step: The current zero-indexed step number.
Returns:
A checkpoint function.
"""
if step >= self.after_n_forward_steps:
return Checkpoint(self.kwargs)
else:
return NoCheckpoint()
class Optimization(OptimizationABC):
def __init__(
self,
parameters: Iterable[torch.nn.Parameter],
optimizer_type: Literal[
"Adam",
"FusedAdam",
"AdamW",
],
lr: float,
max_epochs: int,
scheduler: SchedulerConfig | SequentialSchedulerConfig,
enable_automatic_mixed_precision: bool,
kwargs: Mapping[str, Any],
use_gradient_accumulation: bool = False,
get_checkpoint: Callable[
[int], Checkpoint | NoCheckpoint
] = lambda _: NoCheckpoint(),
):
if optimizer_type == "FusedAdam":
self.optimizer = torch.optim.AdamW(parameters, lr=lr, fused=True, **kwargs)
elif optimizer_type == "Adam":
self.optimizer = torch.optim.Adam(parameters, lr=lr, **kwargs)
elif optimizer_type == "AdamW":
self.optimizer = torch.optim.AdamW(parameters, lr=lr, **kwargs)
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
if enable_automatic_mixed_precision:
self.gscaler: torch.amp.GradScaler | None = torch.amp.GradScaler("cuda")
else:
self.gscaler = None
self.scheduler = scheduler.build(self.optimizer, max_epochs)
self._accumulated_loss = torch.tensor(0.0, device=get_device())
self._use_gradient_accumulation = use_gradient_accumulation
self._get_checkpoint = get_checkpoint
def checkpoint(self, module: nn.Module, step: int) -> nn.Module:
return self._get_checkpoint(step)(module)
@contextlib.contextmanager
def autocast(self):
enabled = self.gscaler is not None
dtype = torch.bfloat16 if enabled else None
with torch.amp.autocast("cuda", enabled=enabled, dtype=dtype):
yield
@property
def learning_rate(self) -> float:
return self.optimizer.param_groups[0]["lr"]
def set_mode(self, modules: nn.ModuleList):
"""
Sets the mode of the module to train.
"""
for m in modules:
m.train()
def step_scheduler(
self,
valid_loss: float | None = None,
is_iteration: bool = False,
):
"""
Step the scheduler.
Args:
valid_loss: The validation loss. Used in schedulers which change the
learning rate based on whether the validation loss is decreasing.
If None, this indicates the call is from within a training iteration
rather than at the end of an epoch.
is_iteration: Whether the step is called from a training iteration or at
the end of an epoch. Default is epoch.
"""
if self.scheduler.should_step(is_iteration):
try:
if valid_loss is not None:
self.scheduler.step(metrics=valid_loss)
else:
self.scheduler.step()
except TypeError:
# Some schedulers don't accept metrics argument
self.scheduler.step()
def detach_if_using_gradient_accumulation(self, state: TensorMapping) -> TensorDict:
if self._use_gradient_accumulation:
return {k: v.detach() for k, v in state.items()}
return dict(state)
def accumulate_loss(self, loss: torch.Tensor):
self._validate_loss(loss)
self._accumulated_loss += loss
if self._use_gradient_accumulation:
self._backward(loss)
def get_accumulated_loss(self) -> torch.Tensor:
return self._accumulated_loss
def _backward(self, loss: torch.Tensor):
if self.gscaler is not None:
self.gscaler.scale(loss).backward()
else:
loss.backward()
def _step_weights(self):
if self.gscaler is not None:
self.gscaler.step(self.optimizer)
else:
self.optimizer.step()
def step_weights(self):
if not self._use_gradient_accumulation:
self._backward(self._accumulated_loss)
self._step_weights()
self.optimizer.zero_grad()
if self.gscaler is not None:
self.gscaler.update()
self._accumulated_loss = torch.tensor(0.0, device=get_device())
def set_learning_rate(self, lr: float):
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr
def load_optimizer_state_for_finetuning(self, state: dict):
"""Load per-parameter optimizer running state and grad scaler state from a
checkpoint for fine-tuning.
Restores per-parameter optimizer state (e.g. Adam moment estimates) and,
if available, the grad scaler state. The freshly-built optimizer's
per-group hyperparameters (``lr``, ``weight_decay``, ``betas``, ``eps``,
...) from the current finetune config are authoritative; any per-group
hyperparameters from the checkpoint (including optimizer-type-specific
flags like ``fused``/``amsgrad`` and scheduler-injected keys like
``initial_lr``) are discarded. Scheduler state is not restored, so the
configured schedule starts from scratch.
Args:
state: The optimization state dict as saved by ``get_state()``,
containing at least ``"optimizer_state_dict"``.
Raises:
ValueError: If the checkpoint's parameter groups are not
structurally compatible with the freshly-built optimizer
(e.g. different group count or per-group parameter count).
"""
fresh_hparams = [
{k: v for k, v in g.items() if k != "params"}
for g in self.optimizer.param_groups
]
try:
self.optimizer.load_state_dict(state["optimizer_state_dict"])
except ValueError as e:
raise ValueError(
"Failed to load optimizer state for fine-tuning: parameter "
"groups in the checkpoint are incompatible with the "
"freshly-built optimizer (e.g. group count or per-group "
"parameter count mismatch). This typically indicates the "
"model architecture or trainable-parameter set changed "
"between the source checkpoint and the current run. "
f"Underlying error: {e}"
) from e
for group, hparams in zip(self.optimizer.param_groups, fresh_hparams):
for k in list(group.keys()):
if k != "params":
del group[k]
group.update(hparams)
if self.gscaler is not None and state.get("gscaler_state_dict") is not None:
self.gscaler.load_state_dict(state["gscaler_state_dict"])
def get_state(self):
"""
Returns state as a serializable data structure.
"""
state = {
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict(),
"gscaler_state_dict": (
self.gscaler.state_dict() if self.gscaler is not None else None
),
}
return state
def load_state(self, state):
"""
Loads state from a serializable data structure.
"""
self.optimizer.load_state_dict(state["optimizer_state_dict"])
self.scheduler.load_state_dict(state["scheduler_state_dict"])
if self.gscaler is not None:
self.gscaler.load_state_dict(state["gscaler_state_dict"])
def _validate_loss(self, loss: torch.Tensor):
with torch.no_grad():
if torch.isnan(loss):
raise ValueError("Loss is NaN-valued during training.")
[docs]@dataclasses.dataclass
class OptimizationConfig:
"""
Configuration for optimization.
Parameters:
optimizer_type: The type of optimizer to use.
lr: The learning rate.
kwargs: Additional keyword arguments to pass to the optimizer.
enable_automatic_mixed_precision: Whether to use automatic mixed
precision.
scheduler: The type of scheduler to use. If none is given, no scheduler
will be used.
use_gradient_accumulation: Whether to use gradient accumulation. This must be
supported by the stepper being optimized, which may accumulate gradients
from separate losses to reduce memory consumption. The stepper may choose
to accumulate gradients differently when this is enabled, such as by
detaching the computational graph between steps. See the documentation of
your stepper (e.g. Stepper) for more details.
resume_optimizer_ckpt_path: Optional path to a training checkpoint
(``ckpt.tar``) whose per-parameter optimizer running state (e.g.
Adam moment estimates) and grad scaler state should be loaded into
the freshly-built ``Optimization`` for fine-tuning. The current
config's per-group hyperparameters (``lr``, ``weight_decay``,
``betas``, ...) and scheduler are kept; only the running state is
transferred. Intended for non-resuming jobs; preemption resume in
the Trainer overrides this state via ``Optimization.load_state``.
"""
optimizer_type: Literal["Adam", "AdamW", "FusedAdam"] = "Adam"
lr: float = 0.001
kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict)
enable_automatic_mixed_precision: bool = False
scheduler: SchedulerConfig | SequentialSchedulerConfig = dataclasses.field(
default_factory=lambda: SchedulerConfig()
)
use_gradient_accumulation: bool = False
checkpoint: CheckpointConfig = dataclasses.field(
default_factory=lambda: CheckpointConfig()
)
resume_optimizer_ckpt_path: str | None = None
def __post_init__(self):
if self.optimizer_type == "FusedAdam":
warnings.warn(
"FusedAdam is deprecated. Use AdamW with fused=True in kwargs instead.",
DeprecationWarning,
)
@property
def has_lr_schedule(self) -> bool:
"""Whether a learning rate scheduler is configured."""
if isinstance(self.scheduler, SequentialSchedulerConfig):
return True
return self.scheduler.type is not None
def build(self, modules: torch.nn.ModuleList, max_epochs: int) -> Optimization:
parameters = itertools.chain(*[module.parameters() for module in modules])
optimization = Optimization(
parameters=parameters,
optimizer_type=self.optimizer_type,
lr=self.lr,
max_epochs=max_epochs,
scheduler=self.scheduler,
enable_automatic_mixed_precision=self.enable_automatic_mixed_precision,
kwargs=self.kwargs,
use_gradient_accumulation=self.use_gradient_accumulation,
get_checkpoint=self.checkpoint.build,
)
if self.resume_optimizer_ckpt_path is not None:
_load_finetune_optimization_state(
optimization, self.resume_optimizer_ckpt_path
)
return optimization
def get_state(self) -> Mapping[str, Any]:
return dataclasses.asdict(self)
@classmethod
def from_state(cls, state: Mapping[str, Any]) -> "OptimizationConfig":
return cls(**state)
NestedTensor: TypeAlias = (
"torch.Tensor | dict[str, NestedTensor] | list[NestedTensor] | tuple[NestedTensor]"
)
def _tensors_to_device(obj: NestedTensor, device: torch.device):
"""Recursively move all tensors in a nested dict/list to *device*."""
if isinstance(obj, torch.Tensor):
return obj.to(device)
elif isinstance(obj, dict):
return {k: _tensors_to_device(v, device) for k, v in obj.items()}
elif isinstance(obj, list | tuple):
return type(obj)(_tensors_to_device(v, device) for v in obj)
return obj
def _load_finetune_optimization_state(optimization: Optimization, checkpoint_path: str):
"""Load optimizer (and optionally grad scaler) state for fine-tuning.
Only loads the optimizer state dict and grad scaler state from the
checkpoint. Scheduler state and training counters are not restored, so
the current config's schedule starts from scratch. All freshly-built
optimizer per-group hyperparameters (lr, weight_decay, betas, eps, ...)
are preserved from the current job's TrainConfig.
The checkpoint is loaded on CPU so that only the optimization state
(not model weights, EMA, etc.) is transferred to the training device.
"""
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
if "optimization" not in checkpoint:
raise ValueError(
f"Checkpoint at {checkpoint_path} does not contain optimization "
"state. Only checkpoints saved with include_optimization=True "
"(i.e. ckpt.tar) support fine-tune optimization loading."
)
optim_state = checkpoint["optimization"]
del checkpoint
optim_state = _tensors_to_device(optim_state, get_device())
optimization.load_optimizer_state_for_finetuning(optim_state)
class NullOptimization(OptimizationABC):
def __init__(self):
self._accumulated_loss = torch.tensor(0.0, device=get_device())
@contextlib.contextmanager
def autocast(self):
yield
@property
def learning_rate(self) -> float:
return float("nan")
def set_learning_rate(self, lr: float):
pass
def checkpoint(self, module: nn.Module, step: int) -> nn.Module:
return module
def step_scheduler(
self, valid_loss: float | None = None, is_iteration: bool = False
):
return
def detach_if_using_gradient_accumulation(self, state: TensorMapping) -> TensorDict:
return dict(state)
def accumulate_loss(self, loss: torch.Tensor):
self._accumulated_loss += loss
def get_accumulated_loss(self) -> torch.Tensor:
return self._accumulated_loss
def step_weights(self):
self._accumulated_loss = torch.tensor(0.0, device=get_device())
return
def get_state(self):
return {}
def load_state(self, state):
return
def set_mode(self, modules: nn.ModuleList):
"""
Sets the mode of the module to eval.
"""
for m in modules:
m.eval()