Source code for fme.ace.aggregator.inference.time_mean

import dataclasses
from collections.abc import Mapping, Sequence
from typing import Literal

import matplotlib.pyplot as plt
import torch
import xarray as xr

from fme.core.dataset.data_typing import VariableMetadata
from fme.core.distributed import Distributed
from fme.core.gridded_ops import GriddedOperations
from fme.core.typing_ import TensorDict, TensorMapping
from fme.core.wandb import Image

from ..plotting import plot_paneled_data
from .build_context import MetricBuildContext
from .data import InferenceBatchData, MetricBuildResult, SubAggregator


@dataclasses.dataclass
class _TargetGenPair:
    name: str
    target: torch.Tensor
    gen: torch.Tensor
    ops: GriddedOperations

    def bias(self):
        return self.gen - self.target

    def rmse(self) -> float:
        ret = float(
            self.ops.area_weighted_rmse(
                predicted=self.gen,
                truth=self.target,
                name=self.name,
            )
            .cpu()
            .numpy()
        )
        return ret

    def weighted_mean_bias(self) -> float:
        return float(
            self.ops.area_weighted_mean_bias(
                predicted=self.gen,
                truth=self.target,
                name=self.name,
            )
            .cpu()
            .numpy()
        )


def get_gen_shape(gen_data: TensorMapping):
    for name in gen_data:
        return gen_data[name].shape


class TimeMeanAggregator:
    _image_captions = {
        "bias_map": "{name} time-mean bias (generated - reference) [{units}]",
        "gen_map": "{name} time-mean generated [{units}]",
        "gen_target_map": (
            "{name} time-mean; (top) generated and (bottom) target [{units}]"
        ),
    }

    def __init__(
        self,
        gridded_operations: GriddedOperations,
        target: Literal["norm", "denorm"] = "denorm",
        variable_metadata: Mapping[str, VariableMetadata] | None = None,
        reference_means: xr.Dataset | None = None,
        log_variables: frozenset[str] | None = None,
    ):
        """
        Args:
            gridded_operations: Computes gridded operations.
            target: Whether to compute metrics on the normalized or denormalized data,
                defaults to "denorm".
            variable_metadata: Mapping of variable names their metadata that will
                used in generating logged image captions.
            reference_means: Dataset containing reference time-mean values
                for bias computation.
            log_variables: If provided, only include per-variable entries in
                get_logs and get_dataset for these variables. All variables
                are still recorded and available via get_data.
        """
        self._ops = gridded_operations
        self._target = target
        if variable_metadata is None:
            self._variable_metadata: Mapping[str, VariableMetadata] = {}
        else:
            self._variable_metadata = variable_metadata
        # Dictionaries of tensors of shape [n_lat, n_lon] represnting time means
        self._data: TensorDict | None = None
        self._n_timesteps = 0
        self._n_samples: int | None = None
        self._reference_means = reference_means
        self._reference_validated = False
        self._log_variables = log_variables

    @staticmethod
    def _add_or_initialize_time_mean(
        maybe_dict: TensorDict | None,
        new_data: TensorMapping,
        ignore_initial: bool = False,
    ) -> TensorDict:
        sample_dim = 0
        time_dim = 1
        if ignore_initial:
            time_slice = slice(1, None)
        else:
            time_slice = slice(0, None)
        if maybe_dict is None:
            d: TensorDict = {
                name: tensor[:, time_slice].sum(dim=time_dim).sum(dim=sample_dim)
                for name, tensor in new_data.items()
            }
        else:
            d = dict(maybe_dict)
            for name, tensor in new_data.items():
                d[name] += tensor[:, time_slice].sum(dim=time_dim).sum(dim=sample_dim)
        return d

    @torch.no_grad()
    def record_batch(
        self,
        data: InferenceBatchData,
    ):
        if self._target == "denorm":
            tensor_data = data.prediction
        else:
            tensor_data = data.prediction_norm
        i_time_start = data.i_time_start
        ignore_initial = i_time_start == 0
        self._data = self._add_or_initialize_time_mean(
            self._data, tensor_data, ignore_initial
        )
        if self._n_samples is None:
            self._n_samples = tensor_data[list(tensor_data)[0]].size(0)
        if ignore_initial:
            self._n_timesteps = tensor_data[list(tensor_data)[0]].size(1) - 1
        else:
            self._n_timesteps += tensor_data[list(tensor_data)[0]].size(1)
        if not self._reference_validated:
            if self._reference_means is not None:
                self.get_logs(label="")
            self._reference_validated = True

    def get_data(self) -> TensorDict:
        if self._n_timesteps == 0 or self._data is None:
            raise ValueError("No data recorded.")

        ret = {}
        dist = Distributed.get_instance()
        names = sorted(list(self._data.keys()))  # sort for rank-consistent order
        for name in names:
            value = self._data[name]
            gen = dist.reduce_mean(value / self._n_timesteps / self._n_samples)
            ret[name] = gen
        return ret

    @torch.no_grad()
    def get_logs(
        self, label: str, target_maps: dict[str, torch.Tensor] | None = None
    ) -> dict[str, float | Image]:
        logs: dict[str, float | Image] = {}
        data = self.get_data()
        gen_map_key = "gen_map"
        for name, pred in data.items():
            if self._log_variables is not None and name not in self._log_variables:
                continue
            if target_maps is not None and name in target_maps:
                gen_map_caption_key = "gen_target_map"
                data_panels = [[pred.cpu().numpy()], [target_maps[name].cpu().numpy()]]
            else:
                gen_map_caption_key = gen_map_key
                data_panels = [[pred.cpu().numpy()]]
            prediction_image = plot_paneled_data(
                data_panels,
                diverging=False,
                caption=self._get_caption(gen_map_caption_key, name),
            )
            logs.update(
                {
                    f"{gen_map_key}/{name}": prediction_image,
                }
            )
            if self._reference_means is not None and name in self._reference_means:
                pair = _TargetGenPair(
                    name=name,
                    target=torch.as_tensor(
                        self._reference_means[name].values, device=pred.device
                    ),
                    gen=pred,
                    ops=self._ops,
                )
                bias_image = plot_paneled_data(
                    [[pair.bias().cpu().numpy()]],
                    diverging=True,
                    caption=self._get_caption("bias_map", name),
                )
                logs.update({f"ref_bias/{name}": pair.weighted_mean_bias()})
                logs.update({f"ref_rmse/{name}": pair.rmse()})
                logs.update({f"ref_bias_map/{name}": bias_image})

        if len(label) != 0:
            return {f"{label}/{key}": logs[key] for key in logs}
        return logs

    def _get_caption(self, key: str, name: str) -> str:
        if name in self._variable_metadata:
            caption_name = self._variable_metadata[name].display_long_name(name)
            units = self._variable_metadata[name].display_units()
        else:
            caption_name, units = name, "unknown_units"
        caption = self._image_captions[key].format(name=caption_name, units=units)
        return caption

    def get_dataset(self) -> xr.Dataset:
        dims = ("lat", "lon")
        data = {}
        for name, pred in self.get_data().items():
            if self._log_variables is not None and name not in self._log_variables:
                continue
            if name in self._variable_metadata:
                long_name = self._variable_metadata[name].display_long_name(name)
                units = self._variable_metadata[name].display_units()
            else:
                long_name = name
                units = "unknown_units"
            gen_metadata = VariableMetadata(long_name=long_name, units=units).as_attrs()
            data.update(
                {
                    f"gen_map-{name}": xr.DataArray(
                        pred.cpu(),
                        dims=dims,
                        attrs=gen_metadata,
                    ),
                }
            )
        return xr.Dataset(data)


class TimeMeanEvaluatorAggregator:
    """Statistics and images on the time-mean state.

    This aggregator keeps track of the time-mean state, then computes
    statistics and images on that time-mean state when logs are retrieved.
    """

    _image_captions = {
        "bias_map": "{name} time-mean bias (generated - target) [{units}]",
    }

    def __init__(
        self,
        ops: GriddedOperations,
        horizontal_dims: list[str],
        target: Literal["norm", "denorm"] = "denorm",
        variable_metadata: Mapping[str, VariableMetadata] | None = None,
        reference_means: xr.Dataset | None = None,
        channel_mean_names: Sequence[str] | None = None,
        log_variables: frozenset[str] | None = None,
    ):
        """
        Args:
            ops: Computes gridded operations.
            horizontal_dims: Names of the horizontal dimensions.
            target: Whether to compute metrics on the normalized or denormalized data,
                defaults to "denorm".
            variable_metadata: Mapping of variable names their metadata that will
                used in generating logged image captions.
            reference_means: Dataset containing reference time-mean values
                for bias computation.
            channel_mean_names: Names of variables whose RMSE will be averaged. If
                not provided, all available variables will be used.
            log_variables: If provided, only include per-variable entries in
                get_logs and get_dataset for these variables. All variables
                are still recorded so that channel_mean is computed correctly.
        """
        self._ops = ops
        self._horizontal_dims = horizontal_dims
        self._target = target
        self._dist = Distributed.get_instance()
        if variable_metadata is None:
            self._variable_metadata: Mapping[str, VariableMetadata] = {}
        else:
            self._variable_metadata = variable_metadata
        self._log_variables = log_variables
        # Dictionaries of tensors of shape [n_lat, n_lon] represnting time means
        self._target_agg = TimeMeanAggregator(
            gridded_operations=ops, target=target, variable_metadata=variable_metadata
        )
        self._gen_agg = TimeMeanAggregator(
            gridded_operations=ops,
            target=target,
            variable_metadata=variable_metadata,
            reference_means=reference_means,
            log_variables=log_variables,
        )
        self._channel_mean_names = channel_mean_names

    @torch.no_grad()
    def record_batch(
        self,
        data: InferenceBatchData,
    ):
        if self._target == "norm":
            target_tensor = data.target_norm
            gen_tensor = data.prediction_norm
        else:
            target_tensor = data.target
            gen_tensor = data.prediction
        target_batch = data.replace(
            prediction=target_tensor, prediction_norm=target_tensor
        )
        gen_batch = data.replace(prediction=gen_tensor, prediction_norm=gen_tensor)
        self._target_agg.record_batch(target_batch)
        self._gen_agg.record_batch(gen_batch)

    def _get_target_gen_pairs(self) -> list[_TargetGenPair]:
        target_data = self._target_agg.get_data()
        gen_data = self._gen_agg.get_data()

        ret = []
        for name in gen_data.keys():
            ret.append(
                _TargetGenPair(
                    gen=gen_data[name],
                    target=target_data[name],
                    name=name,
                    ops=self._ops,
                )
            )
        return ret

    @torch.no_grad()
    def get_logs(self, label: str) -> dict[str, float | torch.Tensor | Image]:
        logs = self._gen_agg.get_logs("", target_maps=self._target_agg.get_data())
        preds = self._get_target_gen_pairs()
        bias_map_key = "bias_map"
        rmse_all_channels = {}
        for pred in preds:
            rmse_all_channels[pred.name] = pred.rmse()
            should_log = self._log_variables is None or pred.name in self._log_variables
            if should_log:
                bias_image = plot_paneled_data(
                    [[pred.bias().cpu().numpy()]],
                    diverging=True,
                    caption=self._get_caption(bias_map_key, pred.name),
                )
                plt.close("all")
                logs.update({f"rmse/{pred.name}": rmse_all_channels[pred.name]})
                if self._target == "denorm":
                    logs.update(
                        {
                            f"{bias_map_key}/{pred.name}": bias_image,
                            f"bias/{pred.name}": pred.weighted_mean_bias(),
                        }
                    )
        if self._target == "norm":
            metric_name = "rmse/channel_mean"
            if self._channel_mean_names is None:
                values_to_average = list(rmse_all_channels.values())
            else:
                missing = [
                    n for n in self._channel_mean_names if n not in rmse_all_channels
                ]
                if missing:
                    raise KeyError(
                        f"channel_mean_names contains entries not present in the "
                        f"recorded data: {missing}. Available: "
                        f"{sorted(rmse_all_channels)}."
                    )
                values_to_average = [
                    rmse_all_channels[name] for name in self._channel_mean_names
                ]
            logs.update({metric_name: sum(values_to_average) / len(values_to_average)})

        if len(label) != 0:
            return {f"{label}/{key}": logs[key] for key in logs}
        return logs

    def _get_caption(self, key: str, name: str) -> str:
        if name in self._variable_metadata:
            caption_name = self._variable_metadata[name].display_long_name(name)
            units = self._variable_metadata[name].display_units()
        else:
            caption_name, units = name, "unknown_units"
        caption = self._image_captions[key].format(name=caption_name, units=units)
        return caption

    def get_dataset(self) -> xr.Dataset:
        data = {}
        preds = self._get_target_gen_pairs()
        for pred in preds:
            if self._log_variables is not None and pred.name not in self._log_variables:
                continue
            if pred.name in self._variable_metadata:
                long_name = self._variable_metadata[pred.name].display_long_name(
                    pred.name
                )
                units = self._variable_metadata[pred.name].display_units()
            else:
                long_name = pred.name
                units = "unknown_units"
            gen_metadata = VariableMetadata(long_name=long_name, units=units).as_attrs()
            bias_metadata = self._variable_metadata.get(
                pred.name, VariableMetadata(long_name=long_name, units=units)
            ).as_attrs()
            data.update(
                {
                    f"bias_map-{pred.name}": xr.DataArray(
                        pred.bias().cpu(),
                        dims=self._horizontal_dims,
                        attrs=bias_metadata,
                    ),
                    f"gen_map-{pred.name}": xr.DataArray(
                        pred.gen.cpu(),
                        dims=self._horizontal_dims,
                        attrs=gen_metadata,
                    ),
                }
            )
        return xr.Dataset(data)


[docs]@dataclasses.dataclass class TimeMeanMetricConfig: variables: list[str] | None = None name: str | None = None target: Literal["denorm", "norm"] = "denorm" reference_data: str | None = None channel_mean_names: list[str] | None = None enabled: bool = True strict: bool = False def __post_init__(self): if self.name is None: self.name = "time_mean_norm" if self.target == "norm" else "time_mean" def get_name(self) -> str: return self.name # type: ignore[return-value] def build(self, ctx: MetricBuildContext) -> MetricBuildResult: is_norm = self.target == "norm" if self.reference_data is not None: ref = xr.open_dataset(self.reference_data, decode_timedelta=False) elif not is_norm: ref = ctx.time_mean_reference_data else: ref = None agg: SubAggregator = TimeMeanEvaluatorAggregator( ctx.ops, horizontal_dims=ctx.horizontal_coordinates.dims, target=self.target, variable_metadata=ctx.variable_metadata, reference_means=ref, channel_mean_names=( (self.channel_mean_names or ctx.channel_mean_names) if is_norm else None ), log_variables=( frozenset(self.variables) if self.variables is not None else None ), ) return MetricBuildResult(aggregator=agg)