Source code for fme.core.ema

"""
Exponential Moving Average (EMA) module

Copied from https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/ema.py
and modified.

MIT License

Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

import dataclasses
from typing import Iterable, Iterator, List, Protocol, Tuple

import torch
from torch import nn


class HasNamedParameters(Protocol):
    def named_parameters(
        self, recurse: bool = True
    ) -> Iterator[Tuple[str, nn.Parameter]]:
        ...


[docs]@dataclasses.dataclass class EMAConfig: """ Configuration for exponential moving average of model weights. Attributes: decay: decay rate for the moving average """ decay: float = 0.9999 def build(self, model: HasNamedParameters): return EMATracker(model, decay=self.decay, faster_decay_at_start=True)
class EMATracker: """ Exponential Moving Average (EMA) tracker. This tracks the moving average of the parameters of a model, and has methods that can be used to temporarily replace the parameters of the model with its EMA. """ def __init__( self, model: HasNamedParameters, decay: float, faster_decay_at_start=True ): """ Create a new EMA tracker. Args: model: The model whose parameters should be tracked. decay: The decay rate of the moving average. faster_decay_at_start: Whether to use the number of updates to determine the decay rate. If True, the decay rate will be min(decay, (1 + num_updates) / (10 + num_updates)). If False, the decay rate will be decay. """ super().__init__() if decay < 0.0 or decay > 1.0: raise ValueError("Decay must be between 0 and 1") self._module_name_to_ema_name = {} self.decay = torch.tensor(decay, dtype=torch.float32) self._faster_decay_at_start = faster_decay_at_start self.num_updates = torch.tensor(0, dtype=torch.int) self._ema_params = {} for name, p in model.named_parameters(): if p.requires_grad: # remove as '.'-character is not allowed in buffers ema_name = name.replace(".", "") self._module_name_to_ema_name.update({name: ema_name}) self._ema_params[ema_name] = p.clone().detach().data self._stored_params: List[nn.Parameter] = [] def __call__(self, model: HasNamedParameters): """ Update the moving average of the parameters. Does not mutate the input, only updates the moving average. Args: model: The model whose parameters should be updated. Should be a model specified identically to the one passed when this object was instantiated. """ decay = self.decay self.num_updates += 1 if self._faster_decay_at_start: decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) with torch.no_grad(): module_parameters = dict(model.named_parameters()) for key in module_parameters: if module_parameters[key].requires_grad: ema_name = self._module_name_to_ema_name[key] self._ema_params[ema_name] = self._ema_params[ema_name].type_as( module_parameters[key] ) self._ema_params[ema_name].sub_( (1.0 - decay) * (self._ema_params[ema_name] - module_parameters[key]) ) elif key in self._module_name_to_ema_name: raise ValueError( f"Expected model parameter {key} to require gradient, " "but it does not" ) def copy_to(self, model: HasNamedParameters): """ Copy the averaged parameters to the model, overwriting its values. """ m_param = dict(model.named_parameters()) for key in m_param: if m_param[key].requires_grad: m_param[key].data.copy_( self._ema_params[self._module_name_to_ema_name[key]].data ) else: assert key not in self._module_name_to_ema_name def store(self, parameters: Iterable[nn.Parameter]): """ Save the current parameters for restoring later. Args: parameters: The parameters to be stored for later restoration by `restore` """ self._stored_params = [param.clone() for param in parameters] def restore(self, parameters: Iterable[nn.Parameter]): """ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without affecting the original optimization process. Store the parameters before the `copy_to` method. After validation (or model saving), use this to restore the former parameters. Args: parameters: The parameters to be updated with the values stored by `store` """ for c_param, param in zip(self._stored_params, parameters): param.data.copy_(c_param.data) def get_state(self): """ Get the state of the EMA tracker, excluding weights. Returns: The state of the EMA tracker. """ return { "decay": self.decay, "num_updates": self.num_updates, "faster_decay_at_start": self._faster_decay_at_start, "module_name_to_ema_name": self._module_name_to_ema_name, } @classmethod def from_state(cls, state, model) -> "EMATracker": """ Create an EMA tracker from a state. Args: state: The state of the EMA tracker. model: The model whose parameters should be tracked, used to initialize the EMA weights. Should come from an EMA checkpoint. Returns: The EMA tracker. """ ema = cls(model, state["decay"], state["faster_decay_at_start"]) ema.num_updates = state["num_updates"] ema._module_name_to_ema_name = state["module_name_to_ema_name"] return ema