Modules

ACE’s code uses a module registry system to allow different machine learning architectures to plug into the framework. This is managed by the fme.ace.ModuleSelector configuration class, which is used to select a module type and version.

class fme.ace.ModuleSelector(type, config, conditional=False)[source]

Bases: object

A dataclass containing all the information needed to build a ModuleConfig, including the type of the ModuleConfig and the data needed to build it.

This is helpful as ModuleSelector can be serialized and deserialized without any additional information, whereas to load a ModuleConfig you would need to know the type of the ModuleConfig being loaded.

It is also convenient because ModuleSelector is a single class that can be used to represent any ModuleConfig, whereas ModuleConfig is a protocol that can be implemented by many different classes.

Parameters:
  • type (str) – the type of the ModuleConfig

  • config (Mapping[str, Any]) – data for a ModuleConfig instance of the indicated type

  • conditional (bool, default: False) – whether to condition the predictions on batch labels.

build(n_in_channels, n_out_channels, dataset_info)[source]

Build a nn.Module given information about the input and output channels and the dataset.

Parameters:
  • n_in_channels (int) – number of input channels

  • n_out_channels (int) – number of output channels

  • dataset_info (DatasetInfo) – Information about the dataset, including img_shape (shape of last two dimensions of data, e.g. latitude and longitude), horizontal coordinates, vertical coordinate, etc.

Return type:

Module

Returns:

a Module object

conditional: bool = False
config: Mapping[str, Any]
classmethod get_available_types()[source]

This class method is used to expose all available types of Modules.

classmethod register(type_name)[source]
Return type:

Callable[[Type[ModuleConfig]], Type[ModuleConfig]]

Parameters:

type_name (str) –

registry: ClassVar[Registry[ModuleConfig]] = <fme.core.registry.registry.Registry object>
type: str

The following module types are available:

dict_keys(['LandNet', 'Samudra', 'FloeNet', 'prebuilt', 'SphericalFourierNeuralOperatorNet', 'SFNO-v0.1.0', 'NoiseConditionedSFNO', 'HEALPixRecUNet'])
fme.core.registry.ModuleSelector.get_available_types()

This class method is used to expose all available types of Modules.

The following module builders are available:

class fme.ace.SphericalFourierNeuralOperatorBuilder(spectral_transform='sht', filter_type='linear', operator_type='diagonal', scale_factor=1, residual_filter_factor=1, embed_dim=256, num_layers=12, hard_thresholding_fraction=1.0, normalization_layer='instance_norm', use_mlp=True, activation_function='gelu', encoder_layers=1, pos_embed=True, big_skip=True, rank=1.0, factorization=None, separable=False, complex_network=True, complex_activation='real', spectral_layers=1, checkpointing=0, data_grid='legendre-gauss')[source]

Bases: ModuleConfig

Configuration for the SFNO architecture used in FourCastNet-SFNO.

Parameters:
  • spectral_transform (str) –

  • filter_type (str) –

  • operator_type (str) –

  • scale_factor (int) –

  • residual_filter_factor (int) –

  • embed_dim (int) –

  • num_layers (int) –

  • hard_thresholding_fraction (float) –

  • normalization_layer (str) –

  • use_mlp (bool) –

  • activation_function (str) –

  • encoder_layers (int) –

  • pos_embed (bool) –

  • big_skip (bool) –

  • rank (float) –

  • factorization (str | None) –

  • separable (bool) –

  • complex_network (bool) –

  • complex_activation (str) –

  • spectral_layers (int) –

  • checkpointing (int) –

  • data_grid (Literal['legendre-gauss', 'equiangular']) –

activation_function: str = 'gelu'
big_skip: bool = True
build(n_in_channels, n_out_channels, dataset_info)[source]

Build a nn.Module given information about the input and output channels and the dataset.

Parameters:
  • n_in_channels (int) – number of input channels

  • n_out_channels (int) – number of output channels

  • dataset_info (DatasetInfo) – Information about the dataset, including img_shape, horizontal coordinates, vertical coordinate, etc.

Returns:

a nn.Module

checkpointing: int = 0
complex_activation: str = 'real'
complex_network: bool = True
data_grid: Literal['legendre-gauss', 'equiangular'] = 'legendre-gauss'
embed_dim: int = 256
encoder_layers: int = 1
factorization: Optional[str] = None
filter_type: str = 'linear'
hard_thresholding_fraction: float = 1.0
normalization_layer: str = 'instance_norm'
num_layers: int = 12
operator_type: str = 'diagonal'
pos_embed: bool = True
rank: float = 1.0
residual_filter_factor: int = 1
scale_factor: int = 1
separable: bool = False
spectral_layers: int = 1
spectral_transform: str = 'sht'
use_mlp: bool = True
class fme.ace.NoiseConditionedSFNO(conditional_model, noise_type='gaussian', embed_dim=256)[source]

Bases: Module

Parameters:
  • conditional_model (SphericalFourierNeuralOperatorNet) –

  • noise_type (Literal['isotropic', 'gaussian']) –

  • embed_dim (int) –

forward(x, labels=None)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
  • x (Tensor) –

  • labels (Tensor | None) –

Return type:

Tensor

class fme.ace.SFNO_V0_1_0(spectral_transform='sht', filter_type='linear', operator_type='dhconv', scale_factor=16, embed_dim=256, num_layers=12, repeat_layers=1, hard_thresholding_fraction=1.0, normalization_layer='instance_norm', use_mlp=True, activation_function='gelu', encoder_layers=1, pos_embed='direct', big_skip=True, rank=1.0, factorization=None, separable=False, complex_activation='real', spectral_layers=1, checkpointing=0, data_grid='legendre-gauss')[source]

Bases: ModuleConfig

Configuration for the SFNO architecture in modulus-makani version 0.1.0.

Parameters:
  • spectral_transform (str) –

  • filter_type (Literal['linear']) –

  • operator_type (str) –

  • scale_factor (int) –

  • embed_dim (int) –

  • num_layers (int) –

  • repeat_layers (int) –

  • hard_thresholding_fraction (float) –

  • normalization_layer (str) –

  • use_mlp (bool) –

  • activation_function (str) –

  • encoder_layers (int) –

  • pos_embed (Literal['none', 'direct', 'frequency']) –

  • big_skip (bool) –

  • rank (float) –

  • factorization (str | None) –

  • separable (bool) –

  • complex_activation (str) –

  • spectral_layers (int) –

  • checkpointing (int) –

  • data_grid (Literal['legendre-gauss', 'equiangular', 'healpix']) –

activation_function: str = 'gelu'
big_skip: bool = True
build(n_in_channels, n_out_channels, dataset_info)[source]

Build a nn.Module given information about the input and output channels and the dataset.

Parameters:
  • n_in_channels (int) – number of input channels

  • n_out_channels (int) – number of output channels

  • dataset_info (DatasetInfo) – Information about the dataset, including img_shape, horizontal coordinates, vertical coordinate, etc.

Returns:

a nn.Module

checkpointing: int = 0
complex_activation: str = 'real'
data_grid: Literal['legendre-gauss', 'equiangular', 'healpix'] = 'legendre-gauss'
embed_dim: int = 256
encoder_layers: int = 1
factorization: Optional[str] = None
filter_type: Literal['linear'] = 'linear'
hard_thresholding_fraction: float = 1.0
normalization_layer: str = 'instance_norm'
num_layers: int = 12
operator_type: str = 'dhconv'
pos_embed: Literal['none', 'direct', 'frequency'] = 'direct'
rank: float = 1.0
repeat_layers: int = 1
scale_factor: int = 16
separable: bool = False
spectral_layers: int = 1
spectral_transform: str = 'sht'
use_mlp: bool = True
class fme.ace.HEALPixRecUNetBuilder(encoder, decoder, presteps=1, input_time_size=0, output_time_size=0, delta_time='6h', reset_cycle='24h', n_constants=2, decoder_input_channels=1, prognostic_variables=7, enable_nhwc=False, enable_healpixpad=False)[source]

Bases: ModuleConfig

Configuration for the HEALPixRecUNet architecture used in DLWP.

Parameters:
  • presteps (int, default: 1) – Number of pre-steps, by default 1.

  • input_time_size (int, default: 0) – Input time dimension, by default 0.

  • output_time_size (int, default: 0) – Output time dimension, by default 0.

  • delta_time (str, default: '6h') – Delta time interval, by default “6h”.

  • reset_cycle (str, default: '24h') – Reset cycle interval, by default “24h”.

  • input_channels – Number of input channels, by default 8.

  • output_channels – Number of output channels, by default 8.

  • n_constants (int, default: 2) – Number of constant input channels, by default 2.

  • decoder_input_channels (int, default: 1) – Number of input channels for the decoder, by default 1.

  • enable_nhwc (bool, default: False) – Flag to enable NHWC data format, by default False.

  • enable_healpixpad (bool, default: False) – Flag to enable HEALPix padding, by default False.

  • encoder (UNetEncoderConfig) –

  • decoder (UNetDecoderConfig) –

  • prognostic_variables (int) –

build(n_in_channels, n_out_channels, dataset_info)[source]

Builds the HEALPixRecUNet model.

Parameters:
  • n_in_channels (int) – Number of input channels.

  • n_out_channels (int) – Number of output channels.

  • dataset_info (DatasetInfo) – Information about the dataset.

Return type:

Module

Returns:

HEALPixRecUNet model.

decoder: UNetDecoderConfig
decoder_input_channels: int = 1
delta_time: str = '6h'
enable_healpixpad: bool = False
enable_nhwc: bool = False
encoder: UNetEncoderConfig
input_time_size: int = 0
n_constants: int = 2
output_time_size: int = 0
presteps: int = 1
prognostic_variables: int = 7
reset_cycle: str = '24h'
class fme.ace.LandNetBuilder(hidden_dims=<factory>, network_type='MLP', use_positional_embedding=False)[source]

Bases: ModuleConfig

Configuration for the LandNet architecture.

Parameters:
  • hidden_dims (list[int]) –

  • network_type (Literal['MLP']) –

  • use_positional_embedding (bool) –

build(n_in_channels, n_out_channels, dataset_info)[source]

Build a nn.Module given information about the input and output channels and the dataset.

Parameters:
  • n_in_channels (int) – number of input channels

  • n_out_channels (int) – number of output channels

  • dataset_info (DatasetInfo) – Information about the dataset, including img_shape, horizontal coordinates, vertical coordinate, etc.

Returns:

a nn.Module

hidden_dims: list[int]
network_type: Literal['MLP'] = 'MLP'
use_positional_embedding: bool = False
class fme.ace.SamudraBuilder(ch_width=<factory>, n_layers=<factory>, dilation=<factory>, pad='circular', norm='instance', norm_kwargs=<factory>, upscale_factor=4, checkpoint_strategy=None)[source]

Bases: ModuleConfig

Configuration for the M2Lines Samudra architecture.

Parameters:
build(n_in_channels, n_out_channels, dataset_info)[source]

Build a nn.Module given information about the input and output channels and the dataset.

Parameters:
  • n_in_channels (int) – number of input channels

  • n_out_channels (int) – number of output channels

  • dataset_info (DatasetInfo) – Information about the dataset, including img_shape, horizontal coordinates, vertical coordinate, etc.

Returns:

a nn.Module

ch_width: list[int]
checkpoint_strategy: Optional[Literal['all', 'simple']] = None
dilation: list[int]
n_layers: list[int]
norm: str = 'instance'
norm_kwargs: Mapping[str, Any]
pad: str = 'circular'
upscale_factor: int = 4