import dataclasses
from collections.abc import Sequence
from typing import Any, Literal
import torch
import xarray as xr
from fme.core.device import get_device
from fme.core.distributed import Distributed
from fme.core.gridded_ops import GriddedOperations
from fme.core.typing_ import TensorMapping
from ..inference.build_context import MetricBuildContext, MetricNotSupportedError
from ..inference.data import InferenceBatchData, MetricBuildResult, SubAggregator
from .build_context import OneStepBuildContext, OneStepMetricBuildResult
from .reduced_metrics import AreaWeightedReducedMetric, ReducedMetric
def get_gen_shape(gen_data: TensorMapping):
for name in gen_data:
return gen_data[name].shape
class MeanAggregator:
"""
Aggregator for mean-reduced metrics.
These are metrics such as means which reduce to a single float for each batch,
and then can be averaged across batches to get a single float for the
entire dataset. This is important because the aggregator uses the mean to combine
metrics across batches and processors.
"""
def __init__(
self,
gridded_operations: GriddedOperations,
target_time: int = 1,
include_bias: bool = True,
include_grad_mag_percent_diff: bool = True,
target: Literal["norm", "denorm"] = "denorm",
channel_mean_names: Sequence[str] | None = None,
log_loss: bool = True,
report_variables: Sequence[str] | None = None,
):
"""
Args:
gridded_operations: GriddedOperations object for computing metrics.
target_time: Time index to compute metrics at, where 0 corresponds to the
first timestep of the initial condition. For example, target_time=1 will
compute metrics at the first timestep of the forward trajectory if there
is 1 initial condition step.
include_bias: Whether to include bias metrics.
include_grad_mag_percent_diff: Whether to include gradient magnitude percent
difference metrics.
target: Whether to compute metrics on normalized ("norm") or denormalized
("denorm") data.
channel_mean_names: Names to include in channel-mean metrics. If None,
channel means will not be logged.
log_loss: Whether to log the mean loss across batches.
report_variables: If set, only per-variable entries for these
variables will appear in logs and datasets. Aggregate entries
like ``channel_mean`` are always included. All variables are
still used for channel-mean computation.
"""
self._gridded_operations = gridded_operations
self._n_batches = 0
self._loss = torch.tensor(0.0, device=get_device())
self._target_time = target_time
self._target = target
self._log_loss = log_loss
self._report_variables = (
frozenset(report_variables) if report_variables is not None else None
)
self._dist = Distributed.get_instance()
device = get_device()
self._variable_metrics: dict[str, ReducedMetric] = {}
self._variable_metrics["weighted_rmse"] = AreaWeightedReducedMetric(
device=device,
compute_metric=self._gridded_operations.area_weighted_rmse_dict,
channel_mean_names=channel_mean_names,
)
if include_bias:
self._variable_metrics["weighted_bias"] = AreaWeightedReducedMetric(
device=device,
compute_metric=self._gridded_operations.area_weighted_mean_bias_dict,
)
if include_grad_mag_percent_diff:
self._variable_metrics["weighted_grad_mag_percent_diff"] = (
AreaWeightedReducedMetric(
device=device,
compute_metric=self._gridded_operations.area_weighted_gradient_magnitude_percent_diff_dict, # noqa: E501
)
)
@torch.no_grad()
def record_batch(
self,
target_data: TensorMapping,
gen_data: TensorMapping,
target_data_norm: TensorMapping | None = None,
gen_data_norm: TensorMapping | None = None,
loss: float = float("nan"),
i_time_start: int = 0,
):
self._loss += loss
time_dim = 1
time_len = gen_data[list(gen_data.keys())[0]].shape[time_dim]
target_time = self._target_time - i_time_start
if self._target == "norm":
if target_data_norm is None or gen_data_norm is None:
raise ValueError(
"target_data_norm and gen_data_norm must be provided "
"if target is 'norm'."
)
target_data = target_data_norm
gen_data = gen_data_norm
if target_time >= 0 and time_len > target_time:
target_snapshot = {}
gen_snapshot = {}
for name in gen_data.keys():
target_snapshot[name] = target_data[name].select(
dim=time_dim, index=target_time
)
gen_snapshot[name] = gen_data[name].select(
dim=time_dim, index=target_time
)
for metric in self._variable_metrics.values():
metric.record(
target=target_snapshot,
gen=gen_snapshot,
)
# only increment n_batches if we actually recorded a batch
self._n_batches += 1
def _get_data(self):
if self._n_batches == 0:
raise ValueError("No batches have been recorded.")
data: dict[str, torch.Tensor] = {}
all_variable_names: set[str] = set()
if self._log_loss:
data["loss"] = self._loss / self._n_batches
for name, metric in self._variable_metrics.items():
metric_results = metric.get()
all_variable_names.update(metric_results.keys())
for key in metric_results:
data[f"{name}/{key}"] = metric_results[key] / self._n_batches
if self._target == "norm":
data[f"{name}/channel_mean"] = (
metric.get_channel_mean() / self._n_batches
)
if self._report_variables is not None:
excluded = all_variable_names - self._report_variables
data = {
k: v
for k, v in data.items()
if not any(seg in excluded for seg in k.split("/"))
}
for key in sorted(data.keys()):
data[key] = float(self._dist.reduce_mean(data[key].detach()).cpu().numpy())
return data
@torch.no_grad()
def get_logs(self, label: str):
"""
Returns logs as can be reported to WandB.
Args:
label: Label to prepend to all log keys.
"""
return {
f"{label}/{key}": data for key, data in sorted(self._get_data().items())
}
@torch.no_grad()
def get_dataset(self) -> xr.Dataset:
data = self._get_data()
data = {key.replace("/", "-"): data[key] for key in data}
data_vars = {}
for key, value in data.items():
data_vars[key] = xr.DataArray(value)
return xr.Dataset(data_vars=data_vars)
class OneStepMeanAdapter:
"""Adapts OneStepMeanAggregator to accept InferenceBatchData."""
def __init__(self, inner: MeanAggregator):
self._inner = inner
def record_batch(self, data: InferenceBatchData) -> None:
self._inner.record_batch(
target_data=data.target,
gen_data=data.prediction,
target_data_norm=data.target_norm,
gen_data_norm=data.prediction_norm,
i_time_start=data.i_time_start,
)
def get_logs(self, label: str) -> dict[str, Any]:
return self._inner.get_logs(label)
def get_dataset(self) -> xr.Dataset:
return self._inner.get_dataset()
[docs]@dataclasses.dataclass
class StepMeanMetricConfig:
step: int
variables: list[str] | None = None
name: str | None = None
target: Literal["denorm", "norm"] = "denorm"
channel_mean_names: list[str] | None = None
enabled: bool = True
strict: bool = False
def __post_init__(self):
if self.name is None:
base = f"mean_step_{self.step}"
self.name = f"{base}_norm" if self.target == "norm" else base
def get_name(self) -> str:
return self.name # type: ignore[return-value]
def build(self, ctx: MetricBuildContext) -> MetricBuildResult:
if self.step > ctx.n_forward_steps:
raise MetricNotSupportedError(
f"step_mean step {self.step} exceeds "
f"n_forward_steps={ctx.n_forward_steps}"
)
target_time = self.step + ctx.n_ic_steps - 1
is_norm = self.target == "norm"
agg: SubAggregator = OneStepMeanAdapter(
MeanAggregator(
ctx.ops,
target_time=target_time,
target=self.target,
log_loss=False,
include_bias=not is_norm,
include_grad_mag_percent_diff=not is_norm,
channel_mean_names=(
(self.channel_mean_names or ctx.channel_mean_names)
if is_norm
else None
),
report_variables=self.variables,
)
)
return MetricBuildResult(aggregator=agg)
[docs]@dataclasses.dataclass
class OneStepMeanMetricConfig:
name: str | None = None
target: Literal["denorm", "norm"] = "denorm"
include_bias: bool = True
include_grad_mag_percent_diff: bool = True
channel_mean_names: list[str] | None = None
enabled: bool = True
strict: bool = False
def __post_init__(self):
if self.name is None:
self.name = "mean" if self.target == "denorm" else "mean_norm"
if self.target == "norm":
if self.include_bias:
raise ValueError("include_bias is not supported when target='norm'.")
if self.include_grad_mag_percent_diff:
raise ValueError(
"include_grad_mag_percent_diff is not supported "
"when target='norm'."
)
def get_name(self) -> str:
return self.name # type: ignore[return-value]
def build(self, ctx: OneStepBuildContext) -> OneStepMetricBuildResult:
agg = MeanAggregator(
ctx.ops,
target=self.target,
include_bias=self.include_bias,
include_grad_mag_percent_diff=self.include_grad_mag_percent_diff,
channel_mean_names=(
(self.channel_mean_names or ctx.channel_mean_names)
if self.target == "norm"
else None
),
)
return OneStepMetricBuildResult(deterministic=agg)