Source code for fme.core.gridded_ops

import abc
from typing import Any, Dict, List, Type, TypeVar

import torch

from fme.core import metrics
from fme.core.device import get_device


[docs]class GriddedOperations(abc.ABC): @abc.abstractmethod def area_weighted_mean( self, data: torch.Tensor, keepdim: bool = False ) -> torch.Tensor: ... def area_weighted_mean_bias( self, truth: torch.Tensor, predicted: torch.Tensor ) -> torch.Tensor: return self.area_weighted_mean(predicted - truth) def area_weighted_rmse( self, truth: torch.Tensor, predicted: torch.Tensor ) -> torch.Tensor: return torch.sqrt(self.area_weighted_mean((predicted - truth) ** 2)) def area_weighted_std(self, data: torch.Tensor, keepdim: bool = False): return self.area_weighted_mean( (data - self.area_weighted_mean(data, keepdim=True)) ** 2, keepdim=keepdim, ).sqrt() @abc.abstractmethod def area_weighted_gradient_magnitude_percent_diff( self, truth: torch.Tensor, predicted: torch.Tensor ): ... def to_state(self) -> Dict[str, Any]: return { "type": self.__class__.__name__, "state": self.get_initialization_kwargs(), }
[docs] @abc.abstractmethod def get_initialization_kwargs(self) -> Dict[str, Any]: """ Get the keyword arguments needed to initialize the instance. """ ...
[docs] @classmethod def from_state(cls, state: Dict[str, Any]) -> "GriddedOperations": """ Given a dictionary with a "type" key and a "state" key, return the GriddedOperations it describes. The "type" key should be the name of a subclass of GriddedOperations, and the "state" key should be a dictionary specific to that subclass. Args: state: A dictionary with a "type" key and a "state" key. Returns: An instance of the subclass. """ if cls is not GriddedOperations: raise RuntimeError( "This method should be called on GriddedOperations, " "not on its subclasses." ) subclasses = get_all_subclasses(cls) for subclass in subclasses: if subclass.__name__ == state["type"]: return subclass(**state["state"]) raise ValueError( f"Unknown subclass type: {state['type']}, " f"available: {[s.__name__ for s in subclasses]}" )
T = TypeVar("T") def get_all_subclasses(cls: Type[T]) -> List[Type[T]]: """ Gets all subclasses of a given class, including their subclasses etc. """ all_subclasses = [] for subclass in cls.__subclasses__(): all_subclasses.append(subclass) all_subclasses.extend(get_all_subclasses(subclass)) return all_subclasses class LatLonOperations(GriddedOperations): HORIZONTAL_DIMS = (-2, -1) def __init__(self, area_weights: torch.Tensor): self._device_area = area_weights.to(get_device()) self._cpu_area = area_weights.to("cpu") def area_weighted_mean( self, data: torch.Tensor, keepdim: bool = False ) -> torch.Tensor: if data.device.type == "cpu": area_weights = self._cpu_area else: area_weights = self._device_area return metrics.weighted_mean( data, area_weights, dim=self.HORIZONTAL_DIMS, keepdim=keepdim ) def area_weighted_gradient_magnitude_percent_diff( self, truth: torch.Tensor, predicted: torch.Tensor ): if predicted.device.type == "cpu": area_weights = self._cpu_area else: area_weights = self._device_area return metrics.gradient_magnitude_percent_diff( truth, predicted, weights=area_weights, dim=self.HORIZONTAL_DIMS ) def get_initialization_kwargs(self) -> Dict[str, Any]: return {"area_weights": self._cpu_area} class HEALPixOperations(GriddedOperations): HORIZONTAL_DIMS = (-3, -2, -1) def area_weighted_mean( self, data: torch.Tensor, keepdim: bool = False ) -> torch.Tensor: return data.mean(dim=self.HORIZONTAL_DIMS, keepdim=keepdim) def area_weighted_gradient_magnitude_percent_diff( self, truth: torch.Tensor, predicted: torch.Tensor ): return metrics.gradient_magnitude_percent_diff( truth, predicted, weights=None, dim=self.HORIZONTAL_DIMS ) def get_initialization_kwargs(self) -> Dict[str, Any]: return {}