Source code for fme.core.normalizer

import dataclasses
from typing import Dict, Iterable, List, Mapping, Optional, Union

import netCDF4
import numpy as np
import torch
import torch.jit

from fme.core.device import move_tensordict_to_device
from fme.core.typing_ import TensorDict, TensorMapping


[docs]@dataclasses.dataclass class NormalizationConfig: """ Configuration for normalizing data. Either global_means_path and global_stds_path or explicit means and stds must be provided. Parameters: global_means_path: Path to a netCDF file containing global means. global_stds_path: Path to a netCDF file containing global stds. means: Mapping from variable names to means. stds: Mapping from variable names to stds. """ global_means_path: Optional[str] = None global_stds_path: Optional[str] = None means: Mapping[str, float] = dataclasses.field(default_factory=dict) stds: Mapping[str, float] = dataclasses.field(default_factory=dict) def __post_init__(self): using_path = ( self.global_means_path is not None and self.global_stds_path is not None ) using_explicit = len(self.means) > 0 and len(self.stds) > 0 if using_path and using_explicit: raise ValueError( "Cannot use both global_means_path and global_stds_path " "and explicit means and stds." ) if not (using_path or using_explicit): raise ValueError( "Must use either global_means_path and global_stds_path " "or explicit means and stds." ) def build(self, names: List[str]): using_path = ( self.global_means_path is not None and self.global_stds_path is not None ) if using_path: return get_normalizer( global_means_path=self.global_means_path, global_stds_path=self.global_stds_path, names=names, ) else: means = {k: torch.tensor(self.means[k]) for k in names} stds = {k: torch.tensor(self.stds[k]) for k in names} return StandardNormalizer(means=means, stds=stds)
[docs]class StandardNormalizer: """ Responsible for normalizing tensors. """ def __init__( self, means: TensorDict, stds: TensorDict, ): self.means = move_tensordict_to_device(means) self.stds = move_tensordict_to_device(stds) self._names = set(means).intersection(stds) def normalize(self, tensors: TensorMapping) -> TensorDict: filtered_tensors = {k: v for k, v in tensors.items() if k in self._names} return _normalize(filtered_tensors, means=self.means, stds=self.stds) def denormalize(self, tensors: TensorMapping) -> TensorDict: filtered_tensors = {k: v for k, v in tensors.items() if k in self._names} return _denormalize(filtered_tensors, means=self.means, stds=self.stds)
[docs] def get_state(self): """ Returns state as a serializable data structure. """ return { "means": {k: float(v.cpu().numpy().item()) for k, v in self.means.items()}, "stds": {k: float(v.cpu().numpy().item()) for k, v in self.stds.items()}, }
[docs] @classmethod def from_state(cls, state) -> "StandardNormalizer": """ Loads state from a serializable data structure. """ means = { k: torch.tensor(v, dtype=torch.float) for k, v in state["means"].items() } stds = {k: torch.tensor(v, dtype=torch.float) for k, v in state["stds"].items()} return cls(means=means, stds=stds)
@torch.jit.script def _normalize( tensors: TensorDict, means: TensorDict, stds: TensorDict, ) -> TensorDict: return {k: (t - means[k]) / stds[k] for k, t in tensors.items()} @torch.jit.script def _denormalize( tensors: TensorDict, means: TensorDict, stds: TensorDict, ) -> TensorDict: return {k: t * stds[k] + means[k] for k, t in tensors.items()} def get_normalizer( global_means_path, global_stds_path, names: List[str] ) -> StandardNormalizer: means = load_dict_from_netcdf( global_means_path, names, defaults={"x": 0.0, "y": 0.0, "z": 0.0} ) means = {k: torch.as_tensor(v, dtype=torch.float) for k, v in means.items()} stds = load_dict_from_netcdf( global_stds_path, names, defaults={"x": 1.0, "y": 1.0, "z": 1.0} ) stds = {k: torch.as_tensor(v, dtype=torch.float) for k, v in stds.items()} return StandardNormalizer(means=means, stds=stds) def load_dict_from_netcdf( path: str, names: Iterable[str], defaults: Mapping[str, Union[float, np.ndarray]] ) -> Dict[str, Union[float, np.ndarray]]: """ Load a dictionary of variables from a netCDF file. Args: path: Path to the netCDF file. names: List of variable names to load. defaults: Dictionary of default values for each variable, if not found in the netCDF file. """ ds = netCDF4.Dataset(path) ds.set_auto_mask(False) result = {} for c in names: if c in ds.variables: result[c] = ds.variables[c][:] elif c in defaults: result[c] = defaults[c] else: raise ValueError(f"Variable {c} not found in {path}") ds.close() return result