Source code for fme.ace.models.healpix.healpix_activations

# flake8: noqa
# Copied from https://github.com/NVIDIA/modulus/commit/89a6091bd21edce7be4e0539cbd91507004faf08
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
from typing import Literal

import torch as th
import torch.nn as nn

from .healpix_layers import HEALPixLayer

# DOWNSAMPLING BLOCKS


class MaxPool(nn.Module):
    """Wrapper for applying Max Pooling with HEALPix or other tensor data.

    This class wraps the `nn.MaxPool2d` class to handle tensor data with
    HEALPix or other geometry layers.
    """

    def __init__(
        self,
        pooling: int = 2,
        enable_nhwc: bool = False,
        enable_healpixpad: bool = False,
    ):
        """
        Args:
            pooling (int, optional): Pooling kernel size passed to geometry layer.
            enable_nhwc (bool, optional): Enable nhwc format, passed to wrapper.
            enable_healpixpad (bool, optional): If HEALPixPadding should be enabled, passed to wrapper.
        """
        super().__init__()
        self.maxpool = HEALPixLayer(
            layer=nn.MaxPool2d,
            kernel_size=pooling,
            enable_nhwc=enable_nhwc,
            enable_healpixpad=enable_healpixpad,
        )

    def forward(self, x: th.Tensor) -> th.Tensor:
        """Forward pass of the MaxPool.

        Args:
            x: The values to MaxPool.

        Returns:
            The MaxPooled values.
        """
        return self.maxpool(x)


class AvgPool(nn.Module):
    """Wrapper for applying Average Pooling with HEALPix or other tensor data.

    This class wraps the `nn.AvgPool2d` class to handle tensor data with
    HEALPix or other geometry layers.
    """

    def __init__(
        self,
        pooling: int = 2,
        enable_nhwc: bool = False,
        enable_healpixpad: bool = False,
    ):
        """
        Args:
            pooling (int, optional): Pooling kernel size passed to geometry layer.
            enable_nhwc (bool, optional): Enable nhwc format, passed to wrapper.
            enable_healpixpad (bool, optional): If HEALPixPadding should be enabled, passed to wrapper.
        """
        super().__init__()
        self.avgpool = HEALPixLayer(
            layer=nn.AvgPool2d,
            kernel_size=pooling,
            enable_nhwc=enable_nhwc,
            enable_healpixpad=enable_healpixpad,
        )

    def forward(self, x: th.Tensor) -> th.Tensor:
        """Forward pass of the AvgPool layer.

        Args:
            x: The values to average.

        Returns:
            The averaged values.
        """
        return self.avgpool(x)


[docs]@dataclasses.dataclass class DownsamplingBlockConfig: """ Configuration for the downsampling block. Generally, either a pooling block or a striding conv block. Parameters: block_type: Type of recurrent block, either "MaxPool" or "AvgPool" pooling: Pooling size enable_nhwc: Flag to enable NHWC data format, default is False. enable_healpixpad: Flag to enable HEALPix padding, default is False. """ block_type: Literal["MaxPool", "AvgPool"] pooling: int = 2 enable_nhwc: bool = False enable_healpixpad: bool = False
[docs] def build(self) -> nn.Module: """ Builds the recurrent block model. Returns: Recurrent block. """ if self.block_type == "MaxPool": return MaxPool( pooling=self.pooling, enable_nhwc=self.enable_nhwc, enable_healpixpad=self.enable_healpixpad, ) elif self.block_type == "AvgPool": return AvgPool( pooling=self.pooling, enable_nhwc=self.enable_nhwc, enable_healpixpad=self.enable_healpixpad, ) else: raise ValueError(f"Unsupported block type: {self.block_type}")
[docs]@dataclasses.dataclass class CappedGELUConfig: """ Configuration for the CappedGELU activation function. Parameters: cap_value: Cap value for the GELU function, default is 10. enable_nhwc: Flag to enable NHWC data format, default is False. enable_healpixpad: Flag to enable HEALPix padding, default is False. """ cap_value: int = 10 enable_nhwc: bool = False enable_healpixpad: bool = False
[docs] def build(self) -> nn.Module: """ Builds the CappedGELU activation function. Returns: CappedGELU activation function. """ return CappedGELU(cap_value=self.cap_value)
class CappedGELU(nn.Module): """ Implements a GELU with capped maximum value. Example ------- >>> capped_gelu_func = modulus.models.layers.CappedGELU() >>> input = th.Tensor([[-2,-1],[0,1],[2,3]]) >>> capped_gelu_func(input) tensor([[-0.0455, -0.1587], [ 0.0000, 0.8413], [ 1.0000, 1.0000]]) """ def __init__(self, cap_value=1.0, **kwargs): """ Args: cap_value: Maximum that values will be capped at **kwargs: Keyword arguments to be passed to the `th.nn.GELU` function """ super().__init__() self.add_module("gelu", th.nn.GELU(**kwargs)) self.register_buffer("cap", th.tensor(cap_value, dtype=th.float32)) def forward(self, inputs): x = self.gelu(inputs) # Convert cap to a scalar value for clamping (ignores grad) cap_value = self.cap.item() x = th.clamp(x, max=cap_value) return x