Source code for fme.ace.aggregator.one_step.map

import dataclasses
from collections.abc import Mapping

import torch
import xarray as xr

from fme.core.dataset.data_typing import VariableMetadata
from fme.core.distributed import Distributed
from fme.core.typing_ import TensorDict, TensorMapping
from fme.core.wandb import Image

from ..plotting import plot_paneled_data
from .build_context import OneStepBuildContext, OneStepMetricBuildResult


class MapAggregator:
    """
    An aggregator that records the average over batches as function of lat and lon.
    """

    _captions = {
        "full-field": (
            "{name} one step mean full field; "
            "(top) generated and (bottom) target [{units}]"
        ),
        "error": (
            "{name} one step mean full field error (generated - target) [{units}]"
        ),
    }

    def __init__(
        self, dims: list[str], metadata: Mapping[str, VariableMetadata] | None = None
    ):
        """
        Args:
            dims: Dimensions of the data.
            metadata: Mapping of variable names their metadata that will
                used in generating logged image captions.
        """
        self._dims = dims
        if metadata is None:
            self._metadata: Mapping[str, VariableMetadata] = {}
        else:
            self._metadata = metadata
        self._n_batches = 0
        self._target_data: TensorDict = {}
        self._gen_data: TensorDict = {}

    @torch.no_grad()
    def record_batch(
        self,
        loss: float,
        target_data: TensorMapping,
        gen_data: TensorMapping,
        target_data_norm: TensorMapping,
        gen_data_norm: TensorMapping,
    ):
        time_dim = 1
        # note that we are only using the first timestep
        # see https://github.com/ai2cm/full-model/issues/1005
        target_time = 1
        self._loss = loss
        for name in target_data:
            if name in self._target_data:
                self._target_data[name] += (
                    target_data[name]
                    .select(dim=time_dim, index=target_time)
                    .mean(dim=0)
                )
            else:
                self._target_data[name] = (
                    target_data[name]
                    .select(dim=time_dim, index=target_time)
                    .mean(dim=0)
                )
        for name in gen_data:
            if name in self._gen_data:
                self._gen_data[name] += (
                    gen_data[name].select(dim=time_dim, index=target_time).mean(dim=0)
                )
            else:
                self._gen_data[name] = (
                    gen_data[name].select(dim=time_dim, index=target_time).mean(dim=0)
                )
        self._n_batches += 1

    def _get_data(self) -> tuple[TensorMapping, TensorMapping]:
        dist = Distributed.get_instance()
        gen, target = {}, {}
        for name in sorted(list(self._gen_data.keys())):
            gen[name] = (
                (dist.reduce_mean(self._gen_data[name]) / self._n_batches).cpu().numpy()
            )
        for name in sorted(list(self._target_data.keys())):
            target[name] = (
                (dist.reduce_mean(self._target_data[name]) / self._n_batches)
                .cpu()
                .numpy()
            )
        return gen, target

    @torch.no_grad()
    def get_logs(self, label: str) -> dict[str, Image]:
        """
        Returns logs as can be reported to WandB.

        Args:
            label: Label to prepend to all log keys.
        """
        gen, target = self._get_data()
        image_logs = {}
        for name in gen.keys():
            image_logs[f"image-error/{name}"] = plot_paneled_data(
                [[(gen[name] - target[name])]],
                diverging=True,
                caption=self._get_caption("error", name),
            )
            image_logs[f"image-full-field/{name}"] = plot_paneled_data(
                [
                    [gen[name]],
                    [target[name]],
                ],
                diverging=False,
                caption=self._get_caption("full-field", name),
            )
        image_logs = {f"{label}/{key}": image_logs[key] for key in image_logs}
        return image_logs

    def _get_caption(self, key: str, name: str) -> str:
        if name in self._metadata:
            caption_name = self._metadata[name].display_long_name(name)
            units = self._metadata[name].display_units()
        else:
            caption_name, units = name, "unknown_units"
        caption = self._captions[key].format(name=caption_name, units=units)
        return caption

    def get_dataset(self) -> xr.Dataset:
        gen, target = self._get_data()
        ds = xr.Dataset()
        for name in gen:
            if name in self._metadata:
                long_name = self._metadata[name].display_long_name(name)
                units = self._metadata[name].display_units()
            else:
                long_name = name
                units = "unknown_units"
            metadata_attrs = {"long_name": long_name, "units": units}
            ds[f"gen_map-{name}"] = xr.DataArray(
                data=gen[name], dims=self._dims, attrs=metadata_attrs
            )
            ds[f"bias_map-{name}"] = xr.DataArray(
                data=gen[name] - target[name],
                dims=self._dims,
                attrs=metadata_attrs,
            )
        return ds


[docs]@dataclasses.dataclass class OneStepMapMetricConfig: name: str = "mean_map" enabled: bool = True strict: bool = False def get_name(self) -> str: return self.name def build(self, ctx: OneStepBuildContext) -> OneStepMetricBuildResult: agg = MapAggregator( ctx.horizontal_coordinates.dims, ctx.variable_metadata, ) return OneStepMetricBuildResult(deterministic=agg)