Source code for fme.ace.aggregator.inference.enso.dynamic_index

import dataclasses
import datetime
import logging
from collections.abc import Callable
from typing import Any

import cftime
import numpy as np
import torch
import xarray as xr
from matplotlib import pyplot as plt

from fme.core.coordinates import LatLonCoordinates
from fme.core.distributed import Distributed
from fme.core.gridded_ops import LatLonOperations
from fme.core.typing_ import TensorDict

from ...plotting import plot_mean_and_samples
from ..build_context import MetricBuildContext, MetricNotSupportedError
from ..data import InferenceBatchData, MetricBuildResult
from ..utils import (
    SAMPLE_DIM,
    TIME_DIM,
    LatLonRegion,
    _calculate_sample_average_power_spectrum,
    _compute_sample_mean_std,
    anomalies_from_monthly_climo,
    compute_psd_band_power,
    convert_cftime_to_datetime_coord,
    running_monthly_mean,
)

SEA_SURFACE_TEMPERATURE_NAMES = ["sst", "surface_temperature", "TS"]


class RegionalIndexAggregator:
    """Aggregator for computing a regional index, in this case a monthly- and area-
    weighted average of a variable over a region.

    Args:
        regional_weights: A tensor of weights for the region.
        regional_mean: A function that computes the regional mean of a variable,
            given weights.
        running_mean_n_months: The number of months to use for the running mean.
    """

    _sample_dim = SAMPLE_DIM
    _time_dim = TIME_DIM
    sea_surface_temperature_names = SEA_SURFACE_TEMPERATURE_NAMES

    def __init__(
        self,
        regional_weights: torch.Tensor,
        regional_mean: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
        running_mean_n_months: int = 5,
    ):
        self._regional_weights = regional_weights
        self._regional_mean = regional_mean
        self._running_mean_n_months = running_mean_n_months
        self._raw_indices: TensorDict = {}
        self._raw_index_times: xr.DataArray | None = None
        self._calendar: str | None = None
        self._already_logged: list[str] = []

    def record_batch(self, data: InferenceBatchData) -> None:
        time = data.time
        prediction = data.prediction
        for sst_name in self.sea_surface_temperature_names:
            if sst_name not in prediction:
                if sst_name not in self._already_logged:
                    logging.info(
                        f"Variable {sst_name} not found in data. "
                        "Skipping Nino3.4 index computation for this variable."
                    )
                    self._already_logged.append(sst_name)
                continue
            regional_average = self._regional_mean(
                prediction[sst_name], self._regional_weights
            )
            if sst_name not in self._raw_indices:
                self._raw_indices[sst_name] = regional_average
            else:
                self._raw_indices[sst_name] = torch.cat(
                    [self._raw_indices[sst_name], regional_average],
                    dim=self._time_dim,
                )
        if self._raw_index_times is None:
            self._raw_index_times = time
        else:
            self._raw_index_times = xr.concat([self._raw_index_times, time], dim="time")
        if self._calendar is None and time.dt.calendar is not None:
            self._calendar = time.dt.calendar

    def get_indices(self) -> xr.Dataset:
        indices = {}
        for sst_name in self.sea_surface_temperature_names:
            if sst_name not in self._raw_indices:
                continue
            anomaly_index_values = anomalies_from_monthly_climo(
                self._raw_indices[sst_name], self._raw_index_times
            )
            time_averaged_index, unique_months = running_monthly_mean(
                anomaly_index_values,
                self._raw_index_times,
                n_months=self._running_mean_n_months,
            )
            gathered_index = self._get_gathered_index(
                time_averaged_index, unique_months.years, unique_months.months
            )
            if gathered_index is None:
                continue
            indices[sst_name] = gathered_index
        return xr.Dataset(indices)

    def _get_gathered_index(
        self, index: torch.Tensor, years: torch.Tensor, months: torch.Tensor
    ) -> xr.DataArray | None:
        dist = Distributed.get_instance()
        if dist.world_size > 1:
            gathered_index, gathered_years, gathered_months = (
                dist.gather_irregular(index),
                dist.gather_irregular(years),
                dist.gather_irregular(months),
            )
            if (
                gathered_index is None
                or gathered_years is None
                or gathered_months is None
            ):
                return None
        else:
            gathered_index, gathered_years, gathered_months = [index], [years], [months]
        return self._to_index_data_array(
            gathered_index, gathered_years, gathered_months
        )

    def _to_index_data_array(
        self,
        indices: list[torch.Tensor],
        years: list[torch.Tensor],
        months: list[torch.Tensor],
    ) -> xr.DataArray:
        index_data_arrays = []
        for index, year, month in zip(indices, years, months):
            time_coord = [  # approximate the middle of the month
                cftime.datetime(
                    single_year.item(),
                    single_month.item(),
                    15,
                    calendar=(
                        self._calendar if self._calendar is not None else "standard"
                    ),
                )
                for single_year, single_month in zip(year.cpu(), month.cpu())
            ]
            index_data_arrays.append(
                xr.DataArray(
                    data=index.cpu().numpy(),
                    dims=["sample", "time"],
                    coords={"time": time_coord},
                )
            )
        return xr.concat(index_data_arrays, dim="sample")

    def get_logs(self, label: str) -> dict[str, Any]:
        indices = self.get_indices()
        logs = {}
        for sst_name in self.sea_surface_temperature_names:
            if sst_name in indices and indices[sst_name].sizes["time"] > 1:
                fig, ax = plt.subplots(1, 1)
                index_plottable = indices[sst_name].assign_coords(
                    {"time": convert_cftime_to_datetime_coord(indices[sst_name].time)}
                )
                plot_mean_and_samples(
                    ax, index_plottable, "ensemble_mean", time_series_dim="time"
                )
                ax.set_title("Nino3.4 Index")
                ax.set_ylabel("K")
                ax.legend()
                fig.tight_layout()
                logs[f"{sst_name}_nino34_index"] = fig
                logs[f"{sst_name}_nino34_index_std"] = _compute_sample_mean_std(
                    indices[sst_name]
                )
        for sst_name in self.sea_surface_temperature_names:
            if (
                sst_name in indices
                and indices[sst_name].dropna("time").sizes["time"] > 1
            ):
                freq, power_spectrum = _calculate_sample_average_power_spectrum(
                    indices[sst_name]
                )
                fig, ax = plt.subplots(1, 1)
                ax.plot(freq, power_spectrum, label="predicted ensemble mean")
                ax.set_title("Power Spectrum of Nino3.4 Index")
                ax.set_xlabel("Frequency [cycles/year]")
                ax.set_ylabel("Power [K**2]")
                ax.set(yscale="log")
                ax.legend()
                fig.tight_layout()
                logs[f"{sst_name}_nino34_index_power_spectrum"] = fig
                logs[f"{sst_name}_nino34_index_power_2_5yr"] = compute_psd_band_power(
                    freq, power_spectrum
                )
                logs[f"{sst_name}_nino34_index_power_1_16yr"] = compute_psd_band_power(
                    freq, power_spectrum, period_bounds=(1.0, 16.0)
                )

        if len(label) > 0:
            label = label + "/"
        return {f"{label}{k}": v for k, v in logs.items()}

    def get_dataset(self) -> xr.Dataset:
        indices = self.get_indices()
        if indices is None:
            return xr.Dataset()
        else:
            return indices


class PairedRegionalIndexAggregator:
    """A paired (target, prediction) RegionalIndexAggregator."""

    def __init__(
        self,
        target_aggregator: RegionalIndexAggregator,
        prediction_aggregator: RegionalIndexAggregator,
    ):
        self._prediction_aggregator = prediction_aggregator
        self._target_aggregator = target_aggregator

    def record_batch(
        self,
        data: InferenceBatchData,
    ) -> None:
        target_data = data.replace(prediction=data.target)
        prediction_data = data.replace(prediction=data.prediction)
        self._target_aggregator.record_batch(target_data)
        self._prediction_aggregator.record_batch(prediction_data)

    def get_logs(self, label: str) -> dict[str, Any]:
        target_indices = self._target_aggregator.get_indices()
        prediction_indices = self._prediction_aggregator.get_indices()
        logs = {}
        for sst_name in self._prediction_aggregator.sea_surface_temperature_names:
            if (
                sst_name in prediction_indices
                and prediction_indices[sst_name].sizes["time"] > 1
            ):
                target_indices_plottable = target_indices.assign_coords(
                    {"time": convert_cftime_to_datetime_coord(target_indices.time)}
                )
                prediction_indices_plottable = prediction_indices.assign_coords(
                    {"time": convert_cftime_to_datetime_coord(prediction_indices.time)}
                )

                fig, ax = plt.subplots(1, 1)
                plot_mean_and_samples(
                    ax,
                    prediction_indices_plottable[sst_name],
                    "predicted ensemble mean",
                    time_series_dim="time",
                )
                plot_mean_and_samples(
                    ax,
                    target_indices_plottable[sst_name],
                    "target",
                    time_series_dim="time",
                    color="orange",
                    plot_samples=False,
                )
                ax.set_title("Nino3.4 Index")
                ax.set_ylabel("K")
                ax.legend()
                fig.tight_layout()
                logs[f"{sst_name}_nino34_index"] = fig
                logs[f"{sst_name}_nino34_index_std"] = _compute_sample_mean_std(
                    prediction_indices[sst_name]
                )
                logs[f"{sst_name}_nino34_index_std_norm"] = _compute_sample_mean_std(
                    prediction_indices[sst_name],
                    target_indices[sst_name],
                )
        for sst_name in self._prediction_aggregator.sea_surface_temperature_names:
            if (
                sst_name in prediction_indices
                and prediction_indices[sst_name].notnull().any().item()
            ):
                pred_freq, prediction_power_spectrum = (
                    _calculate_sample_average_power_spectrum(
                        prediction_indices[sst_name]
                    )
                )
                target_freq, target_power_spectrum = (
                    _calculate_sample_average_power_spectrum(target_indices[sst_name])
                )
                fig, ax = plt.subplots(1, 1)
                ax.plot(
                    pred_freq,
                    prediction_power_spectrum,
                    label="predicted ensemble mean",
                )
                ax.plot(
                    target_freq, target_power_spectrum, label="target", color="orange"
                )
                ax.set_title("Power Spectrum of Nino3.4 Index")
                ax.set_xlabel("Frequency [cycles/year]")
                ax.set(yscale="log")
                ax.legend()
                fig.tight_layout()
                logs[f"{sst_name}_nino34_index_power_spectrum"] = fig
                pred_power_2_5 = compute_psd_band_power(
                    pred_freq, prediction_power_spectrum
                )
                target_power_2_5 = compute_psd_band_power(
                    target_freq, target_power_spectrum
                )
                logs[f"{sst_name}_nino34_index_power_2_5yr"] = pred_power_2_5
                if target_power_2_5 != 0 and not np.isnan(target_power_2_5):
                    logs[f"{sst_name}_nino34_index_power_2_5yr_norm"] = (
                        pred_power_2_5 / target_power_2_5
                    )
                pred_power_1_16 = compute_psd_band_power(
                    pred_freq, prediction_power_spectrum, period_bounds=(1.0, 16.0)
                )
                target_power_1_16 = compute_psd_band_power(
                    target_freq, target_power_spectrum, period_bounds=(1.0, 16.0)
                )
                logs[f"{sst_name}_nino34_index_power_1_16yr"] = pred_power_1_16
                if target_power_1_16 != 0 and not np.isnan(target_power_1_16):
                    logs[f"{sst_name}_nino34_index_power_1_16yr_norm"] = (
                        pred_power_1_16 / target_power_1_16
                    )
        if len(label) > 0:
            label = label + "/"

        return {f"{label}{k}": v for k, v in logs.items()}

    def get_dataset(self) -> xr.Dataset:
        prediction = self._prediction_aggregator.get_dataset()
        target = self._target_aggregator.get_dataset()
        if len(prediction) == 0 or len(target) == 0:
            return xr.Dataset()
        else:
            return xr.concat(
                [
                    target.expand_dims({"source": ["target"]}),
                    prediction.expand_dims({"source": ["prediction"]}),
                ],
                dim="source",
            )


NINO34_LAT = (-5, 5)
NINO34_LON = (190, 240)


[docs]@dataclasses.dataclass class EnsoIndexMetricConfig: name: str = "enso_index" enabled: bool = True strict: bool = False def get_name(self) -> str: return self.name def build(self, ctx: MetricBuildContext) -> MetricBuildResult: if not isinstance(ctx.horizontal_coordinates, LatLonCoordinates): raise MetricNotSupportedError( "enso_index metric requires LatLonCoordinates." ) if not isinstance(ctx.ops, LatLonOperations): raise MetricNotSupportedError( "enso_index metric requires LatLonOperations." ) total_duration = ctx.n_timesteps * ctx.timestep if total_duration <= datetime.timedelta(days=730): raise MetricNotSupportedError( f"enso_index metric requires > ~2 years of data, " f"got {total_duration.days} days" ) nino34_region = LatLonRegion( lat_bounds=NINO34_LAT, lon_bounds=NINO34_LON, lat=ctx.horizontal_coordinates.lat, lon=ctx.horizontal_coordinates.lon, ) return MetricBuildResult( aggregator=PairedRegionalIndexAggregator( target_aggregator=RegionalIndexAggregator( regional_weights=nino34_region.regional_weights, regional_mean=ctx.ops.regional_area_weighted_mean, ), prediction_aggregator=RegionalIndexAggregator( regional_weights=nino34_region.regional_weights, regional_mean=ctx.ops.regional_area_weighted_mean, ), ) )