import dataclasses
import math
from collections.abc import Callable
from typing import Literal
import torch
from fme.ace.registry.registry import ModuleConfig, ModuleSelector
from fme.core.dataset_info import DatasetInfo
from fme.core.models.conditional_sfno.sfnonet import (
Context,
ContextConfig,
get_lat_lon_sfnonet,
)
from fme.core.models.conditional_sfno.sfnonet import (
SphericalFourierNeuralOperatorNet as ConditionalSFNO,
)
def isotropic_noise(
leading_shape: tuple[int, ...],
lmax: int, # length of the ℓ axis expected by isht
mmax: int, # length of the m axis expected by isht
isht: Callable[[torch.Tensor], torch.Tensor],
device: torch.device,
) -> torch.Tensor:
# --- draw independent N(0,1) parts --------------------------------------
coeff_shape = (*leading_shape, lmax, mmax)
real = torch.randn(coeff_shape, dtype=torch.float32, device=device)
imag = torch.randn(coeff_shape, dtype=torch.float32, device=device)
imag[..., :, 0] = 0.0 # m = 0 ⇒ purely real
# m > 0: make Re and Im each N(0,½) → |a_{ℓ m}|² has variance 1
sqrt2 = math.sqrt(2.0)
real[..., :, 1:] /= sqrt2
imag[..., :, 1:] /= sqrt2
# --- global scale that makes Var[T(θ,φ)] = 1 ---------------------------
scale = math.sqrt(4.0 * math.pi) / lmax # (Unsöld theorem ⇒ L = lmax)
alm = (real + 1j * imag) * scale
return isht(alm)
[docs]class NoiseConditionedSFNO(torch.nn.Module):
def __init__(
self,
conditional_model: ConditionalSFNO,
noise_type: Literal["isotropic", "gaussian"] = "gaussian",
embed_dim: int = 256,
):
super().__init__()
self.conditional_model = conditional_model
self.embed_dim = embed_dim
self.noise_type = noise_type
[docs] def forward(
self, x: torch.Tensor, labels: torch.Tensor | None = None
) -> torch.Tensor:
if self.noise_type == "isotropic":
lmax = self.conditional_model.itrans_up.lmax
mmax = self.conditional_model.itrans_up.mmax
noise = isotropic_noise(
tuple(x.shape[:-3]) + (self.embed_dim,),
lmax,
mmax,
self.conditional_model.itrans_up,
device=x.device,
)
elif self.noise_type == "gaussian":
noise = torch.randn(
[*x.shape[:-3], self.embed_dim, *x.shape[-2:]],
device=x.device,
dtype=x.dtype,
)
else:
raise ValueError(f"Invalid noise type: {self.noise_type}")
embedding_scalar = torch.zeros(
[*x.shape[:-3], 0], device=x.device, dtype=x.dtype
)
return self.conditional_model(
x, Context(embedding_scalar=embedding_scalar, labels=labels, noise=noise)
)
# this is based on the call signature of SphericalFourierNeuralOperatorNet at
# https://github.com/NVIDIA/modulus/blob/b8e27c5c4ebc409e53adaba9832138743ede2785/modulus/models/sfno/sfnonet.py#L292 # noqa: E501
@ModuleSelector.register("NoiseConditionedSFNO")
@dataclasses.dataclass
class NoiseConditionedSFNOBuilder(ModuleConfig):
"""
Configuration for a noise-conditioned SFNO model.
Noise is provided as conditioning input to conditional layer normalization.
Attributes:
spectral_transform: Type of spherical transform to use.
Kept for backwards compatibility.
filter_type: Type of filter to use.
operator_type: Type of operator to use.
residual_filter_factor: Factor by which to downsample the residual.
embed_dim: Dimension of the embedding.
noise_embed_dim: Dimension of the noise embedding.
noise_type: Type of noise to use for conditioning.
global_layer_norm: Whether to reduce along the spatial domain when applying
layer normalization.
num_layers: Number of blocks (SFNO and MLP)in the model.
use_mlp: Whether to use a MLP in the model.
mlp_ratio: Ratio of the MLP hidden dimension
to the embedding dimension.
activation_function: Activation function to use.
encoder_layers: Number of encoder layers in the model.
pos_embed: Whether to use a position embedding.
big_skip: Whether to use a big skip connection in the model.
rank: Rank of the model.
factorization: Factorization to use.
separable: Whether to use a separable filter.
complex_network: Whether to use a complex network.
complex_activation: Activation function to use.
spectral_layers: Number of spectral layers in the model.
checkpointing: Whether to use checkpointing.
filter_residual: Whether to filter residual connections through a
SHT round-trip. These will always be filtered if residual_filter_factor
is not 1.
filter_output: Whether to filter the output of the model through a
SHT round-trip.
local_blocks: List of block indices to use discrete-conditional
convolution (DISCO) blocks, which apply local filters. See
Ocampo et al. (2022)
https://arxiv.org/abs/2209.13603 for more details.
normalize_big_skip: Whether to normalize the big_skip connection.
affine_norms: Whether to use element-wise affine parameters in the
normalization layers.
filter_num_groups: Number of groups to use in grouped convolutions
for the spectral filter.
"""
spectral_transform: Literal["sht"] = "sht"
filter_type: str = "non-linear"
operator_type: str = "diagonal"
residual_filter_factor: int = 1
embed_dim: int = 256
noise_embed_dim: int = 256
noise_type: Literal["isotropic", "gaussian"] = "gaussian"
global_layer_norm: bool = False
num_layers: int = 12
use_mlp: bool = True
mlp_ratio: float = 2.0
activation_function: str = "gelu"
encoder_layers: int = 1
pos_embed: bool = True
big_skip: bool = True
rank: float = 1.0
factorization: str | None = None
separable: bool = False
complex_network: bool = True
complex_activation: str = "real"
spectral_layers: int = 1
checkpointing: int = 0
# healpix not supported due to assumptions about number of spatial dims
data_grid: Literal["legendre-gauss", "equiangular"] = "legendre-gauss"
filter_residual: bool = False
filter_output: bool = False
local_blocks: list[int] | None = None
normalize_big_skip: bool = False
affine_norms: bool = False
filter_num_groups: int = 1
def build(
self,
n_in_channels: int,
n_out_channels: int,
dataset_info: DatasetInfo,
):
sfno_net = get_lat_lon_sfnonet(
params=self,
in_chans=n_in_channels,
out_chans=n_out_channels,
img_shape=dataset_info.img_shape,
context_config=ContextConfig(
embed_dim_scalar=0,
embed_dim_labels=len(dataset_info.all_labels),
embed_dim_noise=self.noise_embed_dim,
),
)
return NoiseConditionedSFNO(
sfno_net, noise_type=self.noise_type, embed_dim=self.noise_embed_dim
)