importdataclassesfromtypingimportLiteralfromfme.ace.models.makani.sfnonetimport(SphericalFourierNeuralOperatorNetasMakaniSFNO,)fromfme.ace.models.modulus.sfnonetimportSphericalFourierNeuralOperatorNetfromfme.ace.registry.registryimportModuleConfig,ModuleSelectorfromfme.core.dataset_infoimportDatasetInfo# this is based on the call signature of SphericalFourierNeuralOperatorNet at# https://github.com/NVIDIA/modulus/blob/b8e27c5c4ebc409e53adaba9832138743ede2785/modulus/models/sfno/sfnonet.py#L292 # noqa: E501
[docs]@ModuleSelector.register("SphericalFourierNeuralOperatorNet")@dataclasses.dataclassclassSphericalFourierNeuralOperatorBuilder(ModuleConfig):""" Configuration for the SFNO architecture used in FourCastNet-SFNO. """spectral_transform:str="sht"filter_type:str="linear"operator_type:str="diagonal"scale_factor:int=1residual_filter_factor:int=1embed_dim:int=256num_layers:int=12hard_thresholding_fraction:float=1.0normalization_layer:str="instance_norm"use_mlp:bool=Trueactivation_function:str="gelu"encoder_layers:int=1pos_embed:bool=Truebig_skip:bool=Truerank:float=1.0factorization:str|None=Noneseparable:bool=Falsecomplex_network:bool=Truecomplex_activation:str="real"spectral_layers:int=1checkpointing:int=0data_grid:Literal["legendre-gauss","equiangular"]="legendre-gauss"
[docs]defbuild(self,n_in_channels:int,n_out_channels:int,dataset_info:DatasetInfo,):iflen(dataset_info.all_labels)>0:raiseValueError("SphericalFourierNeuralOperatorNet does not support labels")sfno_net=SphericalFourierNeuralOperatorNet(params=self,in_chans=n_in_channels,out_chans=n_out_channels,img_shape=dataset_info.img_shape,)returnsfno_net
[docs]@ModuleSelector.register("SFNO-v0.1.0")@dataclasses.dataclassclassSFNO_V0_1_0(ModuleConfig):""" Configuration for the SFNO architecture in modulus-makani version 0.1.0. """spectral_transform:str="sht"filter_type:Literal["linear"]="linear"operator_type:str="dhconv"scale_factor:int=16embed_dim:int=256num_layers:int=12repeat_layers:int=1hard_thresholding_fraction:float=1.0normalization_layer:str="instance_norm"use_mlp:bool=Trueactivation_function:str="gelu"encoder_layers:int=1pos_embed:Literal["none","direct","frequency"]="direct"big_skip:bool=Truerank:float=1.0factorization:str|None=Noneseparable:bool=Falsecomplex_activation:str="real"spectral_layers:int=1checkpointing:int=0data_grid:Literal["legendre-gauss","equiangular","healpix"]="legendre-gauss"
[docs]defbuild(self,n_in_channels:int,n_out_channels:int,dataset_info:DatasetInfo,):img_shape=dataset_info.img_shapeiflen(dataset_info.all_labels)>0:raiseValueError("SFNO-v0.1.0 does not support labels")returnMakaniSFNO(inp_chans=n_in_channels,out_chans=n_out_channels,inp_shape=img_shape,out_shape=img_shape,model_grid_type=self.data_grid,spectral_transform=self.spectral_transform,filter_type=self.filter_type,operator_type=self.operator_type,scale_factor=self.scale_factor,embed_dim=self.embed_dim,num_layers=self.num_layers,repeat_layers=self.repeat_layers,hard_thresholding_fraction=self.hard_thresholding_fraction,normalization_layer=self.normalization_layer,use_mlp=self.use_mlp,activation_function=self.activation_function,encoder_layers=self.encoder_layers,pos_embed=self.pos_embed,big_skip=self.big_skip,rank=self.rank,factorization=self.factorization,separable=self.separable,complex_activation=self.complex_activation,spectral_layers=self.spectral_layers,checkpointing=self.checkpointing,)