Source code for fme.ace.aggregator.inference.ipo.ipo_index

import dataclasses
import datetime
import logging
from typing import Any

import cftime
import numpy as np
import torch
import xarray as xr
from matplotlib import pyplot as plt
from scipy import signal

from fme.core.coordinates import LatLonCoordinates
from fme.core.distributed import Distributed
from fme.core.typing_ import TensorDict

from ...plotting import plot_mean_and_samples
from ..build_context import MetricBuildContext, MetricNotSupportedError
from ..data import InferenceBatchData, MetricBuildResult
from ..utils import (
    LatLonRegion,
    UniqueMonths,
    _calculate_sample_average_power_spectrum,
    _compute_sample_mean_std,
    anomalies_from_monthly_climo,
    convert_cftime_to_datetime_coord,
    running_monthly_mean,
)

SAMPLE_DIM, TIME_DIM = 0, 1

DEFAULT_SST_NAMES = ["sst"]

MIN_YEARS_FOR_FILTERED_TPI = 80

TPI_REGIONS = {
    "T1": {"lat_bounds": (25.0, 45.0), "lon_bounds": (140.0, 215.0)},
    "T2": {"lat_bounds": (-10.0, 10.0), "lon_bounds": (170.0, 270.0)},
    "T3": {"lat_bounds": (-50.0, -15.0), "lon_bounds": (150.0, 200.0)},
}


def _nan_aware_regional_mean(data: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
    """Compute area-weighted regional mean, excluding NaN points.

    Args:
        data: Tensor of shape (sample, time, lat, lon), may contain NaN.
        weights: Tensor of shape (lat, lon) with regional area weights.

    Returns:
        Tensor of shape (sample, time) with the NaN-excluded weighted mean.
    """
    valid_mask = ~torch.isnan(data)
    data_filled = torch.where(valid_mask, data, torch.zeros_like(data))
    w = weights.to(data.device).unsqueeze(0).unsqueeze(0)
    weighted_sum = (data_filled * w * valid_mask).sum(dim=(-2, -1))
    weight_total = (w * valid_mask).sum(dim=(-2, -1))
    return weighted_sum / weight_total


def low_pass_filter(
    data: np.ndarray,
    cutoff_period_yrs: float = 13.0,
    sampling_freq_per_yr: int = 12,
    filter_order: int = 5,
    passband_ripple_db: float = 0.5,
) -> np.ndarray:
    """Apply a Chebyshev Type I low-pass filter (zero-phase) to a 1-d array.

    Args:
        data: Input time series (monthly values).
        cutoff_period_yrs: Cutoff period in years.
        sampling_freq_per_yr: Number of samples per year.
        filter_order: Filter order.
        passband_ripple_db: Passband ripple in dB.

    Returns:
        Filtered array of same shape as input.
    """
    nyquist_freq = 0.5 * sampling_freq_per_yr
    cutoff_freq = 1.0 / cutoff_period_yrs
    wn = cutoff_freq / nyquist_freq

    b, a = signal.cheby1(
        N=filter_order, rp=passband_ripple_db, Wn=wn, btype="low", analog=False
    )
    return signal.filtfilt(b, a, data)


class _IPORegionalAccumulator:
    """Accumulates area-weighted regional means for the three TPI regions.

    Uses NaN-aware weighted mean so that ocean-only SST fields (with NaN
    over land) are handled correctly.
    """

    def __init__(
        self,
        regions: dict[str, LatLonRegion],
        sst_names: list[str] | None = None,
    ):
        self._regions = regions
        self._sst_names = sst_names if sst_names is not None else DEFAULT_SST_NAMES
        self._raw_means: dict[str, TensorDict] = {region: {} for region in regions}
        self._raw_times: xr.DataArray | None = None
        self._calendar: str | None = None
        self._already_logged: list[str] = []

    def record_batch(self, data: InferenceBatchData) -> None:
        time = data.time
        prediction = data.prediction
        for sst_name in self._sst_names:
            if sst_name not in prediction:
                if sst_name not in self._already_logged:
                    logging.info(
                        f"Variable {sst_name} not found in data. "
                        "Skipping IPO TPI computation for this variable."
                    )
                    self._already_logged.append(sst_name)
                continue
            for region_name, region in self._regions.items():
                regional_avg = _nan_aware_regional_mean(
                    prediction[sst_name], region.regional_weights
                )
                if sst_name not in self._raw_means[region_name]:
                    self._raw_means[region_name][sst_name] = regional_avg
                else:
                    self._raw_means[region_name][sst_name] = torch.cat(
                        [self._raw_means[region_name][sst_name], regional_avg],
                        dim=TIME_DIM,
                    )
        if self._raw_times is None:
            self._raw_times = time
        else:
            self._raw_times = xr.concat([self._raw_times, time], dim="time")
        if self._calendar is None and time.dt.calendar is not None:
            self._calendar = time.dt.calendar

    def get_tpi_indices(self) -> xr.Dataset:
        """Compute TPI = T2_anom - 0.5 * (T1_anom + T3_anom) for each SST var.

        Returns monthly TPI (not low-pass filtered).
        """
        indices = {}
        for sst_name in self._sst_names:
            if not all(sst_name in self._raw_means[r] for r in self._regions):
                continue

            regional_anomalies = {}
            for region_name in self._regions:
                raw = self._raw_means[region_name][sst_name]
                anom = anomalies_from_monthly_climo(raw, self._raw_times)
                monthly_mean, unique_months = running_monthly_mean(
                    anom, self._raw_times, n_months=1
                )
                regional_anomalies[region_name] = monthly_mean

            tpi = regional_anomalies["T2"] - 0.5 * (
                regional_anomalies["T1"] + regional_anomalies["T3"]
            )

            gathered_tpi = self._gather_index(tpi, unique_months)
            if gathered_tpi is not None:
                indices[sst_name] = gathered_tpi
        return xr.Dataset(indices)

    def _gather_index(
        self, index: torch.Tensor, unique_months: UniqueMonths
    ) -> xr.DataArray | None:
        dist = Distributed.get_instance()
        if dist.world_size > 1:
            gathered_index = dist.gather_irregular(index)
            gathered_years = dist.gather_irregular(unique_months.years)
            gathered_months = dist.gather_irregular(unique_months.months)
            if (
                gathered_index is None
                or gathered_years is None
                or gathered_months is None
            ):
                return None
        else:
            gathered_index = [index]
            gathered_years = [unique_months.years]
            gathered_months = [unique_months.months]
        return self._to_data_array(gathered_index, gathered_years, gathered_months)

    def _to_data_array(
        self,
        indices: list[torch.Tensor],
        years: list[torch.Tensor],
        months: list[torch.Tensor],
    ) -> xr.DataArray:
        calendar = self._calendar if self._calendar is not None else "standard"
        index_data_arrays = []
        for index, year, month in zip(indices, years, months):
            time_coord = [
                cftime.datetime(
                    single_year.item(), single_month.item(), 15, calendar=calendar
                )
                for single_year, single_month in zip(year.cpu(), month.cpu())
            ]
            index_data_arrays.append(
                xr.DataArray(
                    data=index.cpu().numpy(),
                    dims=["sample", "time"],
                    coords={"time": time_coord},
                )
            )
        return xr.concat(index_data_arrays, dim="sample")


class PairedIPOIndexAggregator:
    """Paired (target, prediction) aggregator for IPO Tripolar Index.

    Computes the TPI (Henley et al. 2017) from model SST output,
    applies a 13-year Chebyshev low-pass filter, and reports:
    - Filtered TPI time series plot (prediction vs target)
    - Power spectrum of unfiltered monthly TPI
    - Scalar std ratio of filtered prediction vs target TPI
    """

    def __init__(
        self,
        lat: torch.Tensor,
        lon: torch.Tensor,
        cutoff_period_yrs: float = 13.0,
        sst_names: list[str] | None = None,
    ):
        regions = {
            name: LatLonRegion(
                lat=lat,
                lon=lon,
                lat_bounds=spec["lat_bounds"],
                lon_bounds=spec["lon_bounds"],
            )
            for name, spec in TPI_REGIONS.items()
        }
        self._sst_names = sst_names if sst_names is not None else DEFAULT_SST_NAMES
        self._target_accumulator = _IPORegionalAccumulator(
            regions, sst_names=self._sst_names
        )
        self._prediction_accumulator = _IPORegionalAccumulator(
            regions, sst_names=self._sst_names
        )
        self._cutoff_period_yrs = cutoff_period_yrs

    def record_batch(self, data: InferenceBatchData) -> None:
        target_data = data.replace(prediction=data.target)
        prediction_data = data.replace(prediction=data.prediction)
        self._target_accumulator.record_batch(target_data)
        self._prediction_accumulator.record_batch(prediction_data)

    def get_logs(self, label: str) -> dict[str, Any]:
        target_tpi = self._target_accumulator.get_tpi_indices()
        prediction_tpi = self._prediction_accumulator.get_tpi_indices()
        logs: dict[str, Any] = {}

        for sst_name in self._sst_names:
            if sst_name not in prediction_tpi or sst_name not in target_tpi:
                continue
            pred_da = prediction_tpi[sst_name]
            tgt_da = target_tpi[sst_name]
            if pred_da.sizes["time"] < 2:
                continue

            pred_filtered = self._apply_filter_to_samples(pred_da)
            tgt_filtered = self._apply_filter_to_samples(tgt_da)

            if pred_filtered is not None and tgt_filtered is not None:
                fig = self._plot_filtered_tpi(pred_filtered, tgt_filtered)
                logs[f"{sst_name}_ipo_tpi_filtered"] = fig

                logs[f"{sst_name}_ipo_tpi_std"] = _compute_sample_mean_std(
                    pred_filtered
                )
                logs[f"{sst_name}_ipo_tpi_std_norm"] = _compute_sample_mean_std(
                    pred_filtered, tgt_filtered
                )

                fig = self._plot_power_spectrum(pred_da, tgt_da, sst_name)
                logs[f"{sst_name}_ipo_tpi_power_spectrum"] = fig

        if len(label) > 0:
            label = label + "/"
        return {f"{label}{k}": v for k, v in logs.items()}

    def get_dataset(self) -> xr.Dataset:
        prediction_tpi = self._prediction_accumulator.get_tpi_indices()
        target_tpi = self._target_accumulator.get_tpi_indices()
        if len(prediction_tpi) == 0 or len(target_tpi) == 0:
            return xr.Dataset()
        return xr.concat(
            [
                target_tpi.expand_dims({"source": ["target"]}),
                prediction_tpi.expand_dims({"source": ["prediction"]}),
            ],
            dim="source",
        )

    def _apply_filter_to_samples(self, tpi_da: xr.DataArray) -> xr.DataArray | None:
        """Apply low-pass filter to each sample, trimming edge transients.

        Trims one cutoff period from each end to remove filtfilt edge artifacts.
        Returns None if the series is too short.
        """
        trim = int(self._cutoff_period_yrs * 12)
        min_length = MIN_YEARS_FOR_FILTERED_TPI * 12
        filtered_samples = []
        for sample in range(tpi_da.sizes["sample"]):
            sample_data = tpi_da.isel(sample=sample).dropna("time")
            if sample_data.sizes["time"] < min_length:
                return None
            filtered_values = low_pass_filter(
                sample_data.values, cutoff_period_yrs=self._cutoff_period_yrs
            )
            trimmed = xr.DataArray(
                filtered_values[trim:-trim],
                coords={"time": sample_data.time[trim:-trim]},
                dims=sample_data.dims,
            )
            filtered_samples.append(trimmed)
        return xr.concat(filtered_samples, dim="sample")

    def _plot_filtered_tpi(
        self, prediction: xr.DataArray, target: xr.DataArray
    ) -> plt.Figure:
        fig, ax = plt.subplots(1, 1)
        pred_plottable = prediction.assign_coords(
            {"time": convert_cftime_to_datetime_coord(prediction.time)}
        )
        tgt_plottable = target.assign_coords(
            {"time": convert_cftime_to_datetime_coord(target.time)}
        )
        plot_mean_and_samples(
            ax,
            pred_plottable,
            "predicted ensemble mean",
            time_series_dim="time",
        )
        plot_mean_and_samples(
            ax,
            tgt_plottable,
            "target",
            time_series_dim="time",
            color="orange",
            plot_samples=False,
        )
        ax.set_title("IPO TPI (13-yr low-pass filtered)")
        ax.set_ylabel("K")
        ax.legend()
        fig.tight_layout()
        return fig

    def _plot_power_spectrum(
        self,
        prediction_tpi: xr.DataArray,
        target_tpi: xr.DataArray,
        sst_name: str,
    ) -> plt.Figure:
        pred_freq, pred_power = _calculate_sample_average_power_spectrum(prediction_tpi)
        tgt_freq, tgt_power = _calculate_sample_average_power_spectrum(target_tpi)
        fig, ax = plt.subplots(1, 1)
        ax.plot(pred_freq, pred_power, label="predicted ensemble mean")
        ax.plot(tgt_freq, tgt_power, label="target", color="orange")
        ax.set_title("Power Spectrum of IPO TPI (unfiltered)")
        ax.set_xlabel("Frequency [cycles/year]")
        ax.set_ylabel("Power [K**2]")
        ax.set(xscale="log", yscale="log")
        ax.legend()
        fig.tight_layout()
        return fig


[docs]@dataclasses.dataclass class IpoIndexMetricConfig: name: str = "ipo_index" enabled: bool = True strict: bool = False def get_name(self) -> str: return self.name def build(self, ctx: MetricBuildContext) -> MetricBuildResult: if not isinstance(ctx.horizontal_coordinates, LatLonCoordinates): raise MetricNotSupportedError( "ipo_index metric requires LatLonCoordinates." ) total_duration = ctx.n_timesteps * ctx.timestep min_days = MIN_YEARS_FOR_FILTERED_TPI * 365 if total_duration <= datetime.timedelta(days=min_days): raise MetricNotSupportedError( f"ipo_index metric requires > ~{MIN_YEARS_FOR_FILTERED_TPI} years " f"of data, got {total_duration.days} days" ) return MetricBuildResult( aggregator=PairedIPOIndexAggregator( lat=ctx.horizontal_coordinates.lat, lon=ctx.horizontal_coordinates.lon, ) )