Source code for fme.ace.models.healpix.healpix_decoder

# flake8: noqa
# Copied from https://github.com/NVIDIA/modulus/commit/89a6091bd21edce7be4e0539cbd91507004faf08
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
from typing import List, Literal, Optional, Sequence

import torch as th
import torch.nn as nn

from .healpix_blocks import ConvBlockConfig, UpsamplingBlockConfig


[docs]@dataclasses.dataclass class UNetDecoderConfig: """ Configuration for the UNet Decoder. Parameters: conv_block: Configuration for the convolutional block. up_sampling_block: Configuration for the spatial upsampling block (transpose conv, smoothed interpolate+conv, etc.). output_layer: Configuration for the output layer block. n_channels: Number of channels for each layer, by default (34, 68, 136). n_layers: Number of layers in each block, by default (1, 2, 2). output_channels: Number of output channels, by default 1. dilations: List of dilation rates for the layers, by default None. enable_nhwc: Flag to enable NHWC data format, by default False. hpx_padding_mode: HEALPix padding backend (``"earth2grid"``, ``"karlbauer"``, or ``"isolatitude"``), by default ``"earth2grid"``. nside: Face height/width per decoder level (shallowest to deepest), or ``None`` to omit per-level padding resolution. """ conv_block: ConvBlockConfig up_sampling_block: UpsamplingBlockConfig output_layer: ConvBlockConfig n_channels: List[int] = dataclasses.field(default_factory=lambda: [34, 68, 136]) n_layers: List[int] = dataclasses.field(default_factory=lambda: [1, 2, 2]) output_channels: int = 1 dilations: Optional[list] = None enable_nhwc: bool = False hpx_padding_mode: Literal["earth2grid", "karlbauer", "isolatitude"] = "earth2grid" nside: Optional[Sequence[int]] = None
[docs] def build(self) -> nn.Module: """ Builds the UNet Decoder model. Returns: UNet Decoder model. """ return UNetDecoder( conv_block=self.conv_block, up_sampling_block=self.up_sampling_block, output_layer=self.output_layer, n_channels=self.n_channels, n_layers=self.n_layers, output_channels=self.output_channels, dilations=self.dilations, enable_nhwc=self.enable_nhwc, hpx_padding_mode=self.hpx_padding_mode, nside=self.nside, )
class UNetDecoder(nn.Module): """Generic UNetDecoder that can be applied to arbitrary meshes.""" def __init__( self, conv_block: ConvBlockConfig, up_sampling_block: UpsamplingBlockConfig, output_layer: ConvBlockConfig, n_channels: Sequence = (64, 32, 16), n_layers: Sequence = (1, 2, 2), output_channels: int = 1, dilations: Optional[list] = None, enable_nhwc: bool = False, hpx_padding_mode: Literal[ "earth2grid", "karlbauer", "isolatitude" ] = "earth2grid", nside: Optional[Sequence[int]] = None, ): """ Initialize the UNetDecoder. Args: conv_block: Configuration for the convolutional block. up_sampling_block: Configuration for the upsampling block. output_layer: Configuration for the output layer. n_channels: Sequence specifying the number of channels in each decoder layer. n_layers: Sequence specifying the number of layers in each block. output_channels: Number of output channels. dilations: List of dilations to use for the convolutional blocks. enable_nhwc: If True, use channel last format. hpx_padding_mode: HEALPix padding backend. Default ``"earth2grid"``; also supports ``"karlbauer"`` and ``"isolatitude"``. nside: Face height/width per level (shallowest to deepest). Length must match ``len(n_channels)`` when set. Decoder stage ``n`` (deepest first) uses ``nside[-(n + 1)]``. """ super().__init__() self.channel_dim = 1 if dilations is None: dilations = [1 for _ in range(len(n_channels))] nside_levels: Optional[tuple[int, ...]] = None if nside is not None: nside_levels = tuple(int(v) for v in nside) if len(nside_levels) != len(n_channels): raise ValueError( f"nside length must match decoder levels; got {len(nside_levels)} " f"vs {len(n_channels)}" ) conv_tpl = dataclasses.replace(conv_block) up_tpl = dataclasses.replace(up_sampling_block) up_factor = up_tpl.stride n_levels = len(n_channels) self.decoder = [] for n, curr_channel in enumerate(n_channels): up_sample_module = None level_nside = ( None if nside_levels is None else nside_levels[n_levels - 1 - n] ) if n != 0: if nside_levels is not None: before = nside_levels[n_levels - n] after = level_nside if before * up_factor != after: raise ValueError( f"decoder nside upsample: nside[{n_levels - 1 - n}]={after} " f"must equal nside[{n_levels - n}] * upsample factor " f"({up_factor}), but nside[{n_levels - n}]={before}" ) up_cfg = dataclasses.replace( up_tpl, in_channels=curr_channel, out_channels=curr_channel, enable_nhwc=enable_nhwc, hpx_padding_mode=hpx_padding_mode, nside=None if nside_levels is None else nside_levels[n_levels - n], nside_after=level_nside, ) up_sample_module = up_cfg.build() next_channel = ( n_channels[n + 1] if n < len(n_channels) - 1 else n_channels[-1] ) conv_cfg = dataclasses.replace( conv_tpl, in_channels=curr_channel * 2 if n > 0 else curr_channel, latent_channels=curr_channel, out_channels=next_channel, dilation=dilations[n], n_layers=n_layers[n], enable_nhwc=enable_nhwc, hpx_padding_mode=hpx_padding_mode, nside=level_nside, ) conv_module = conv_cfg.build() self.decoder.append( nn.ModuleDict( { "upsamp": up_sample_module, "conv": conv_module, } ) ) self.decoder = nn.ModuleList(self.decoder) out_nside = None if nside_levels is None else nside_levels[0] out_cfg = dataclasses.replace( output_layer, in_channels=curr_channel, out_channels=output_channels, dilation=dilations[-1], enable_nhwc=enable_nhwc, hpx_padding_mode=hpx_padding_mode, nside=out_nside, ) self.output_layer = out_cfg.build() def forward(self, inputs: Sequence[th.Tensor]) -> th.Tensor: """ Forward pass of the UNetDecoder. Args: inputs: Sequence of tensors, one for each decoder level. Returns: The decoded values. """ x = inputs[-1] for n, layer in enumerate(self.decoder): if layer["upsamp"] is not None: up = layer["upsamp"](x) x = th.cat([up, inputs[-1 - n]], dim=self.channel_dim) x = layer["conv"](x) return self.output_layer(x)