import dataclasses
import datetime
from collections.abc import Callable, Mapping
from functools import partial
from typing import Any, Optional
import numpy as np
import torch
import xarray as xr
from matplotlib.figure import Figure
from fme.core.dataset.data_typing import VariableMetadata
from fme.core.device import get_device
from fme.core.distributed import Distributed
from fme.core.gridded_ops import GriddedOperations
from fme.core.typing_ import TensorMapping
from ..plotting import plot_mean_and_samples
from .build_context import MetricBuildContext, MetricNotSupportedError, maybe_filter
from .data import InferenceBatchData, MetricBuildResult, SubAggregator
class PairedGlobalMeanAnnualAggregator:
def __init__(
self,
ops: GriddedOperations,
timestep: datetime.timedelta,
variable_metadata: Mapping[str, VariableMetadata] | None = None,
monthly_reference_data: xr.Dataset | None = None,
):
self._area_weighted_mean = ops.area_weighted_mean
self.timestep = timestep
self.variable_metadata = variable_metadata or {}
self._target_aggregator = GlobalMeanAnnualAggregator(
ops, timestep, variable_metadata
)
self._gen_aggregator = GlobalMeanAnnualAggregator(
ops, timestep, variable_metadata
)
self._monthly_reference_data = monthly_reference_data
self._variable_reference_data: dict[str, VariableReferenceData] = {}
def _get_reference(self, name: str) -> Optional["VariableReferenceData"]:
if self._monthly_reference_data is None:
return None
if name not in self._variable_reference_data:
if name not in self._monthly_reference_data:
return None
area_weighted_mean = partial(self._area_weighted_mean, name=name)
self._variable_reference_data[name] = process_monthly_reference(
self._monthly_reference_data, area_weighted_mean, name
)
return self._variable_reference_data[name]
@torch.no_grad()
def record_batch(
self,
data: InferenceBatchData,
):
"""Record a batch of data for computing time variability statistics."""
target_data = data.replace(prediction=data.target)
gen_data = data.replace(prediction=data.prediction)
self._target_aggregator.record_batch(target_data)
self._gen_aggregator.record_batch(gen_data)
def _get_gathered_means(self) -> tuple[xr.Dataset, xr.Dataset] | None:
"""
Gather the mean target and generated data across all processes.
Returns:
A tuple of the target and generated datasets, or None if this is not the
root rank.
"""
target = self._target_aggregator.get_gathered_means()
gen = self._gen_aggregator.get_gathered_means()
if target is None or gen is None:
return None
return target, gen
@torch.no_grad()
def get_logs(self, label: str) -> dict[str, Any]:
gathered = self._get_gathered_means()
if gathered is None: # not the root rank
return {}
target, gen = gathered
plots = {}
metrics = {}
for name in gen.data_vars.keys():
if name == "counts":
continue
if name in self.variable_metadata:
long_name = self.variable_metadata[name].display_long_name(name)
units = self.variable_metadata[name].display_units("unknown units")
else:
long_name = name
units = "unknown units"
fig = Figure() # create directly for cleanup when it leaves scope
ax = fig.add_subplot(1, 1, 1) # Add an axes to the figure
ref = self._get_reference(name)
if ref is not None:
if ref.mean.sizes["year"] > 1:
# dataarray.plot() does not work for singleton dimensions
ref.mean.plot(ax=ax, x="year", label="ref_mean", color="black")
ref.min.plot(
ax=ax, x="year", label="ref_min", color="grey", linestyle="--"
)
ref.max.plot(
ax=ax, x="year", label="ref_max", color="grey", linestyle="--"
)
if gen.sizes["year"] > 1:
target_ensemble_mean = target[name].mean("sample")
gen_ensemble_mean = gen[name].mean("sample")
# compute R2 values
if ref is not None:
r2_target = get_r2(target_ensemble_mean, ref.mean)
r2_gen = get_r2(gen_ensemble_mean, ref.mean)
metrics[f"r2/{name}_target"] = r2_target
metrics[f"r2/{name}_gen"] = r2_gen
target_label = f"target R2: {r2_target:.2f}"
gen_label = f"gen R2: {r2_gen:.2f}"
else:
target_label = "target"
gen_label = "gen"
plot_mean_and_samples(
ax, target[name], target_label, color="orange", plot_samples=False
)
plot_mean_and_samples(ax, gen[name], gen_label)
ax.set_title(f"{name}")
ax.set_ylabel(f"{long_name} [{units}]")
ax.legend()
fig.tight_layout()
plots[name] = fig
if len(label) > 0:
label = label + "/"
logs = {}
logs.update({f"{label}{name}": plots[name] for name in plots.keys()})
logs.update({f"{label}{name}": metrics[name] for name in metrics.keys()})
return logs
def get_dataset(self) -> xr.Dataset:
gathered = self._get_gathered_means()
if gathered is None:
return xr.Dataset()
target, gen = gathered
return xr.concat(
[
target.expand_dims({"source": ["target"]}),
gen.expand_dims({"source": ["prediction"]}),
],
dim="source",
)
class GlobalMeanAnnualAggregator:
def __init__(
self,
ops: GriddedOperations,
timestep: datetime.timedelta,
variable_metadata: Mapping[str, VariableMetadata] | None = None,
):
self._area_weighted_mean_dict = ops.area_weighted_mean_dict
self.timestep = timestep
self.variable_metadata = variable_metadata or {}
self._datasets: list[xr.Dataset] | None = None
@torch.no_grad()
def record_batch(self, data: InferenceBatchData):
"""Record a batch of data for computing time variability statistics."""
time = data.time
data_area_mean = {
name: tensor.cpu()
for name, tensor in self._area_weighted_mean_dict(data.prediction).items()
}
ds = to_dataset(data_area_mean, time)
# must keep a separate dataset for each sample to avoid averaging across
# samples when we groupby year
if self._datasets is None:
self._datasets = []
for i_sample in range(ds.sizes["sample"]):
self._datasets.append(
ds.isel(sample=i_sample)
.groupby(ds["valid_time"].isel(sample=i_sample).dt.year)
.sum()
)
else:
for i_sample in range(ds.sizes["sample"]):
self._datasets[i_sample] = _add_dataarray(
self._datasets[i_sample],
ds.isel(sample=i_sample)
.groupby(ds["valid_time"].isel(sample=i_sample).dt.year)
.sum(),
)
def get_gathered_means(self) -> xr.Dataset | None:
"""
Gather the mean data across all processes.
Returns:
The mean dataset, or None if this is not the root rank.
"""
if self._datasets is None:
raise ValueError("No data has been recorded yet.")
dist = Distributed.get_instance()
data = xr.concat(self._datasets, dim="sample", join="outer")
if dist.world_size > 1:
data = _gather_sample_datasets(dist, data)
if data is None:
return None # we are not root rank
# filter out data with insufficient samples
min_samples = _get_min_samples(self.timestep)
data = data.where(data["counts"] > min_samples, drop=True)
data = data / data["counts"]
# ensure the 'year' coordinate has no jumps, filling in with NaNs as needed
if data.sizes["year"] > 0:
min_year = data["year"].min()
max_year = data["year"].max()
years = np.arange(min_year, max_year + 1, dtype=data.year.dtype)
data = data.reindex(year=years)
return data
@torch.no_grad()
def get_logs(self, label: str) -> dict[str, Any]:
ds = self.get_gathered_means()
if ds is None: # not the root rank
return {}
plots = {}
for name in ds.data_vars.keys():
if name == "counts":
continue
if name in self.variable_metadata:
long_name = self.variable_metadata[name].display_long_name(name)
units = self.variable_metadata[name].display_units("unknown units")
else:
long_name = name
units = "unknown units"
fig = Figure() # create directly for cleanup when it leaves scope
ax = fig.add_subplot(1, 1, 1) # Add an axes to the figure
if ds.sizes["year"] > 1:
plot_mean_and_samples(ax, ds[name], "ensemble mean")
ax.set_title(f"{name}")
ax.set_ylabel(f"{long_name} [{units}]")
ax.legend()
fig.tight_layout()
plots[name] = fig
if len(label) > 0:
label = label + "/"
logs = {f"{label}{name}": plot for name, plot in plots.items()}
return logs
def get_dataset(self) -> xr.Dataset:
gathered = self.get_gathered_means()
if gathered is None:
return xr.Dataset()
return gathered
@dataclasses.dataclass
class VariableReferenceData:
mean: xr.DataArray
min: xr.DataArray
max: xr.DataArray
def process_monthly_reference(
monthly_reference_data: xr.Dataset,
area_weighted_mean: Callable[[torch.Tensor], torch.Tensor],
name: str,
) -> VariableReferenceData:
ref_global_mean = xr.DataArray(
area_weighted_mean(torch.as_tensor(monthly_reference_data[name].values)),
dims=monthly_reference_data[name].dims[:-2],
coords={"time": monthly_reference_data[name].coords["time"]},
)
valid_time_0 = monthly_reference_data.valid_time.isel(sample=0)
for i in range(1, monthly_reference_data.sizes["sample"]):
valid_time_i = monthly_reference_data.valid_time.isel(sample=i)
if not valid_time_0.equals(valid_time_i):
raise ValueError("All time axes must be the same")
# need to select just one time axis so we don't lose sample dim
ref_annual_coarsened = (ref_global_mean * monthly_reference_data["counts"]).groupby(
valid_time_0.dt.year
).sum() / monthly_reference_data["counts"].groupby(valid_time_0.dt.year).sum()
return VariableReferenceData(
mean=ref_annual_coarsened.mean("sample"),
min=ref_annual_coarsened.min("sample"),
max=ref_annual_coarsened.max("sample"),
)
def _add_dataarray(da1: xr.DataArray, da2: xr.DataArray):
"""
Perform dataarray addition, assuming any missing year indices
have zero values.
"""
union_index = np.union1d(da1.year.values, da2.year.values)
da1 = da1.reindex(year=union_index, fill_value=0)
da2 = da2.reindex(year=union_index, fill_value=0)
return da1 + da2
def get_r2(da: xr.DataArray, reference: xr.DataArray) -> float:
"""Compute the R2 value of the target compared to the reference."""
ref_data = reference.sel(year=da.year)
SS_ref = np.sum((ref_data.values - np.mean(ref_data.values)) ** 2)
SS_pred = np.sum((da - ref_data).values ** 2)
return float(1 - SS_pred / SS_ref)
def _gather_sample_datasets(
dist: Distributed, dataset: xr.Dataset
) -> xr.Dataset | None:
"""
Gather the dataset across all processes, concatenating on the sample dimension.
Assumes all dataset variables have the same dimensions and shape, and that the
first dimension is "sample".
"""
# collect all data into one torch.Tensor for gathering, sort for determinism
names = sorted(list(dataset.data_vars))
tensor = torch.cat(
[torch.asarray(np.expand_dims(dataset[name].values, axis=0)) for name in names],
dim=0,
).to(get_device())
years = torch.asarray(dataset.year.values).to(get_device())
gathered_tensors = dist.gather_irregular(tensor)
gathered_years = dist.gather_irregular(years)
if gathered_tensors is None or gathered_years is None:
return None
datasets = []
for tensor, years in zip(gathered_tensors, gathered_years):
single_rank_dataset = xr.Dataset(
{
name: (["sample", "year"], tensor[i].cpu())
for i, name in enumerate(names)
},
coords={"year": years.cpu()},
)
datasets.append(single_rank_dataset)
# concat ranks along sample dim
dataset_out = xr.concat(datasets, dim="sample")
return dataset_out
@torch.no_grad()
def to_dataset(data: TensorMapping, time: xr.DataArray) -> xr.Dataset:
"""Convert a dictionary of data to an xarray dataset."""
assert time.dims == ("sample", "time") # must be consistent with this module
data_vars = {}
for name, tensor in data.items():
data_vars[name] = (["sample", "time"], tensor)
data_vars["counts"] = (
["sample", "time"],
np.ones(shape=time.shape, dtype=np.float32),
)
return xr.Dataset(data_vars, coords={"valid_time": time})
def _get_min_samples(timestep: datetime.timedelta) -> int:
steps_per_day = datetime.timedelta(days=1) // timestep
return 362 * steps_per_day
[docs]@dataclasses.dataclass
class AnnualMetricConfig:
variables: list[str] | None = None
name: str = "annual"
reference_data: str | None = None
enabled: bool = True
strict: bool = False
def get_name(self) -> str:
return self.name
def build(self, ctx: MetricBuildContext) -> MetricBuildResult:
total_duration = ctx.n_timesteps * ctx.timestep
if total_duration <= datetime.timedelta(days=730):
raise MetricNotSupportedError(
f"annual metric requires > ~2 years of data, "
f"got {total_duration.days} days"
)
if self.reference_data is not None:
ref = xr.open_dataset(self.reference_data, decode_timedelta=False)
else:
ref = ctx.monthly_reference_data
agg: SubAggregator = PairedGlobalMeanAnnualAggregator(
ops=ctx.ops,
timestep=ctx.timestep,
variable_metadata=ctx.variable_metadata,
monthly_reference_data=ref,
)
return MetricBuildResult(aggregator=maybe_filter(agg, self.variables))