import dataclasses
import logging
import warnings
from collections.abc import Sequence
import torch
from fme.ace.aggregator.loss_metrics import PerStepLossAggregator
from fme.ace.aggregator.one_step.deterministic import (
DeterministicTrainOutput,
OneStepDeterministicAggregator,
)
from fme.ace.aggregator.one_step.ensemble import (
OneStepEnsembleMetricConfig,
get_one_step_ensemble_aggregator,
)
from fme.ace.stepper import TrainOutput
from fme.core.dataset_info import DatasetInfo
from fme.core.generics.aggregator import AggregatorABC, AggregatorSummary
from fme.core.tensors import fold_ensemble_dim, fold_sized_ensemble_dim
from fme.core.typing_ import EnsembleTensorDict, TensorMapping
from .build_context import (
Aggregator,
EnsembleAggregator,
MetricNotSupportedError,
OneStepBuildContext,
OneStepMetricBuildResult,
)
from .map import OneStepMapMetricConfig
from .reduced import OneStepMeanMetricConfig
from .snapshot import OneStepSnapshotMetricConfig
from .spectrum import OneStepSpectrumMetricConfig
OneStepMetricConfig = (
OneStepMeanMetricConfig
| OneStepSnapshotMetricConfig
| OneStepMapMetricConfig
| OneStepSpectrumMetricConfig
| OneStepEnsembleMetricConfig
)
class OneStepAggregator(AggregatorABC[TrainOutput]):
"""
Aggregates statistics for the first timestep.
To use, call `record_batch` on the results of each batch, then call
`get_logs` to get a dictionary of statistics when you're done.
"""
def __init__(
self,
deterministic_aggregator: OneStepDeterministicAggregator,
ensemble_aggregators: dict[str, EnsembleAggregator] | None = None,
):
self._deterministic_aggregator = deterministic_aggregator
self._ensemble_aggregators: dict[str, EnsembleAggregator] = (
ensemble_aggregators or {}
)
self._ensemble_recorded = False
self._per_step_losses = PerStepLossAggregator()
@torch.no_grad()
def record_batch(
self,
batch: TrainOutput,
):
step_metrics = {
k: v
for k, v in batch.metrics.items()
if k.startswith("loss_step_") or k.startswith("loss/")
}
self._per_step_losses.record(step_metrics)
folded_gen_data, n_ensemble = fold_ensemble_dim(batch.gen_data)
folded_target_data = fold_sized_ensemble_dim(batch.target_data, n_ensemble)
self._deterministic_aggregator.record_batch(
DeterministicTrainOutput(
metrics=batch.metrics,
gen_data=folded_gen_data,
target_data=folded_target_data,
normalize=batch.normalize,
)
)
if n_ensemble > 1 and self._ensemble_aggregators:
target_data_norm = EnsembleTensorDict(batch.normalize(batch.target_data))
gen_data_norm = EnsembleTensorDict(batch.normalize(batch.gen_data))
for ensemble_aggregator in self._ensemble_aggregators.values():
ensemble_aggregator.record_batch(
target_data=batch.target_data,
gen_data=batch.gen_data,
target_data_norm=target_data_norm,
gen_data_norm=gen_data_norm,
i_time_start=0,
)
self._ensemble_recorded = True
@torch.no_grad()
def get_summary(self, label: str) -> AggregatorSummary:
det_summary = self._deterministic_aggregator.get_summary(label)
logs = dict(det_summary.logs)
logs.update(self._per_step_losses.get_logs(label))
if self._ensemble_recorded and self._ensemble_aggregators:
stochastic_logs: dict = {}
for agg_name, ensemble_aggregator in self._ensemble_aggregators.items():
for k, v in ensemble_aggregator.get_logs(label=agg_name).items():
stochastic_logs[f"{label}/{k}"] = v
if len(set(logs.keys()) & set(stochastic_logs.keys())) > 0:
raise ValueError(
"Stochastic and deterministic logs have overlapping keys, "
f"stochastic logs: {stochastic_logs}, "
f"deterministic logs: {logs}"
)
logs.update(stochastic_logs)
return AggregatorSummary(logs=logs, loss=det_summary.loss)
def get_logs(self, label: str):
return self.get_summary(label).logs
@torch.no_grad()
def flush_diagnostics(self, subdir: str | None = None):
self._deterministic_aggregator.flush_diagnostics(subdir)
def _validate_no_duplicate_names(metrics: list[OneStepMetricConfig]) -> None:
names = [m.get_name() for m in metrics]
seen: set[str] = set()
duplicates: set[str] = set()
for n in names:
if n in seen:
duplicates.add(n)
seen.add(n)
if duplicates:
raise ValueError(
f"Duplicate metric names: {sorted(duplicates)}. "
"Use the 'name' field to disambiguate."
)
def build_one_step_aggregator(
metrics: list[OneStepMetricConfig],
dataset_info: DatasetInfo,
save_diagnostics: bool = True,
output_dir: str | None = None,
loss_scaling: TensorMapping | None = None,
channel_mean_names: Sequence[str] | None = None,
raise_on_unsupported: bool = True,
include_default_ensemble: bool = True,
) -> OneStepAggregator:
_validate_no_duplicate_names(metrics)
ctx = OneStepBuildContext(
ops=dataset_info.gridded_operations,
horizontal_coordinates=dataset_info.horizontal_coordinates,
variable_metadata=dataset_info.variable_metadata,
channel_mean_names=channel_mean_names,
)
deterministic_aggregators: dict[str, Aggregator] = {}
ensemble_aggregators: dict[str, EnsembleAggregator] = {}
for metric in metrics:
name = metric.get_name()
try:
result: OneStepMetricBuildResult = metric.build(ctx)
except MetricNotSupportedError as e:
if raise_on_unsupported or metric.strict:
raise
logging.warning(
f"{name} metric not supported for this configuration, omitting: {e}"
)
continue
if result.deterministic is not None:
deterministic_aggregators[name] = result.deterministic
if result.ensemble is not None:
ensemble_aggregators[name] = result.ensemble
if not ensemble_aggregators and include_default_ensemble:
ensemble_aggregators["ensemble"] = get_one_step_ensemble_aggregator(
gridded_operations=ctx.ops,
target_time=1,
metadata=ctx.variable_metadata,
)
deterministic = OneStepDeterministicAggregator(
aggregators=deterministic_aggregators,
coords=dataset_info.horizontal_coordinates.coords,
save_diagnostics=save_diagnostics,
output_dir=output_dir,
loss_scaling=loss_scaling,
)
return OneStepAggregator(
deterministic_aggregator=deterministic,
ensemble_aggregators=ensemble_aggregators,
)
[docs]@dataclasses.dataclass
class OneStepAggregatorConfig:
"""
Configuration for the validation OneStepAggregator.
Each metric is a named field with its own typed configuration and an
``enabled`` flag. Defaults match the standard metric set.
Metrics whose runtime requirements are not met (e.g. ``power_spectrum``
on a non-spherical grid) are skipped with a warning when ``strict``
is ``False`` (the default for built-in metrics).
Parameters:
mean_denorm: Mean metrics on denormalized data.
mean_norm: Mean metrics on normalized data.
power_spectrum: Spherical power spectrum metrics.
snapshot: Snapshot image metrics.
mean_map: Mean map image metrics.
ensemble_denorm: Ensemble spread metrics on denormalized data.
ensemble_norm: Ensemble spread metrics on normalized data.
"""
mean_denorm: OneStepMeanMetricConfig = dataclasses.field(
default_factory=lambda: OneStepMeanMetricConfig(target="denorm")
)
mean_norm: OneStepMeanMetricConfig = dataclasses.field(
default_factory=lambda: OneStepMeanMetricConfig(
target="norm",
include_bias=False,
include_grad_mag_percent_diff=False,
)
)
power_spectrum: OneStepSpectrumMetricConfig = dataclasses.field(
default_factory=OneStepSpectrumMetricConfig
)
snapshot: OneStepSnapshotMetricConfig = dataclasses.field(
default_factory=OneStepSnapshotMetricConfig
)
mean_map: OneStepMapMetricConfig = dataclasses.field(
default_factory=OneStepMapMetricConfig
)
ensemble_denorm: OneStepEnsembleMetricConfig = dataclasses.field(
default_factory=lambda: OneStepEnsembleMetricConfig(target="denorm")
)
ensemble_norm: OneStepEnsembleMetricConfig = dataclasses.field(
default_factory=lambda: OneStepEnsembleMetricConfig(
target="norm", enabled=False
)
)
def __post_init__(self):
if not self.mean_denorm.enabled:
raise ValueError("mean_denorm cannot be disabled.")
if not self.mean_norm.enabled:
raise ValueError("mean_norm cannot be disabled.")
if self.mean_denorm.target != "denorm":
raise ValueError(
f"mean_denorm.target must be 'denorm', "
f"got '{self.mean_denorm.target}'"
)
if self.mean_norm.target != "norm":
raise ValueError(
f"mean_norm.target must be 'norm', got '{self.mean_norm.target}'"
)
if self.ensemble_denorm.target != "denorm":
raise ValueError(
f"ensemble_denorm.target must be 'denorm', "
f"got '{self.ensemble_denorm.target}'"
)
if self.ensemble_norm.target != "norm":
raise ValueError(
f"ensemble_norm.target must be 'norm', "
f"got '{self.ensemble_norm.target}'"
)
def _get_metrics(self) -> list[OneStepMetricConfig]:
all_metrics: list[OneStepMetricConfig] = [
self.mean_denorm,
self.mean_norm,
self.power_spectrum,
self.snapshot,
self.mean_map,
self.ensemble_denorm,
self.ensemble_norm,
]
return [m for m in all_metrics if m.enabled]
def build(
self,
dataset_info: DatasetInfo,
save_diagnostics: bool = True,
output_dir: str | None = None,
loss_scaling: TensorMapping | None = None,
channel_mean_names: Sequence[str] | None = None,
) -> OneStepAggregator:
return build_one_step_aggregator(
metrics=self._get_metrics(),
dataset_info=dataset_info,
save_diagnostics=save_diagnostics,
output_dir=output_dir,
loss_scaling=loss_scaling,
channel_mean_names=channel_mean_names,
raise_on_unsupported=False,
include_default_ensemble=False,
)
[docs]@dataclasses.dataclass
class LegacyFlagOneStepAggregatorConfig:
"""
Legacy configuration for the validation OneStepAggregator using boolean flags.
Deprecated: Use OneStepAggregatorConfig with typed metrics instead.
Arguments:
log_snapshots: Whether to log snapshot images.
log_mean_maps: Whether to log mean map images.
"""
log_snapshots: bool = True
log_mean_maps: bool = True
def __post_init__(self):
warnings.warn(
"LegacyFlagOneStepAggregatorConfig is deprecated. "
"Use OneStepAggregatorConfig instead.",
DeprecationWarning,
stacklevel=2,
)
def _get_metrics(self) -> list[OneStepMetricConfig]:
metrics: list[OneStepMetricConfig] = [
OneStepMeanMetricConfig(target="denorm"),
OneStepMeanMetricConfig(
target="norm",
include_bias=False,
include_grad_mag_percent_diff=False,
),
OneStepSpectrumMetricConfig(),
]
if self.log_snapshots:
metrics.append(OneStepSnapshotMetricConfig())
if self.log_mean_maps:
metrics.append(OneStepMapMetricConfig())
metrics.append(OneStepEnsembleMetricConfig(log_mean_maps=self.log_mean_maps))
return metrics
def build(
self,
dataset_info: DatasetInfo,
save_diagnostics: bool = True,
output_dir: str | None = None,
loss_scaling: TensorMapping | None = None,
channel_mean_names: Sequence[str] | None = None,
) -> OneStepAggregator:
return build_one_step_aggregator(
metrics=self._get_metrics(),
dataset_info=dataset_info,
save_diagnostics=save_diagnostics,
output_dir=output_dir,
loss_scaling=loss_scaling,
channel_mean_names=channel_mean_names,
raise_on_unsupported=False,
)