Modules¶
ACE’s code uses a module registry system to allow different machine learning architectures to plug into the framework.
This is managed by the fme.ace.ModuleSelector configuration class, which is used to select a module type and version.
- class fme.ace.ModuleSelector(type, config, conditional=False)[source]
Bases:
objectA dataclass containing all the information needed to build a ModuleConfig, including the type of the ModuleConfig and the data needed to build it.
This is helpful as ModuleSelector can be serialized and deserialized without any additional information, whereas to load a ModuleConfig you would need to know the type of the ModuleConfig being loaded.
It is also convenient because ModuleSelector is a single class that can be used to represent any ModuleConfig, whereas ModuleConfig is a protocol that can be implemented by many different classes.
- Parameters:
- build(n_in_channels, n_out_channels, dataset_info)[source]
Build a nn.Module given information about the input and output channels and the dataset.
- Parameters:
- Return type:
Module- Returns:
a Module object
-
conditional:
bool= False
- classmethod get_available_types()[source]
This class method is used to expose all available types of Modules.
- classmethod register(type_name)[source]
-
registry:
ClassVar[Registry[ModuleConfig]] = <fme.core.registry.registry.Registry object>
-
type:
str
The following module types are available:
dict_keys(['LandNet', 'Samudra', 'FloeNet', 'prebuilt', 'SphericalFourierNeuralOperatorNet', 'SFNO-v0.1.0', 'NoiseConditionedSFNO', 'HEALPixRecUNet'])
- fme.core.registry.ModuleSelector.get_available_types()¶
This class method is used to expose all available types of Modules.
The following module builders are available:
- class fme.ace.SphericalFourierNeuralOperatorBuilder(spectral_transform='sht', filter_type='linear', operator_type='diagonal', scale_factor=1, residual_filter_factor=1, embed_dim=256, num_layers=12, hard_thresholding_fraction=1.0, normalization_layer='instance_norm', use_mlp=True, activation_function='gelu', encoder_layers=1, pos_embed=True, big_skip=True, rank=1.0, factorization=None, separable=False, complex_network=True, complex_activation='real', spectral_layers=1, checkpointing=0, data_grid='legendre-gauss')[source]
Bases:
ModuleConfigConfiguration for the SFNO architecture used in FourCastNet-SFNO.
- Parameters:
spectral_transform (str) –
filter_type (str) –
operator_type (str) –
scale_factor (int) –
residual_filter_factor (int) –
embed_dim (int) –
num_layers (int) –
hard_thresholding_fraction (float) –
normalization_layer (str) –
use_mlp (bool) –
activation_function (str) –
encoder_layers (int) –
pos_embed (bool) –
big_skip (bool) –
rank (float) –
factorization (str | None) –
separable (bool) –
complex_network (bool) –
complex_activation (str) –
spectral_layers (int) –
checkpointing (int) –
data_grid (Literal['legendre-gauss', 'equiangular']) –
-
activation_function:
str= 'gelu'
-
big_skip:
bool= True
- build(n_in_channels, n_out_channels, dataset_info)[source]
Build a nn.Module given information about the input and output channels and the dataset.
-
checkpointing:
int= 0
-
complex_activation:
str= 'real'
-
complex_network:
bool= True
-
data_grid:
Literal['legendre-gauss','equiangular'] = 'legendre-gauss'
-
embed_dim:
int= 256
-
encoder_layers:
int= 1
-
filter_type:
str= 'linear'
-
hard_thresholding_fraction:
float= 1.0
-
normalization_layer:
str= 'instance_norm'
-
num_layers:
int= 12
-
operator_type:
str= 'diagonal'
-
pos_embed:
bool= True
-
rank:
float= 1.0
-
residual_filter_factor:
int= 1
-
scale_factor:
int= 1
-
separable:
bool= False
-
spectral_layers:
int= 1
-
spectral_transform:
str= 'sht'
-
use_mlp:
bool= True
- class fme.ace.NoiseConditionedSFNO(conditional_model, noise_type='gaussian', embed_dim=256)[source]
Bases:
Module- Parameters:
- forward(x, labels=None)[source]
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters:
x (Tensor) –
labels (Tensor | None) –
- Return type:
Tensor
- class fme.ace.SFNO_V0_1_0(spectral_transform='sht', filter_type='linear', operator_type='dhconv', scale_factor=16, embed_dim=256, num_layers=12, repeat_layers=1, hard_thresholding_fraction=1.0, normalization_layer='instance_norm', use_mlp=True, activation_function='gelu', encoder_layers=1, pos_embed='direct', big_skip=True, rank=1.0, factorization=None, separable=False, complex_activation='real', spectral_layers=1, checkpointing=0, data_grid='legendre-gauss')[source]
Bases:
ModuleConfigConfiguration for the SFNO architecture in modulus-makani version 0.1.0.
- Parameters:
spectral_transform (str) –
filter_type (Literal['linear']) –
operator_type (str) –
scale_factor (int) –
embed_dim (int) –
num_layers (int) –
repeat_layers (int) –
hard_thresholding_fraction (float) –
normalization_layer (str) –
use_mlp (bool) –
activation_function (str) –
encoder_layers (int) –
pos_embed (Literal['none', 'direct', 'frequency']) –
big_skip (bool) –
rank (float) –
factorization (str | None) –
separable (bool) –
complex_activation (str) –
spectral_layers (int) –
checkpointing (int) –
data_grid (Literal['legendre-gauss', 'equiangular', 'healpix']) –
-
activation_function:
str= 'gelu'
-
big_skip:
bool= True
- build(n_in_channels, n_out_channels, dataset_info)[source]
Build a nn.Module given information about the input and output channels and the dataset.
-
checkpointing:
int= 0
-
complex_activation:
str= 'real'
-
data_grid:
Literal['legendre-gauss','equiangular','healpix'] = 'legendre-gauss'
-
embed_dim:
int= 256
-
encoder_layers:
int= 1
-
filter_type:
Literal['linear'] = 'linear'
-
hard_thresholding_fraction:
float= 1.0
-
normalization_layer:
str= 'instance_norm'
-
num_layers:
int= 12
-
operator_type:
str= 'dhconv'
-
pos_embed:
Literal['none','direct','frequency'] = 'direct'
-
rank:
float= 1.0
-
repeat_layers:
int= 1
-
scale_factor:
int= 16
-
separable:
bool= False
-
spectral_layers:
int= 1
-
spectral_transform:
str= 'sht'
-
use_mlp:
bool= True
- class fme.ace.HEALPixRecUNetBuilder(encoder, decoder, presteps=1, input_time_size=0, output_time_size=0, delta_time='6h', reset_cycle='24h', n_constants=2, decoder_input_channels=1, prognostic_variables=7, enable_nhwc=False, enable_healpixpad=False)[source]
Bases:
ModuleConfigConfiguration for the HEALPixRecUNet architecture used in DLWP.
- Parameters:
presteps (
int, default:1) – Number of pre-steps, by default 1.input_time_size (
int, default:0) – Input time dimension, by default 0.output_time_size (
int, default:0) – Output time dimension, by default 0.delta_time (
str, default:'6h') – Delta time interval, by default “6h”.reset_cycle (
str, default:'24h') – Reset cycle interval, by default “24h”.input_channels – Number of input channels, by default 8.
output_channels – Number of output channels, by default 8.
n_constants (
int, default:2) – Number of constant input channels, by default 2.decoder_input_channels (
int, default:1) – Number of input channels for the decoder, by default 1.enable_nhwc (
bool, default:False) – Flag to enable NHWC data format, by default False.enable_healpixpad (
bool, default:False) – Flag to enable HEALPix padding, by default False.encoder (UNetEncoderConfig) –
decoder (UNetDecoderConfig) –
prognostic_variables (int) –
- build(n_in_channels, n_out_channels, dataset_info)[source]
Builds the HEALPixRecUNet model.
-
decoder:
UNetDecoderConfig
-
decoder_input_channels:
int= 1
-
delta_time:
str= '6h'
-
enable_healpixpad:
bool= False
-
enable_nhwc:
bool= False
-
encoder:
UNetEncoderConfig
-
input_time_size:
int= 0
-
n_constants:
int= 2
-
output_time_size:
int= 0
-
presteps:
int= 1
-
prognostic_variables:
int= 7
-
reset_cycle:
str= '24h'
- class fme.ace.LandNetBuilder(hidden_dims=<factory>, network_type='MLP', use_positional_embedding=False)[source]
Bases:
ModuleConfigConfiguration for the LandNet architecture.
- Parameters:
- build(n_in_channels, n_out_channels, dataset_info)[source]
Build a nn.Module given information about the input and output channels and the dataset.
-
network_type:
Literal['MLP'] = 'MLP'
-
use_positional_embedding:
bool= False
- class fme.ace.SamudraBuilder(ch_width=<factory>, n_layers=<factory>, dilation=<factory>, pad='circular', norm='instance', norm_kwargs=<factory>, upscale_factor=4, checkpoint_strategy=None)[source]
Bases:
ModuleConfigConfiguration for the M2Lines Samudra architecture.
- Parameters:
- build(n_in_channels, n_out_channels, dataset_info)[source]
Build a nn.Module given information about the input and output channels and the dataset.
-
norm:
str= 'instance'
-
pad:
str= 'circular'
-
upscale_factor:
int= 4