import abc
from collections.abc import Callable
from typing import Any, TypeVar, final
import torch
from torch import nn
from fme.core import metrics
from fme.core.cuhpx.sht import SHT as CuHpxSHT
from fme.core.cuhpx.sht import iSHT as CuHpxiSHT
from fme.core.device import get_device
from fme.core.distributed.distributed import Distributed
from fme.core.hpx.reorder import get_reordering_xy_to_ring
from fme.core.mask_provider import MaskProviderABC, NullMaskProvider
from fme.core.tensors import assert_dict_allclose
from fme.core.typing_ import TensorDict, TensorMapping
[docs]class GriddedOperations(abc.ABC):
def __eq__(self, other) -> bool:
if not isinstance(other, GriddedOperations):
return False
try:
assert_dict_allclose(self.get_state(), other.get_state())
except AssertionError:
return False
return True
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
+ ", ".join(
[f"{k}={v}" for k, v in self.get_initialization_kwargs().items()]
)
+ ")"
)
@property
@abc.abstractmethod
def zonal_mean(self) -> Callable[[torch.Tensor], torch.Tensor] | None: ...
@abc.abstractmethod
def area_weighted_sum(
self,
data: torch.Tensor,
keepdim: bool = False,
name: str | None = None,
) -> torch.Tensor: ...
@final
def area_weighted_sum_dict(
self, data: TensorMapping, keepdim: bool = False
) -> TensorDict:
result = {}
for name in data:
result[name] = self.area_weighted_sum(
data=data[name],
keepdim=keepdim,
name=name,
)
return result
@abc.abstractmethod
def area_weighted_mean(
self,
data: torch.Tensor,
keepdim: bool = False,
name: str | None = None,
) -> torch.Tensor: ...
@final
def area_weighted_mean_dict(
self, data: TensorMapping, keepdim: bool = False
) -> TensorDict:
result = {}
for name in data:
result[name] = self.area_weighted_mean(
data=data[name],
keepdim=keepdim,
name=name,
)
return result
def area_weighted_mean_bias(
self,
truth: torch.Tensor,
predicted: torch.Tensor,
name: str | None = None,
) -> torch.Tensor:
return self.area_weighted_mean(predicted - truth, name=name)
@final
def area_weighted_mean_bias_dict(
self, truth: TensorMapping, predicted: TensorMapping
) -> TensorDict:
result = {}
for name in truth:
result[name] = self.area_weighted_mean_bias(
truth=truth[name],
predicted=predicted[name],
name=name,
)
return result
def area_weighted_rmse(
self,
truth: torch.Tensor,
predicted: torch.Tensor,
name: str | None = None,
) -> torch.Tensor:
return torch.sqrt(self.area_weighted_mean((predicted - truth) ** 2, name=name))
@final
def area_weighted_rmse_dict(
self, truth: TensorMapping, predicted: TensorMapping
) -> TensorDict:
result = {}
for name in truth:
result[name] = self.area_weighted_rmse(
truth=truth[name],
predicted=predicted[name],
name=name,
)
return result
def area_weighted_std(
self,
data: torch.Tensor,
keepdim: bool = False,
name: str | None = None,
):
return self.area_weighted_mean(
(data - self.area_weighted_mean(data, keepdim=True, name=name)) ** 2,
keepdim=keepdim,
name=name,
).sqrt()
@final
def area_weighted_std_dict(
self,
data: TensorMapping,
keepdim: bool = False,
) -> TensorDict:
result = {}
for name in data:
result[name] = self.area_weighted_std(
data=data[name],
keepdim=keepdim,
name=name,
)
return result
@abc.abstractmethod
def area_weighted_gradient_magnitude_percent_diff(
self,
truth: torch.Tensor,
predicted: torch.Tensor,
name: str | None = None,
): ...
@final
def area_weighted_gradient_magnitude_percent_diff_dict(
self, truth: TensorMapping, predicted: TensorMapping
) -> TensorDict:
result = {}
for name in truth:
result[name] = self.area_weighted_gradient_magnitude_percent_diff(
truth=truth[name],
predicted=predicted[name],
name=name,
)
return result
@abc.abstractmethod
def regional_area_weighted_mean(
self,
data: torch.Tensor,
regional_weights: torch.Tensor,
keepdim: bool = False,
name: str | None = None,
) -> torch.Tensor: ...
@final
def regional_area_weighted_mean_dict(
self,
data: TensorMapping,
regional_weights: torch.Tensor,
keepdim: bool = False,
) -> TensorDict:
result = {}
for name in data:
result[name] = self.regional_area_weighted_mean(
data=data[name],
regional_weights=regional_weights,
keepdim=keepdim,
name=name,
)
return result
@abc.abstractmethod
def get_real_sht(
self,
) -> nn.Module: ...
@abc.abstractmethod
def get_real_isht(
self,
) -> nn.Module: ...
def get_state(self) -> dict[str, Any]:
return {
"type": self.__class__.__name__,
"state": self.get_initialization_kwargs(),
}
[docs] @abc.abstractmethod
def get_initialization_kwargs(self) -> dict[str, Any]:
"""
Get the keyword arguments needed to initialize the instance.
"""
...
[docs] @classmethod
def from_state(cls, state: dict[str, Any]) -> "GriddedOperations":
"""
Given a dictionary with a "type" key and a "state" key, return
the GriddedOperations it describes.
The "type" key should be the name of a subclass of GriddedOperations,
and the "state" key should be a dictionary specific to
that subclass.
Args:
state: A dictionary with a "type" key and a "state" key.
Returns:
An instance of the subclass.
"""
if cls is not GriddedOperations:
raise RuntimeError(
"This method should be called on GriddedOperations, "
"not on its subclasses."
)
subclasses = get_all_subclasses(cls)
for subclass in subclasses:
if subclass.__name__ == state["type"]:
return subclass(**state["state"])
raise ValueError(
f"Unknown subclass type: {state['type']}, "
f"available: {[s.__name__ for s in subclasses]}"
)
T = TypeVar("T")
def get_all_subclasses(cls: type[T]) -> list[type[T]]:
"""
Gets all subclasses of a given class, including their subclasses etc.
"""
all_subclasses = []
for subclass in cls.__subclasses__():
all_subclasses.append(subclass)
all_subclasses.extend(get_all_subclasses(subclass))
return all_subclasses
def _mask_area_weights(
area_weights: torch.Tensor,
mask_provider: MaskProviderABC,
name: str | None,
) -> torch.Tensor:
if name is None:
return area_weights
mask = mask_provider.get_mask_tensor_for(name)
if mask is None:
return area_weights
return area_weights * mask
class LatLonOperations(GriddedOperations):
HORIZONTAL_DIMS = (-2, -1)
def __init__(
self,
area_weights: torch.Tensor,
mask_provider: MaskProviderABC = NullMaskProvider,
grid: str = "legendre-gauss",
):
self._validate_area_weights(area_weights)
self._cpu_area_global = area_weights.to("cpu", copy=True)
dist = Distributed.get_instance()
nlat, nlon = area_weights.shape[-2], area_weights.shape[-1]
h_slice, w_slice = dist.get_local_slices((nlat, nlon))
local_weights = self._cpu_area_global[..., h_slice, w_slice]
self._device_area = local_weights.to(get_device(), copy=True)
self._cpu_area = local_weights.to("cpu", copy=True)
self._device_mask_provider = mask_provider.to(get_device())
self._cpu_mask_provider = mask_provider.to("cpu")
self._grid = grid
def _validate_area_weights(self, area_weights: torch.Tensor) -> None:
"""Check that area weights are longitudinally uniform."""
if not torch.allclose(area_weights, area_weights[..., :1]):
raise ValueError(
"Area weights must be longitudinally uniform, "
"as assumed for zonal mean."
)
@property
def zonal_mean(self) -> Callable[[torch.Tensor], torch.Tensor]:
return Distributed.get_instance().zonal_mean
def _get_area_weights(
self,
data: torch.Tensor,
name: str | None = None,
regional_weights: torch.Tensor | None = None,
):
if data.device == torch.device("cpu"):
area_weights = self._cpu_area
mask_provider = self._cpu_mask_provider
else:
area_weights = self._device_area
mask_provider = self._device_mask_provider
area_weights = _mask_area_weights(area_weights, mask_provider, name)
if regional_weights is None:
return area_weights
if regional_weights.device.type != data.device.type:
regional_weights = regional_weights.to(data.device)
return regional_weights * area_weights
def area_weighted_sum(
self,
data: torch.Tensor,
keepdim: bool = False,
name: str | None = None,
) -> torch.Tensor:
area_weights = self._get_area_weights(data, name)
local_sum = metrics.weighted_sum(
data, area_weights, dim=self.HORIZONTAL_DIMS, keepdim=keepdim
)
return Distributed.get_instance().spatial_reduce_sum(local_sum)
def area_weighted_mean(
self,
data: torch.Tensor,
keepdim: bool = False,
name: str | None = None,
) -> torch.Tensor:
area_weights = self._get_area_weights(data, name)
return Distributed.get_instance().weighted_mean(
data, area_weights, dim=self.HORIZONTAL_DIMS, keepdim=keepdim
)
def regional_area_weighted_mean(
self,
data: torch.Tensor,
regional_weights: torch.Tensor,
keepdim: bool = False,
name: str | None = None,
) -> torch.Tensor:
regional_area_weights = self._get_area_weights(data, name, regional_weights)
return Distributed.get_instance().weighted_mean(
data, regional_area_weights, dim=self.HORIZONTAL_DIMS, keepdim=keepdim
)
def area_weighted_gradient_magnitude_percent_diff(
self,
truth: torch.Tensor,
predicted: torch.Tensor,
name: str | None = None,
):
area_weights = self._get_area_weights(truth, name)
img_shape = (
self._cpu_area_global.shape[-2],
self._cpu_area_global.shape[-1],
)
return Distributed.get_instance().gradient_magnitude_percent_diff(
truth,
predicted,
weights=area_weights,
dim=self.HORIZONTAL_DIMS,
img_shape=img_shape,
)
def get_real_sht(self) -> nn.Module:
nlat = self._cpu_area_global.shape[-2]
nlon = self._cpu_area_global.shape[-1]
return (
Distributed.get_instance()
.get_sht(nlat, nlon, grid=self._grid)
.to(get_device())
)
def get_real_isht(self) -> nn.Module:
nlat = self._cpu_area_global.shape[-2]
nlon = self._cpu_area_global.shape[-1]
return (
Distributed.get_instance()
.get_isht(nlat, nlon, grid=self._grid)
.to(get_device())
)
def get_initialization_kwargs(self) -> dict[str, Any]:
return {"area_weights": self._cpu_area_global}
class HEALPixSHT(nn.Module):
def __init__(self, nside: int, lmax: int, mmax: int, grid: str):
super().__init__()
self.nside = nside
self.lmax = lmax
self.mmax = mmax
self.grid = grid
self.sht = CuHpxSHT(nside, lmax=lmax, mmax=mmax, grid=grid)
self._reordering = get_reordering_xy_to_ring(nside, device=get_device())
self._reordering_cpu = self._reordering.to("cpu")
def forward(self, data: torch.Tensor) -> torch.Tensor:
if data.shape[-2] == 1: # ring ordering, stored as [..., 1, npix]
return self.sht(data[..., 0, :])
else: # face ordering, stored as [..., 12, n_channel, ny, nx]
n_face, ny, nx = data.shape[-3:]
if n_face != 12:
raise ValueError(
f"Expected 12 faces, got {n_face} in shape {data.shape}"
)
if ny != nx:
raise ValueError(
f"Expected square grid, got {ny}x{nx} in shape {data.shape}"
)
if ny != self.nside:
raise ValueError(
f"Expected nside {self.nside}, got {ny} in shape {data.shape}"
)
data = data.reshape(*data.shape[:-3], 12 * self.nside * self.nside)
if data.device.type == "cpu":
data = data[..., self._reordering_cpu]
else:
data = data[..., self._reordering]
return self.sht(data)
class HEALPixInverseSHT(nn.Module):
def __init__(self, nside: int, lmax: int, mmax: int, grid: str):
super().__init__()
self.nside = nside
self.lmax = lmax
self.mmax = mmax
self.grid = grid
self.isht = CuHpxiSHT(nside, lmax=lmax, mmax=mmax, grid=grid)
def forward(self, data: torch.Tensor) -> torch.Tensor:
return self.isht(data).unsqueeze(-2)
class HEALPixOperations(GriddedOperations):
HORIZONTAL_DIMS = (-3, -2, -1)
def __init__(self, nside: int | None = None):
"""
Args:
nside: The nside of the HEALPix grid. nside must be specified in order to
use the SHT. It is allowed to be None only for backwards compatibility.
"""
self.nside = nside
@property
def zonal_mean(self) -> None:
# not implemented, though we definitely could
# as HEALPix rings are constant-latitude
return None
def area_weighted_sum(
self,
data: torch.Tensor,
keepdim: bool = False,
name: str | None = None,
) -> torch.Tensor:
# For HEALPix, area weights are uniform, so sum is sufficient
return data.sum(dim=self.HORIZONTAL_DIMS, keepdim=keepdim)
def area_weighted_mean(
self,
data: torch.Tensor,
keepdim: bool = False,
name: str | None = None,
) -> torch.Tensor:
# For HEALPix, area weights are uniform, so mean is sufficient
return data.mean(dim=self.HORIZONTAL_DIMS, keepdim=keepdim)
def area_weighted_gradient_magnitude_percent_diff(
self,
truth: torch.Tensor,
predicted: torch.Tensor,
name: str | None = None,
) -> torch.Tensor:
return metrics.gradient_magnitude_percent_diff(
truth, predicted, weights=None, dim=self.HORIZONTAL_DIMS
)
def regional_area_weighted_mean(
self,
data: torch.Tensor,
weights: torch.Tensor,
keepdim: bool = False,
name: str | None = None,
) -> torch.Tensor:
raise NotImplementedError(
"Regional area weighted mean is not implemented for HEALPix."
)
def get_real_sht(self) -> nn.Module:
if self.nside is None:
raise ValueError("nside must be specified for SHT.")
lmax = 2 * self.nside - 1
return HEALPixSHT(self.nside, lmax=lmax, mmax=lmax, grid="healpix")
def get_real_isht(self) -> nn.Module:
if self.nside is None:
raise ValueError("nside must be specified for SHT.")
lmax = 2 * self.nside - 1
return HEALPixInverseSHT(self.nside, lmax=lmax, mmax=lmax, grid="healpix")
def get_initialization_kwargs(self) -> dict[str, Any]:
if self.nside is None:
return {}
return {"nside": self.nside}