importcollectionsimportdataclassesimportrefromcollections.abcimportCallablefromtypingimportLiteral,Protocol,runtime_checkableimporttorchfromfme.core.typing_importTensorDict,TensorMappingdefreplace_on_mask(original:torch.Tensor,replacement:torch.Tensor,mask:torch.Tensor,mask_value:int,):"""Replace original with replacement in masked regions. Args: original: The original data tensor. replacement: The replacement data tensor. mask: The mask tensor. mask_value: The value of the mask variable in the region to be replaced. """rounded_mask=torch.round(mask).to(int)returntorch.where(condition=rounded_mask==mask_value,input=replacement,other=original,)@runtime_checkableclassHasGetMaskTensorFor(Protocol):defbuild_output_masker(self)->Callable[[TensorMapping],TensorDict]:...defget_mask_tensor_for(self,name:str)->torch.Tensor|None:"""Get the mask for a specific variable name."""...defto(self,device:str)->"HasGetMaskTensorFor":...
[docs]@dataclasses.dataclassclassStaticMaskingConfig:""" Replace static masked regions with a fill value. Parameters: mask_value: Value of the mask variable in masked regions. Either 0 or 1. fill_value: A float fill value to use outside of masked regions. Can also be "mean", in which case the normalizer means are used as channel-specific fill values. exclude_names_and_prefixes: Names (2D variables) and prefixes (3D variables) to exclude when applying the mask. """mask_value:intfill_value:Literal["mean"]|float=0.0exclude_names_and_prefixes:list[str]|None=Nonedef__post_init__(self):ifself.mask_valuenotin[0,1]:raiseValueError(f"mask_value must be either 0 or 1, but got {self.mask_value}")
[docs]defbuild(self,mask:HasGetMaskTensorFor,means:TensorMapping|None=None):""" Build StaticMasking. """ifisinstance(self.fill_value,float):returnStaticMasking(mask_value=self.mask_value,fill_value=collections.defaultdict(lambda:torch.as_tensor(self.fill_value)),mask=mask,exclude_names_and_prefixes=self.exclude_names_and_prefixes,)ifmeansisNone:raiseValueError("fill_values mapping required by build unless configured ""fill_value is a float.")returnStaticMasking(mask_value=self.mask_value,fill_value=means,mask=mask,exclude_names_and_prefixes=self.exclude_names_and_prefixes,)
classStaticMasking:def__init__(self,mask_value:int,fill_value:float|TensorMapping,mask:HasGetMaskTensorFor,exclude_names_and_prefixes:list[str]|None=None,):ifisinstance(fill_value,float):fill_mapping:TensorMapping=collections.defaultdict(lambda:torch.as_tensor(fill_value))else:fill_mapping=fill_valueself._fill_mapping=fill_mappingself._mask_value=mask_valueself._mask=maskself._exclude_regex=self._build_regex(exclude_names_and_prefixes)def_build_regex(self,names_and_prefixes:list[str]|None)->str|None:ifnames_and_prefixes:regex=[]fornameinnames_and_prefixes:ifname.endswith("_"):regex.append(rf"^{name}\d+$")elifnotre.match(r".+_\d+$",name):regex.append(f"^{name}$")regex.append(rf"^{name}_\d+$")else:regex.append(rf"^{name}$")returnr"|".join(regex)returnNonedef_masks(self,name:str):exclude=self._exclude_regexandre.match(self._exclude_regex,name)returnnotexcludedef__call__(self,data:TensorMapping)->TensorDict:""" Apply masking to the data for standard names recognized by a stacker. Args: data: The data to mask. """data_:TensorDict={**data}forname,tensorindata_.items():ifnotself._masks(name):continuemask=self._mask.get_mask_tensor_for(name)ifmaskisNone:continuetry:fill_value=self._fill_mapping[name]exceptKeyErroraserr:raiseKeyError("StaticMasking was initialized with a fill_value mapping "f"but the mapping is missing key '{name}'.")fromerrfill=torch.full_like(tensor,fill_value)mask=mask.expand(fill.shape)masked=replace_on_mask(original=tensor,replacement=fill,mask=mask,mask_value=self._mask_value,)data_[name]=maskedreturndata_classNullMasking:def__call__(self,data:TensorMapping)->TensorDict:returndict(data)