Source code for fme.ace.aggregator.one_step.snapshot

import dataclasses
from collections.abc import Mapping

import torch
import xarray as xr

from fme.core.dataset.data_typing import VariableMetadata
from fme.core.typing_ import TensorMapping
from fme.core.wandb import Image

from ..plotting import plot_paneled_data
from .build_context import OneStepBuildContext, OneStepMetricBuildResult


class SnapshotAggregator:
    """
    An aggregator that records the first sample of the last batch of data.
    """

    _captions = {
        "full-field": (
            "{name} one step full field for first sample in last batch; "
            "(top) generated and (bottom) target [{units}]"
        ),
        "residual": (
            "{name} one step residual (prediction - previous time) for first sample in "
            "last batch; (top) generated and (bottom) target [{units}]"
        ),
        "error": (
            "{name} one step full field error (generated - target) "
            "for first sample in last batch [{units}]"
        ),
    }

    def __init__(
        self, dims: list[str], metadata: Mapping[str, VariableMetadata] | None = None
    ):
        """
        Args:
            dims: Dimensions of the data.
            metadata: Mapping of variable names their metadata that will
                used in generating logged image captions.
        """
        self._dims = dims
        if metadata is None:
            self._metadata: Mapping[str, VariableMetadata] = {}
        else:
            self._metadata = metadata

    @torch.no_grad()
    def record_batch(
        self,
        loss: float,
        target_data: TensorMapping,
        gen_data: TensorMapping,
        target_data_norm: TensorMapping,
        gen_data_norm: TensorMapping,
    ):
        self._loss = loss
        self._target_data = target_data
        self._gen_data = gen_data
        self._target_data_norm = target_data_norm
        self._gen_data_norm = gen_data_norm

    def _get_data(self) -> tuple[TensorMapping, TensorMapping, TensorMapping]:
        time_dim = 1
        input_time = 0
        target_time = 1
        gen, target, input = {}, {}, {}
        for name in self._gen_data.keys():
            # use first sample in batch
            gen[name] = (
                self._gen_data[name]
                .select(dim=time_dim, index=target_time)[0]
                .cpu()
                .numpy()
            )
            target[name] = (
                self._target_data[name]
                .select(dim=time_dim, index=target_time)[0]
                .cpu()
                .numpy()
            )
            input[name] = (
                self._target_data[name]
                .select(dim=time_dim, index=input_time)[0]
                .cpu()
                .numpy()
            )
        return gen, target, input

    @torch.no_grad()
    def get_logs(self, label: str) -> dict[str, Image]:
        """
        Returns logs as can be reported to WandB.

        Args:
            label: Label to prepend to all log keys.
        """
        image_logs = {}
        gen, target, input = self._get_data()
        for name in gen:
            images = {}
            images["error"] = [[(gen[name] - target[name])]]
            images["full-field"] = [[gen[name]], [target[name]]]
            images["residual"] = [
                [(gen[name] - input[name])],
                [(target[name] - input[name])],
            ]
            for key, data in images.items():
                if key == "error" or key == "residual":
                    diverging = True
                else:
                    diverging = False
                caption = self._get_caption(key, name)
                wandb_image = plot_paneled_data(data, diverging, caption=caption)
                image_logs[f"image-{key}/{name}"] = wandb_image
        image_logs = {f"{label}/{key}": image_logs[key] for key in image_logs}
        return image_logs

    def _get_caption(self, key: str, name: str) -> str:
        if name in self._metadata:
            caption_name = self._metadata[name].display_long_name(name)
            units = self._metadata[name].display_units()
        else:
            caption_name, units = name, "unknown_units"
        caption = self._captions[key].format(name=caption_name, units=units)
        return caption

    def get_dataset(self) -> xr.Dataset:
        gen, target, input = self._get_data()
        ds = xr.Dataset()
        for name in gen:
            if name in self._metadata:
                long_name = self._metadata[name].display_long_name(name)
                units = self._metadata[name].display_units()
            else:
                long_name = name
                units = "unknown_units"
            metadata_attrs = {"long_name": long_name, "units": units}
            ds[f"error_map-{name}"] = xr.DataArray(
                data=(gen[name] - target[name]), dims=self._dims, attrs=metadata_attrs
            )
            ds[f"gen_full_field_map-{name}"] = xr.DataArray(
                data=gen[name],
                dims=self._dims,
                attrs=metadata_attrs,
            )
            ds[f"gen_residual_map-{name}"] = xr.DataArray(
                data=gen[name] - input[name],
                dims=self._dims,
                attrs=metadata_attrs,
            )
            ds[f"target_residual_map-{name}"] = xr.DataArray(
                data=target[name] - input[name],
                dims=self._dims,
                attrs=metadata_attrs,
            )
        return ds


[docs]@dataclasses.dataclass class OneStepSnapshotMetricConfig: name: str = "snapshot" enabled: bool = True strict: bool = False def get_name(self) -> str: return self.name def build(self, ctx: OneStepBuildContext) -> OneStepMetricBuildResult: agg = SnapshotAggregator( ctx.horizontal_coordinates.dims, ctx.variable_metadata, ) return OneStepMetricBuildResult(deterministic=agg)