import dataclasses
from collections.abc import Mapping, Sequence
from typing import Literal
import matplotlib.pyplot as plt
import torch
import xarray as xr
from fme.core.dataset.data_typing import VariableMetadata
from fme.core.distributed import Distributed
from fme.core.gridded_ops import GriddedOperations
from fme.core.typing_ import TensorDict, TensorMapping
from fme.core.wandb import Image
from ..plotting import plot_paneled_data
from .build_context import MetricBuildContext
from .data import InferenceBatchData, MetricBuildResult, SubAggregator
@dataclasses.dataclass
class _TargetGenPair:
name: str
target: torch.Tensor
gen: torch.Tensor
ops: GriddedOperations
def bias(self):
return self.gen - self.target
def rmse(self) -> float:
ret = float(
self.ops.area_weighted_rmse(
predicted=self.gen,
truth=self.target,
name=self.name,
)
.cpu()
.numpy()
)
return ret
def weighted_mean_bias(self) -> float:
return float(
self.ops.area_weighted_mean_bias(
predicted=self.gen,
truth=self.target,
name=self.name,
)
.cpu()
.numpy()
)
def get_gen_shape(gen_data: TensorMapping):
for name in gen_data:
return gen_data[name].shape
class TimeMeanAggregator:
_image_captions = {
"bias_map": "{name} time-mean bias (generated - reference) [{units}]",
"gen_map": "{name} time-mean generated [{units}]",
"gen_target_map": (
"{name} time-mean; (top) generated and (bottom) target [{units}]"
),
}
def __init__(
self,
gridded_operations: GriddedOperations,
target: Literal["norm", "denorm"] = "denorm",
variable_metadata: Mapping[str, VariableMetadata] | None = None,
reference_means: xr.Dataset | None = None,
log_variables: frozenset[str] | None = None,
):
"""
Args:
gridded_operations: Computes gridded operations.
target: Whether to compute metrics on the normalized or denormalized data,
defaults to "denorm".
variable_metadata: Mapping of variable names their metadata that will
used in generating logged image captions.
reference_means: Dataset containing reference time-mean values
for bias computation.
log_variables: If provided, only include per-variable entries in
get_logs and get_dataset for these variables. All variables
are still recorded and available via get_data.
"""
self._ops = gridded_operations
self._target = target
if variable_metadata is None:
self._variable_metadata: Mapping[str, VariableMetadata] = {}
else:
self._variable_metadata = variable_metadata
# Dictionaries of tensors of shape [n_lat, n_lon] represnting time means
self._data: TensorDict | None = None
self._n_timesteps = 0
self._n_samples: int | None = None
self._reference_means = reference_means
self._reference_validated = False
self._log_variables = log_variables
@staticmethod
def _add_or_initialize_time_mean(
maybe_dict: TensorDict | None,
new_data: TensorMapping,
ignore_initial: bool = False,
) -> TensorDict:
sample_dim = 0
time_dim = 1
if ignore_initial:
time_slice = slice(1, None)
else:
time_slice = slice(0, None)
if maybe_dict is None:
d: TensorDict = {
name: tensor[:, time_slice].sum(dim=time_dim).sum(dim=sample_dim)
for name, tensor in new_data.items()
}
else:
d = dict(maybe_dict)
for name, tensor in new_data.items():
d[name] += tensor[:, time_slice].sum(dim=time_dim).sum(dim=sample_dim)
return d
@torch.no_grad()
def record_batch(
self,
data: InferenceBatchData,
):
if self._target == "denorm":
tensor_data = data.prediction
else:
tensor_data = data.prediction_norm
i_time_start = data.i_time_start
ignore_initial = i_time_start == 0
self._data = self._add_or_initialize_time_mean(
self._data, tensor_data, ignore_initial
)
if self._n_samples is None:
self._n_samples = tensor_data[list(tensor_data)[0]].size(0)
if ignore_initial:
self._n_timesteps = tensor_data[list(tensor_data)[0]].size(1) - 1
else:
self._n_timesteps += tensor_data[list(tensor_data)[0]].size(1)
if not self._reference_validated:
if self._reference_means is not None:
self.get_logs(label="")
self._reference_validated = True
def get_data(self) -> TensorDict:
if self._n_timesteps == 0 or self._data is None:
raise ValueError("No data recorded.")
ret = {}
dist = Distributed.get_instance()
names = sorted(list(self._data.keys())) # sort for rank-consistent order
for name in names:
value = self._data[name]
gen = dist.reduce_mean(value / self._n_timesteps / self._n_samples)
ret[name] = gen
return ret
@torch.no_grad()
def get_logs(
self, label: str, target_maps: dict[str, torch.Tensor] | None = None
) -> dict[str, float | Image]:
logs: dict[str, float | Image] = {}
data = self.get_data()
gen_map_key = "gen_map"
for name, pred in data.items():
if self._log_variables is not None and name not in self._log_variables:
continue
if target_maps is not None and name in target_maps:
gen_map_caption_key = "gen_target_map"
data_panels = [[pred.cpu().numpy()], [target_maps[name].cpu().numpy()]]
else:
gen_map_caption_key = gen_map_key
data_panels = [[pred.cpu().numpy()]]
prediction_image = plot_paneled_data(
data_panels,
diverging=False,
caption=self._get_caption(gen_map_caption_key, name),
)
logs.update(
{
f"{gen_map_key}/{name}": prediction_image,
}
)
if self._reference_means is not None and name in self._reference_means:
pair = _TargetGenPair(
name=name,
target=torch.as_tensor(
self._reference_means[name].values, device=pred.device
),
gen=pred,
ops=self._ops,
)
bias_image = plot_paneled_data(
[[pair.bias().cpu().numpy()]],
diverging=True,
caption=self._get_caption("bias_map", name),
)
logs.update({f"ref_bias/{name}": pair.weighted_mean_bias()})
logs.update({f"ref_rmse/{name}": pair.rmse()})
logs.update({f"ref_bias_map/{name}": bias_image})
if len(label) != 0:
return {f"{label}/{key}": logs[key] for key in logs}
return logs
def _get_caption(self, key: str, name: str) -> str:
if name in self._variable_metadata:
caption_name = self._variable_metadata[name].display_long_name(name)
units = self._variable_metadata[name].display_units()
else:
caption_name, units = name, "unknown_units"
caption = self._image_captions[key].format(name=caption_name, units=units)
return caption
def get_dataset(self) -> xr.Dataset:
dims = ("lat", "lon")
data = {}
for name, pred in self.get_data().items():
if self._log_variables is not None and name not in self._log_variables:
continue
if name in self._variable_metadata:
long_name = self._variable_metadata[name].display_long_name(name)
units = self._variable_metadata[name].display_units()
else:
long_name = name
units = "unknown_units"
gen_metadata = VariableMetadata(long_name=long_name, units=units).as_attrs()
data.update(
{
f"gen_map-{name}": xr.DataArray(
pred.cpu(),
dims=dims,
attrs=gen_metadata,
),
}
)
return xr.Dataset(data)
class TimeMeanEvaluatorAggregator:
"""Statistics and images on the time-mean state.
This aggregator keeps track of the time-mean state, then computes
statistics and images on that time-mean state when logs are retrieved.
"""
_image_captions = {
"bias_map": "{name} time-mean bias (generated - target) [{units}]",
}
def __init__(
self,
ops: GriddedOperations,
horizontal_dims: list[str],
target: Literal["norm", "denorm"] = "denorm",
variable_metadata: Mapping[str, VariableMetadata] | None = None,
reference_means: xr.Dataset | None = None,
channel_mean_names: Sequence[str] | None = None,
log_variables: frozenset[str] | None = None,
):
"""
Args:
ops: Computes gridded operations.
horizontal_dims: Names of the horizontal dimensions.
target: Whether to compute metrics on the normalized or denormalized data,
defaults to "denorm".
variable_metadata: Mapping of variable names their metadata that will
used in generating logged image captions.
reference_means: Dataset containing reference time-mean values
for bias computation.
channel_mean_names: Names of variables whose RMSE will be averaged. If
not provided, all available variables will be used.
log_variables: If provided, only include per-variable entries in
get_logs and get_dataset for these variables. All variables
are still recorded so that channel_mean is computed correctly.
"""
self._ops = ops
self._horizontal_dims = horizontal_dims
self._target = target
self._dist = Distributed.get_instance()
if variable_metadata is None:
self._variable_metadata: Mapping[str, VariableMetadata] = {}
else:
self._variable_metadata = variable_metadata
self._log_variables = log_variables
# Dictionaries of tensors of shape [n_lat, n_lon] represnting time means
self._target_agg = TimeMeanAggregator(
gridded_operations=ops, target=target, variable_metadata=variable_metadata
)
self._gen_agg = TimeMeanAggregator(
gridded_operations=ops,
target=target,
variable_metadata=variable_metadata,
reference_means=reference_means,
log_variables=log_variables,
)
self._channel_mean_names = channel_mean_names
@torch.no_grad()
def record_batch(
self,
data: InferenceBatchData,
):
if self._target == "norm":
target_tensor = data.target_norm
gen_tensor = data.prediction_norm
else:
target_tensor = data.target
gen_tensor = data.prediction
target_batch = data.replace(
prediction=target_tensor, prediction_norm=target_tensor
)
gen_batch = data.replace(prediction=gen_tensor, prediction_norm=gen_tensor)
self._target_agg.record_batch(target_batch)
self._gen_agg.record_batch(gen_batch)
def _get_target_gen_pairs(self) -> list[_TargetGenPair]:
target_data = self._target_agg.get_data()
gen_data = self._gen_agg.get_data()
ret = []
for name in gen_data.keys():
ret.append(
_TargetGenPair(
gen=gen_data[name],
target=target_data[name],
name=name,
ops=self._ops,
)
)
return ret
@torch.no_grad()
def get_logs(self, label: str) -> dict[str, float | torch.Tensor | Image]:
logs = self._gen_agg.get_logs("", target_maps=self._target_agg.get_data())
preds = self._get_target_gen_pairs()
bias_map_key = "bias_map"
rmse_all_channels = {}
for pred in preds:
rmse_all_channels[pred.name] = pred.rmse()
should_log = self._log_variables is None or pred.name in self._log_variables
if should_log:
bias_image = plot_paneled_data(
[[pred.bias().cpu().numpy()]],
diverging=True,
caption=self._get_caption(bias_map_key, pred.name),
)
plt.close("all")
logs.update({f"rmse/{pred.name}": rmse_all_channels[pred.name]})
if self._target == "denorm":
logs.update(
{
f"{bias_map_key}/{pred.name}": bias_image,
f"bias/{pred.name}": pred.weighted_mean_bias(),
}
)
if self._target == "norm":
metric_name = "rmse/channel_mean"
if self._channel_mean_names is None:
values_to_average = list(rmse_all_channels.values())
else:
missing = [
n for n in self._channel_mean_names if n not in rmse_all_channels
]
if missing:
raise KeyError(
f"channel_mean_names contains entries not present in the "
f"recorded data: {missing}. Available: "
f"{sorted(rmse_all_channels)}."
)
values_to_average = [
rmse_all_channels[name] for name in self._channel_mean_names
]
logs.update({metric_name: sum(values_to_average) / len(values_to_average)})
if len(label) != 0:
return {f"{label}/{key}": logs[key] for key in logs}
return logs
def _get_caption(self, key: str, name: str) -> str:
if name in self._variable_metadata:
caption_name = self._variable_metadata[name].display_long_name(name)
units = self._variable_metadata[name].display_units()
else:
caption_name, units = name, "unknown_units"
caption = self._image_captions[key].format(name=caption_name, units=units)
return caption
def get_dataset(self) -> xr.Dataset:
data = {}
preds = self._get_target_gen_pairs()
for pred in preds:
if self._log_variables is not None and pred.name not in self._log_variables:
continue
if pred.name in self._variable_metadata:
long_name = self._variable_metadata[pred.name].display_long_name(
pred.name
)
units = self._variable_metadata[pred.name].display_units()
else:
long_name = pred.name
units = "unknown_units"
gen_metadata = VariableMetadata(long_name=long_name, units=units).as_attrs()
bias_metadata = self._variable_metadata.get(
pred.name, VariableMetadata(long_name=long_name, units=units)
).as_attrs()
data.update(
{
f"bias_map-{pred.name}": xr.DataArray(
pred.bias().cpu(),
dims=self._horizontal_dims,
attrs=bias_metadata,
),
f"gen_map-{pred.name}": xr.DataArray(
pred.gen.cpu(),
dims=self._horizontal_dims,
attrs=gen_metadata,
),
}
)
return xr.Dataset(data)
[docs]@dataclasses.dataclass
class TimeMeanMetricConfig:
variables: list[str] | None = None
name: str | None = None
target: Literal["denorm", "norm"] = "denorm"
reference_data: str | None = None
channel_mean_names: list[str] | None = None
enabled: bool = True
strict: bool = False
def __post_init__(self):
if self.name is None:
self.name = "time_mean_norm" if self.target == "norm" else "time_mean"
def get_name(self) -> str:
return self.name # type: ignore[return-value]
def build(self, ctx: MetricBuildContext) -> MetricBuildResult:
is_norm = self.target == "norm"
if self.reference_data is not None:
ref = xr.open_dataset(self.reference_data, decode_timedelta=False)
elif not is_norm:
ref = ctx.time_mean_reference_data
else:
ref = None
agg: SubAggregator = TimeMeanEvaluatorAggregator(
ctx.ops,
horizontal_dims=ctx.horizontal_coordinates.dims,
target=self.target,
variable_metadata=ctx.variable_metadata,
reference_means=ref,
channel_mean_names=(
(self.channel_mean_names or ctx.channel_mean_names) if is_norm else None
),
log_variables=(
frozenset(self.variables) if self.variables is not None else None
),
)
return MetricBuildResult(aggregator=agg)