# flake8: noqa
# Copied from https://github.com/NVIDIA/modulus/commit/89a6091bd21edce7be4e0539cbd91507004faf08
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import dataclasses
from typing import Literal, Optional, Tuple, Union, cast
import torch as th
import torch.nn as nn
from .healpix_activations import CappedGELUConfig
from .healpix_layers import HEALPixLayer
# RECURRENT BLOCKS
[docs]@dataclasses.dataclass
class RecurrentBlockConfig:
"""
Configuration for the recurrent block.
Parameters:
in_channels: Number of input channels, default is 3.
kernel_size: Size of the kernel, default is 1.
enable_nhwc: Flag to enable NHWC data format, default is False.
enable_healpixpad: Flag to enable HEALPix padding, default is False.
block_type: Type of recurrent block, either "ConvGRUBlock" or "ConvLSTMBlock",
default is "ConvGRUBlock".
"""
in_channels: int = 3
kernel_size: int = 1
enable_nhwc: bool = False
enable_healpixpad: bool = False
block_type: Literal["ConvGRUBlock", "ConvLSTMBlock"] = "ConvGRUBlock"
[docs] def build(self) -> nn.Module:
"""
Builds the recurrent block model.
Returns:
Recurrent block.
"""
if self.block_type == "ConvGRUBlock":
return ConvGRUBlock(
in_channels=self.in_channels,
kernel_size=self.kernel_size,
enable_nhwc=self.enable_nhwc,
enable_healpixpad=self.enable_healpixpad,
)
elif self.block_type == "ConvLSTMBlock":
return ConvLSTMBlock(
in_channels=self.in_channels,
kernel_size=self.kernel_size,
enable_nhwc=self.enable_nhwc,
enable_healpixpad=self.enable_healpixpad,
)
else:
raise ValueError(f"Unsupported block type: {self.block_type}")
[docs]@dataclasses.dataclass
class ConvBlockConfig:
"""
Configuration for the convolutional block.
Parameters:
in_channels: Number of input channels, default is 3.
out_channels: Number of output channels, default is 1.
kernel_size: Size of the kernel, default is 3.
dilation: Dilation rate, default is 1.
n_layers: Number of layers, default is 1.
upsampling: Upsampling factor for TransposedConvUpsample, default is 2.
upscale_factor: Upscale factor for ConvNeXtBlock and SymmetricConvNeXtBlock,
default is 4.
latent_channels: Number of latent channels, default is None.
activation: Activation configuration, default is None.
enable_nhwc: Flag to enable NHWC data format, default is False.
enable_healpixpad: Flag to enable HEALPix padding, default is False.
block_type: Type of block, default is "BasicConvBlock".
"""
in_channels: int = 3
out_channels: int = 1
kernel_size: int = 3
dilation: int = 1
n_layers: int = 1
upsampling: int = 2
upscale_factor: int = 4
latent_channels: Optional[int] = None
activation: Optional[CappedGELUConfig] = None
enable_nhwc: bool = False
enable_healpixpad: bool = False
block_type: Literal[
"BasicConvBlock",
"ConvNeXtBlock",
"SymmetricConvNeXtBlock",
"TransposedConvUpsample",
] = "BasicConvBlock"
[docs] def build(self) -> nn.Module:
"""
Builds the convolutional block model.
Returns:
Convolutional block model.
"""
if self.block_type == "BasicConvBlock":
return BasicConvBlock(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=self.kernel_size,
dilation=self.dilation,
n_layers=self.n_layers,
latent_channels=self.latent_channels,
activation=self.activation,
enable_nhwc=self.enable_nhwc,
enable_healpixpad=self.enable_healpixpad,
)
elif self.block_type == "ConvNeXtBlock":
if self.latent_channels is None:
self.latent_channels = 1
return ConvNeXtBlock(
in_channels=self.in_channels,
latent_channels=cast(int, self.latent_channels),
out_channels=self.out_channels,
kernel_size=self.kernel_size,
dilation=self.dilation,
upscale_factor=self.upscale_factor,
activation=self.activation,
enable_nhwc=self.enable_nhwc,
enable_healpixpad=self.enable_healpixpad,
)
elif self.block_type == "SymmetricConvNeXtBlock":
if self.latent_channels is None:
self.latent_channels = 1
return SymmetricConvNeXtBlock(
in_channels=self.in_channels,
latent_channels=cast(int, self.latent_channels),
out_channels=self.out_channels,
kernel_size=self.kernel_size,
dilation=self.dilation,
upscale_factor=self.upscale_factor,
activation=self.activation,
enable_nhwc=self.enable_nhwc,
enable_healpixpad=self.enable_healpixpad,
)
elif self.block_type == "TransposedConvUpsample":
return TransposedConvUpsample(
in_channels=self.in_channels,
out_channels=self.out_channels,
upsampling=self.upsampling,
activation=self.activation,
enable_nhwc=self.enable_nhwc,
enable_healpixpad=self.enable_healpixpad,
)
else:
raise ValueError(f"Unsupported block type: {self.block_type}")
class ConvGRUBlock(nn.Module):
"""Class that implements a Convolutional GRU.
Code modified from:
https://github.com/happyjin/ConvGRU-pytorch/blob/master/convGRU.py
"""
def __init__(
self,
in_channels=3,
kernel_size=1,
enable_nhwc=False,
enable_healpixpad=False,
):
"""
Args:
in_channels: The number of input channels.
kernel_size: Size of the convolutional kernel.
enable_nhwc: Enable nhwc format, passed to wrapper.
enable_healpixpad: If HEALPixPadding should be enabled, passed to wrapper.
"""
super().__init__()
self.channels = in_channels
self.conv_gates = HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=in_channels + self.channels,
out_channels=2 * self.channels, # for update_gate, reset_gate respectively
kernel_size=kernel_size,
padding="same",
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
self.conv_can = HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=in_channels + self.channels,
out_channels=self.channels, # for candidate neural memory
kernel_size=kernel_size,
padding="same",
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
self.h = th.zeros(1, 1, 1, 1)
def forward(self, inputs):
"""Forward pass of the ConvGRUBlock.
Args:
inputs: Input to the forward pass.
Returns:
th.Tensor: Result of the forward pass.
"""
if inputs.shape != self.h.shape:
self.h = th.zeros_like(inputs)
combined = th.cat([inputs, self.h], dim=1)
combined_conv = self.conv_gates(combined)
gamma, beta = th.split(combined_conv, self.channels, dim=1)
reset_gate = th.sigmoid(gamma)
update_gate = th.sigmoid(beta)
combined = th.cat([inputs, reset_gate * self.h], dim=1)
cc_cnm = self.conv_can(combined)
cnm = th.tanh(cc_cnm)
h_next = (1 - update_gate) * self.h + update_gate * cnm
self.h = h_next
return inputs + h_next
def reset(self):
"""Reset the update gates."""
self.h = th.zeros_like(self.h)
class ConvLSTMBlock(nn.Module):
"""Convolutional LSTM block."""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 1,
latent_channels: int = 1,
kernel_size: int = 3,
downscale_factor: int = 4,
upscale_factor: int = 4,
n_layers: int = 1,
latent_conv_size: int = 3, # Add latent_conv_size parameter
dilation: int = 1,
activation: nn.Module = None,
enable_nhwc: bool = False,
enable_healpixpad: bool = False,
):
"""
Args:
in_channels: The number of input channels.
out_channels: The number of output channels.
latent_channels: Number of latent channels.
kernel_size: Size of the convolutional kernel.
downscale_factor: Downscale factor.
upscale_factor: Upscale factor.
n_layers: Number of layers.
latent_conv_size: Size of latent convolution.
dilation: Spacing between kernel points.
activation: Activation function.
enable_nhwc: Enable nhwc format.
enable_healpixpad: If HEALPixPadding should be enabled.
"""
super().__init__()
# Instantiate 1x1 conv to increase/decrease channel depth if necessary
# Skip connection for output
if in_channels == out_channels:
self.skip_module = lambda x: x # Identity-function required in forward pass
else:
self.skip_module = HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=in_channels,
out_channels=in_channels, # out channels describes the space of the output of conv here; but we have the output of LSTM which is the input layer size
kernel_size=1,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
# Convolution block
convblock = []
# 3x3 convolution increasing channels
convblock.append(
HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=in_channels
* 2, # accounts for the h layer, which is concatenated before convolution runs
out_channels=int(latent_channels * upscale_factor),
kernel_size=kernel_size,
dilation=dilation,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
)
if activation is not None:
convblock.append(activation)
# 3x3 convolution maintaining increased channels
convblock.append(
HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=int(latent_channels * upscale_factor),
out_channels=int(latent_channels * upscale_factor),
kernel_size=kernel_size,
dilation=dilation,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
)
if activation is not None:
convblock.append(activation)
# Now for the LSTM bit
self.channels = in_channels
self.lstm_gates = HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=latent_channels * upscale_factor,
out_channels=self.channels
* 4, # for input_gate, forget_gate, cell_gate, output_gate respectively (LSTM)
kernel_size=kernel_size,
padding="same",
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
self.h = th.zeros(1, 1, 1, 1)
self.c = th.zeros(1, 1, 1, 1)
self.convblock = nn.Sequential(*convblock)
def forward(self, inputs):
"""Forward pass of the ConvLSTMBlock.
Args:
x: Inputs to the forward pass.
Returns:
th.Tensor: Result of the forward pass.
"""
if inputs.shape != self.h.shape:
self.h = th.zeros_like(inputs)
self.c = th.zeros_like(inputs)
combined = th.cat([inputs, self.h], dim=1)
conv_outputs = self.convblock(combined)
lstm_gates = self.lstm_gates(conv_outputs)
# Split the combined_conv into input_gate, forget_gate, cell_gate, output_gate
i, f, c_hat, o = th.split(lstm_gates, self.channels, dim=1)
input_gate = th.sigmoid(i)
forget_gate = th.sigmoid(f)
cell_gate = th.tanh(c_hat)
output_gate = th.sigmoid(o)
self.c = forget_gate * self.c + input_gate * cell_gate
self.h = output_gate * th.tanh(self.c)
skip_connection = self.skip_module(inputs)
return skip_connection + self.h
def reset(self):
self.h = th.zeros_like(self.h)
self.c = th.zeros_like(self.c)
# CONV BLOCKS
class BasicConvBlock(nn.Module):
"""Convolution block consisting of n subsequent convolutions and activations."""
def __init__(
self,
in_channels=3,
out_channels=1,
kernel_size=3,
dilation=1,
n_layers=1,
latent_channels=None,
activation=None,
enable_nhwc=False,
enable_healpixpad=False,
):
"""
Args:
in_channels: The number of input channels.
out_channels: The number of output channels.
kernel_size: Size of the convolutional kernel.
dilation: Spacing between kernel points, passed to nn.Conv2d.
n_layers: Number of convolutional layers.
latent_channels: Number of latent channels.
activation: ModuleConfig for activation function to use.
enable_nhwc: Enable nhwc format, passed to wrapper.
enable_healpixpad:: If HEALPixPadding should be enabled, passed to wrapper.
"""
super().__init__()
if latent_channels is None:
latent_channels = max(in_channels, out_channels)
convblock = []
for n in range(n_layers):
convblock.append(
HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=in_channels if n == 0 else latent_channels,
out_channels=out_channels if n == n_layers - 1 else latent_channels,
kernel_size=kernel_size,
dilation=dilation,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
)
if activation is not None:
convblock.append(activation.build())
self.convblock = nn.Sequential(*convblock)
def forward(self, x):
"""Forward pass of the BasicConvBlock.
Args:
x: Inputs to the forward pass.
Returns:
th.Tensor: Result of the forward pass.
"""
return self.convblock(x)
class ConvNeXtBlock(nn.Module):
"""A modified ConvNeXt network block as described in the paper
"A ConvNet for the 21st Century" (https://arxiv.org/pdf/2201.03545.pdf).
This block consists of a series of convolutional layers with optional activation functions,
and a residual connection.
Parameters:
skip_module: A module to align the input and output channels for the residual connection.
convblock: A sequential container of convolutional layers with optional activation functions.
"""
def __init__(
self,
in_channels: int = 3,
latent_channels: int = 1,
out_channels: int = 1,
kernel_size: int = 3,
dilation: int = 1,
upscale_factor: int = 4,
activation: Optional[CappedGELUConfig] = None,
enable_nhwc: bool = False,
enable_healpixpad: bool = False,
):
"""
Initializes a ConvNeXtBlock instance with specified parameters.
Args:
in_channels: Number of input channels.
latent_channels: Number of latent channels used in the block.
out_channels: Number of output channels.
kernel_size: Size of the convolutional kernels.
dilation: Dilation rate for convolutions.
upscale_factor: Factor by which to upscale the number of latent channels.
activation: Configuration for the activation function used between layers.
enable_nhwc: Whether to enable NHWC format.
enable_healpixpad: Whether to enable HEALPixPadding.
"""
super().__init__()
# Instantiate 1x1 conv to increase/decrease channel depth if necessary
if in_channels == out_channels:
self.skip_module = lambda x: x # Identity-function required in forward pass
else:
self.skip_module = HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
# Convolution block
convblock = []
# 3x3 convolution increasing channels
convblock.append(
HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=in_channels,
out_channels=int(latent_channels * upscale_factor),
kernel_size=kernel_size,
dilation=dilation,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
)
if activation is not None:
convblock.append(activation.build())
# 3x3 convolution maintaining increased channels
convblock.append(
HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=int(latent_channels * upscale_factor),
out_channels=int(latent_channels * upscale_factor),
kernel_size=kernel_size,
dilation=dilation,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
)
if activation is not None:
convblock.append(activation.build())
# Linear postprocessing
convblock.append(
HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=int(latent_channels * upscale_factor),
out_channels=out_channels,
kernel_size=1,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
)
self.convblock = nn.Sequential(*convblock)
def forward(self, x):
"""Forward pass of the ConvNeXtBlock.
Args:
x: Input tensor.
Returns:
The result of the forward pass.
"""
return self.skip_module(x) + self.convblock(x)
class DoubleConvNeXtBlock(nn.Module):
"""A variant of the ConvNeXt block that includes two sequential ConvNeXt blocks within a single module.
Parameters:
skip_module1: A module to align the input and intermediate channels for the first residual connection.
skip_module2: A module to align the intermediate and output channels for the second residual connection.
convblock1: A sequential container of convolutional layers for the first ConvNeXt block.
convblock2: A sequential container of convolutional layers for the second ConvNeXt block.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 1,
kernel_size: int = 3,
dilation: int = 1,
upscale_factor: int = 4,
latent_channels: int = 1,
activation: Optional[CappedGELUConfig] = None,
enable_nhwc: bool = False,
enable_healpixpad: bool = False,
):
"""
Initializes a DoubleConvNeXtBlock instance with specified parameters.
Args:
in_channels: Number of input channels (default is 3).
out_channels: Number of output channels (default is 1).
kernel_size: Size of the convolutional kernels (default is 3).
dilation: Dilation rate for convolutions (default is 1).
upscale_factor: Factor by which to upscale the number of latent channels (default is 4).
latent_channels: Number of latent channels used in the block (default is 1).
activation: Configuration for the activation function used between layers (default is None).
enable_nhwc: Whether to enable NHWC format (default is False).
enable_healpixpad: Whether to enable HEALPixPadding (default is False).
"""
super().__init__()
if in_channels == int(latent_channels):
self.skip_module1 = (
lambda x: x
) # Identity-function required in forward pass
else:
self.skip_module1 = HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=in_channels,
out_channels=int(latent_channels),
kernel_size=1,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
if out_channels == int(latent_channels):
self.skip_module2 = (
lambda x: x
) # Identity-function required in forward pass
else:
self.skip_module2 = HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=int(latent_channels),
out_channels=out_channels,
kernel_size=1,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
# 1st ConvNeXt block, the output of this one remains internal
convblock1 = []
# 3x3 convolution establishing latent channels channels
convblock1.append(
HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=in_channels,
out_channels=int(latent_channels),
kernel_size=kernel_size,
dilation=dilation,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
)
if activation is not None:
convblock1.append(activation.build())
# 1x1 convolution establishing increased channels
convblock1.append(
HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=int(latent_channels),
out_channels=int(latent_channels * upscale_factor),
kernel_size=1,
dilation=dilation,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
)
if activation is not None:
convblock1.append(activation.build())
# 1x1 convolution returning to latent channels
convblock1.append(
HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=int(latent_channels * upscale_factor),
out_channels=int(latent_channels),
kernel_size=1,
dilation=dilation,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
)
if activation is not None:
convblock1.append(activation.build())
self.convblock1 = nn.Sequential(*convblock1)
# 2nd ConNeXt block, takes the output of the first convnext block
convblock2 = []
# 3x3 convolution establishing latent channels channels
convblock2.append(
HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=int(latent_channels),
out_channels=int(latent_channels),
kernel_size=kernel_size,
dilation=dilation,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
)
if activation is not None:
convblock2.append(activation.build())
# 1x1 convolution establishing increased channels
convblock2.append(
HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=int(latent_channels),
out_channels=int(latent_channels * upscale_factor),
kernel_size=1,
dilation=dilation,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
)
if activation is not None:
convblock2.append(activation.build())
# 1x1 convolution reducing to output channels
convblock2.append(
HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=int(latent_channels * upscale_factor),
out_channels=out_channels,
kernel_size=1,
dilation=dilation,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
)
if activation is not None:
convblock2.append(activation.build())
self.convblock2 = nn.Sequential(*convblock2)
def forward(self, x):
"""Forward pass of the DoubleConvNextBlock
Args:
x: inputs to the forward pass
Returns:
result of the forward pass
"""
# internal convnext result
x1 = self.skip_module1(x) + self.convblock1(x)
# return second convnext result
return self.skip_module2(x1) + self.convblock2(x1)
class SymmetricConvNeXtBlock(nn.Module):
"""A symmetric variant of the ConvNeXt block, with convolutional layers mirrored
around a central axis for symmetric feature extraction.
Parameters:
skip_module1: A module to align the input and intermediate channels for the first residual connection.
skip_module2: A module to align the intermediate and output channels for the second residual connection.
convblock1: A sequential container of convolutional layers for the symmetric ConvNeXt block.
"""
def __init__(
self,
in_channels: int = 3,
latent_channels: int = 1,
out_channels: int = 1,
kernel_size: int = 3,
dilation: int = 1,
upscale_factor: int = 4,
activation: Optional[CappedGELUConfig] = None,
enable_nhwc: bool = False,
enable_healpixpad: bool = False,
):
"""
Initializes a SymmetricConvNeXtBlock instance with specified parameters.
Args:
in_channels: Number of input channels (default is 3).
out_channels: Number of output channels (default is 1).
kernel_size: Size of the convolutional kernels (default is 3).
dilation: Dilation rate for convolutions (default is 1).
upscale_factor: Upscale factor.
latent_channels: Number of latent channels used in the block (default is 1).
activation: Configuration for the activation function used between layers (default is None).
enable_nhwc: Whether to enable NHWC format (default is False).
enable_healpixpad: Whether to enable HEALPixPadding (default is False).
"""
if in_channels == int(latent_channels):
self.skip_module = lambda x: x # Identity-function required in forward pass
else:
self.skip_module = HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
# 1st ConvNeXt block, the output of this one remains internal
convblock = []
# 3x3 convolution establishing latent channels channels
convblock.append(
HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=in_channels,
out_channels=int(latent_channels),
kernel_size=kernel_size,
dilation=dilation,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
)
if activation is not None:
convblock.append(activation.build())
# 1x1 convolution establishing increased channels
convblock.append(
HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=int(latent_channels),
out_channels=int(latent_channels * upscale_factor),
kernel_size=1,
dilation=dilation,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
)
if activation is not None:
convblock.append(activation.build())
# 1x1 convolution returning to latent channels
convblock.append(
HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=int(latent_channels * upscale_factor),
out_channels=int(latent_channels),
kernel_size=1,
dilation=dilation,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
)
if activation is not None:
convblock.append(activation.build())
# 3x3 convolution from latent channels to latent channels
convblock.append(
HEALPixLayer(
layer=th.nn.Conv2d,
in_channels=int(latent_channels),
out_channels=out_channels, # int(latent_channels),
kernel_size=kernel_size,
dilation=dilation,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
)
if activation is not None:
convblock.append(activation.build())
self.convblock = nn.Sequential(*convblock)
def forward(self, x):
"""Forward pass of the SymmetricConvNextBlock
Args:
x: inputs to the forward pass
Returns:
result of the forward pass
"""
# residual connection with reshaped inpute and output of conv block
return self.skip_module(x) + self.convblock(x)
class TransposedConvUpsample(nn.Module):
"""Wrapper for upsampling with a transposed convolution using HEALPix or other tensor data.
This class wraps the `nn.ConvTranspose2d` class to handle tensor data with
HEALPix or other geometry layers.
"""
def __init__(
self,
in_channels: int = 3,
out_channels: int = 1,
upsampling: int = 2,
activation: Optional[CappedGELUConfig] = None,
enable_nhwc: bool = False,
enable_healpixpad: bool = False,
):
"""
Args:
in_channels: The number of input channels.
out_channels: The number of output channels.
upsampling: Size used for upsampling.
activation: ModuleConfig for the activation function used in upsampling.
enable_nhwc: Enable nhwc format, passed to wrapper.
enable_healpixpad: If HEALPixPadding should be enabled, passed to wrapper.
"""
super().__init__()
upsampler = []
# Upsample transpose conv
upsampler.append(
HEALPixLayer(
layer=nn.ConvTranspose2d,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=upsampling,
stride=upsampling,
padding=0,
enable_nhwc=enable_nhwc,
enable_healpixpad=enable_healpixpad,
)
)
if activation is not None:
upsampler.append(activation.build())
self.upsampler = nn.Sequential(*upsampler)
def forward(self, x):
"""Forward pass of the TransposedConvUpsample layer.
Args:
x: The values to upsample.
Returns:
th.Tensor: The upsampled values.
"""
return self.upsampler(x)
# Helpers
class Interpolate(nn.Module):
"""Helper class for interpolation.
This class handles interpolation, storing scale factor and mode for
`nn.functional.interpolate`.
"""
def __init__(self, scale_factor: Union[int, Tuple], mode: str = "nearest"):
"""
Args:
scale_factor: Multiplier for spatial size, passed to `nn.functional.interpolate`.
mode, : Interpolation mode used for upsampling, passed to `nn.functional.interpolate`.
"""
super().__init__()
self.interp = nn.functional.interpolate
self.scale_factor = scale_factor
self.mode = mode
def forward(self, inputs):
"""Forward pass of the Interpolate layer.
Args:
inputs: Inputs to interpolate.
Returns:
th.Tensor: The interpolated values.
"""
return self.interp(inputs, scale_factor=self.scale_factor, mode=self.mode)