Source code for fme.ace.models.healpix.healpix_blocks

# flake8: noqa
# Copied from https://github.com/NVIDIA/modulus/commit/89a6091bd21edce7be4e0539cbd91507004faf08
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import math
from typing import Literal, Optional, Sequence, Tuple, Union, cast

import torch as th
import torch.nn as nn

from .healpix_activations import CappedGELUConfig
from .healpix_layers import HEALPixLayer


def _healpix_layer_kwargs(
    enable_nhwc: bool,
    hpx_padding_mode: Literal["earth2grid", "karlbauer", "isolatitude"] = "earth2grid",
    nside: Optional[int] = None,
) -> dict:
    """
    Build keyword arguments passed to ``HEALPixLayer``.

    Args:
        enable_nhwc: Use channels-last memory format.
        hpx_padding_mode: HEALPix padding backend.
        nside: Native face height/width; included in the dict only when not ``None``.
    """
    out: dict = {"enable_nhwc": enable_nhwc, "hpx_padding_mode": hpx_padding_mode}
    if nside is not None:
        out["nside"] = nside
    return out


# --- Configuration dataclasses ---


[docs]@dataclasses.dataclass class DownsamplingBlockConfig: """ Configuration for the downsampling block (pooling or dealiased strided blur). Parameters: block_type: One of ``"MaxPool"``, ``"AvgPool"``, or ``"DealiasedDownsample"``. pooling: Pooling size for pool blocks. stride: Spatial stride for ``DealiasedDownsample`` (power of two). enable_nhwc: Use channels-last memory format. hpx_padding_mode: HEALPix padding backend passed to child modules. nside: Native face height/width for HEALPix padding. in_channels: Input channels for ``DealiasedDownsample`` (set by encoder before build). resample_filter: 1D filter weights for dealiased blur stages. """ block_type: Literal["MaxPool", "AvgPool", "DealiasedDownsample"] pooling: int = 2 enable_nhwc: bool = False hpx_padding_mode: Literal["earth2grid", "karlbauer", "isolatitude"] = "earth2grid" nside: Optional[int] = None in_channels: Optional[int] = None resample_filter: Sequence[float] = dataclasses.field( default_factory=lambda: [1.0, 2.0, 1.0] ) stride: int = 2 def downsample_spatial_factor(self) -> int: if self.block_type in ("MaxPool", "AvgPool"): return self.pooling if self.block_type == "DealiasedDownsample": return self.stride raise ValueError(f"Unsupported block type: {self.block_type}") def build(self) -> nn.Module: if self.block_type == "MaxPool": return MaxPool( pooling=self.pooling, enable_nhwc=self.enable_nhwc, hpx_padding_mode=self.hpx_padding_mode, nside=self.nside, ) if self.block_type == "AvgPool": return AvgPool( pooling=self.pooling, enable_nhwc=self.enable_nhwc, hpx_padding_mode=self.hpx_padding_mode, nside=self.nside, ) if self.block_type == "DealiasedDownsample": if self.in_channels is None: raise ValueError( "DealiasedDownsample requires in_channels " "(set by UNetEncoder before build)" ) return DealiasedDownsample( in_channels=self.in_channels, resample_filter=self.resample_filter, stride=self.stride, enable_nhwc=self.enable_nhwc, hpx_padding_mode=self.hpx_padding_mode, nside=self.nside, ) raise ValueError(f"Unsupported block type: {self.block_type}")
[docs]@dataclasses.dataclass class UpsamplingBlockConfig: """ Configuration for HEALPix upsampling (transpose conv, interpolate+conv, pure ``nn.Upsample``, etc.). ``block_type`` ``"Interpolate"`` uses ``stride`` as ``nn.Upsample`` ``scale_factor`` and ``upsample_mode`` as the interpolation mode. Parameters: block_type: Upsampling implementation to build. in_channels: Input channel count for conv-based upsamplers. out_channels: Output channel count for conv-based upsamplers. stride: Upsampling scale factor (also used as ``nn.Upsample`` scale when applicable). kernel_size: Convolution kernel size for ``SmoothedInterpolateConv``. dilation: Convolution dilation for ``SmoothedInterpolateConv``. upsample_mode: Interpolation mode for smoothed / pure interpolate paths. activation: Optional ``CappedGELUConfig`` for transpose-conv upsampling. enable_nhwc: Use channels-last memory format. hpx_padding_mode: HEALPix padding backend passed to child modules. nside: Native face height/width for HEALPix padding at upsample input. nside_after: Face height/width after upsampling (``SmoothedInterpolateConv`` only). align_corners: Passed to ``nn.Upsample`` when ``block_type`` is ``"Interpolate"``. scale_factor: Alias for ``stride`` when set in config. mode: Alias for ``upsample_mode`` when set in config. """ block_type: Literal[ "TransposedConvUpsample", "SmoothedInterpolateConv", "Interpolate", ] in_channels: int = 3 out_channels: int = 1 stride: int = 2 kernel_size: int = 3 dilation: int = 1 upsample_mode: str = "nearest" activation: Optional[CappedGELUConfig] = None enable_nhwc: bool = False hpx_padding_mode: Literal["earth2grid", "karlbauer", "isolatitude"] = "earth2grid" nside: Optional[int] = None nside_after: Optional[int] = None align_corners: bool = False scale_factor: Optional[int] = None mode: Optional[str] = None def __post_init__(self) -> None: if self.scale_factor is not None: self.stride = self.scale_factor if self.mode is not None: self.upsample_mode = self.mode def build(self) -> nn.Module: if self.block_type == "TransposedConvUpsample": return TransposedConvUpsample( in_channels=self.in_channels, out_channels=self.out_channels, upsampling=self.stride, activation=self.activation, enable_nhwc=self.enable_nhwc, hpx_padding_mode=self.hpx_padding_mode, nside=self.nside, ) if self.block_type == "SmoothedInterpolateConv": return SmoothedInterpolateConv( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, dilation=self.dilation, scale_factor=self.stride, mode=self.upsample_mode, activation=self.activation.build() if self.activation else None, enable_nhwc=self.enable_nhwc, hpx_padding_mode=self.hpx_padding_mode, nside=self.nside, nside_after=self.nside_after, ) if self.block_type == "Interpolate": if self.align_corners is False: return nn.Upsample( scale_factor=self.stride, mode=self.upsample_mode, ) return nn.Upsample( scale_factor=self.stride, mode=self.upsample_mode, align_corners=self.align_corners, ) raise ValueError(f"Unsupported block type: {self.block_type}")
[docs]@dataclasses.dataclass class ConvBlockConfig: """ Configuration for convolutional residual / ConvNeXt style blocks (no spatial resample). Parameters: in_channels: Number of input channels. out_channels: Number of output channels. kernel_size: Convolution kernel size. dilation: Convolution dilation. n_layers: Number of repeated layers (for multi-block types). upscale_factor: Channel upscale factor inside ConvNeXt blocks. latent_channels: Latent channel width; defaults to ``max(in_channels, out_channels)``. activation: Optional ``CappedGELUConfig`` between layers. enable_nhwc: Use channels-last memory format. hpx_padding_mode: HEALPix padding backend passed to child modules. nside: Native face height/width for HEALPix padding. block_type: Which block implementation to build. """ in_channels: int = 3 out_channels: int = 1 kernel_size: int = 3 dilation: int = 1 n_layers: int = 1 upscale_factor: int = 4 latent_channels: Optional[int] = None activation: Optional[CappedGELUConfig] = None enable_nhwc: bool = False hpx_padding_mode: Literal["earth2grid", "karlbauer", "isolatitude"] = "earth2grid" nside: Optional[int] = None block_type: Literal[ "BasicConvBlock", "ConvNeXtBlock", "SymmetricConvNeXtBlock", "Multi_SymmetricConvNeXtBlock", ] = "BasicConvBlock" def build(self) -> nn.Module: if self.block_type == "BasicConvBlock": return BasicConvBlock( in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=self.kernel_size, dilation=self.dilation, n_layers=self.n_layers, latent_channels=self.latent_channels, activation=self.activation, enable_nhwc=self.enable_nhwc, hpx_padding_mode=self.hpx_padding_mode, nside=self.nside, ) if self.block_type == "ConvNeXtBlock": if self.latent_channels is None: self.latent_channels = 1 return ConvNeXtBlock( in_channels=self.in_channels, latent_channels=cast(int, self.latent_channels), out_channels=self.out_channels, kernel_size=self.kernel_size, dilation=self.dilation, upscale_factor=self.upscale_factor, activation=self.activation, enable_nhwc=self.enable_nhwc, hpx_padding_mode=self.hpx_padding_mode, nside=self.nside, ) if self.block_type == "SymmetricConvNeXtBlock": if self.latent_channels is None: self.latent_channels = 1 return SymmetricConvNeXtBlock( in_channels=self.in_channels, latent_channels=cast(int, self.latent_channels), out_channels=self.out_channels, kernel_size=self.kernel_size, dilation=self.dilation, upscale_factor=self.upscale_factor, activation=self.activation, enable_nhwc=self.enable_nhwc, hpx_padding_mode=self.hpx_padding_mode, nside=self.nside, ) if self.block_type == "Multi_SymmetricConvNeXtBlock": if self.latent_channels is None: self.latent_channels = 1 return Multi_SymmetricConvNeXtBlock( in_channels=self.in_channels, latent_channels=cast(int, self.latent_channels), out_channels=self.out_channels, kernel_size=self.kernel_size, dilation=self.dilation, upscale_factor=self.upscale_factor, n_layers=self.n_layers, activation=self.activation, enable_nhwc=self.enable_nhwc, hpx_padding_mode=self.hpx_padding_mode, nside=self.nside, ) raise ValueError(f"Unsupported block type: {self.block_type}")
# --- Downsampling modules --- class MaxPool(nn.Module): """Wrapper for applying Max Pooling with HEALPix or other tensor data.""" def __init__( self, pooling: int = 2, enable_nhwc: bool = False, hpx_padding_mode: Literal[ "earth2grid", "karlbauer", "isolatitude" ] = "earth2grid", nside: Optional[int] = None, ): """ Args: pooling: ``MaxPool2d`` kernel size (and stride). enable_nhwc: Use channels-last memory format. hpx_padding_mode: HEALPix padding backend passed to ``HEALPixLayer``. nside: Native face height/width for HEALPix padding. """ super().__init__() self.maxpool = HEALPixLayer( layer=nn.MaxPool2d, kernel_size=pooling, **_healpix_layer_kwargs(enable_nhwc, hpx_padding_mode, nside), ) def forward(self, x: th.Tensor) -> th.Tensor: """ Args: x: Input tensor ``[N * 12, C, H, W]``. Returns: Pooled tensor with halved spatial size per pooling factor. """ return self.maxpool(x) class AvgPool(nn.Module): """Wrapper for applying Average Pooling with HEALPix or other tensor data.""" def __init__( self, pooling: int = 2, enable_nhwc: bool = False, hpx_padding_mode: Literal[ "earth2grid", "karlbauer", "isolatitude" ] = "earth2grid", nside: Optional[int] = None, ): """ Args: pooling: ``AvgPool2d`` kernel size (and stride). enable_nhwc: Use channels-last memory format. hpx_padding_mode: HEALPix padding backend passed to ``HEALPixLayer``. nside: Native face height/width for HEALPix padding. """ super().__init__() self.avgpool = HEALPixLayer( layer=nn.AvgPool2d, kernel_size=pooling, **_healpix_layer_kwargs(enable_nhwc, hpx_padding_mode, nside), ) def forward(self, x: th.Tensor) -> th.Tensor: """ Args: x: Input tensor ``[N * 12, C, H, W]``. Returns: Pooled tensor with halved spatial size per pooling factor. """ return self.avgpool(x) class DealiasBlurConv2d(nn.Module): """Depthwise blur with fixed kernel using functional conv2d.""" @staticmethod def _normalized_depthwise_blur_weights( resample_filter: Sequence[float], in_channels: int ) -> th.Tensor: f = th.as_tensor(list(resample_filter), dtype=th.float32) if f.ndim != 1: raise ValueError("resample_filter must be 1D") m = int(f.numel()) f2d = f[:, None] * f[None, :] f2d = f2d / f2d.sum() return f2d.unsqueeze(0).unsqueeze(0).expand(in_channels, 1, m, m).clone() def __init__( self, in_channels: int, stride: int = 1, resample_filter: Sequence[float] | None = None, **kwargs, ): """ Args: in_channels: Number of input channels (depthwise groups). stride: Stride of the depthwise blur convolution. resample_filter: 1D separable filter weights used to build the 2D kernel. **kwargs: Accepted for API compatibility; not used. """ super().__init__() if resample_filter is None: resample_filter = [1.0, 2.0, 1.0] filt = tuple(float(x) for x in resample_filter) if len(filt) < 1: raise ValueError("resample_filter must be non-empty") if sum(filt) == 0: raise ValueError("resample_filter must not sum to zero") self.in_channels = in_channels self.stride = stride self.register_buffer( "weight", self._normalized_depthwise_blur_weights(filt, in_channels), ) def forward(self, x: th.Tensor) -> th.Tensor: """ Args: x: Input tensor ``[N, C, H, W]``. Returns: Depthwise-blurred tensor with optional strided downsampling. """ return th.nn.functional.conv2d( x, self.weight.to(device=x.device, dtype=x.dtype), bias=None, stride=self.stride, padding=0, groups=self.in_channels, ) class DealiasedDownsample(nn.Module): """De-aliased downsampling via fixed depthwise blur stages (stride power of 2).""" def __init__( self, in_channels: int = 3, resample_filter: Sequence[float] | None = None, stride: int = 2, enable_nhwc: bool = False, hpx_padding_mode: Literal[ "earth2grid", "karlbauer", "isolatitude" ] = "earth2grid", nside: Optional[int] = None, ): """ Args: in_channels: Number of input channels. resample_filter: 1D filter weights for each blur stage. stride: Total downsampling factor (must be a power of two). enable_nhwc: Use channels-last memory format. hpx_padding_mode: HEALPix padding backend passed to ``HEALPixLayer``. nside: Native face height/width for HEALPix padding. """ super().__init__() if resample_filter is None: resample_filter = [1.0, 2.0, 1.0] filt = tuple(float(x) for x in resample_filter) m = len(filt) if m < 1: raise ValueError("resample_filter must be non-empty") if sum(filt) == 0: raise ValueError("resample_filter must not sum to zero") if stride < 1 or (math.log2(stride) % 1) != 0: raise ValueError("stride must be a positive power of 2") n_layers = int(math.log2(stride)) pool_layers = [] healpix_kwargs = _healpix_layer_kwargs( enable_nhwc, hpx_padding_mode, nside, ) for _ in range(n_layers): pool_layers.append( HEALPixLayer( layer=DealiasBlurConv2d, in_channels=in_channels, out_channels=in_channels, kernel_size=m, stride=2, padding=0, groups=in_channels, bias=False, dilation=1, resample_filter=filt, **healpix_kwargs, ) ) self.pool = nn.Sequential(*pool_layers) def forward(self, x: th.Tensor) -> th.Tensor: """ Args: x: Input tensor ``[N * 12, C, H, W]``. Returns: Dealiased downsampled tensor. """ return self.pool(x) # --- Upsampling modules --- class TransposedConvUpsample(nn.Module): """Wrapper for upsampling with a transposed convolution using HEALPix or other tensor data. This class wraps the `nn.ConvTranspose2d` class to handle tensor data with HEALPix or other geometry layers. """ def __init__( self, in_channels: int = 3, out_channels: int = 1, upsampling: int = 2, activation: Optional[CappedGELUConfig] = None, enable_nhwc: bool = False, hpx_padding_mode: Literal[ "earth2grid", "karlbauer", "isolatitude" ] = "earth2grid", nside: Optional[int] = None, ): """ Args: in_channels: The number of input channels. out_channels: The number of output channels. upsampling: Stride size that will be used for upsampling. activation: ModuleConfig for the activation function used in upsampling. enable_nhwc: Enable nhwc format, passed to wrapper. hpx_padding_mode: HEALPix padding backend passed to wrapper. nside: Native face height/width for HEALPix padding. """ super().__init__() upsampler = [] # Upsample transpose conv upsampler.append( HEALPixLayer( layer=nn.ConvTranspose2d, in_channels=in_channels, out_channels=out_channels, kernel_size=upsampling, stride=upsampling, padding=0, **_healpix_layer_kwargs( enable_nhwc, hpx_padding_mode, nside, ), ) ) if activation is not None: upsampler.append(activation.build()) self.upsampler = nn.Sequential(*upsampler) def forward(self, x): """Forward pass of the TransposedConvUpsample layer. Args: x: The values to upsample. Returns: th.Tensor: The upsampled values. """ return self.upsampler(x) class SmoothedInterpolate(nn.Module): """Interpolate then apply four-point smoother (zonally uniform signals).""" def __init__( self, in_channels: int = 3, scale_factor: int = 2, mode: str = "nearest", trim_size: int = 0, ): """ Args: in_channels: Number of channels for the depthwise smoother. scale_factor: Interpolation scale factor. mode: Interpolation mode passed to ``F.interpolate``. trim_size: Border pixels to crop after smoothing (removes edge artifacts). """ super().__init__() self.in_channels = in_channels self.scale_factor = scale_factor self.mode = mode self.trim_size = trim_size self.interp = th.nn.functional.interpolate smoother_kernel = th.tensor([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 0.0]]) smoother_kernel = smoother_kernel.unsqueeze(0).unsqueeze(0) smoother_kernel = smoother_kernel.repeat((in_channels, 1, 1, 1)) self.register_buffer("smoother_kernel", smoother_kernel) def forward(self, x: th.Tensor) -> th.Tensor: """ Args: x: Input tensor ``[N, C, H, W]``. Returns: Upsampled and smoothed tensor, optionally trimmed. """ x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode) x = ( th.nn.functional.conv2d( x, self.smoother_kernel, padding=0, groups=self.in_channels, ) / 4 ) if self.trim_size > 0: x = x[ ..., self.trim_size : -self.trim_size, self.trim_size : -self.trim_size, ] return x class SmoothedInterpolateConv(nn.Module): """Interpolate with seam padding, smoothing, then Conv2d on HEALPix data.""" def __init__( self, in_channels: int = 3, out_channels: int = 3, kernel_size: int = 3, dilation: int = 1, scale_factor: int = 2, mode: str = "nearest", activation: Optional[nn.Module] = None, enable_nhwc: bool = False, hpx_padding_mode: Literal[ "earth2grid", "karlbauer", "isolatitude" ] = "earth2grid", nside: Optional[int] = None, nside_after: Optional[int] = None, ): """ Args: in_channels: Number of input channels. out_channels: Number of output channels. kernel_size: Convolution kernel size after interpolation. dilation: Convolution dilation (must be 1 for HEALPix resize). scale_factor: Interpolation scale factor. mode: Interpolation mode for the smoothed upsample step. activation: Optional activation module appended after the conv. enable_nhwc: Use channels-last memory format. hpx_padding_mode: HEALPix padding backend passed to ``HEALPixLayer``. nside: Face height/width before upsampling (isolatitude gather indices). nside_after: Face height/width after upsampling for the conv step; required when ``nside`` is set and ``hpx_padding_mode`` is ``"isolatitude"``. """ super().__init__() if dilation > 1: raise ValueError( f"dilation > 1 is not supported for HEALPix resize convolutions, got {dilation}" ) if nside is not None and nside_after is None: if hpx_padding_mode == "isolatitude": raise ValueError( "SmoothedInterpolateConv requires nside_after when nside is set " 'and hpx_padding_mode="isolatitude"' ) nside_after = nside if ( nside is not None and nside_after is not None and nside_after != nside * scale_factor ): raise ValueError( f"nside_after ({nside_after}) must equal nside ({nside}) * " f"scale_factor ({scale_factor})" ) trim_size = 1 healpix_kwargs = _healpix_layer_kwargs( enable_nhwc, hpx_padding_mode, nside, ) healpix_kwargs_after = _healpix_layer_kwargs( enable_nhwc, hpx_padding_mode, nside_after, ) block = [ HEALPixLayer( layer=SmoothedInterpolate, in_channels=in_channels, scale_factor=scale_factor, mode=mode, trim_size=trim_size, **healpix_kwargs, ), HEALPixLayer( layer=nn.Conv2d, in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, dilation=dilation, **healpix_kwargs_after, ), ] if activation is not None: block.append(activation) self.block = nn.Sequential(*block) def forward(self, x: th.Tensor) -> th.Tensor: """ Args: x: Input tensor ``[N * 12, C, H, W]``. Returns: Upsampled convolved tensor. """ return self.block(x) # --- Convolution stack modules --- class BasicConvBlock(nn.Module): """Convolution block consisting of n subsequent convolutions and activations.""" def __init__( self, in_channels=3, out_channels=1, kernel_size=3, dilation=1, n_layers=1, latent_channels=None, activation=None, enable_nhwc=False, hpx_padding_mode="earth2grid", nside=None, ): """ Args: in_channels: The number of input channels. out_channels: The number of output channels. kernel_size: Size of the convolutional kernel. dilation: Spacing between kernel points, passed to nn.Conv2d. n_layers: Number of convolutional layers. latent_channels: Number of latent channels. activation: ModuleConfig for activation function to use. enable_nhwc: Enable nhwc format, passed to wrapper. hpx_padding_mode: HEALPix padding backend passed to wrapper. nside: Native face height/width for HEALPix padding. """ super().__init__() if latent_channels is None: latent_channels = max(in_channels, out_channels) convblock = [] for n in range(n_layers): convblock.append( HEALPixLayer( layer=th.nn.Conv2d, in_channels=in_channels if n == 0 else latent_channels, out_channels=out_channels if n == n_layers - 1 else latent_channels, kernel_size=kernel_size, dilation=dilation, **_healpix_layer_kwargs( enable_nhwc, hpx_padding_mode, nside, ), ) ) if activation is not None: convblock.append(activation.build()) self.convblock = nn.Sequential(*convblock) def forward(self, x): """Forward pass of the BasicConvBlock. Args: x: Inputs to the forward pass. Returns: th.Tensor: Result of the forward pass. """ return self.convblock(x) class ConvNeXtBlock(nn.Module): """A modified ConvNeXt network block as described in the paper "A ConvNet for the 21st Century" (https://arxiv.org/pdf/2201.03545.pdf). This block consists of a series of convolutional layers with optional activation functions, and a residual connection. Parameters: skip_module: A module to align the input and output channels for the residual connection. convblock: A sequential container of convolutional layers with optional activation functions. """ def __init__( self, in_channels: int = 3, latent_channels: int = 1, out_channels: int = 1, kernel_size: int = 3, dilation: int = 1, upscale_factor: int = 4, activation: Optional[CappedGELUConfig] = None, enable_nhwc: bool = False, hpx_padding_mode: Literal[ "earth2grid", "karlbauer", "isolatitude" ] = "earth2grid", nside: Optional[int] = None, ): """ Initializes a ConvNeXtBlock instance with specified parameters. Args: in_channels: Number of input channels. latent_channels: Number of latent channels used in the block. out_channels: Number of output channels. kernel_size: Size of the convolutional kernels. dilation: Dilation rate for convolutions. upscale_factor: Factor by which to upscale the number of latent channels. activation: Configuration for the activation function used between layers. enable_nhwc: Whether to enable NHWC format. hpx_padding_mode: HEALPix padding backend passed to wrapper. nside: Native face height/width for HEALPix padding. """ super().__init__() healpix_kwargs = _healpix_layer_kwargs( enable_nhwc, hpx_padding_mode, nside, ) # Instantiate 1x1 conv to increase/decrease channel depth if necessary if in_channels == out_channels: self.skip_module = lambda x: x # Identity-function required in forward pass else: self.skip_module = HEALPixLayer( layer=th.nn.Conv2d, in_channels=in_channels, out_channels=out_channels, kernel_size=1, **healpix_kwargs, ) # Convolution block convblock = [] # 3x3 convolution increasing channels convblock.append( HEALPixLayer( layer=th.nn.Conv2d, in_channels=in_channels, out_channels=int(latent_channels * upscale_factor), kernel_size=kernel_size, dilation=dilation, **healpix_kwargs, ) ) if activation is not None: convblock.append(activation.build()) # 3x3 convolution maintaining increased channels convblock.append( HEALPixLayer( layer=th.nn.Conv2d, in_channels=int(latent_channels * upscale_factor), out_channels=int(latent_channels * upscale_factor), kernel_size=kernel_size, dilation=dilation, **healpix_kwargs, ) ) if activation is not None: convblock.append(activation.build()) # Linear postprocessing convblock.append( HEALPixLayer( layer=th.nn.Conv2d, in_channels=int(latent_channels * upscale_factor), out_channels=out_channels, kernel_size=1, **healpix_kwargs, ) ) self.convblock = nn.Sequential(*convblock) def forward(self, x): """Forward pass of the ConvNeXtBlock. Args: x: Input tensor. Returns: The result of the forward pass. """ return self.skip_module(x) + self.convblock(x) class DoubleConvNeXtBlock(nn.Module): """A variant of the ConvNeXt block that includes two sequential ConvNeXt blocks within a single module. Parameters: skip_module1: A module to align the input and intermediate channels for the first residual connection. skip_module2: A module to align the intermediate and output channels for the second residual connection. convblock1: A sequential container of convolutional layers for the first ConvNeXt block. convblock2: A sequential container of convolutional layers for the second ConvNeXt block. """ def __init__( self, in_channels: int = 3, out_channels: int = 1, kernel_size: int = 3, dilation: int = 1, upscale_factor: int = 4, latent_channels: int = 1, activation: Optional[CappedGELUConfig] = None, enable_nhwc: bool = False, hpx_padding_mode: Literal[ "earth2grid", "karlbauer", "isolatitude" ] = "earth2grid", nside: Optional[int] = None, ): """ Initializes a DoubleConvNeXtBlock instance with specified parameters. Args: in_channels: Number of input channels (default is 3). out_channels: Number of output channels (default is 1). kernel_size: Size of the convolutional kernels (default is 3). dilation: Dilation rate for convolutions (default is 1). upscale_factor: Factor by which to upscale the number of latent channels (default is 4). latent_channels: Number of latent channels used in the block (default is 1). activation: Configuration for the activation function used between layers (default is None). enable_nhwc: Whether to enable NHWC format (default is False). hpx_padding_mode: HEALPix padding backend passed to wrapper. nside: Native face height/width for HEALPix padding. """ super().__init__() healpix_kwargs = _healpix_layer_kwargs( enable_nhwc, hpx_padding_mode, nside, ) if in_channels == int(latent_channels): self.skip_module1 = ( lambda x: x ) # Identity-function required in forward pass else: self.skip_module1 = HEALPixLayer( layer=th.nn.Conv2d, in_channels=in_channels, out_channels=int(latent_channels), kernel_size=1, **healpix_kwargs, ) if out_channels == int(latent_channels): self.skip_module2 = ( lambda x: x ) # Identity-function required in forward pass else: self.skip_module2 = HEALPixLayer( layer=th.nn.Conv2d, in_channels=int(latent_channels), out_channels=out_channels, kernel_size=1, **healpix_kwargs, ) # 1st ConvNeXt block, the output of this one remains internal convblock1 = [] # 3x3 convolution establishing latent channels channels convblock1.append( HEALPixLayer( layer=th.nn.Conv2d, in_channels=in_channels, out_channels=int(latent_channels), kernel_size=kernel_size, dilation=dilation, **healpix_kwargs, ) ) if activation is not None: convblock1.append(activation.build()) # 1x1 convolution establishing increased channels convblock1.append( HEALPixLayer( layer=th.nn.Conv2d, in_channels=int(latent_channels), out_channels=int(latent_channels * upscale_factor), kernel_size=1, dilation=dilation, **healpix_kwargs, ) ) if activation is not None: convblock1.append(activation.build()) # 1x1 convolution returning to latent channels convblock1.append( HEALPixLayer( layer=th.nn.Conv2d, in_channels=int(latent_channels * upscale_factor), out_channels=int(latent_channels), kernel_size=1, dilation=dilation, **healpix_kwargs, ) ) if activation is not None: convblock1.append(activation.build()) self.convblock1 = nn.Sequential(*convblock1) # 2nd ConNeXt block, takes the output of the first convnext block convblock2 = [] # 3x3 convolution establishing latent channels channels convblock2.append( HEALPixLayer( layer=th.nn.Conv2d, in_channels=int(latent_channels), out_channels=int(latent_channels), kernel_size=kernel_size, dilation=dilation, **healpix_kwargs, ) ) if activation is not None: convblock2.append(activation.build()) # 1x1 convolution establishing increased channels convblock2.append( HEALPixLayer( layer=th.nn.Conv2d, in_channels=int(latent_channels), out_channels=int(latent_channels * upscale_factor), kernel_size=1, dilation=dilation, **healpix_kwargs, ) ) if activation is not None: convblock2.append(activation.build()) # 1x1 convolution reducing to output channels convblock2.append( HEALPixLayer( layer=th.nn.Conv2d, in_channels=int(latent_channels * upscale_factor), out_channels=out_channels, kernel_size=1, dilation=dilation, **healpix_kwargs, ) ) if activation is not None: convblock2.append(activation.build()) self.convblock2 = nn.Sequential(*convblock2) def forward(self, x): """Forward pass of the DoubleConvNextBlock Args: x: inputs to the forward pass Returns: result of the forward pass """ # internal convnext result x1 = self.skip_module1(x) + self.convblock1(x) # return second convnext result return self.skip_module2(x1) + self.convblock2(x1) class SymmetricConvNeXtBlock(nn.Module): """A symmetric variant of the ConvNeXt block, with convolutional layers mirrored around a central axis for symmetric feature extraction. Parameters: skip_module1: A module to align the input and intermediate channels for the first residual connection. skip_module2: A module to align the intermediate and output channels for the second residual connection. convblock1: A sequential container of convolutional layers for the symmetric ConvNeXt block. """ def __init__( self, in_channels: int = 3, latent_channels: int = 1, out_channels: int = 1, kernel_size: int = 3, dilation: int = 1, upscale_factor: int = 4, activation: Optional[CappedGELUConfig] = None, enable_nhwc: bool = False, hpx_padding_mode: Literal[ "earth2grid", "karlbauer", "isolatitude" ] = "earth2grid", nside: Optional[int] = None, ): """ Initializes a SymmetricConvNeXtBlock instance with specified parameters. Args: in_channels: Number of input channels (default is 3). out_channels: Number of output channels (default is 1). kernel_size: Size of the convolutional kernels (default is 3). dilation: Dilation rate for convolutions (default is 1). upscale_factor: Upscale factor. latent_channels: Number of latent channels used in the block (default is 1). activation: Configuration for the activation function used between layers (default is None). enable_nhwc: Whether to enable NHWC format (default is False). hpx_padding_mode: HEALPix padding backend passed to wrapper. nside: Native face height/width for HEALPix padding. """ super().__init__() healpix_kwargs = _healpix_layer_kwargs( enable_nhwc, hpx_padding_mode, nside, ) if in_channels == int(latent_channels): self.skip_module = lambda x: x # Identity-function required in forward pass else: self.skip_module = HEALPixLayer( layer=th.nn.Conv2d, in_channels=in_channels, out_channels=out_channels, kernel_size=1, **healpix_kwargs, ) # 1st ConvNeXt block, the output of this one remains internal convblock = [] # 3x3 convolution establishing latent channels channels convblock.append( HEALPixLayer( layer=th.nn.Conv2d, in_channels=in_channels, out_channels=int(latent_channels), kernel_size=kernel_size, dilation=dilation, **healpix_kwargs, ) ) if activation is not None: convblock.append(activation.build()) # 1x1 convolution establishing increased channels convblock.append( HEALPixLayer( layer=th.nn.Conv2d, in_channels=int(latent_channels), out_channels=int(latent_channels * upscale_factor), kernel_size=1, dilation=dilation, **healpix_kwargs, ) ) if activation is not None: convblock.append(activation.build()) # 1x1 convolution returning to latent channels convblock.append( HEALPixLayer( layer=th.nn.Conv2d, in_channels=int(latent_channels * upscale_factor), out_channels=int(latent_channels), kernel_size=1, dilation=dilation, **healpix_kwargs, ) ) if activation is not None: convblock.append(activation.build()) # 3x3 convolution from latent channels to latent channels convblock.append( HEALPixLayer( layer=th.nn.Conv2d, in_channels=int(latent_channels), out_channels=out_channels, # int(latent_channels), kernel_size=kernel_size, dilation=dilation, **healpix_kwargs, ) ) if activation is not None: convblock.append(activation.build()) self.convblock = nn.Sequential(*convblock) def forward(self, x): """Forward pass of the SymmetricConvNextBlock Args: x: inputs to the forward pass Returns: result of the forward pass """ # residual connection with reshaped inpute and output of conv block return self.skip_module(x) + self.convblock(x) class Multi_SymmetricConvNeXtBlock(nn.Module): """Serial wrapper of ``SymmetricConvNeXtBlock`` repeated ``n_layers`` times.""" def __init__( self, in_channels: int = 3, latent_channels: int = 1, out_channels: int = 1, kernel_size: int = 3, dilation: int = 1, upscale_factor: int = 4, n_layers: int = 1, activation: Optional[CappedGELUConfig] = None, enable_nhwc: bool = False, hpx_padding_mode: Literal[ "earth2grid", "karlbauer", "isolatitude" ] = "earth2grid", nside: Optional[int] = None, ): """ Args: in_channels: Number of input channels (first block only). latent_channels: Latent channel width inside each symmetric block. out_channels: Number of output channels for every block. kernel_size: Convolution kernel size. dilation: Convolution dilation. upscale_factor: Channel upscale factor inside each block. n_layers: Number of stacked ``SymmetricConvNeXtBlock`` modules. activation: Optional ``CappedGELUConfig`` between layers. enable_nhwc: Use channels-last memory format. hpx_padding_mode: HEALPix padding backend passed to child blocks. nside: Native face height/width for HEALPix padding. """ super().__init__() self.blocks = nn.ModuleList() for i in range(n_layers): curr_in_channels = in_channels if i == 0 else out_channels self.blocks.append( SymmetricConvNeXtBlock( in_channels=curr_in_channels, latent_channels=latent_channels, out_channels=out_channels, kernel_size=kernel_size, dilation=dilation, upscale_factor=upscale_factor, activation=activation, enable_nhwc=enable_nhwc, hpx_padding_mode=hpx_padding_mode, nside=nside, ) ) def forward(self, x): """ Args: x: Input tensor. Returns: Output after ``n_layers`` symmetric ConvNeXt blocks. """ out = x for block in self.blocks: out = block(out) return out # --- Utilities --- class Interpolate(nn.Module): """Helper class for interpolation. This class handles interpolation, storing scale factor and mode for `nn.functional.interpolate`. """ def __init__(self, scale_factor: Union[int, Tuple], mode: str = "nearest"): """ Args: scale_factor: Multiplier for spatial size, passed to `nn.functional.interpolate`. mode: Interpolation mode used for upsampling, passed to `nn.functional.interpolate`. """ super().__init__() self.interp = nn.functional.interpolate self.scale_factor = scale_factor self.mode = mode def forward(self, inputs): """Forward pass of the Interpolate layer. Args: inputs: Inputs to interpolate. Returns: th.Tensor: The interpolated values. """ return self.interp(inputs, scale_factor=self.scale_factor, mode=self.mode)