import dataclasses
from collections.abc import Mapping
import numpy as np
import torch
import xarray as xr
from fme.core.coordinates import LatLonCoordinates
from fme.core.dataset.data_typing import VariableMetadata
from fme.core.distributed import Distributed
from fme.core.typing_ import TensorDict, TensorMapping
from fme.core.wandb import WandB
from .build_context import MetricBuildContext, MetricNotSupportedError, maybe_filter
from .data import InferenceBatchData, MetricBuildResult, SubAggregator
wandb = WandB.get_instance()
def _get_gen_shape(gen_data: TensorMapping):
for name in gen_data:
return gen_data[name].shape
raise ValueError("No data in gen_data")
@dataclasses.dataclass
class _ErrorData:
rmse: TensorDict
min_err: TensorDict
max_err: TensorDict
class _ErrorVideoData:
"""
Record batches of video data and compute statistics on the error.
"""
def __init__(self, n_timesteps: int):
self._mse_data: TensorDict | None = None
self._min_err_data: TensorDict | None = None
self._max_err_data: TensorDict | None = None
self._n_timesteps = n_timesteps
self._n_batches = torch.zeros([n_timesteps], dtype=torch.int32).cpu()
self._dist = Distributed.get_instance()
@torch.no_grad()
def record_batch(
self,
target_data: TensorMapping,
gen_data: TensorMapping,
i_time_start: int,
):
"""
Record a batch of data.
Args:
target_data: Dict of tensors of shape (n_samples, n_timesteps, ...)
gen_data: Dict of tensors of shape (n_samples, n_timesteps, ...)
i_time_start: Index of the first timestep in the batch.
"""
if self._mse_data is None:
self._mse_data = _initialize_video_from_batch(gen_data, self._n_timesteps)
if self._min_err_data is None:
self._min_err_data = _initialize_video_from_batch(
gen_data, self._n_timesteps, fill_value=np.inf
)
if self._max_err_data is None:
self._max_err_data = _initialize_video_from_batch(
gen_data, self._n_timesteps, fill_value=-np.inf
)
window_steps = next(iter(target_data.values())).shape[1]
time_slice = slice(i_time_start, i_time_start + window_steps)
for name, gen_tensor in gen_data.items():
target_tensor = target_data[name]
error_tensor = (gen_tensor - target_tensor).cpu()
self._mse_data[name][time_slice, ...] += torch.mean(
torch.square(error_tensor), dim=0
)
self._min_err_data[name][time_slice, ...] = torch.minimum(
self._min_err_data[name][time_slice, ...], error_tensor.min(dim=0)[0]
)
self._max_err_data[name][time_slice, ...] = torch.maximum(
self._max_err_data[name][time_slice, ...], error_tensor.max(dim=0)[0]
)
self._n_batches[time_slice] += 1
@torch.no_grad()
def get(
self,
) -> _ErrorData:
if (
self._mse_data is None
or self._min_err_data is None
or self._max_err_data is None
):
raise RuntimeError("No data recorded")
rmse_data = {}
min_err_data = {}
max_err_data = {}
for name in sorted(self._mse_data):
tensor = self._mse_data[name]
mse = (tensor / self._n_batches[None, :, None, None]).mean(dim=0)
mse = self._dist.reduce_mean(mse)
rmse_data[name] = torch.sqrt(mse)
for name in sorted(self._min_err_data):
min_err_data[name] = self._dist.reduce_min(self._min_err_data[name])
for name in sorted(self._max_err_data):
max_err_data[name] = self._dist.reduce_max(self._max_err_data[name])
return _ErrorData(rmse_data, min_err_data, max_err_data)
class _MeanVideoData:
"""
Record batches of video data and compute the mean.
"""
def __init__(self, n_timesteps: int):
self._target_data: TensorDict | None = None
self._gen_data: TensorDict | None = None
self._n_timesteps = n_timesteps
self._n_batches = torch.zeros([n_timesteps], dtype=torch.int32).cpu()
self._dist = Distributed.get_instance()
@torch.no_grad()
def record_batch(
self,
target_data: TensorMapping,
gen_data: TensorMapping,
i_time_start: int,
):
"""
Record a batch of data.
Args:
target_data: Dict of tensors of shape (n_samples, n_timesteps, ...)
gen_data: Dict of tensors of shape (n_samples, n_timesteps, ...)
i_time_start: Index of the first timestep in the batch.
"""
if self._target_data is None:
self._target_data = _initialize_video_from_batch(
target_data, self._n_timesteps
)
if self._gen_data is None:
self._gen_data = _initialize_video_from_batch(gen_data, self._n_timesteps)
window_steps = next(iter(target_data.values())).shape[1]
time_slice = slice(i_time_start, i_time_start + window_steps)
for name, tensor in target_data.items():
self._target_data[name][time_slice, ...] += tensor.mean(dim=0).cpu()
for name, tensor in gen_data.items():
self._gen_data[name][time_slice, ...] += tensor.mean(dim=0).cpu()
self._n_batches[time_slice] += 1
@torch.no_grad()
def get(self) -> tuple[TensorDict, TensorDict]:
if self._gen_data is None or self._target_data is None:
raise RuntimeError("No data recorded")
target_data = {}
gen_data = {}
for name in sorted(self._target_data):
tensor = self._target_data[name]
target_data[name] = tensor / self._n_batches[:, None, None]
target_data[name] = self._dist.reduce_mean(target_data[name])
for name in sorted(self._gen_data):
tensor = self._gen_data[name]
gen_data[name] = tensor / self._n_batches[:, None, None]
gen_data[name] = self._dist.reduce_mean(gen_data[name])
return gen_data, target_data
class _VarianceVideoData:
"""
Record batches of video data and compute the variance.
"""
def __init__(self, n_timesteps: int):
self._target_means: TensorDict | None = None
self._gen_means: TensorDict | None = None
self._target_squares: TensorDict | None = None
self._gen_squares: TensorDict | None = None
self._n_timesteps = n_timesteps
self._n_batches = torch.zeros([n_timesteps], dtype=torch.int32).cpu()
self._dist = Distributed.get_instance()
@torch.no_grad()
def record_batch(
self,
target_data: TensorMapping,
gen_data: TensorMapping,
i_time_start: int,
):
"""
Record a batch of data.
Args:
target_data: Dict of tensors of shape (n_samples, n_timesteps, ...)
gen_data: Dict of tensors of shape (n_samples, n_timesteps, ...)
i_time_start: Index of the first timestep in the batch.
"""
if self._target_means is None:
self._target_means = _initialize_video_from_batch(
target_data, self._n_timesteps
)
if self._gen_means is None:
self._gen_means = _initialize_video_from_batch(gen_data, self._n_timesteps)
if self._target_squares is None:
self._target_squares = _initialize_video_from_batch(
target_data, self._n_timesteps
)
if self._gen_squares is None:
self._gen_squares = _initialize_video_from_batch(
gen_data, self._n_timesteps
)
window_steps = next(iter(target_data.values())).shape[1]
time_slice = slice(i_time_start, i_time_start + window_steps)
for name, tensor in target_data.items():
self._target_means[name][time_slice, ...] += tensor.mean(dim=0).cpu()
self._target_squares[name][time_slice, ...] += (tensor**2).mean(dim=0).cpu()
for name, tensor in gen_data.items():
self._gen_means[name][time_slice, ...] += tensor.mean(dim=0).cpu()
self._gen_squares[name][time_slice, ...] += (tensor**2).mean(dim=0).cpu()
self._n_batches[time_slice] += 1
@torch.no_grad()
def get(self) -> tuple[TensorDict, TensorDict]:
if (
self._gen_means is None
or self._target_means is None
or self._gen_squares is None
or self._target_squares is None
):
raise RuntimeError("No data recorded")
target_data = {}
gen_data = {}
# calculate variance as E[X^2] - E[X]^2
for name in sorted(self._target_means):
tensor = self._target_means[name]
mean = tensor / self._n_batches[:, None, None]
mean = self._dist.reduce_mean(mean)
square = self._target_squares[name] / self._n_batches[:, None, None]
square = self._dist.reduce_mean(square)
target_data[name] = square - mean**2
for name in sorted(self._gen_means):
tensor = self._gen_means[name]
mean = tensor / self._n_batches[:, None, None]
mean = self._dist.reduce_mean(mean)
square = self._gen_squares[name] / self._n_batches[:, None, None]
square = self._dist.reduce_mean(square)
gen_data[name] = square - mean**2
return gen_data, target_data
def _initialize_video_from_batch(
batch: TensorMapping, n_timesteps: int, fill_value: float = 0.0
):
"""
Initialize a video of the same shape as the batch, but with all valeus equal
to fill_value and with n_timesteps timesteps.
"""
video = {}
for name, value in batch.items():
shape = list(value.shape[1:])
shape[0] = n_timesteps
video[name] = torch.zeros(shape, dtype=torch.double).cpu()
video[name][:, ...] = fill_value
return video
@dataclasses.dataclass
class _MaybePairedVideoData:
caption: str
gen: torch.Tensor
units: str | None
long_name: str | None
target: torch.Tensor | None = None
def make_video(self):
return _make_video(
caption=self.caption,
gen=self.gen,
target=self.target,
)
class VideoAggregator:
"""Videos of state evolution."""
def __init__(
self,
n_timesteps: int,
enable_extended_videos: bool,
variable_metadata: Mapping[str, VariableMetadata] | None = None,
):
"""
Args:
n_timesteps: Number of timesteps of inference that will be run.
enable_extended_videos: Whether to log videos of statistical
metrics of state evolution
variable_metadata: Mapping of variable names their metadata that will
used in generating logged video captions.
"""
if variable_metadata is None:
self._variable_metadata: Mapping[str, VariableMetadata] = {}
else:
self._variable_metadata = variable_metadata
self._mean_data = _MeanVideoData(n_timesteps=n_timesteps)
if enable_extended_videos:
self._error_data: _ErrorVideoData | None = _ErrorVideoData(
n_timesteps=n_timesteps
)
self._variance_data: _VarianceVideoData | None = _VarianceVideoData(
n_timesteps=n_timesteps
)
self._enable_extended_videos = True
else:
self._error_data = None
self._variance_data = None
self._enable_extended_videos = False
@torch.no_grad()
def record_batch(
self,
data: InferenceBatchData,
):
self._mean_data.record_batch(
target_data=data.target,
gen_data=data.prediction,
i_time_start=data.i_time_start,
)
if self._error_data is not None:
self._error_data.record_batch(
target_data=data.target,
gen_data=data.prediction,
i_time_start=data.i_time_start,
)
if self._variance_data is not None:
self._variance_data.record_batch(
target_data=data.target,
gen_data=data.prediction,
i_time_start=data.i_time_start,
)
@torch.no_grad()
def get_logs(self, label: str):
"""
Returns logs as can be reported to WandB.
Args:
label: Label to prepend to all log keys.
"""
data = self._get_data()
videos = {}
for sub_label, d in data.items():
videos[f"{label}/{sub_label}"] = d.make_video()
return videos
@torch.no_grad()
def _get_data(self) -> Mapping[str, _MaybePairedVideoData]:
"""
Returns video data as can be reported to WandB.
Args:
label: Label to prepend to all log keys.
"""
gen_data, target_data = self._mean_data.get()
video_data = {}
def get_units(name: str) -> str | None:
if name in self._variable_metadata:
return self._variable_metadata[name].units
return None
def get_long_name(name: str) -> str | None:
if name in self._variable_metadata:
return self._variable_metadata[name].long_name
return None
for name in gen_data:
long_name = get_long_name(name) or name
video_data[name] = _MaybePairedVideoData(
caption=self._get_caption(name),
gen=gen_data[name],
target=target_data[name],
units=get_units(name),
long_name=f"ensemble mean of {long_name}",
)
if self._enable_extended_videos:
video_data[f"bias/{name}"] = _MaybePairedVideoData(
caption=(f"prediction - target for {name}"),
gen=gen_data[name] - target_data[name],
units=get_units(name),
long_name=f"bias of {get_long_name(name)}",
)
if self._error_data is not None:
data = self._error_data.get()
for name in data.rmse:
video_data[f"rmse/{name}"] = _MaybePairedVideoData(
caption=f"RMSE over ensemble for {name}",
gen=data.rmse[name],
units=get_units(name),
long_name=f"root mean squared error of {get_long_name(name)}",
)
for name in data.min_err:
video_data[f"min_err/{name}"] = _MaybePairedVideoData(
caption=f"Min across ensemble members of min error for {name}",
gen=data.min_err[name],
units=get_units(name),
long_name=(
f"min error of {get_long_name(name)} across ensemble members"
),
)
for name in data.max_err:
video_data[f"max_err/{name}"] = _MaybePairedVideoData(
caption=f"Max across ensemble members of max error for {name}",
gen=data.max_err[name],
units=get_units(name),
long_name=(
f"max error of {get_long_name(name)} across ensemble members"
),
)
if self._variance_data is not None:
gen_data, target_data = self._variance_data.get()
for name in gen_data:
video_data[f"gen_var/{name}"] = _MaybePairedVideoData(
caption=(
f"Variance of gen data for {name} "
"as fraction of target variance"
),
gen=gen_data[name] / target_data[name],
units="",
long_name=(
f"prediction variance of {get_long_name(name)} "
"as fraction of target variance"
),
)
return video_data
@torch.no_grad()
def get_dataset(self) -> xr.Dataset:
"""
Return video data as an xarray Dataset.
"""
data = self._get_data()
video_data = {}
for label, d in data.items():
label = label.strip("/").replace("/", "_") # remove leading slash
attrs = {}
if d.units is not None:
attrs["units"] = d.units
if d.long_name is not None:
attrs["long_name"] = d.long_name
if d.target is not None:
video_data[label] = xr.DataArray(
data=np.concatenate(
[d.gen.cpu().numpy()[None, :], d.target.cpu().numpy()[None, :]],
axis=0,
),
dims=("source", "timestep", "lat", "lon"),
attrs=attrs,
)
else:
video_data[label] = xr.DataArray(
data=d.gen.cpu().numpy(),
dims=("timestep", "lat", "lon"),
attrs=attrs,
)
return xr.Dataset(video_data)
def _get_caption(self, name: str) -> str:
caption = (
"Autoregressive (left) prediction and (right) target for {name} [{units}]"
)
if name in self._variable_metadata:
caption_name = self._variable_metadata[name].display_long_name(name)
units = self._variable_metadata[name].display_units("unknown units")
else:
caption_name, units = name, "unknown units"
return caption.format(name=caption_name, units=units)
def _make_video(
caption: str,
gen: torch.Tensor,
target: torch.Tensor | None = None,
):
if target is None:
video_data = np.expand_dims(gen.cpu().numpy(), axis=1)
else:
gen = np.expand_dims(gen.cpu().numpy(), axis=1)
target = np.expand_dims(target.cpu().numpy(), axis=1)
gap = np.zeros([gen.shape[0], 1, gen.shape[2], 10])
video_data = np.concatenate([gen, gap, target], axis=-1)
if target is None:
data_min = np.nanmin(video_data)
data_max = np.nanmax(video_data)
else:
# use target data to set the color scale
data_min = np.nanmin(target)
data_max = np.nanmax(target)
# video data is brightness values on a 0-255 scale
video_data = 255 * (video_data - data_min) / (data_max - data_min)
video_data = np.minimum(video_data, 255)
video_data = np.maximum(video_data, 0)
video_data[np.isnan(video_data)] = 0
# Convert from single-channel grayscale to three-channel RGB to work
# around imageio error
video_data = video_data.repeat(3, axis=1)
caption += f"; vmin={data_min:.4g}, vmax={data_max:.4g}"
return wandb.Video(
np.flip(video_data, axis=-2),
caption=caption,
fps=4,
format="gif",
)
[docs]@dataclasses.dataclass
class VideoMetricConfig:
variables: list[str] | None = None
name: str = "video"
enable_extended_videos: bool = False
enabled: bool = False
strict: bool = True
def get_name(self) -> str:
return self.name
def build(self, ctx: MetricBuildContext) -> MetricBuildResult:
if not isinstance(ctx.horizontal_coordinates, LatLonCoordinates):
raise MetricNotSupportedError("Video metric requires LatLonCoordinates.")
agg: SubAggregator = VideoAggregator(
n_timesteps=ctx.n_timesteps,
enable_extended_videos=self.enable_extended_videos,
variable_metadata=ctx.variable_metadata,
)
return MetricBuildResult(aggregator=maybe_filter(agg, self.variables))