Source code for fme.ace.aggregator.inference.seasonal

import dataclasses
import logging
from collections.abc import Mapping
from typing import Any, cast

import numpy as np
import torch
import xarray as xr

from fme.ace.aggregator.plotting import plot_paneled_data
from fme.core.dataset.data_typing import VariableMetadata
from fme.core.device import get_device
from fme.core.distributed import Distributed
from fme.core.gridded_ops import GriddedOperations
from fme.core.typing_ import TensorMapping
from fme.core.wandb import Image

from .build_context import MetricBuildContext, maybe_filter
from .data import InferenceBatchData, MetricBuildResult, SubAggregator


class SeasonalAggregator:
    def __init__(
        self,
        ops: GriddedOperations,
        variable_metadata: Mapping[str, VariableMetadata] | None = None,
    ):
        self._area_weighted_mean = ops.area_weighted_mean
        self._variable_metadata = variable_metadata
        self._target_dataset: xr.Dataset | None = None
        self._gen_dataset: xr.Dataset | None = None

    @torch.no_grad()
    def record_batch(
        self,
        data: InferenceBatchData,
    ):
        """Record a batch of data for computing time variability statistics."""
        time = data.time
        target_data = {name: value.cpu() for name, value in data.target.items()}
        gen_data = {name: value.cpu() for name, value in data.prediction.items()}
        target_ds = _to_dataset(target_data, time)
        gen_ds = _to_dataset(gen_data, time)

        # must keep a separate dataset for each sample to avoid averaging across
        # samples when we groupby year
        if self._target_dataset is None:
            self._target_dataset = target_ds.groupby(
                target_ds.valid_time.dt.season
            ).sum(dim="stacked_sample_time", skipna=False)
        else:
            self._target_dataset = _add_dataarray(
                self._target_dataset,
                target_ds.groupby(target_ds.valid_time.dt.season).sum(
                    dim="stacked_sample_time", skipna=False
                ),
            )

        if self._gen_dataset is None:
            self._gen_dataset = gen_ds.groupby(gen_ds.valid_time.dt.season).sum(
                dim="stacked_sample_time", skipna=False
            )
        else:
            self._gen_dataset = _add_dataarray(
                self._gen_dataset,
                gen_ds.groupby(gen_ds.valid_time.dt.season).sum(
                    dim="stacked_sample_time", skipna=False
                ),
            )

    @torch.no_grad()
    def get_logs(self, label: str) -> dict[str, Any]:
        if self._target_dataset is None or self._gen_dataset is None:
            raise ValueError("No data has been recorded yet.")
        dist = Distributed.get_instance()
        if dist.world_size > 1:
            target = _reduce_datasets(dist, self._target_dataset)
            gen = _reduce_datasets(dist, self._gen_dataset)
        else:
            target = self._target_dataset
            gen = self._gen_dataset
        if not dist.is_root():
            return {}
        if target is None or gen is None:
            raise RuntimeError("we are on root but no data was collected")

        if len(gen.season) < 4 or len(target.season) < 4:
            return {}  # seasonal metrics undefined when not all seasons are recorded

        target = cast(xr.Dataset, target / target["counts"])  # type: ignore
        gen = cast(xr.Dataset, gen / gen["counts"])  # type: ignore
        bias = gen - target
        plots: dict[str, Image] = {}
        metric_logs: dict[str, float] = {}

        for name in gen.data_vars.keys():
            if name == "counts":
                continue

            if self._variable_metadata is not None and name in self._variable_metadata:
                long_name = self._variable_metadata[name].display_long_name(name)
                units = self._variable_metadata[name].display_units()
                caption_name = f"{long_name} ({units})"
            else:
                caption_name = name

            target_mean_pattern = target[name].mean(dim="season")
            gen_anomaly = gen[name] - target_mean_pattern
            target_anomaly = target[name] - target_mean_pattern
            r2 = get_r2(gen_anomaly, target_anomaly)

            image = plot_paneled_data(
                [
                    [
                        target_anomaly.sel(season="DJF").values,
                        target_anomaly.sel(season="MAM").values,
                        target_anomaly.sel(season="JJA").values,
                        target_anomaly.sel(season="SON").values,
                    ],
                    [
                        gen_anomaly.sel(season="DJF").values,
                        gen_anomaly.sel(season="MAM").values,
                        gen_anomaly.sel(season="JJA").values,
                        gen_anomaly.sel(season="SON").values,
                    ],
                ],
                diverging=True,
                caption=(
                    f"Seasonal time-mean anomaly of {caption_name} for target (top) "
                    f"and gen (bottom) starting with DJF, R2={r2:.4f}. "
                    "Time-mean of target is subtracted from predictions and target."
                ),
            )
            plots[f"anomaly/{name}"] = image

            image_err = plot_paneled_data(
                [
                    [
                        bias[name].sel(season="DJF").values,
                        bias[name].sel(season="MAM").values,
                    ],
                    [
                        bias[name].sel(season="JJA").values,
                        bias[name].sel(season="SON").values,
                    ],
                ],
                diverging=True,
                caption=(
                    f"Seasonal bias of {caption_name} for DJF (Upper-Left), "
                    "MAM (UR), JJA (LL), and SON (LR). "
                    f"Seasonal anomaly R2={r2:.4f} (excludes time-mean of target)."
                ),
            )
            plots[f"bias/{name}"] = image_err

            mse_tensor = self._area_weighted_mean(
                torch.as_tensor(bias[name].values ** 2),
                name=name,
            )
            for i, season in enumerate(bias[name].season.values):
                rmse = float(mse_tensor[i].sqrt().numpy())
                metric_logs[f"time-mean-rmse/{name}-{season}"] = rmse
            rmse = float(
                # must compute area mean and then mean across seasons
                # before sqrt, so we can't use metrics.root_mean_squared_error
                mse_tensor.mean().sqrt().numpy()
            )
            metric_logs[f"time-mean-rmse/{name}"] = rmse

        if len(label) > 0:
            label = label + "/"
        logs: dict[str, Image | float] = {}
        logs.update({f"{label}{name}": plots[name] for name in plots.keys()})
        logs.update({f"{label}{name}": val for name, val in metric_logs.items()})
        return logs

    def get_dataset(self) -> xr.Dataset:
        logging.debug(
            "get_dataset not implemented for SeasonalAggregator. "
            "Returning an empty dataset."
        )
        return xr.Dataset()


ALL_SEASONS = np.asarray(["DJF", "MAM", "JJA", "SON"])


def _add_dataarray(da1: xr.DataArray, da2: xr.DataArray):
    """
    Perform dataarray addition, assuming any missing season indices
    have zero values.
    """
    if len(da1.season) < 4:
        da1 = da1.reindex(season=ALL_SEASONS, fill_value=0)
    if len(da2.season) < 4:
        da2 = da2.reindex(season=ALL_SEASONS, fill_value=0)
    return da1 + da2


def get_r2(da: xr.DataArray, target: xr.DataArray) -> float:
    """Compute the R2 value of the target compared to the reference."""
    SS_ref = np.sum((target.values - np.mean(target.values)) ** 2)
    SS_pred = np.sum((da - target).values ** 2)
    return float(1 - SS_pred / SS_ref)


def _reduce_datasets(dist: Distributed, dataset: xr.Dataset) -> xr.Dataset | None:
    """
    Add the dataset across all processes.

    Requires all dataset variables have the same shape.
    """
    # collect all data into one torch.Tensor for gathering, sort for determinism
    names = sorted(list(dataset.data_vars))
    # 'counts' must be present in the data, but we don't want to pack it with the others
    names.remove("counts")
    for name in names:
        if dataset[name].shape != dataset[names[0]].shape:
            raise ValueError(
                f"Variable {name} has shape {dataset[name].shape} "
                f"which is not equal to {dataset[names[0]].shape}"
            )
    tensor = torch.stack(
        [torch.as_tensor(dataset[name].values) for name in names],
        dim=0,
    ).to(get_device())
    reduced = dist.reduce_sum(tensor).cpu()
    reduced_counts = dist.reduce_sum(
        torch.as_tensor(dataset["counts"].values).to(get_device())
    ).cpu()
    dataset_out = xr.Dataset(
        {name: (["season", "lat", "lon"], reduced[i]) for i, name in enumerate(names)},
        coords=dataset.coords,
    )
    dataset_out["counts"] = xr.DataArray(reduced_counts, dims=["season"])
    return dataset_out


@torch.no_grad()
def _to_dataset(data: TensorMapping, time: xr.DataArray) -> xr.Dataset:
    """Convert a dictionary of data to an xarray dataset."""
    assert time.dims == ("sample", "time")  # must be consistent with this module
    data_vars = {}
    for name, tensor in data.items():
        data_vars[name] = (["sample", "time", "lat", "lon"], tensor)
    data_vars["counts"] = (["sample", "time"], np.ones(shape=time.shape))
    return xr.Dataset(data_vars, coords={"valid_time": time})


[docs]@dataclasses.dataclass class SeasonalMetricConfig: variables: list[str] | None = None name: str = "seasonal" enabled: bool = False strict: bool = True def get_name(self) -> str: return self.name def build(self, ctx: MetricBuildContext) -> MetricBuildResult: agg: SubAggregator = SeasonalAggregator( ops=ctx.ops, variable_metadata=ctx.variable_metadata, ) return MetricBuildResult(aggregator=maybe_filter(agg, self.variables))