importabcimportdataclassesfromtypingimportAny,Callable,ClassVar,Mapping,Tuple,Type,TypeVarimportdaciteimportnumpyasnpimporttorchfromfme.core.registry.registryimportRegistry@dataclasses.dataclassclassPerturbationConfig(abc.ABC):""" Returns a perturbation function config class. """@classmethoddeffrom_state(cls,state:Mapping[str,Any])->"PerturbationConfig":""" Create a PerturbationSelector from a dictionary containing all the information needed to build a PerturbationConfig. """returndacite.from_dict(data_class=cls,data=state,config=dacite.Config(strict=True))@abc.abstractmethoddefapply_perturbation(self,data:torch.Tensor,lat:torch.Tensor,lon:torch.Tensor,ocean_fraction:torch.Tensor,)->None:...PT=TypeVar("PT",bound=Type[PerturbationConfig])
[docs]@dataclasses.dataclassclassPerturbationSelector:type:strconfig:Mapping[str,Any]registry:ClassVar[Registry]=Registry()def__post__init(self):ifself.registryisnotRegistry():raiseValueError("PerturbationSelector.registry should not be set manually")@classmethoddefregister(cls,type_name)->Callable[[PT],PT]:returncls.registry.register(type_name)defbuild(self)->PerturbationConfig:returnself.registry.from_dict(self.get_state())
[docs]defget_state(self)->Mapping[str,Any]:""" Get a dictionary containing all the information needed to build a PerturbationConfig. """return{"type":self.type,"config":self.config}
[docs]@classmethoddefget_available_types(cls):"""This class method is used to expose all available types of Perturbations."""returncls(type="",config={}).registry._types.keys()
[docs]@dataclasses.dataclassclassSSTPerturbation:""" Configuration for sea surface temperature perturbations applied to initial condition and forcing data. Currently, this is strictly applied to both. Parameters: sst: List of perturbation selectors for SST perturbations. """sst:list[PerturbationSelector]def__post_init__(self):self.perturbations:list[PerturbationConfig]=[perturbation.build()forperturbationinself.sst]
[docs]@PerturbationSelector.register("constant")@dataclasses.dataclassclassConstantConfig(PerturbationConfig):""" Configuration for a constant perturbation. """amplitude:float=1.0defapply_perturbation(self,data:torch.Tensor,lat:torch.Tensor,lon:torch.Tensor,ocean_fraction:torch.Tensor,):ocean_mask=_get_ocean_mask(ocean_fraction)data[ocean_mask]+=self.amplitude# type: ignore
[docs]@PerturbationSelector.register("greens_function")@dataclasses.dataclassclassGreensFunctionConfig(PerturbationConfig):""" Configuration for a single sinusoidal patch of a Green's function perturbation. See equation 1 in BlochâJohnson, J., et al. (2024). Parameters: amplitude: The amplitude of the perturbation, maximum is reached at (lat_center, lon_center). lat_center: The latitude at the center of the patch in degrees. lon_center: The longitude at the center of the patch in degrees. lat_width: latitudinal width of the patch in degrees. lon_width: longitudinal width of the patch in degrees. """amplitude:float=1.0lat_center:float=0.0lon_center:float=0.0lat_width:float=10.0lon_width:float=10.0def__post_init__(self):self._lat_center_rad=np.deg2rad(self.lat_center)self._lon_center_rad=np.deg2rad(self.lon_center)self._lat_width_rad=np.deg2rad(self.lat_width)self._lon_width_rad=np.deg2rad(self.lon_width)def_wrap_longitude_discontinuity(self,lon:torch.Tensor,)->Tuple[torch.Tensor,torch.Tensor]:""" Assume longitude is in the range [0, 360) degrees. If the patch crosses the discontinuity at 0/360 degrees, shift the longitude accordingly. """lon_min=self.lon_center-self.lon_width/2.0lon_max=self.lon_center+self.lon_width/2.0iflon_min<0:lon_shifted=((lon+180)%360)-180lon_in_patch=(lon_shifted>lon_min)&(lon_shifted<lon_max)eliflon_max>360:lon_in_patch=(lon>lon_min)|(lon<lon_max%360)lon_shifted=((lon+180)%360)+180else:lon_in_patch=(lon>lon_min)&(lon<lon_max)lon_shifted=lonreturnlon_in_patch,lon_shifteddefapply_perturbation(self,data:torch.Tensor,lat:torch.Tensor,lon:torch.Tensor,ocean_fraction:torch.Tensor,):lat_in_patch=torch.abs(lat-self.lat_center)<self.lat_width/2.0lon_in_patch,lon_shifted=self._wrap_longitude_discontinuity(lon)mask=lat_in_patch&lon_in_patchocean_mask=_get_ocean_mask(ocean_fraction)perturbation=self.amplitude*(torch.cos(torch.pi/2*(lat.deg2rad()-self._lat_center_rad)/(self._lat_width_rad/2.0))**2*torch.cos(torch.pi/2*(lon_shifted.deg2rad()-self._lon_center_rad)/(self._lon_width_rad/2.0))**2)mask=mask.expand(data.shape)perturbation=perturbation.expand(data.shape)data[mask&ocean_mask]+=perturbation[mask&ocean_mask]