[docs]@dataclasses.dataclassclassFrozenParameterConfig:""" Configuration for freezing parameters in a model. Parameter names can include wildcards, e.g. "encoder.*" will select all parameters in the encoder, while "encoder.*.bias" will select all bias parameters in the encoder. All parameters must be specified in either the include or exclude list, or an exception will be raised. An exception is raised if a parameter is included by both lists. Parameters: include: list of parameter names to freeze (set requires_grad = False) exclude: list of parameter names to ignore """include:List[str]=dataclasses.field(default_factory=list)exclude:List[str]=dataclasses.field(default_factory=list)def__post_init__(self):forpatterninself.include:ifany(wildcard_match(pattern,exclude)forexcludeinself.exclude):raiseValueError(f"Parameter {pattern} is included in both include "f"{self.include} and exclude {self.exclude}")forpatterninself.exclude:ifany(wildcard_match(pattern,include)forincludeinself.include):raiseValueError(f"Parameter {pattern} is included in both include "f"{self.include} and exclude {self.exclude}")defapply(self,model:nn.Module):apply_by_wildcard(model,_freeze_weight,self.include,self.exclude)
[docs]@dataclasses.dataclassclassParameterInitializationConfig:""" A class which applies custom initialization to module parameters. Assumes the module weights have already been randomly initialized. Supports overwriting the weights of the built model with weights from a pre-trained model. If the built model has larger weights than the pre-trained model, only the initial slice of the weights is overwritten. Parameters: weight_path: path to a SingleModuleStepper checkpoint containing weights to load exclude_parameters: list of parameter names to exclude from the loaded weights. Used for example to keep the random initialization for final layer(s) of a model, and only overwrite the weights for earlier layers. Takes values like "decoder.2.weight". frozen_parameters: configuration for freezing parameters in the built model alpha: L2 regularization coefficient keeping initialized weights close to their intiial values beta: L2 regularization coefficient keeping uninitialized weights close to zero """weights_path:Optional[str]=Noneexclude_parameters:List[str]=dataclasses.field(default_factory=list)frozen_parameters:FrozenParameterConfig=dataclasses.field(default_factory=lambda:FrozenParameterConfig(exclude=["*"]))alpha:float=0.0beta:float=0.0
[docs]defapply(self,module:nn.Module,init_weights:bool)->Tuple[nn.Module,RegularizerFunction]:""" Apply the weight initialization to a module. Args: module: a nn.Module to initialize init_weights: whether to initialize the weight values Returns: a nn.Module with initialization applied a function which returns the regularization loss term """ifinit_weightsandself.weights_pathisnotNone:loaded_state_dict=self.get_base_weights()ifloaded_state_dictisnotNone:overwrite_weights(loaded_state_dict,module,exclude_parameters=self.exclude_parameters,)else:loaded_state_dict=Noneself.frozen_parameters.apply(module)device=get_device()ifloaded_state_dictisNoneor(self.alpha==0andself.beta==0):defregularizer():returntorch.tensor(0.0,device=device)returnmodule,regularizerelse:loaded_state_dict={name:value.to(device)forname,valueinloaded_state_dict.items()}from_names=set(loaded_state_dict.keys())to_names=set(module.state_dict().keys())ifnotfrom_names.issubset(to_names):missing_parameters=from_names-to_namesraiseValueError(f"Dest module is missing parameters {missing_parameters}, ""which is not allowed")non_optional_state_dict=loaded_state_dictdefregularizer():loss=torch.tensor(0.0,device=device)fornameinfrom_names:try:param=module.get_parameter(name)exceptAttributeError:# non-trainable state datacontinueifany(wildcard_match(pattern,name)forpatterninself.exclude_parameters):loss+=(self.beta/2*torch.linalg.norm(param.flatten(),ord=2))else:loss+=(self.alpha/2*torch.linalg.norm((param-non_optional_state_dict[name]).flatten(),ord=2,))returnlossreturnmodule,regularizer
[docs]defget_base_weights(self)->Optional[Mapping[str,Any]]:""" If a weights_path is provided, return the model base weights used for initialization. """ifself.weights_pathisnotNone:checkpoint=torch.load(self.weights_path,map_location=get_device())returnstrip_leading_module(checkpoint["stepper"]["module"])else:returnNone