Source code for fme.ace.aggregator.inference.spectrum

import dataclasses
import warnings
from collections import defaultdict
from collections.abc import Callable, Mapping

import matplotlib.pyplot as plt
import numpy as np
import torch
import xarray as xr

from fme.core.dataset.data_typing import VariableMetadata
from fme.core.distributed import Distributed
from fme.core.fill import SmoothFloodFill
from fme.core.gridded_ops import GriddedOperations
from fme.core.metrics import spherical_power_spectrum
from fme.core.typing_ import TensorMapping

from .build_context import MetricBuildContext, MetricNotSupportedError, maybe_filter
from .data import InferenceBatchData, MetricBuildResult, SubAggregator


class SphericalPowerSpectrumAggregator:
    """Average the power spectrum over batch and time dimensions."""

    def __init__(
        self,
        gridded_operations: GriddedOperations,
        nan_fill_fn: Callable[[torch.Tensor, str], torch.Tensor] = lambda x, _: x,
        report_plot: bool = True,
        variable_metadata: Mapping[str, VariableMetadata] | None = None,
    ):
        Distributed.get_instance().require_no_spatial_parallelism(
            "SphericalPowerSpectrumAggregator does not support spatial parallelism."
        )
        self._real_sht = gridded_operations.get_real_sht()
        self._power_spectrum: dict[str, torch.Tensor] = {}
        self._counts: dict[str, int] = defaultdict(int)
        self._nan_fill_fn = nan_fill_fn
        self._report_plot = report_plot
        self._variable_metadata: Mapping[str, VariableMetadata] = (
            variable_metadata or {}
        )

    @torch.no_grad()
    def record_batch(self, data: InferenceBatchData):
        self.record_data(data.prediction)

    @torch.no_grad()
    def record_data(self, prediction: TensorMapping):
        for name in prediction:
            batch_size = prediction[name].shape[0]
            time_size = prediction[name].shape[1]
            tensor = self._nan_fill_fn(prediction[name], name)
            power_spectrum = spherical_power_spectrum(tensor, self._real_sht)
            mean_power_spectrum = torch.mean(power_spectrum, dim=(0, 1))
            new_count = batch_size * time_size
            if name not in self._power_spectrum:
                self._power_spectrum[name] = mean_power_spectrum
            else:
                weighted_average = (
                    new_count * mean_power_spectrum
                    + self._counts[name] * self._power_spectrum[name]
                ) / (new_count + self._counts[name])
                self._power_spectrum[name] = weighted_average
            self._counts[name] += new_count

    def get_mean(self) -> dict[str, torch.Tensor]:
        dist = Distributed.get_instance()
        logs = {}
        sorted_names = sorted(list(self._power_spectrum))
        for name in sorted_names:
            _mean_spectrum = self._power_spectrum[name]
            if dist.world_size > 1:
                # assuming same count on all workers
                _mean_spectrum = dist.reduce_mean(_mean_spectrum)
            logs[name] = _mean_spectrum
        return logs

    @torch.no_grad()
    def get_logs(self, label: str) -> dict[str, plt.Figure]:
        logs: dict[str, plt.Figure] = {}
        if self._report_plot:
            for name, spectrum in self.get_mean().items():
                fig = _plot_spectrum_pair(spectrum.cpu(), target=None)
                logs[f"{label}/{name}"] = fig
                plt.close(fig)
        return logs

    @torch.no_grad()
    def get_dataset(self) -> xr.Dataset:
        data_vars = {}
        for name, spectrum in self.get_mean().items():
            spectrum_np = spectrum.cpu().numpy()
            wavenumber = np.arange(len(spectrum_np))
            metadata = self._variable_metadata.get(name)
            if metadata is not None and metadata.long_name is not None:
                long_name = f"spherical power spectrum of {metadata.long_name}"
                units = f"({metadata.units or 'unknown_units'})^2"
            else:
                long_name = f"spherical power spectrum of {name}"
                units = "unknown_units"
            attrs = {"long_name": long_name, "units": units}
            data_vars[name] = xr.DataArray(
                spectrum_np,
                dims=["wavenumber"],
                coords={"wavenumber": wavenumber},
                attrs=attrs,
            )
        return xr.Dataset(data_vars)


class PairedSphericalPowerSpectrumAggregator:
    """Record batches and return plots for paired prediction and target data.

    Parameters:
        gridded_operations: lat/lon grid helpers used to build the SHT.
        report_plot: master plot toggle. When ``False``, no per-variable
            spectrum-pair PNGs are emitted regardless of ``plot_variables``.
        nan_fill_fn: per-variable NaN filler applied before the SHT.
        variable_metadata: optional CF metadata for spectrum-dataset attrs.
        report_directional_bias: when ``True`` (default), emit
            ``positive_norm_bias`` and ``negative_norm_bias`` scalar metrics
            per variable. When ``False``, only ``mean_abs_norm_bias`` and
            ``smallest_scale_norm_bias`` are emitted — useful when the
            directional split is redundant clutter in the W&B run.
        plot_variables: when ``None`` (default), plot spectrum-pair figures
            for every variable (current behaviour). When a list, restrict
            plotting to those variable names — scalar metrics are still
            emitted for every variable regardless. Lets callers keep the
            cheap scalar comparisons cohort-wide while limiting the
            expensive per-variable plot output to a small reference set
            (e.g. upper-level + h500 + precipitation).
    """

    def __init__(
        self,
        gridded_operations: GriddedOperations,
        report_plot: bool,
        nan_fill_fn: Callable[[torch.Tensor, str], torch.Tensor] = lambda x, _: x,
        variable_metadata: Mapping[str, VariableMetadata] | None = None,
        report_directional_bias: bool = True,
        plot_variables: list[str] | None = None,
    ):
        self._gen_aggregator = SphericalPowerSpectrumAggregator(
            gridded_operations,
            nan_fill_fn,
            report_plot=False,
            variable_metadata=variable_metadata,
        )
        self._target_aggregator = SphericalPowerSpectrumAggregator(
            gridded_operations,
            nan_fill_fn,
            report_plot=False,
            variable_metadata=variable_metadata,
        )
        self._report_plot = report_plot
        self._report_directional_bias = report_directional_bias
        self._plot_variables: set[str] | None = (
            None if plot_variables is None else set(plot_variables)
        )

    @torch.no_grad()
    def record_batch(self, data: InferenceBatchData):
        target = data.target if data.has_target else None
        self.record_paired_data(prediction=data.prediction, target=target)

    @torch.no_grad()
    def record_paired_data(
        self,
        prediction: TensorMapping,
        target: TensorMapping | None,
    ):
        self._gen_aggregator.record_data(prediction)
        if target is not None:
            self._target_aggregator.record_data(target)

    @torch.no_grad()
    def get_logs(self, label: str) -> dict[str, plt.Figure | float]:
        logs: dict[str, plt.Figure | float] = {}
        gen_spectrum = self._gen_aggregator.get_mean()
        target_spectrum = self._target_aggregator.get_mean()
        if self._report_plot:
            for name in gen_spectrum:
                if (
                    self._plot_variables is not None
                    and name not in self._plot_variables
                ):
                    continue
                gen_spectrum_cpu = gen_spectrum[name].cpu()
                if name not in target_spectrum:
                    warnings.warn(f"Missing power spectrum target data for {name}")
                    target_spectrum_cpu = None
                else:
                    target_spectrum_cpu = target_spectrum[name].cpu()
                fig = _plot_spectrum_pair(gen_spectrum_cpu, target_spectrum_cpu)
                logs[f"{label}/{name}"] = fig
                plt.close(fig)
        metrics = _get_spectrum_metrics(
            gen_spectrum,
            target_spectrum,
            report_directional_bias=self._report_directional_bias,
        )
        for name, value in metrics.items():
            logs[f"{label}/{name}"] = value
        return logs

    @torch.no_grad()
    def get_dataset(self) -> xr.Dataset:
        gen_ds = self._gen_aggregator.get_dataset()
        target_ds = self._target_aggregator.get_dataset()
        if not gen_ds or not target_ds:
            return xr.Dataset()
        source = ["prediction", "target"]
        ds = xr.concat([gen_ds, target_ds], dim="source")
        return ds.assign_coords(source=source)


def _get_spectrum_metrics(
    gen_spectrum: dict[str, torch.Tensor],
    target_spectrum: dict[str, torch.Tensor],
    report_directional_bias: bool = True,
) -> dict[str, float]:
    """
    Compute metrics for the spectrum.

    Args:
        gen_spectrum: Dictionary of 1-dimensional generated mean power spectra.
        target_spectrum: Dictionary of 1-dimensional target mean power spectra.
        report_directional_bias: when True (default), emit per-variable
            ``positive_norm_bias`` and ``negative_norm_bias`` keys. When
            False, only ``mean_abs_norm_bias`` and ``smallest_scale_norm_bias``
            are emitted — the directional pair is redundant with
            ``mean_abs_norm_bias``.

    Returns:
        Dictionary of metrics.
    """
    metrics = {}
    for name in gen_spectrum:
        if len(gen_spectrum[name].shape) != 1:
            raise ValueError(
                f"Expected 1-dimensional power spectrum for {name}, "
                f"got {gen_spectrum[name].shape}"
            )
        metrics[f"smallest_scale_norm_bias/{name}"] = get_smallest_scale_power_bias(
            gen_spectrum[name], target_spectrum[name]
        )
        positive_bias, negative_bias = get_positive_and_negative_power_bias(
            gen_spectrum[name], target_spectrum[name]
        )
        if report_directional_bias:
            metrics[f"positive_norm_bias/{name}"] = positive_bias
            metrics[f"negative_norm_bias/{name}"] = negative_bias
        metrics[f"mean_abs_norm_bias/{name}"] = abs(positive_bias) + abs(negative_bias)
    return metrics


def get_smallest_scale_power_bias(
    gen_spectrum: torch.Tensor,
    target_spectrum: torch.Tensor,
) -> float:
    return float((gen_spectrum[-1] / target_spectrum[-1] - 1).mean().cpu())


def get_positive_and_negative_power_bias(
    gen_spectrum: torch.Tensor,
    target_spectrum: torch.Tensor,
) -> tuple[float, float]:
    """
    Compute the positive and negative power bias for the spectrum,
    normalized by the target spectrum.
    """
    ratio = gen_spectrum / target_spectrum - 1
    positive_bias = ratio[ratio > 0].sum() / target_spectrum.shape[0]
    negative_bias = ratio[ratio < 0].sum() / target_spectrum.shape[0]
    return float(positive_bias.cpu()), float(negative_bias.cpu())


def _plot_spectrum_pair(
    prediction: torch.Tensor, target: torch.Tensor | None
) -> plt.Figure:
    fig, ax = plt.subplots()
    ax.plot(prediction, "--", label="prediction", color="C1")
    if target is not None:
        ax.plot(target, "-", label="target", color="C0")
    ax.set(yscale="log")
    ax.set(xlabel="wavenumber")
    ax.legend()
    plt.tight_layout()
    return fig


[docs]@dataclasses.dataclass class PowerSpectrumMetricConfig: """ Parameters: variables: when set, filter the aggregator to these variables only (affects every output — scalar metrics and the per-variable spectrum-pair plots). name: log prefix and wandb key prefix. enabled: master toggle for the metric. strict: raise if the metric can't be built (e.g. wrong grid). report_directional_bias: when False, drop the ``positive_norm_bias`` and ``negative_norm_bias`` scalar metrics. ``mean_abs_norm_bias`` is unaffected (and is the directional pair's redundant summary). Defaults to True for backwards compatibility. plot_variables: when set, restrict the per-variable spectrum-pair plot to these variable names — scalar metrics are still emitted for every variable that passed ``variables``. Use to keep the cheap scalar comparisons cohort-wide while limiting the expensive per-variable plot output to a small reference list. Defaults to None (plot everything that passed ``variables``, current behaviour). """ variables: list[str] | None = None name: str = "power_spectrum" enabled: bool = True strict: bool = False report_directional_bias: bool = True plot_variables: list[str] | None = None def get_name(self) -> str: return self.name def build(self, ctx: MetricBuildContext) -> MetricBuildResult: try: agg: SubAggregator = PairedSphericalPowerSpectrumAggregator( gridded_operations=ctx.ops, nan_fill_fn=SmoothFloodFill(num_steps=4), report_plot=True, variable_metadata=ctx.variable_metadata, report_directional_bias=self.report_directional_bias, plot_variables=self.plot_variables, ) except NotImplementedError as e: raise MetricNotSupportedError(str(e)) from e return MetricBuildResult(aggregator=maybe_filter(agg, self.variables))