[docs]@dataclasses.dataclassclassAugmentationConfig:""" Configuration for data augmentation. Attributes: rotate_probability: The probability of rotating the sphere by 180 degrees, as a value between 0.0 and 1.0. additional_directional_names: Names of variables whose sign is flipped when the poles are reversed. By default this includes known directional names as stored in RotateModifier.FLIP_NAMES. """rotate_probability:float=0.0additional_directional_names:list[str]=dataclasses.field(default_factory=list)def__post_init__(self):ifnot0.0<=self.rotate_probability<=1.0:raiseValueError("rotate_probability must be between 0.0 and 1.0, "f"got {self.rotate_probability}")defbuild_modifier(self)->"BatchModifierABC":ifself.rotate_probability>0.0:returnRotateModifier(self.rotate_probability,self.additional_directional_names)returnNullModifier()
classBatchModifierABC(abc.ABC):@abc.abstractmethoddef__call__(self,batch:BatchData)->BatchData:...classRotateModifier(BatchModifierABC):""" Modifier that rotates the sphere by 180 degrees so that the poles swap places. This is the same as flipping both zonal and meridional axes. Also flips the sign of horizontal directional variables such as horizontal winds in specific directions, so their new values reflect the rotated axes. The names of such variables are stored in the `FLIP_NAMES` class variable. Variables not included in this list are not flipped. Specifically, the regex pattern r'{name}(_?[0-9]+m?)?$' is used to match the names of variables whose sign is flipped when the poles are reversed, for each name in `FLIP_NAMES`. This will match both names that end with something like "_0", "_1", etc. or something like "10m" or "2m". Note that seasons are handled by the fact that solar insolation is a data variable, but time is not modified. This means monthly or seasonal averages using this data will be affected by the rotation. """# names of variables whose sign is flipped when the poles are reversedFLIP_NAMES=["eastward_wind","northward_wind","UGRD","VGRD","U","V",]def__init__(self,rotate_probability:float,additional_directional_names:list[str],):self.rotate_probability=rotate_probabilityself.additional_directional_names=additional_directional_namesself._pattern=re.compile(r"({})(_?[0-9]+m?)?$".format("|".join(self.FLIP_NAMES+self.additional_directional_names)))def__call__(self,batch:BatchData)->BatchData:ifbatch.horizontal_dims!=["lat","lon"]:raiseNotImplementedError("Horizontal dimensions must be lat and lon to rotate the sphere, got "f"{batch.horizontal_dims}")example_value=next(iter(batch.data.values()))apply=(torch.rand(example_value.shape[0]).to(example_value.device)<self.rotate_probability)whilelen(apply.shape)<len(example_value.shape):apply=apply.unsqueeze(-1)new_data={}forname,valueinbatch.data.items():new_value=torch.flip(value,dims=[-2,-1])ifself._pattern.match(name):new_value=-1*new_valuenew_data[name]=torch.where(apply,new_value,value)returnBatchData(data=new_data,time=batch.time,horizontal_dims=batch.horizontal_dims,labels=batch.labels,)classNullModifier(BatchModifierABC):def__call__(self,batch:BatchData)->BatchData:returnbatch