Source code for fme.core.ema

"""
Exponential Moving Average (EMA) module.

Copied from https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/ema.py
and modified.

MIT License

Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

import contextlib
import dataclasses
import logging
from collections.abc import Iterable, Iterator
from typing import Protocol

import torch
from torch import nn

from fme.core.device import get_device


class HasNamedParameters(Protocol):
    def named_parameters(
        self, recurse: bool = True
    ) -> Iterator[tuple[str, nn.Parameter]]: ...

    def parameters(self) -> Iterator[nn.Parameter]: ...


[docs]@dataclasses.dataclass class EMAConfig: """ Configuration for exponential moving average of model weights. Parameters: decay: decay rate for the moving average resume_ema_ckpt_path: Optional path to a training checkpoint (``ckpt.tar``) whose EMA running state (averaged weights and update counter) should be loaded into the freshly-built ``EMATracker`` for fine-tuning. The current config's ``decay`` is kept; only the running state is transferred. Intended for non-resuming jobs; preemption resume in the Trainer overrides this state via ``EMATracker.from_state``. """ decay: float = 0.9999 resume_ema_ckpt_path: str | None = None def build(self, model: HasNamedParameters): ema = EMATracker(model, decay=self.decay, faster_decay_at_start=True) if self.resume_ema_ckpt_path is not None: _load_finetune_ema_state(ema, self.resume_ema_ckpt_path) return ema
class EMATracker: """ Exponential Moving Average (EMA) tracker. This tracks the moving average of the parameters of a model, and has methods that can be used to temporarily replace the parameters of the model with its EMA. """ def __init__( self, model: HasNamedParameters, decay: float, faster_decay_at_start=True ): """ Create a new EMA tracker. Args: model: The model whose parameters should be tracked. decay: The decay rate of the moving average. faster_decay_at_start: Whether to use the number of updates to determine the decay rate. If True, the decay rate will be min(decay, (1 + num_updates) / (10 + num_updates)). If False, the decay rate will be decay. """ super().__init__() if decay < 0.0 or decay > 1.0: raise ValueError("Decay must be between 0 and 1") self._module_name_to_ema_name = {} self.decay = torch.tensor(decay, dtype=torch.float32).to(get_device()) self._faster_decay_at_start = faster_decay_at_start self.num_updates = torch.tensor(0, dtype=torch.int).to(get_device()) self._ema_params: dict[str, torch.Tensor] = {} for name, p in model.named_parameters(): if p.requires_grad: # remove as '.'-character is not allowed in buffers ema_name = name.replace(".", "") self._module_name_to_ema_name.update({name: ema_name}) self._ema_params[ema_name] = p.clone().detach().data self._stored_params: list[nn.Parameter] = [] self._in_context = False def __call__(self, model: HasNamedParameters): """ Update the moving average of the parameters. Does not mutate the input, only updates the moving average. Args: model: The model whose parameters should be updated. Should be a model specified identically to the one passed when this object was instantiated. """ decay = self.decay self.num_updates += 1 if self._faster_decay_at_start: decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) with torch.no_grad(): module_parameters = dict(model.named_parameters()) for key in module_parameters: if module_parameters[key].requires_grad: ema_name = self._module_name_to_ema_name[key] self._ema_params[ema_name] = self._ema_params[ema_name].type_as( module_parameters[key] ) self._ema_params[ema_name].sub_( (1.0 - decay) * (self._ema_params[ema_name] - module_parameters[key]) ) elif key in self._module_name_to_ema_name: raise ValueError( f"Expected model parameter {key} to require gradient, " "but it does not" ) @contextlib.contextmanager def applied_params(self, model: HasNamedParameters) -> Iterator[None]: self.store(parameters=model.parameters()) if self._in_context: raise RuntimeError("Cannot nest EMA contexts") self._in_context = True self.copy_to(model) try: yield finally: self.restore(parameters=model.parameters()) self._in_context = False def copy_to(self, model: HasNamedParameters) -> None: """ Copy the averaged parameters to the model, overwriting its values. """ m_param = dict(model.named_parameters()) for key in m_param: if m_param[key].requires_grad: m_param[key].data.copy_( self._ema_params[self._module_name_to_ema_name[key]].data ) else: assert key not in self._module_name_to_ema_name def store(self, parameters: Iterable[nn.Parameter]): """ Save the current parameters for restoring later. Args: parameters: The parameters to be stored for later restoration by `restore` """ self._stored_params = [param.clone() for param in parameters] def restore(self, parameters: Iterable[nn.Parameter]): """ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the `copy_to` method. After validation (or model saving), use this to restore the former parameters. Args: parameters: The parameters to be updated with the values stored by `store` """ for c_param, param in zip(self._stored_params, parameters): param.data.copy_(c_param.data) def get_state(self): """ Get the state of the EMA tracker. Returns: The state of the EMA tracker. """ return { "decay": self.decay.clone(), "num_updates": self.num_updates.clone(), "faster_decay_at_start": self._faster_decay_at_start, "module_name_to_ema_name": dict(self._module_name_to_ema_name), "ema_params": { name: param.clone().detach() for name, param in self._ema_params.items() }, } def load_ema_state_for_finetuning(self, state: dict): """Load EMA running state from a checkpoint for fine-tuning. Restores the averaged parameter weights and update counter from a previously saved EMA state. The current tracker's ``decay`` and ``faster_decay_at_start`` (set at construction from the current config) are preserved; only the running state is transferred. Args: state: The EMA state dict as saved by ``get_state()``, containing at least ``"ema_params"``, ``"num_updates"``, and ``"module_name_to_ema_name"``. Raises: ValueError: If the state does not contain ``"ema_params"`` (e.g. from a checkpoint saved without ``include_optimization=True``). """ if "ema_params" not in state: raise ValueError( "EMA state does not contain ema_params. Only ckpt.tar " "checkpoints (saved with include_optimization=True) " "contain the full EMA state needed for fine-tuning." ) device = get_device() self.num_updates = state["num_updates"].to(device, copy=True) self._module_name_to_ema_name = state["module_name_to_ema_name"] self._ema_params = { name: param.to(device, copy=True) for name, param in state["ema_params"].items() } @classmethod def from_state(cls, state, model) -> "EMATracker": """ Create an EMA tracker from a state. Args: state: The state of the EMA tracker. model: The model whose parameters should be tracked, used to initialize the EMA weights. Returns: The EMA tracker. """ device = get_device() ema = cls(model, float(state["decay"]), state["faster_decay_at_start"]) ema.num_updates = state["num_updates"].to(device, copy=True) ema._module_name_to_ema_name = state["module_name_to_ema_name"] if "ema_params" in state: ema._ema_params = { name: param.to(device, copy=True) for name, param in state["ema_params"].items() } else: logging.warning("EMA params not found in state and will not be restored.") return ema def _load_finetune_ema_state(ema: EMATracker, checkpoint_path: str): """Load EMA running state from a training checkpoint for fine-tuning. Only loads the EMA averaged weights and update counter from the checkpoint. The current tracker's decay and faster_decay_at_start are preserved from the current config. The checkpoint is loaded on CPU so that only the EMA state (not model weights, optimizer, etc.) is transferred to the training device. """ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) if "ema" not in checkpoint: raise ValueError( f"Checkpoint at {checkpoint_path} does not contain EMA state. " "Only training checkpoints (ckpt.tar) contain EMA state." ) ema_state = checkpoint["ema"] del checkpoint ema.load_ema_state_for_finetuning(ema_state)