Source code for fme.ace.models.healpix.healpix_decoder

# flake8: noqa
# Copied from https://github.com/NVIDIA/modulus/commit/89a6091bd21edce7be4e0539cbd91507004faf08
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
from typing import List, Optional, Sequence

import torch as th
import torch.nn as nn

from .healpix_blocks import ConvBlockConfig, RecurrentBlockConfig


[docs]@dataclasses.dataclass class UNetDecoderConfig: """ Configuration for the UNet Decoder. Parameters: conv_block: Configuration for the convolutional block. up_sampling_block: Configuration for the up-sampling block. output_layer: Configuration for the output layer block. recurrent_block: Configuration for the recurrent block, by default None. n_channels: Number of channels for each layer, by default (34, 68, 136). n_layers: Number of layers in each block, by default (1, 2, 2). output_channels: Number of output channels, by default 1. dilations: List of dilation rates for the layers, by default None. enable_nhwc: Flag to enable NHWC data format, by default False. enable_healpixpad: Flag to enable HEALPix padding, by default False. """ conv_block: ConvBlockConfig up_sampling_block: ConvBlockConfig output_layer: ConvBlockConfig recurrent_block: Optional[RecurrentBlockConfig] = None n_channels: List[int] = dataclasses.field(default_factory=lambda: [34, 68, 136]) n_layers: List[int] = dataclasses.field(default_factory=lambda: [1, 2, 2]) output_channels: int = 1 dilations: Optional[list] = None enable_nhwc: bool = False enable_healpixpad: bool = False
[docs] def build(self) -> nn.Module: """ Builds the UNet Decoder model. Returns: UNet Decoder model. """ return UNetDecoder( conv_block=self.conv_block, up_sampling_block=self.up_sampling_block, output_layer=self.output_layer, recurrent_block=self.recurrent_block, n_channels=self.n_channels, n_layers=self.n_layers, output_channels=self.output_channels, dilations=self.dilations, enable_nhwc=self.enable_nhwc, enable_healpixpad=self.enable_healpixpad, )
class UNetDecoder(nn.Module): """Generic UNetDecoder that can be applied to arbitrary meshes.""" def __init__( self, conv_block: ConvBlockConfig, up_sampling_block: ConvBlockConfig, output_layer: ConvBlockConfig, recurrent_block: Optional[RecurrentBlockConfig] = None, n_channels: Sequence = (64, 32, 16), n_layers: Sequence = (1, 2, 2), output_channels: int = 1, dilations: Optional[list] = None, enable_nhwc: bool = False, enable_healpixpad: bool = False, ): """ Initialize the UNetDecoder. Args: conv_block: Configuration for the convolutional block. up_sampling_block: Configuration for the upsampling block. output_layer: Configuration for the output layer. recurrent_block: Configuration for the recurrent block. If None, recurrent blocks are not used. n_channels: Sequence specifying the number of channels in each decoder layer. n_layers: Sequence specifying the number of layers in each block. output_channels: Number of output channels. dilations: List of dilations to use for the convolutional blocks. enable_nhwc: If True, use channel last format. enable_healpixpad: If True, use the healpixpad library if installed. """ super().__init__() self.channel_dim = 1 if dilations is None: dilations = [1 for _ in range(len(n_channels))] self.decoder = [] for n, curr_channel in enumerate(n_channels): up_sample_module = None if n != 0: up_sampling_block.in_channels = curr_channel up_sampling_block.out_channels = curr_channel up_sampling_block.enable_nhwc = enable_nhwc up_sampling_block.enable_healpixpad = enable_healpixpad up_sample_module = up_sampling_block.build() next_channel = ( n_channels[n + 1] if n < len(n_channels) - 1 else n_channels[-1] ) conv_block.in_channels = curr_channel * 2 if n > 0 else curr_channel conv_block.latent_channels = curr_channel conv_block.out_channels = next_channel conv_block.dilation = dilations[n] conv_block.n_layers = n_layers[n] conv_block.enable_nhwc = enable_nhwc conv_block.enable_healpixpad = enable_healpixpad conv_module = conv_block.build() rec_module = None if recurrent_block is not None: recurrent_block.in_channels = next_channel recurrent_block.enable_healpixpad = enable_healpixpad rec_module = recurrent_block.build() self.decoder.append( nn.ModuleDict( { "upsamp": up_sample_module, "conv": conv_module, "recurrent": rec_module, } ) ) self.decoder = nn.ModuleList(self.decoder) output_layer.in_channels = curr_channel output_layer.out_channels = output_channels output_layer.dilation = dilations[-1] output_layer.enable_nhwc = enable_nhwc output_layer.enable_healpixpad = enable_healpixpad self.output_layer = output_layer.build() def forward(self, inputs): """ Forward pass of the UNetDecoder. Args: inputs: The inputs to the forward pass. Returns: The decoded values. """ x = inputs[-1] for n, layer in enumerate(self.decoder): if layer["upsamp"] is not None: up = layer["upsamp"](x) x = th.cat([up, inputs[-1 - n]], dim=self.channel_dim) x = layer["conv"](x) if layer["recurrent"] is not None: x = layer["recurrent"](x) return self.output_layer(x) def reset(self): """Resets the state of the decoder layers.""" for layer in self.decoder: if layer["recurrent"] is not None: layer["recurrent"].reset()