import abc
import dataclasses
from collections.abc import Mapping, Sequence
from typing import Literal
import torch
import xarray as xr
from fme.ace.aggregator.plotting import plot_paneled_data
from fme.core.dataset.data_typing import VariableMetadata
from fme.core.distributed import Distributed
from fme.core.ensemble import get_crps
from fme.core.gridded_ops import GriddedOperations
from fme.core.typing_ import EnsembleTensorDict, TensorMapping
from ..inference.build_context import MetricBuildContext, MetricNotSupportedError
from ..inference.data import MetricBuildResult
from .build_context import OneStepBuildContext, OneStepMetricBuildResult
def get_gen_shape(gen_data: TensorMapping):
for name in gen_data:
return gen_data[name].shape
def get_one_step_ensemble_aggregator(
gridded_operations: GriddedOperations,
target_time: int = 1,
log_mean_maps: bool = True,
metadata: Mapping[str, VariableMetadata] | None = None,
target: Literal["norm", "denorm"] = "denorm",
channel_mean_names: Sequence[str] | None = None,
report_variables: Sequence[str] | None = None,
) -> "SelectStepEnsembleAggregator":
return SelectStepEnsembleAggregator(
aggregator=_EnsembleAggregator(
gridded_operations=gridded_operations,
log_mean_maps=log_mean_maps,
metadata=metadata,
target=target,
channel_mean_names=channel_mean_names,
report_variables=report_variables,
),
i_target_time=target_time,
)
class ReducedMetric(abc.ABC):
"""
Used to record a metric value on batches of data (potentially out-of-memory)
and then get the final metric at the end.
"""
@abc.abstractmethod
def record(self, target: torch.Tensor, gen: torch.Tensor):
"""
Update metric for a batch of data.
"""
...
@abc.abstractmethod
def get(self) -> torch.Tensor:
"""
Get the final metric value.
"""
...
class CRPSMetric(ReducedMetric):
def __init__(self):
self._total = None
self._n_batches = 0
def record(self, target: torch.Tensor, gen: torch.Tensor):
crps = get_crps(gen=gen, target=target, alpha=0.95).mean(dim=(0, 1))
if self._total is None:
self._total = crps
else:
self._total += crps
self._n_batches += 1
def get(self) -> torch.Tensor:
if self._total is None:
raise ValueError("No batches have been recorded.")
return self._total / self._n_batches
class EnsembleMeanRMSEMetric(ReducedMetric):
"""
Computes the ensemble mean RMSE.
"""
def __init__(self):
self._total_rmse = None
self._n_batches = 0
def record(self, target: torch.Tensor, gen: torch.Tensor):
ensemble_mean = gen.mean(dim=1, keepdim=True) # mean over ensemble dimension
rmse = ((ensemble_mean - target) ** 2).mean(dim=(0, 1, 2)).sqrt()
if self._total_rmse is None:
self._total_rmse = rmse
else:
self._total_rmse += rmse
self._n_batches += 1
def get(self) -> torch.Tensor:
if self._total_rmse is None:
raise ValueError("No batches have been recorded.")
return self._total_rmse / self._n_batches
class SSRBiasMetric(ReducedMetric):
"""
Computes the spread-skill ratio bias (equal to (stdev / rmse) - 1).
"""
def __init__(self):
self._total_unbiased_mse = None
self._total_variance = None
self._n_batches = 0
def record(self, target: torch.Tensor, gen: torch.Tensor):
num_ensemble = gen.shape[1]
ensemble_mean = gen.mean(dim=1, keepdim=True) # batch, 1, time
mse = ((ensemble_mean - target) ** 2).mean(dim=(0, 1, 2)) # batch, 1, time
variance = gen.var(dim=1, unbiased=True).mean(dim=(0, 1))
self._add_unbiased_mse(mse, variance, num_ensemble)
self._add_variance(variance)
self._n_batches += 1
def _add_unbiased_mse(
self, mse: torch.Tensor, variance: torch.Tensor, num_ensemble: int
):
if self._total_unbiased_mse is None:
self._total_unbiased_mse = torch.zeros_like(mse)
# must remove the component of the MSE that is due to the
# variance of the generated values
self._total_unbiased_mse += mse - variance / num_ensemble
def _add_variance(self, variance: torch.Tensor):
if self._total_variance is None:
self._total_variance = torch.zeros_like(variance)
self._total_variance += variance
def get(self) -> torch.Tensor:
if self._total_unbiased_mse is None or self._total_variance is None:
raise ValueError("No batches have been recorded.")
spread = self._total_variance.sqrt()
# Clamp to avoid NaN from sqrt of negative values. The unbiased MSE
# correction (mse - variance/n_ensemble) can overshoot with small
# ensembles or few batches, producing negative values at some grid
# cells that do not indicate spread truly exceeding skill.
skill = torch.clamp(self._total_unbiased_mse, min=0.0).sqrt()
# When skill is zero (clamped or genuinely perfect), SSR is undefined.
# Use -1 as the convention (equivalent to zero spread).
return torch.where(skill > 0, spread / skill - 1, torch.tensor(-1.0))
class _EnsembleAggregator:
"""
Aggregator for ensemble-based metrics.
"""
def __init__(
self,
gridded_operations: GriddedOperations,
log_mean_maps: bool = True,
metadata: Mapping[str, VariableMetadata] | None = None,
target: Literal["norm", "denorm"] = "denorm",
channel_mean_names: Sequence[str] | None = None,
report_variables: Sequence[str] | None = None,
):
"""
Args:
gridded_operations: Gridded operations to use.
log_mean_maps: Whether to log mean maps.
metadata: Mapping of variable names their metadata that will
used in generating logged image captions.
target: Whether to compute metrics on normalized ("norm") or
denormalized ("denorm") data. Channel-mean metrics are only
logged when target is "norm", since averaging metrics across
variables with different physical units is not meaningful.
channel_mean_names: Names of variables to include in channel-mean
metrics. If None and target is "norm", the channel mean is
computed over all variables present in the data. Names that
are not present in the data raise KeyError. Ignored when
target is "denorm".
report_variables: If set, only per-variable entries for these
variables will appear in logs and datasets. Aggregate entries
like ``channel_mean`` are always included. All variables are
still used for channel-mean computation.
"""
self._gridded_operations = gridded_operations
self._n_batches = 0
self._variable_metrics: dict[str, dict[str, ReducedMetric]] | None = None
self._dist = Distributed.get_instance()
self._log_mean_maps = log_mean_maps
self._metadata = metadata
self._diverging_metrics = {"ssr_bias"}
self._target = target
self._channel_mean_names = channel_mean_names
self._report_variables = (
frozenset(report_variables) if report_variables is not None else None
)
def _get_variable_metrics(self, gen_data: TensorMapping):
if self._variable_metrics is None:
self._variable_metrics = {
"crps": {},
"ssr_bias": {},
"ensemble_mean_rmse": {},
}
for key in gen_data:
self._variable_metrics["crps"][key] = CRPSMetric()
self._variable_metrics["ssr_bias"][key] = SSRBiasMetric()
self._variable_metrics["ensemble_mean_rmse"][key] = (
EnsembleMeanRMSEMetric()
)
return self._variable_metrics
@torch.no_grad()
def record_batch(
self,
target_data: EnsembleTensorDict,
gen_data: EnsembleTensorDict,
target_data_norm: EnsembleTensorDict | None = None,
gen_data_norm: EnsembleTensorDict | None = None,
):
"""
Record a batch of data.
Args:
target_data: Target data, of shape [batch, ensemble, time, ...].
gen_data: Generated data, of shape [batch, ensemble, time, ...].
target_data_norm: Normalized target data. Required when target is
"norm".
gen_data_norm: Normalized generated data. Required when target is
"norm".
"""
if self._target == "norm":
if target_data_norm is None or gen_data_norm is None:
raise ValueError(
"target_data_norm and gen_data_norm must be provided "
"when target is 'norm'."
)
target_data = target_data_norm
gen_data = gen_data_norm
variable_metrics = self._get_variable_metrics(gen_data)
for metric in variable_metrics:
for name in gen_data:
variable_metrics[metric][name].record(
target=target_data[name],
gen=gen_data[name],
)
self._n_batches += 1
def _get_caption(self, name: str) -> str:
if self._metadata is not None and name in self._metadata:
caption_name = self._metadata[name].display_long_name(name)
units = self._metadata[name].display_units()
else:
caption_name, units = name, "unknown_units"
caption = f"{caption_name} ({units})"
return caption
def _get_data(self):
if self._variable_metrics is None or self._n_batches == 0:
raise ValueError("No batches have been recorded.")
data: dict[str, torch.Tensor] = {}
all_variable_names: set[str] = set()
for metric in sorted(self._variable_metrics):
for key in sorted(self._variable_metrics[metric]):
all_variable_names.add(key)
metric_value = self._dist.reduce_mean(
self._variable_metrics[metric][key].get()
)
data[f"{metric}/{key}"] = (
self._gridded_operations.area_weighted_mean(metric_value, name=key)
.cpu()
.numpy()
)
if self._log_mean_maps:
data[f"{metric}/mean_map/{key}"] = plot_paneled_data(
[[metric_value.cpu().numpy()]],
diverging=metric in self._diverging_metrics,
caption=self._get_caption(key),
)
if self._target == "norm":
all_keys = list(self._variable_metrics[metric].keys())
if self._channel_mean_names is None:
names = all_keys
else:
missing = [n for n in self._channel_mean_names if n not in all_keys]
if missing:
raise KeyError(
f"channel_mean_names contains entries not present "
f"in the recorded data: {missing}. Available: "
f"{sorted(all_keys)}."
)
names = list(self._channel_mean_names)
if names:
scalars = [data[f"{metric}/{key}"] for key in names]
data[f"{metric}/channel_mean"] = sum(scalars) / len(scalars)
if self._report_variables is not None:
excluded = all_variable_names - self._report_variables
data = {
k: v
for k, v in data.items()
if not any(seg in excluded for seg in k.split("/"))
}
return data
@torch.no_grad()
def get_logs(self, label: str):
"""
Returns logs as can be reported to WandB.
Args:
label: Label to prepend to all log keys.
"""
return {
f"{label}/{key}": data for key, data in sorted(self._get_data().items())
}
@torch.no_grad()
def get_dataset(self) -> xr.Dataset:
data = self._get_data()
data = {key.replace("/", "-"): data[key] for key in data}
data_vars = {}
for key, value in data.items():
data_vars[key] = xr.DataArray(value)
return xr.Dataset(data_vars=data_vars)
class SelectStepEnsembleAggregator:
"""
Wraps an aggregator that takes a time dimension, and records
metrics for a specific time step.
"""
def __init__(
self,
aggregator: _EnsembleAggregator,
i_target_time: int,
):
"""
Args:
aggregator: Aggregator to wrap.
i_target_time: Global time index of the target time step.
"""
self._aggregator = aggregator
self._i_target_time = i_target_time
def record_batch(
self,
target_data: EnsembleTensorDict,
gen_data: EnsembleTensorDict,
target_data_norm: EnsembleTensorDict | None = None,
gen_data_norm: EnsembleTensorDict | None = None,
i_time_start: int = 0,
):
"""
Record a specific global timestep of data.
Call does nothing if the target time is not in this batch.
Args:
target_data: Target data, of shape [batch, ensemble, time, ...].
gen_data: Generated data, of shape [batch, ensemble, time, ...].
target_data_norm: Normalized target data, same shape as target_data.
gen_data_norm: Normalized generated data, same shape as gen_data.
i_time_start: Global time index of the first time step in the batch.
"""
n_timesteps = next(iter(target_data.values())).shape[2]
if (
self._i_target_time >= i_time_start
and self._i_target_time < i_time_start + n_timesteps
):
batch_i_target_time = self._i_target_time - i_time_start
def _select(data: EnsembleTensorDict) -> EnsembleTensorDict:
return EnsembleTensorDict(
{
key: value[
:, :, batch_i_target_time : batch_i_target_time + 1, ...
]
for key, value in data.items()
}
)
self._aggregator.record_batch(
target_data=_select(target_data),
gen_data=_select(gen_data),
target_data_norm=(
_select(target_data_norm) if target_data_norm is not None else None
),
gen_data_norm=(
_select(gen_data_norm) if gen_data_norm is not None else None
),
)
def get_logs(self, label: str):
"""
Returns logs as can be reported to WandB.
Args:
label: Label to prepend to all log keys.
"""
return self._aggregator.get_logs(label)
def get_dataset(self) -> xr.Dataset:
return self._aggregator.get_dataset()
[docs]@dataclasses.dataclass
class EnsembleMetricConfig:
"""
Configuration for an ensemble metric (CRPS, SSR bias, ensemble-mean RMSE)
at a specific forward step.
Attributes:
step: Forward step at which to compute the metric.
name: Name to use for the logged metric. Defaults to
``ensemble_step_{step}`` for ``target="denorm"`` and
``ensemble_step_{step}_norm`` for ``target="norm"``.
log_mean_maps: Whether to log per-variable mean maps.
enabled: Whether the metric is enabled.
strict: Whether to raise if the metric cannot be built.
target: Whether to compute metrics on normalized ("norm") or
denormalized ("denorm") data. ``channel_mean`` is only logged
when ``target="norm"``, since averaging metrics across variables
with different physical units is not meaningful.
channel_mean_names: Names of variables to include in the channel-mean
metric. If None, falls back to the aggregator-level value passed
via the build context, and finally to all variables present in
the data when that is also None. Names not present in the data
raise KeyError. Ignored when ``target="denorm"``.
"""
step: int = 20
name: str | None = None
log_mean_maps: bool = False
enabled: bool = True
strict: bool = False
target: Literal["norm", "denorm"] = "denorm"
channel_mean_names: list[str] | None = None
def __post_init__(self):
if self.name is None:
base = f"ensemble_step_{self.step}"
self.name = f"{base}_norm" if self.target == "norm" else base
def get_name(self) -> str:
return self.name # type: ignore[return-value]
def build(self, ctx: MetricBuildContext) -> MetricBuildResult:
if self.step > ctx.n_forward_steps:
raise MetricNotSupportedError(
f"ensemble step {self.step} exceeds "
f"n_forward_steps={ctx.n_forward_steps}"
)
is_norm = self.target == "norm"
return MetricBuildResult(
ensemble=get_one_step_ensemble_aggregator(
gridded_operations=ctx.ops,
target_time=self.step,
log_mean_maps=self.log_mean_maps,
metadata=ctx.variable_metadata,
target=self.target,
channel_mean_names=(
(self.channel_mean_names or ctx.channel_mean_names)
if is_norm
else None
),
)
)
[docs]@dataclasses.dataclass
class OneStepEnsembleMetricConfig:
"""
Configuration for the one-step ensemble metric (CRPS, SSR bias,
ensemble-mean RMSE) at the first forward step.
Attributes:
name: Name to use for the logged metric. Defaults to ``ensemble``
for ``target="denorm"`` and ``ensemble_norm`` for
``target="norm"``.
log_mean_maps: Whether to log per-variable mean maps.
enabled: Whether the metric is enabled.
strict: Whether to raise if the metric cannot be built.
target: Whether to compute metrics on normalized ("norm") or
denormalized ("denorm") data. ``channel_mean`` is only logged
when ``target="norm"``, since averaging metrics across variables
with different physical units is not meaningful.
channel_mean_names: Names of variables to include in the channel-mean
metric. If None, falls back to the aggregator-level value passed
via the build context, and finally to all variables present in
the data when that is also None. Names not present in the data
raise KeyError. Ignored when ``target="denorm"``.
"""
name: str | None = None
log_mean_maps: bool = True
enabled: bool = True
strict: bool = False
target: Literal["norm", "denorm"] = "denorm"
channel_mean_names: list[str] | None = None
def __post_init__(self):
if self.name is None:
self.name = "ensemble_norm" if self.target == "norm" else "ensemble"
def get_name(self) -> str:
return self.name # type: ignore[return-value]
def build(self, ctx: OneStepBuildContext) -> OneStepMetricBuildResult:
is_norm = self.target == "norm"
return OneStepMetricBuildResult(
ensemble=get_one_step_ensemble_aggregator(
gridded_operations=ctx.ops,
log_mean_maps=self.log_mean_maps,
target_time=1,
metadata=ctx.variable_metadata,
target=self.target,
channel_mean_names=(
(self.channel_mean_names or ctx.channel_mean_names)
if is_norm
else None
),
)
)