Source code for fme.ace.inference.data_writer.time_coarsen

import dataclasses
from typing import Literal, Protocol

import torch
import xarray as xr

TIME_DIM_NAME = "time"
TIME_DIM = 1  # sample, time, lat, lon


class _PairedDataWriter(Protocol):
    def append_batch(
        self,
        target: dict[str, torch.Tensor],
        prediction: dict[str, torch.Tensor],
        batch_time: xr.DataArray,
    ):
        pass

    def flush(self):
        pass

    def finalize(self):
        pass


class _DataWriter(Protocol):
    def append_batch(
        self,
        data: dict[str, torch.Tensor],
        batch_time: xr.DataArray,
    ):
        pass

    def flush(self):
        pass

    def finalize(self):
        pass


[docs]@dataclasses.dataclass class TimeCoarsenConfig: """ Config for inference data time coarsening. Args: coarsen_factor: Factor by which to coarsen in time, an integer 1 or greater. The resulting time labels will be coarsened to the mean of the original labels. method: Method to use for coarsening, currently only "block_mean" is supported. """ def __post_init__(self): if self.coarsen_factor < 1: raise ValueError( f"coarsen_factor must be 1 or greater, got {self.coarsen_factor}" ) coarsen_factor: int method: Literal["block_mean"] = "block_mean" def build_paired(self, data_writer: _PairedDataWriter) -> "PairedTimeCoarsen": return PairedTimeCoarsen( data_writer=data_writer, coarsen_factor=self.coarsen_factor, ) def build(self, data_writer: _DataWriter) -> "TimeCoarsen": return TimeCoarsen( data_writer=data_writer, coarsen_factor=self.coarsen_factor, )
[docs] def n_coarsened_timesteps(self, n_timesteps: int) -> int: """Assumes initial condition is NOT in n_timesteps.""" return (n_timesteps) // self.coarsen_factor
def validate(self, forward_steps_in_memory: int, n_forward_steps: int): if forward_steps_in_memory % self.coarsen_factor != 0: raise ValueError( "forward_steps_in_memory must be divisible by " f"time_coarsen.coarsen_factor. Got {forward_steps_in_memory} " f"and {self.coarsen_factor}." ) if n_forward_steps % self.coarsen_factor != 0: raise ValueError( "n_forward_steps must be divisible by " f"time_coarsen.coarsen_factor. Got {n_forward_steps} " f"and {self.coarsen_factor}." )
@dataclasses.dataclass class MonthlyCoarsenConfig: """ Config for inference data monthly coarsening. Args: method: Method to use for coarsening, currently only "monthly_mean" is supported. """ method: Literal["monthly_mean"] = "monthly_mean" def validate(self, forward_steps_in_memory: int, n_forward_steps: int): # monthly coarsening works with any combination of steps pass class PairedTimeCoarsen: """Wraps a data writer and coarsens its arguments in time before passing them on.""" def __init__( self, data_writer: _PairedDataWriter, coarsen_factor: int, ): self._data_writer: _PairedDataWriter = data_writer self._coarsen_factor: int = coarsen_factor def append_batch( self, target: dict[str, torch.Tensor], prediction: dict[str, torch.Tensor], batch_time: xr.DataArray, ): (target_coarsened, batch_times_coarsened) = coarsen_batch( target, batch_time, self._coarsen_factor ) (prediction_coarsened, _) = coarsen_batch( prediction, batch_time, self._coarsen_factor ) self._data_writer.append_batch( target_coarsened, prediction_coarsened, batch_times_coarsened, ) def flush(self): self._data_writer.flush() def finalize(self): self._data_writer.finalize() class TimeCoarsen: """Wraps a data writer and coarsens its arguments in time before passing them on.""" def __init__( self, data_writer: _DataWriter, coarsen_factor: int, ): self._data_writer: _DataWriter = data_writer self._coarsen_factor: int = coarsen_factor def append_batch( self, data: dict[str, torch.Tensor], batch_time: xr.DataArray, ): (data_coarsened, batch_times_coarsened) = coarsen_batch( data, batch_time, self._coarsen_factor ) self._data_writer.append_batch( data_coarsened, batch_times_coarsened, ) def flush(self): self._data_writer.flush() def finalize(self): self._data_writer.finalize() def coarsen_batch( data: dict[str, torch.Tensor], batch_time: xr.DataArray, coarsen_factor: int, ) -> tuple[dict[str, torch.Tensor], xr.DataArray]: data_coarsened = _coarsen_tensor_dict(data, coarsen_factor) batch_time_coarsened = batch_time.coarsen({TIME_DIM_NAME: coarsen_factor}).mean() return data_coarsened, batch_time_coarsened def _coarsen_tensor_dict( tensor_dict: dict[str, torch.Tensor], coarsen_factor: int ) -> dict[str, torch.Tensor]: """Coarsen each tensor along a given axis by a given factor.""" coarsened_tensor_dict = {} for name, tensor in tensor_dict.items(): coarsened_tensor_dict[name] = tensor.unfold( dimension=TIME_DIM, size=coarsen_factor, step=coarsen_factor ).mean(dim=-1) return coarsened_tensor_dict