import dataclasses
from collections import defaultdict
from collections.abc import Mapping
from typing import Literal, Protocol
import numpy as np
import torch
import xarray as xr
from fme.core.dataset.data_typing import VariableMetadata
from fme.core.device import get_device
from fme.core.distributed import Distributed
from fme.core.gridded_ops import GriddedOperations
from fme.core.typing_ import TensorDict, TensorMapping
from fme.core.wandb import Table, WandB
from .build_context import MetricBuildContext, maybe_filter
from .data import InferenceBatchData, MetricBuildResult
@dataclasses.dataclass
class _SeriesData:
metric_name: str
var_name: str
data: np.ndarray
def get_wandb_key(self) -> str:
return f"{self.metric_name}/{self.var_name}"
def get_xarray_key(self) -> str:
return f"{self.metric_name}-{self.var_name}"
def get_gen_shape(gen_data: TensorMapping):
for name in gen_data:
return gen_data[name].shape
class MeanMetric(Protocol):
def record(self, target: TensorMapping, gen: TensorMapping, i_time_start: int):
"""
Update metric for a batch of data.
"""
...
def get(self) -> TensorDict:
"""
Get the total metric value per variable,
not divided by number of recorded batches.
"""
...
class SingleTargetMeanMetric(Protocol):
def record(self, tensors: TensorMapping, i_time_start: int):
"""
Update metric for a batch of data.
"""
...
def get(self) -> TensorDict:
"""
Get the total metric value per variable,
not divided by number of recorded batches.
"""
...
class AreaWeightedFunction(Protocol):
"""
A function that computes a metric on the true and predicted values,
weighted by area.
"""
def __call__(
self,
truth: TensorMapping,
predicted: TensorMapping,
) -> TensorDict: ...
class AreaWeightedSingleTargetFunction(Protocol):
"""
A function that computes a metric on a single value, weighted by area.
"""
def __call__(
self,
tensors: TensorMapping,
) -> TensorDict: ...
def compute_metric_on(
source: Literal["gen", "target"], metric: AreaWeightedSingleTargetFunction
) -> AreaWeightedFunction:
"""Turns a single-target metric function
(computed on only the generated or target data) into a function that takes in
both the generated and target data as arguments, as required for the APIs
which call generic metric functions.
"""
def metric_wrapper(
truth: TensorMapping,
predicted: TensorMapping,
) -> TensorDict:
if source == "gen":
return metric(predicted)
elif source == "target":
return metric(truth)
return metric_wrapper
class AreaWeightedReducedMetric:
"""
A wrapper around an area-weighted metric function.
"""
def __init__(
self,
device: torch.device,
compute_metric: AreaWeightedFunction,
n_timesteps: int,
):
self._compute_metric = compute_metric
self._total: TensorDict = {}
self._n_batches = torch.zeros(
n_timesteps, dtype=torch.int32, device=get_device()
)
self._device = device
self._n_timesteps = n_timesteps
def record(self, target: TensorMapping, gen: TensorMapping, i_time_start: int):
"""Add a batch of data to the metric.
Args:
target: Target data. Dictionary mapping variable names to tensors of shape
[batch, time, height, width].
gen: Generated data. Dictionary mapping variable names to tensors of shape
[batch, time, height, width].
i_time_start: The index of the first timestep in the batch.
"""
time_dim = 1
for name in target:
if target[name].shape != gen[name].shape:
raise RuntimeError(
"Tensors in target and gen must have the same shape, but got "
f"{target[name].shape} and {gen[name].shape} "
f"for the tensor '{name}'."
)
time_dim_len = next(iter(gen.values())).shape[time_dim]
time_slice = slice(i_time_start, i_time_start + time_dim_len)
# Update totals for each variable
new_values = self._compute_metric(truth=target, predicted=gen)
for name, tensor in new_values.items():
if name not in self._total:
self._total[name] = torch.zeros(
[self._n_timesteps], dtype=tensor.dtype, device=self._device
)
new_value = tensor.mean(dim=0)
self._total[name][time_slice] += new_value
self._n_batches[time_slice] += 1
def get(self) -> TensorDict:
"""Returns the mean metric across recorded batches for each variable."""
if not self._total:
# no batches recorded yet
return defaultdict(lambda: torch.tensor(torch.nan))
return {name: tensor / self._n_batches for name, tensor in self._total.items()}
class MeanAggregator:
def __init__(
self,
gridded_operations: GriddedOperations,
target: Literal["norm", "denorm"],
n_timesteps: int,
variable_metadata: Mapping[str, VariableMetadata] | None = None,
):
self._gridded_operations = gridded_operations
# Store one metric object per metric type (e.g., rmse, bias)
self._target = target
self._n_timesteps = n_timesteps
self._dist = Distributed.get_instance()
if variable_metadata is None:
self._variable_metadata: Mapping[str, VariableMetadata] = {}
else:
self._variable_metadata = variable_metadata
self._variable_metrics: dict[str, MeanMetric] = {}
device = get_device()
self._variable_metrics["weighted_rmse"] = AreaWeightedReducedMetric(
device=device,
compute_metric=self._gridded_operations.area_weighted_rmse_dict,
n_timesteps=self._n_timesteps,
)
if self._target == "denorm":
self._variable_metrics["weighted_grad_mag_percent_diff"] = (
AreaWeightedReducedMetric(
device=device,
compute_metric=self._gridded_operations.area_weighted_gradient_magnitude_percent_diff_dict, # noqa: E501
n_timesteps=self._n_timesteps,
)
)
self._variable_metrics["weighted_mean_gen"] = AreaWeightedReducedMetric(
device=device,
compute_metric=compute_metric_on(
source="gen",
metric=(
lambda tensors: self._gridded_operations.area_weighted_mean_dict(
tensors
)
),
),
n_timesteps=self._n_timesteps,
)
self._variable_metrics["weighted_mean_target"] = AreaWeightedReducedMetric(
device=device,
compute_metric=compute_metric_on(
source="target",
metric=(
lambda tensors: self._gridded_operations.area_weighted_mean_dict(
tensors
)
),
),
n_timesteps=self._n_timesteps,
)
self._variable_metrics["weighted_bias"] = AreaWeightedReducedMetric(
device=device,
compute_metric=self._gridded_operations.area_weighted_mean_bias_dict,
n_timesteps=self._n_timesteps,
)
self._variable_metrics["weighted_std_gen"] = AreaWeightedReducedMetric(
device=device,
compute_metric=compute_metric_on(
source="gen",
metric=(
lambda tensors: self._gridded_operations.area_weighted_std_dict(
tensors
)
),
),
n_timesteps=self._n_timesteps,
)
self._n_batches = 0
@torch.no_grad()
def record_batch(
self,
data: InferenceBatchData,
):
if self._target == "norm":
target_data = data.target_norm
gen_data = data.prediction_norm
else:
target_data = data.target
gen_data = data.prediction
for metric in self._variable_metrics.values():
metric.record(
target=target_data,
gen=gen_data,
i_time_start=data.i_time_start,
)
self._n_batches += 1
def _get_series_data(self, step_slice: slice | None = None) -> list[_SeriesData]:
"""Converts internally stored variable_metrics to a list."""
if self._n_batches == 0:
raise ValueError("No batches have been recorded.")
data: list[_SeriesData] = []
for name, metric in self._variable_metrics.items():
metric_results = metric.get() # TensorDict: {var_name: metric_series}
sorted_keys = sorted(list(metric_results.keys()))
for key in sorted_keys:
arr = metric_results[key].detach()
if step_slice is not None:
arr = arr[step_slice]
datum = _SeriesData(
metric_name=name,
var_name=key,
data=self._dist.reduce_mean(arr).cpu().numpy(),
)
data.append(datum)
return data
@torch.no_grad()
def get_logs(self, label: str, step_slice: slice | None = None):
"""
Returns logs as can be reported to WandB.
Args:
label: Label to prepend to all log keys.
step_slice: Slice of forecast steps to log.
"""
logs = {}
series_data: dict[str, np.ndarray] = {
datum.get_wandb_key(): datum.data
for datum in self._get_series_data(step_slice)
}
init_step = 0 if step_slice is None else step_slice.start
table = data_to_table(series_data, init_step)
logs[f"{label}/series"] = table
return logs
@torch.no_grad()
def get_dataset(self) -> xr.Dataset:
"""
Returns a dataset representation of the logs.
"""
data_vars = {}
for datum in self._get_series_data():
metadata = self._variable_metadata.get(
datum.var_name, VariableMetadata("unknown_units", datum.var_name)
)
data_vars[datum.get_xarray_key()] = xr.DataArray(
datum.data, dims=["forecast_step"], attrs=metadata.as_attrs()
)
if len(data_vars.values()) > 0:
n_forecast_steps = len(next(iter(data_vars.values())))
coords = {"forecast_step": np.arange(n_forecast_steps)}
else:
coords = {"forecast_step": np.arange(0)}
return xr.Dataset(data_vars=data_vars, coords=coords)
def data_to_table(data: dict[str, np.ndarray], init_step: int = 0) -> Table:
"""
Convert a dictionary of 1-dimensional timeseries data to a wandb Table.
Args:
data: dictionary of timeseries data.
init_step: initial step corresponding to the first row's "forecast_step"
"""
keys = sorted(list(data.keys()))
wandb = WandB.get_instance()
table = wandb.Table(columns=["forecast_step"] + keys)
if len(keys) > 0:
for i in range(len(data[keys[0]])):
row = [init_step + i]
for key in keys:
row.append(data[key][i])
table.add_data(*row)
return table
class AreaWeightedSingleTargetReducedMetric:
"""
A wrapper around an area-weighted metric function on a single data source.
"""
def __init__(
self,
device: torch.device,
compute_metric: AreaWeightedSingleTargetFunction,
n_timesteps: int,
):
self._compute_metric = compute_metric
self._total: TensorDict = {}
self._n_batches = torch.zeros(
n_timesteps, dtype=torch.int32, device=get_device()
)
self._device = device
self._n_timesteps = n_timesteps
def record(self, tensors: TensorMapping, i_time_start: int):
"""Add a batch of data to the metric.
Args:
tensors: Dictionary mapping variable names to tensors of shape
[batch, time, height, width].
i_time_start: The index of the first timestep in the batch.
"""
time_dim = 1
time_dim_len = next(iter(tensors.values())).shape[time_dim]
time_slice = slice(i_time_start, i_time_start + time_dim_len)
# Update totals for each variable
new_values = self._compute_metric(tensors)
for name, tensor in new_values.items():
if name not in self._total:
self._total[name] = torch.zeros(
[self._n_timesteps], dtype=tensor.dtype, device=self._device
)
new_value = tensor.mean(dim=0)
self._total[name][time_slice] += new_value
self._n_batches[time_slice] += 1
def get(self) -> TensorDict:
"""Returns the mean metric across recorded batches for each variable."""
if not self._total:
return defaultdict(lambda: torch.tensor(torch.nan))
return {name: tensor / self._n_batches for name, tensor in self._total.items()}
class SingleTargetMeanAggregator:
def __init__(
self,
gridded_operations: GriddedOperations,
n_timesteps: int,
variable_metadata: Mapping[str, VariableMetadata] | None = None,
):
self._ops = gridded_operations
self._n_timesteps = n_timesteps
self._dist = Distributed.get_instance()
if variable_metadata is None:
self._variable_metadata: Mapping[str, VariableMetadata] = {}
else:
self._variable_metadata = variable_metadata
self._variable_metrics: dict[str, SingleTargetMeanMetric] = {}
device = get_device()
self._variable_metrics["weighted_mean_gen"] = (
AreaWeightedSingleTargetReducedMetric(
device=device,
compute_metric=(
lambda tensors: self._ops.area_weighted_mean_dict(tensors)
),
n_timesteps=self._n_timesteps,
)
)
self._variable_metrics["weighted_std_gen"] = (
AreaWeightedSingleTargetReducedMetric(
device=device,
compute_metric=(
lambda tensors: self._ops.area_weighted_std_dict(tensors)
),
n_timesteps=self._n_timesteps,
)
)
self._n_batches = 0
@torch.no_grad()
def record_batch(
self,
data: InferenceBatchData,
):
for metric in self._variable_metrics.values():
metric.record(
tensors=data.prediction,
i_time_start=data.i_time_start,
)
self._n_batches += 1
def _get_series_data(self, step_slice: slice | None = None) -> list[_SeriesData]:
"""Converts internally stored variable_metrics to a list."""
if self._n_batches == 0:
raise ValueError("No batches have been recorded.")
data: list[_SeriesData] = []
for name, metric in self._variable_metrics.items():
metric_results = metric.get() # TensorDict: {var_name: metric_series}
sorted_keys = sorted(list(metric_results.keys()))
for key in sorted_keys:
arr = metric_results[key].detach()
if step_slice is not None:
arr = arr[step_slice]
datum = _SeriesData(
metric_name=name,
var_name=key,
data=self._dist.reduce_mean(arr).cpu().numpy(),
)
data.append(datum)
return data
@torch.no_grad()
def get_logs(self, label: str, step_slice: slice | None = None):
"""
Returns logs as can be reported to WandB.
Args:
label: Label to prepend to all log keys.
step_slice: Slice of forecast steps to log.
"""
logs = {}
series_data: dict[str, np.ndarray] = {
datum.get_wandb_key(): datum.data
for datum in self._get_series_data(step_slice)
}
init_step = 0 if step_slice is None else step_slice.start
table = data_to_table(series_data, init_step)
logs[f"{label}/series"] = table
return logs
@torch.no_grad()
def get_dataset(self) -> xr.Dataset:
"""
Returns a dataset representation of the logs.
"""
data_vars = {}
for datum in self._get_series_data():
metadata = self._variable_metadata.get(
datum.var_name, VariableMetadata("unknown_units", datum.var_name)
)
data_vars[datum.get_xarray_key()] = xr.DataArray(
datum.data, dims=["forecast_step"], attrs=metadata.as_attrs()
)
n_forecast_steps = len(next(iter(data_vars.values())))
coords = {"forecast_step": np.arange(n_forecast_steps)}
return xr.Dataset(data_vars=data_vars, coords=coords)
[docs]@dataclasses.dataclass
class MeanMetricConfig:
variables: list[str] | None = None
name: str | None = None
target: Literal["denorm", "norm"] = "denorm"
enabled: bool = True
strict: bool = False
def __post_init__(self):
if self.name is None:
self.name = "mean_norm" if self.target == "norm" else "mean"
def get_name(self) -> str:
return self.name # type: ignore[return-value]
def build(self, ctx: MetricBuildContext) -> MetricBuildResult:
agg = MeanAggregator(
ctx.ops,
target=self.target,
n_timesteps=ctx.n_timesteps,
variable_metadata=ctx.variable_metadata,
)
return MetricBuildResult(
aggregator=maybe_filter(agg, self.variables), time_series=agg
)