Source code for fme.core.gridded_ops

import abc
from collections.abc import Callable
from typing import Any, TypeVar, final

import torch
from torch import nn

from fme.core import metrics
from fme.core.cuhpx.sht import SHT as CuHpxSHT
from fme.core.cuhpx.sht import iSHT as CuHpxiSHT
from fme.core.device import get_device
from fme.core.distributed.distributed import Distributed
from fme.core.hpx.reorder import get_reordering_xy_to_ring
from fme.core.mask_provider import MaskProviderABC, NullMaskProvider
from fme.core.tensors import assert_dict_allclose
from fme.core.typing_ import TensorDict, TensorMapping


[docs]class GriddedOperations(abc.ABC): def __eq__(self, other) -> bool: if not isinstance(other, GriddedOperations): return False try: assert_dict_allclose(self.get_state(), other.get_state()) except AssertionError: return False return True def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" + ", ".join( [f"{k}={v}" for k, v in self.get_initialization_kwargs().items()] ) + ")" ) @property @abc.abstractmethod def zonal_mean(self) -> Callable[[torch.Tensor], torch.Tensor] | None: ... @abc.abstractmethod def area_weighted_sum( self, data: torch.Tensor, keepdim: bool = False, name: str | None = None, ) -> torch.Tensor: ... @final def area_weighted_sum_dict( self, data: TensorMapping, keepdim: bool = False ) -> TensorDict: result = {} for name in data: result[name] = self.area_weighted_sum( data=data[name], keepdim=keepdim, name=name, ) return result @abc.abstractmethod def area_weighted_mean( self, data: torch.Tensor, keepdim: bool = False, name: str | None = None, ) -> torch.Tensor: ... @final def area_weighted_mean_dict( self, data: TensorMapping, keepdim: bool = False ) -> TensorDict: result = {} for name in data: result[name] = self.area_weighted_mean( data=data[name], keepdim=keepdim, name=name, ) return result def area_weighted_mean_bias( self, truth: torch.Tensor, predicted: torch.Tensor, name: str | None = None, ) -> torch.Tensor: return self.area_weighted_mean(predicted - truth, name=name) @final def area_weighted_mean_bias_dict( self, truth: TensorMapping, predicted: TensorMapping ) -> TensorDict: result = {} for name in truth: result[name] = self.area_weighted_mean_bias( truth=truth[name], predicted=predicted[name], name=name, ) return result def area_weighted_rmse( self, truth: torch.Tensor, predicted: torch.Tensor, name: str | None = None, ) -> torch.Tensor: return torch.sqrt(self.area_weighted_mean((predicted - truth) ** 2, name=name)) @final def area_weighted_rmse_dict( self, truth: TensorMapping, predicted: TensorMapping ) -> TensorDict: result = {} for name in truth: result[name] = self.area_weighted_rmse( truth=truth[name], predicted=predicted[name], name=name, ) return result def area_weighted_std( self, data: torch.Tensor, keepdim: bool = False, name: str | None = None, ): return self.area_weighted_mean( (data - self.area_weighted_mean(data, keepdim=True, name=name)) ** 2, keepdim=keepdim, name=name, ).sqrt() @final def area_weighted_std_dict( self, data: TensorMapping, keepdim: bool = False, ) -> TensorDict: result = {} for name in data: result[name] = self.area_weighted_std( data=data[name], keepdim=keepdim, name=name, ) return result @abc.abstractmethod def area_weighted_gradient_magnitude_percent_diff( self, truth: torch.Tensor, predicted: torch.Tensor, name: str | None = None, ): ... @final def area_weighted_gradient_magnitude_percent_diff_dict( self, truth: TensorMapping, predicted: TensorMapping ) -> TensorDict: result = {} for name in truth: result[name] = self.area_weighted_gradient_magnitude_percent_diff( truth=truth[name], predicted=predicted[name], name=name, ) return result @abc.abstractmethod def regional_area_weighted_mean( self, data: torch.Tensor, regional_weights: torch.Tensor, keepdim: bool = False, name: str | None = None, ) -> torch.Tensor: ... @final def regional_area_weighted_mean_dict( self, data: TensorMapping, regional_weights: torch.Tensor, keepdim: bool = False, ) -> TensorDict: result = {} for name in data: result[name] = self.regional_area_weighted_mean( data=data[name], regional_weights=regional_weights, keepdim=keepdim, name=name, ) return result @abc.abstractmethod def get_real_sht( self, ) -> nn.Module: ... @abc.abstractmethod def get_real_isht( self, ) -> nn.Module: ... def get_state(self) -> dict[str, Any]: return { "type": self.__class__.__name__, "state": self.get_initialization_kwargs(), }
[docs] @abc.abstractmethod def get_initialization_kwargs(self) -> dict[str, Any]: """ Get the keyword arguments needed to initialize the instance. """ ...
[docs] @classmethod def from_state(cls, state: dict[str, Any]) -> "GriddedOperations": """ Given a dictionary with a "type" key and a "state" key, return the GriddedOperations it describes. The "type" key should be the name of a subclass of GriddedOperations, and the "state" key should be a dictionary specific to that subclass. Args: state: A dictionary with a "type" key and a "state" key. Returns: An instance of the subclass. """ if cls is not GriddedOperations: raise RuntimeError( "This method should be called on GriddedOperations, " "not on its subclasses." ) subclasses = get_all_subclasses(cls) for subclass in subclasses: if subclass.__name__ == state["type"]: return subclass(**state["state"]) raise ValueError( f"Unknown subclass type: {state['type']}, " f"available: {[s.__name__ for s in subclasses]}" )
T = TypeVar("T") def get_all_subclasses(cls: type[T]) -> list[type[T]]: """ Gets all subclasses of a given class, including their subclasses etc. """ all_subclasses = [] for subclass in cls.__subclasses__(): all_subclasses.append(subclass) all_subclasses.extend(get_all_subclasses(subclass)) return all_subclasses def _mask_area_weights( area_weights: torch.Tensor, mask_provider: MaskProviderABC, name: str | None, ) -> torch.Tensor: if name is None: return area_weights mask = mask_provider.get_mask_tensor_for(name) if mask is None: return area_weights return area_weights * mask class LatLonOperations(GriddedOperations): HORIZONTAL_DIMS = (-2, -1) def __init__( self, area_weights: torch.Tensor, mask_provider: MaskProviderABC = NullMaskProvider, grid: str = "legendre-gauss", ): self._validate_area_weights(area_weights) self._cpu_area_global = area_weights.to("cpu") dist = Distributed.get_instance() nlat, nlon = area_weights.shape[-2], area_weights.shape[-1] h_slice, w_slice = dist.get_local_slices((nlat, nlon)) local_weights = area_weights[..., h_slice, w_slice] self._device_area = local_weights.to(get_device()) self._cpu_area = local_weights.to("cpu") self._device_mask_provider = mask_provider.to(get_device()) self._cpu_mask_provider = mask_provider.to("cpu") self._grid = grid def _validate_area_weights(self, area_weights: torch.Tensor) -> None: """Check that area weights are longitudinally uniform.""" if not torch.allclose(area_weights, area_weights[..., :1]): raise ValueError( "Area weights must be longitudinally uniform, " "as assumed for zonal mean." ) @property def zonal_mean(self) -> Callable[[torch.Tensor], torch.Tensor]: return Distributed.get_instance().zonal_mean def _get_area_weights( self, data: torch.Tensor, name: str | None = None, regional_weights: torch.Tensor | None = None, ): if data.device == torch.device("cpu"): area_weights = self._cpu_area mask_provider = self._cpu_mask_provider else: area_weights = self._device_area mask_provider = self._device_mask_provider area_weights = _mask_area_weights(area_weights, mask_provider, name) if regional_weights is None: return area_weights if regional_weights.device.type != data.device.type: regional_weights = regional_weights.to(data.device) return regional_weights * area_weights def area_weighted_sum( self, data: torch.Tensor, keepdim: bool = False, name: str | None = None, ) -> torch.Tensor: area_weights = self._get_area_weights(data, name) local_sum = metrics.weighted_sum( data, area_weights, dim=self.HORIZONTAL_DIMS, keepdim=keepdim ) return Distributed.get_instance().spatial_reduce_sum(local_sum) def area_weighted_mean( self, data: torch.Tensor, keepdim: bool = False, name: str | None = None, ) -> torch.Tensor: area_weights = self._get_area_weights(data, name) return Distributed.get_instance().weighted_mean( data, area_weights, dim=self.HORIZONTAL_DIMS, keepdim=keepdim ) def regional_area_weighted_mean( self, data: torch.Tensor, regional_weights: torch.Tensor, keepdim: bool = False, name: str | None = None, ) -> torch.Tensor: regional_area_weights = self._get_area_weights(data, name, regional_weights) return Distributed.get_instance().weighted_mean( data, regional_area_weights, dim=self.HORIZONTAL_DIMS, keepdim=keepdim ) def area_weighted_gradient_magnitude_percent_diff( self, truth: torch.Tensor, predicted: torch.Tensor, name: str | None = None, ): area_weights = self._get_area_weights(truth, name) img_shape = ( self._cpu_area_global.shape[-2], self._cpu_area_global.shape[-1], ) return Distributed.get_instance().gradient_magnitude_percent_diff( truth, predicted, weights=area_weights, dim=self.HORIZONTAL_DIMS, img_shape=img_shape, ) def get_real_sht(self) -> nn.Module: nlat = self._cpu_area_global.shape[-2] nlon = self._cpu_area_global.shape[-1] return ( Distributed.get_instance() .get_sht(nlat, nlon, grid=self._grid) .to(get_device()) ) def get_real_isht(self) -> nn.Module: nlat = self._cpu_area_global.shape[-2] nlon = self._cpu_area_global.shape[-1] return ( Distributed.get_instance() .get_isht(nlat, nlon, grid=self._grid) .to(get_device()) ) def get_initialization_kwargs(self) -> dict[str, Any]: return {"area_weights": self._cpu_area_global} class HEALPixSHT(nn.Module): def __init__(self, nside: int, lmax: int, mmax: int, grid: str): super().__init__() self.nside = nside self.lmax = lmax self.mmax = mmax self.grid = grid self.sht = CuHpxSHT(nside, lmax=lmax, mmax=mmax, grid=grid) self._reordering = get_reordering_xy_to_ring(nside, device=get_device()) self._reordering_cpu = self._reordering.to("cpu") def forward(self, data: torch.Tensor) -> torch.Tensor: if data.shape[-2] == 1: # ring ordering, stored as [..., 1, npix] return self.sht(data[..., 0, :]) else: # face ordering, stored as [..., 12, n_channel, ny, nx] n_face, ny, nx = data.shape[-3:] if n_face != 12: raise ValueError( f"Expected 12 faces, got {n_face} in shape {data.shape}" ) if ny != nx: raise ValueError( f"Expected square grid, got {ny}x{nx} in shape {data.shape}" ) if ny != self.nside: raise ValueError( f"Expected nside {self.nside}, got {ny} in shape {data.shape}" ) data = data.reshape(*data.shape[:-3], 12 * self.nside * self.nside) if data.device.type == "cpu": data = data[..., self._reordering_cpu] else: data = data[..., self._reordering] return self.sht(data) class HEALPixInverseSHT(nn.Module): def __init__(self, nside: int, lmax: int, mmax: int, grid: str): super().__init__() self.nside = nside self.lmax = lmax self.mmax = mmax self.grid = grid self.isht = CuHpxiSHT(nside, lmax=lmax, mmax=mmax, grid=grid) def forward(self, data: torch.Tensor) -> torch.Tensor: return self.isht(data).unsqueeze(-2) class HEALPixOperations(GriddedOperations): HORIZONTAL_DIMS = (-3, -2, -1) def __init__(self, nside: int | None = None): """ Args: nside: The nside of the HEALPix grid. nside must be specified in order to use the SHT. It is allowed to be None only for backwards compatibility. """ self.nside = nside @property def zonal_mean(self) -> None: # not implemented, though we definitely could # as HEALPix rings are constant-latitude return None def area_weighted_sum( self, data: torch.Tensor, keepdim: bool = False, name: str | None = None, ) -> torch.Tensor: # For HEALPix, area weights are uniform, so sum is sufficient return data.sum(dim=self.HORIZONTAL_DIMS, keepdim=keepdim) def area_weighted_mean( self, data: torch.Tensor, keepdim: bool = False, name: str | None = None, ) -> torch.Tensor: # For HEALPix, area weights are uniform, so mean is sufficient return data.mean(dim=self.HORIZONTAL_DIMS, keepdim=keepdim) def area_weighted_gradient_magnitude_percent_diff( self, truth: torch.Tensor, predicted: torch.Tensor, name: str | None = None, ) -> torch.Tensor: return metrics.gradient_magnitude_percent_diff( truth, predicted, weights=None, dim=self.HORIZONTAL_DIMS ) def regional_area_weighted_mean( self, data: torch.Tensor, weights: torch.Tensor, keepdim: bool = False, name: str | None = None, ) -> torch.Tensor: raise NotImplementedError( "Regional area weighted mean is not implemented for HEALPix." ) def get_real_sht(self) -> nn.Module: if self.nside is None: raise ValueError("nside must be specified for SHT.") lmax = 2 * self.nside - 1 return HEALPixSHT(self.nside, lmax=lmax, mmax=lmax, grid="healpix") def get_real_isht(self) -> nn.Module: if self.nside is None: raise ValueError("nside must be specified for SHT.") lmax = 2 * self.nside - 1 return HEALPixInverseSHT(self.nside, lmax=lmax, mmax=lmax, grid="healpix") def get_initialization_kwargs(self) -> dict[str, Any]: if self.nside is None: return {} return {"nside": self.nside}