importcollectionsimportdataclassesfromcollections.abcimportCallablefromtypingimportLiteral,Protocol,runtime_checkableimporttorchfromfme.core.name_and_prefix_matcherimportNameAndPrefixMatcherfromfme.core.typing_importTensorDict,TensorMappingdefreplace_on_mask(original:torch.Tensor,replacement:torch.Tensor,mask:torch.Tensor,mask_value:int,):"""Replace original with replacement in masked regions. Args: original: The original data tensor. replacement: The replacement data tensor. mask: The mask tensor. mask_value: The value of the mask variable in the region to be replaced. """rounded_mask=torch.round(mask).to(int)returntorch.where(condition=rounded_mask==mask_value,input=replacement,other=original,)@runtime_checkableclassHasGetSpatialMask(Protocol):defbuild_output_spatial_masker(self)->Callable[[TensorMapping],TensorDict]:...defget_mask_tensor_for(self,name:str)->torch.Tensor|None:"""Get the mask for a specific variable name."""...defto(self,device:str)->"HasGetSpatialMask":...
[docs]@dataclasses.dataclassclassStaticSpatialMaskingConfig:""" Replace static spatially masked regions with a fill value. Parameters: mask_value: Value of the mask variable in masked regions. Either 0 or 1. fill_value: A float fill value to use outside of masked regions. Can also be "mean", in which case the normalizer means are used as channel-specific fill values. exclude_names_and_prefixes: Names (2D variables) and prefixes (3D variables) to exclude when applying the mask. """mask_value:intfill_value:Literal["mean"]|float=0.0exclude_names_and_prefixes:list[str]|None=Nonedef__post_init__(self):ifself.mask_valuenotin[0,1]:raiseValueError(f"mask_value must be either 0 or 1, but got {self.mask_value}")
[docs]defbuild(self,mask:HasGetSpatialMask,means:TensorMapping|None=None):""" Build StaticSpatialMasking. """exclude=NameAndPrefixMatcher(self.exclude_names_and_prefixes)ifisinstance(self.fill_value,float):returnStaticSpatialMasking(mask_value=self.mask_value,fill_value=collections.defaultdict(lambda:torch.as_tensor(self.fill_value)),mask=mask,exclude=exclude,)ifmeansisNone:raiseValueError("fill_values mapping required by build unless configured ""fill_value is a float.")returnStaticSpatialMasking(mask_value=self.mask_value,fill_value=means,mask=mask,exclude=exclude,)
classStaticSpatialMasking:def__init__(self,mask_value:int,fill_value:float|TensorMapping,mask:HasGetSpatialMask,exclude:NameAndPrefixMatcher=NameAndPrefixMatcher(),):ifisinstance(fill_value,float):fill_mapping:TensorMapping=collections.defaultdict(lambda:torch.as_tensor(fill_value))else:fill_mapping=fill_valueself._fill_mapping=fill_mappingself._mask_value=mask_valueself._mask=maskself._exclude=excludedef_masks(self,name:str)->bool:returnnotself._exclude.match(name)def__call__(self,data:TensorMapping)->TensorDict:""" Apply masking to the data for standard names recognized by a stacker. Args: data: The data to mask. """data_:TensorDict={**data}forname,tensorindata_.items():ifnotself._masks(name):continuemask=self._mask.get_mask_tensor_for(name)ifmaskisNone:continuetry:fill_value=self._fill_mapping[name]exceptKeyErroraserr:raiseKeyError("StaticSpatialMasking was initialized with a fill_value mapping "f"but the mapping is missing key '{name}'.")fromerrfill=torch.full_like(tensor,fill_value)mask=mask.expand(fill.shape)masked=replace_on_mask(original=tensor,replacement=fill,mask=mask,mask_value=self._mask_value,)data_[name]=maskedreturndata_classNullSpatialMasking:def__call__(self,data:TensorMapping)->TensorDict:returndict(data)