Modules¶
ACE’s code uses a module registry system to allow different machine learning architectures to plug into the framework.
This is managed by the fme.ace.ModuleSelector configuration class, which is used to select a module type and version.
- class fme.ace.ModuleSelector(type, config, conditional=False, allow_missing_variables=False)[source]
Bases:
objectA dataclass containing all the information needed to build a ModuleConfig, including the type of the ModuleConfig and the data needed to build it.
This is helpful as ModuleSelector can be serialized and deserialized without any additional information, whereas to load a ModuleConfig you would need to know the type of the ModuleConfig being loaded.
It is also convenient because ModuleSelector is a single class that can be used to represent any ModuleConfig, whereas ModuleConfig is a protocol that can be implemented by many different classes.
- Parameters:
type (
str) – the type of the ModuleConfigconfig (
Mapping[str,Any]) – data for a ModuleConfig instance of the indicated typeconditional (
bool, default:False) – whether to condition the predictions on batch labels.allow_missing_variables (
bool, default:False) – whether the data pipeline is allowed to produce variable masks (for incomplete datasets). When False (default), missing required variables cause an error.
-
allow_missing_variables:
bool= False
- build(n_in_channels, n_out_channels, dataset_info)[source]
Build a nn.Module given information about the input and output channels and the dataset.
- Parameters:
- Return type:
Module- Returns:
a Module object
-
conditional:
bool= False
- classmethod get_available_types()[source]
This class method is used to expose all available types of Modules.
- property module_config: ModuleConfig
- classmethod register(type_name)[source]
-
registry:
ClassVar[Registry[ModuleConfig]] = <fme.core.registry.registry.Registry object>
-
type:
str
The following module types are available:
dict_keys(['MLP', 'LandNet', 'NoiseConditionedSFNO', 'AnkurLocalNet', 'LocalNet', 'Samudra', 'FloeNet', 'prebuilt', 'SphericalFourierNeuralOperatorNet', 'SFNO-v0.1.0', 'HEALPixUNet'])
- fme.core.registry.ModuleSelector.get_available_types()¶
This class method is used to expose all available types of Modules.
The following module builders are available:
- class fme.ace.SphericalFourierNeuralOperatorBuilder(spectral_transform='sht', filter_type='linear', operator_type='diagonal', scale_factor=1, residual_filter_factor=1, embed_dim=256, num_layers=12, hard_thresholding_fraction=1.0, normalization_layer='instance_norm', use_mlp=True, activation_function='gelu', encoder_layers=1, pos_embed=True, big_skip=True, rank=1.0, factorization=None, separable=False, complex_network=True, complex_activation='real', spectral_layers=1, checkpointing=0, data_grid='legendre-gauss')[source]
Bases:
ModuleConfigConfiguration for the SFNO architecture used in FourCastNet-SFNO.
- Parameters:
spectral_transform (str) –
filter_type (str) –
operator_type (str) –
scale_factor (int) –
residual_filter_factor (int) –
embed_dim (int) –
num_layers (int) –
hard_thresholding_fraction (float) –
normalization_layer (str) –
use_mlp (bool) –
activation_function (str) –
encoder_layers (int) –
pos_embed (bool) –
big_skip (bool) –
rank (float) –
factorization (str | None) –
separable (bool) –
complex_network (bool) –
complex_activation (str) –
spectral_layers (int) –
checkpointing (int) –
data_grid (Literal['legendre-gauss', 'equiangular']) –
-
activation_function:
str= 'gelu'
-
big_skip:
bool= True
- build(n_in_channels, n_out_channels, dataset_info)[source]
Build a nn.Module given information about the input and output channels and the dataset.
-
checkpointing:
int= 0
-
complex_activation:
str= 'real'
-
complex_network:
bool= True
-
data_grid:
Literal['legendre-gauss','equiangular'] = 'legendre-gauss'
-
embed_dim:
int= 256
-
encoder_layers:
int= 1
-
filter_type:
str= 'linear'
-
hard_thresholding_fraction:
float= 1.0
-
normalization_layer:
str= 'instance_norm'
-
num_layers:
int= 12
-
operator_type:
str= 'diagonal'
-
pos_embed:
bool= True
-
rank:
float= 1.0
-
residual_filter_factor:
int= 1
-
scale_factor:
int= 1
-
separable:
bool= False
-
spectral_layers:
int= 1
-
spectral_transform:
str= 'sht'
-
use_mlp:
bool= True
- fme.ace.NoiseConditionedSFNO
alias of
NoiseConditionedModel
- class fme.ace.SFNO_V0_1_0(spectral_transform='sht', filter_type='linear', operator_type='dhconv', scale_factor=16, embed_dim=256, num_layers=12, repeat_layers=1, hard_thresholding_fraction=1.0, normalization_layer='instance_norm', use_mlp=True, activation_function='gelu', encoder_layers=1, pos_embed='direct', big_skip=True, rank=1.0, factorization=None, separable=False, complex_activation='real', spectral_layers=1, checkpointing=0, data_grid='legendre-gauss')[source]
Bases:
ModuleConfigConfiguration for the SFNO architecture in modulus-makani version 0.1.0.
- Parameters:
spectral_transform (str) –
filter_type (Literal['linear']) –
operator_type (str) –
scale_factor (int) –
embed_dim (int) –
num_layers (int) –
repeat_layers (int) –
hard_thresholding_fraction (float) –
normalization_layer (str) –
use_mlp (bool) –
activation_function (str) –
encoder_layers (int) –
pos_embed (Literal['none', 'direct', 'frequency']) –
big_skip (bool) –
rank (float) –
factorization (str | None) –
separable (bool) –
complex_activation (str) –
spectral_layers (int) –
checkpointing (int) –
data_grid (Literal['legendre-gauss', 'equiangular', 'healpix']) –
-
activation_function:
str= 'gelu'
-
big_skip:
bool= True
- build(n_in_channels, n_out_channels, dataset_info)[source]
Build a nn.Module given information about the input and output channels and the dataset.
-
checkpointing:
int= 0
-
complex_activation:
str= 'real'
-
data_grid:
Literal['legendre-gauss','equiangular','healpix'] = 'legendre-gauss'
-
embed_dim:
int= 256
-
encoder_layers:
int= 1
-
filter_type:
Literal['linear'] = 'linear'
-
hard_thresholding_fraction:
float= 1.0
-
normalization_layer:
str= 'instance_norm'
-
num_layers:
int= 12
-
operator_type:
str= 'dhconv'
-
pos_embed:
Literal['none','direct','frequency'] = 'direct'
-
rank:
float= 1.0
-
repeat_layers:
int= 1
-
scale_factor:
int= 16
-
separable:
bool= False
-
spectral_layers:
int= 1
-
spectral_transform:
str= 'sht'
-
use_mlp:
bool= True
- class fme.ace.HEALPixUNetBuilder(encoder, decoder, enable_nhwc=False, hpx_padding_mode='earth2grid', nside=None)[source]
Bases:
ModuleConfigConfiguration for the HEALPix UNet (feed-forward encoder–decoder stack).
Time stepping, multi-step inputs, residual prediction, and rollout live in the stepper, not in this module.
- Parameters:
encoder (
UNetEncoderConfig) – UNet encoder configuration.decoder (
UNetDecoderConfig) – UNet decoder configuration.enable_nhwc (
bool, default:False) – Use NHWC tensor layout for child modules.hpx_padding_mode (
Literal['earth2grid','karlbauer','isolatitude'], default:'earth2grid') – HEALPix padding backend ("earth2grid","karlbauer","isolatitude"). Default"earth2grid".nside (
Optional[Sequence[int]], default:None) – Face height/width per UNet level (shallowest to deepest). Required forisolatitudepadding.
- build(n_in_channels, n_out_channels, dataset_info)[source]
Build a HEALPixUNet model.
-
decoder:
UNetDecoderConfig
-
enable_nhwc:
bool= False
-
encoder:
UNetEncoderConfig
-
hpx_padding_mode:
Literal['earth2grid','karlbauer','isolatitude'] = 'earth2grid'
- class fme.ace.LandNetBuilder(hidden_dims=<factory>, network_type='MLP', use_positional_embedding=False)[source]
Bases:
ModuleConfigConfiguration for the LandNet architecture.
- Parameters:
- build(n_in_channels, n_out_channels, dataset_info)[source]
Build a nn.Module given information about the input and output channels and the dataset.
-
network_type:
Literal['MLP'] = 'MLP'
-
use_positional_embedding:
bool= False
- class fme.ace.SamudraBuilder(ch_width=<factory>, n_layers=<factory>, dilation=<factory>, pad='circular', norm='instance', norm_kwargs=<factory>, upscale_factor=4, checkpoint_strategy=None)[source]
Bases:
ModuleConfigConfiguration for the M2Lines Samudra architecture.
- Parameters:
- build(n_in_channels, n_out_channels, dataset_info)[source]
Build a nn.Module given information about the input and output channels and the dataset.
-
norm:
str= 'instance'
-
pad:
str= 'circular'
-
upscale_factor:
int= 4