Source code for fme.core.parameter_init

import dataclasses
from typing import Any, Callable, List, Mapping, Optional, Tuple

import torch
from torch import nn

from fme.core.device import get_device
from fme.core.wildcard import apply_by_wildcard, wildcard_match

from .weight_ops import overwrite_weights, strip_leading_module


[docs]@dataclasses.dataclass class FrozenParameterConfig: """ Configuration for freezing parameters in a model. Parameter names can include wildcards, e.g. "encoder.*" will select all parameters in the encoder, while "encoder.*.bias" will select all bias parameters in the encoder. All parameters must be specified in either the include or exclude list, or an exception will be raised. An exception is raised if a parameter is included by both lists. Attributes: include: list of parameter names to freeze (set requires_grad = False) exclude: list of parameter names to ignore """ include: List[str] = dataclasses.field(default_factory=list) exclude: List[str] = dataclasses.field(default_factory=list) def __post_init__(self): for pattern in self.include: if any(wildcard_match(pattern, exclude) for exclude in self.exclude): raise ValueError( f"Parameter {pattern} is included in both include " f"{self.include} and exclude {self.exclude}" ) for pattern in self.exclude: if any(wildcard_match(pattern, include) for include in self.include): raise ValueError( f"Parameter {pattern} is included in both include " f"{self.include} and exclude {self.exclude}" ) def apply(self, model: nn.Module): apply_by_wildcard(model, _freeze_weight, self.include, self.exclude)
def _freeze_weight(module: nn.Module, name: str): try: module.get_parameter(name).requires_grad = False except AttributeError: # non-parameter state pass RegularizerFunction = Callable[[], torch.Tensor]
[docs]@dataclasses.dataclass class ParameterInitializationConfig: """ A class which applies custom initialization to module parameters. Assumes the module weights have already been randomly initialized. Supports overwriting the weights of the built model with weights from a pre-trained model. If the built model has larger weights than the pre-trained model, only the initial slice of the weights is overwritten. Attributes: weight_path: path to a SingleModuleStepper checkpoint containing weights to load exclude_parameters: list of parameter names to exclude from the loaded weights. Used for example to keep the random initialization for final layer(s) of a model, and only overwrite the weights for earlier layers. Takes values like "decoder.2.weight". frozen_parameters: configuration for freezing parameters in the built model alpha: L2 regularization coefficient keeping initialized weights close to their intiial values beta: L2 regularization coefficient keeping uninitialized weights close to zero """ weights_path: Optional[str] = None exclude_parameters: List[str] = dataclasses.field(default_factory=list) frozen_parameters: FrozenParameterConfig = dataclasses.field( default_factory=lambda: FrozenParameterConfig(exclude=["*"]) ) alpha: float = 0.0 beta: float = 0.0
[docs] def apply( self, module: nn.Module, init_weights: bool ) -> Tuple[nn.Module, RegularizerFunction]: """ Apply the weight initialization to a module. Args: module: a nn.Module to initialize init_weights: whether to initialize the weight values Returns: a nn.Module with initialization applied a function which returns the regularization loss term """ if init_weights and self.weights_path is not None: loaded_state_dict = self.get_base_weights() if loaded_state_dict is not None: overwrite_weights( loaded_state_dict, module, exclude_parameters=self.exclude_parameters, ) else: loaded_state_dict = None self.frozen_parameters.apply(module) device = get_device() if loaded_state_dict is None or (self.alpha == 0 and self.beta == 0): def regularizer(): return torch.tensor(0.0, device=device) else: loaded_state_dict = { name: value.to(device) for name, value in loaded_state_dict.items() } from_names = set(loaded_state_dict.keys()) to_names = set(module.state_dict().keys()) if not from_names.issubset(to_names): missing_parameters = from_names - to_names raise ValueError( f"Dest module is missing parameters {missing_parameters}, " "which is not allowed" ) def regularizer(): loss = torch.tensor(0.0, device=device) for name in from_names: try: param = module.get_parameter(name) except AttributeError: # non-trainable state data continue if any( wildcard_match(pattern, name) for pattern in self.exclude_parameters ): loss += ( self.beta / 2 * torch.linalg.norm(param.flatten(), ord=2) ) else: loss += ( self.alpha / 2 * torch.linalg.norm( (param - loaded_state_dict[name]).flatten(), ord=2, ) ) return loss return module, regularizer
[docs] def get_base_weights(self) -> Optional[Mapping[str, Any]]: """ If a weights_path is provided, return the model base weights used for initialization. """ if self.weights_path is not None: checkpoint = torch.load(self.weights_path, map_location=get_device()) return strip_leading_module(checkpoint["stepper"]["module"]) else: return None