Source code for fme.ace.aggregator.inference.main

import dataclasses
import datetime
import logging
import warnings
from collections.abc import Callable, Mapping, Sequence

import numpy as np
import torch
import xarray as xr

from fme.ace.data_loading.batch_data import PairedData, PrognosticState
from fme.core.coordinates import HorizontalCoordinates, LatLonCoordinates
from fme.core.dataset_info import DatasetInfo
from fme.core.diagnostics import get_reduced_diagnostics, write_reduced_diagnostics
from fme.core.fill import SmoothFloodFill
from fme.core.generics.aggregator import (
    InferenceAggregatorABC,
    InferenceLog,
    InferenceLogs,
)
from fme.core.gridded_ops import GriddedOperations, LatLonOperations
from fme.core.tensors import unfold_ensemble_dim
from fme.core.typing_ import TensorDict, TensorMapping
from fme.core.wandb import Table, WandB

from ..one_step.ensemble import EnsembleMetricConfig, SelectStepEnsembleAggregator
from ..one_step.reduced import StepMeanMetricConfig
from .annual import AnnualMetricConfig, GlobalMeanAnnualAggregator
from .build_context import MetricBuildContext, MetricNotSupportedError
from .data import InferenceBatchData, MetricBuildResult, SubAggregator, TimeSeriesLogs
from .enso import RegionalIndexAggregator
from .enso.dynamic_index import EnsoIndexMetricConfig
from .enso.enso_coefficient import EnsoCoefficientMetricConfig
from .histogram import HistogramMetricConfig
from .ipo.ipo_index import MIN_YEARS_FOR_FILTERED_TPI, IpoIndexMetricConfig
from .reduced import MeanMetricConfig, SingleTargetMeanAggregator
from .seasonal import SeasonalMetricConfig
from .spectrum import PowerSpectrumMetricConfig, SphericalPowerSpectrumAggregator
from .time_mean import TimeMeanAggregator, TimeMeanMetricConfig
from .utils import LatLonRegion
from .video import VideoMetricConfig
from .zonal_mean import ZonalMeanMetricConfig

wandb = WandB.get_instance()
APPROXIMATELY_TWO_YEARS = datetime.timedelta(days=730)
SLIGHTLY_LESS_THAN_FIVE_YEARS = datetime.timedelta(days=1800)
APPROXIMATELY_EIGHTY_YEARS = datetime.timedelta(days=MIN_YEARS_FOR_FILTERED_TPI * 365)
NINO34_LAT = (-5, 5)
NINO34_LON = (190, 240)

MetricConfig = (
    MeanMetricConfig
    | StepMeanMetricConfig
    | PowerSpectrumMetricConfig
    | ZonalMeanMetricConfig
    | VideoMetricConfig
    | TimeMeanMetricConfig
    | HistogramMetricConfig
    | SeasonalMetricConfig
    | AnnualMetricConfig
    | EnsoIndexMetricConfig
    | EnsoCoefficientMetricConfig
    | EnsembleMetricConfig
    | IpoIndexMetricConfig
)


def _validate_no_duplicate_names(metrics: list[MetricConfig]) -> None:
    names = [m.get_name() for m in metrics]
    seen: set[str] = set()
    duplicates: set[str] = set()
    for n in names:
        if n in seen:
            duplicates.add(n)
        seen.add(n)
    if duplicates:
        raise ValueError(
            f"Duplicate metric names: {sorted(duplicates)}. "
            "Use the 'name' field to disambiguate."
        )


def build_inference_evaluator_aggregator(
    metrics: list[MetricConfig],
    dataset_info: DatasetInfo,
    n_ic_steps: int,
    n_forward_steps: int,
    initial_time: xr.DataArray,
    normalize: Callable[[TensorMapping], TensorDict],
    monthly_reference_data: str | None = None,
    time_mean_reference_data: str | None = None,
    output_dir: str | None = None,
    channel_mean_names: Sequence[str] | None = None,
    save_diagnostics: bool = True,
    n_ensemble_per_ic: int = 1,
    enable_time_series: bool = True,
    raise_on_unsupported: bool = True,
) -> "InferenceEvaluatorAggregator":
    _validate_no_duplicate_names(metrics)
    if save_diagnostics and output_dir is None:
        raise ValueError("Output directory must be set to save diagnostics.")
    if monthly_reference_data is None:
        monthly_ref = None
    else:
        monthly_ref = xr.open_dataset(monthly_reference_data, decode_timedelta=False)
    if time_mean_reference_data is None:
        time_mean_ref = None
    else:
        time_mean_ref = xr.open_dataset(
            time_mean_reference_data, decode_timedelta=False
        )

    n_timesteps = n_ic_steps + n_forward_steps
    ctx = MetricBuildContext(
        ops=dataset_info.gridded_operations,
        horizontal_coordinates=dataset_info.horizontal_coordinates,
        n_timesteps=n_timesteps,
        n_ic_steps=n_ic_steps,
        timestep=dataset_info.timestep,
        variable_metadata=dataset_info.variable_metadata,
        channel_mean_names=channel_mean_names,
        monthly_reference_data=monthly_ref,
        time_mean_reference_data=time_mean_ref,
        initial_time=initial_time,
    )

    metrics = list(metrics)
    if not enable_time_series:
        metrics = [m for m in metrics if not isinstance(m, MeanMetricConfig)]

    aggregators: dict[str, SubAggregator] = {}
    time_series_aggregators: dict[str, TimeSeriesLogs] = {}
    ensemble_aggregators: dict[str, SelectStepEnsembleAggregator] = {}

    for metric in metrics:
        name = metric.get_name()
        try:
            result: MetricBuildResult = metric.build(ctx)
        except MetricNotSupportedError as e:
            if raise_on_unsupported or metric.strict:
                raise
            logging.warning(
                f"{name} metric not supported for this configuration, omitting: {e}"
            )
            continue

        if result.aggregator is not None:
            aggregators[name] = result.aggregator
        if result.time_series is not None:
            time_series_aggregators[name] = result.time_series
        if result.ensemble is not None:
            ensemble_aggregators[name] = result.ensemble

    return InferenceEvaluatorAggregator(
        aggregators=aggregators,
        time_series_aggregators=time_series_aggregators,
        coords=dataset_info.horizontal_coordinates.coords,
        n_ic_steps=n_ic_steps,
        normalize=normalize,
        save_diagnostics=save_diagnostics,
        output_dir=output_dir,
        n_ensemble_per_ic=n_ensemble_per_ic,
        ensemble_aggregators=ensemble_aggregators,
    )


[docs]@dataclasses.dataclass class InferenceEvaluatorAggregatorConfig: """ Configuration for inference evaluator aggregator. Each metric is a named field with its own typed configuration and an ``enabled`` flag. Defaults match the standard metric set: metrics that are always desired are enabled, while optional ones (``histogram``, ``video``, ``seasonal``) are disabled. Metrics whose runtime requirements are not met (e.g. ``enso_index`` on a non-lat/lon grid) are skipped with a warning when ``strict`` is ``False`` (the default for built-in metrics), or raise an error when ``strict`` is ``True`` (the default for user-enabled metrics like ``histogram``, ``video``, ``seasonal``). Parameters: mean_denorm: Global-mean time-series metrics on denormalized data. mean_norm: Global-mean time-series metrics on normalized data. step_means: Per-step snapshot metrics. Defaults to step-20 denorm and norm. ensembles: Ensemble spread metrics. Defaults to step-20. Silently skipped when ``n_ensemble <= 1``. power_spectrum: Spherical power spectrum metrics. zonal_mean: Zonal-mean image metrics. time_mean_denorm: Time-mean metrics on denormalized data. time_mean_norm: Time-mean metrics on normalized data. video: Video (animated map) metrics. Disabled by default. histogram: Distribution histogram metrics. Disabled by default. seasonal: Seasonal-mean metrics. Disabled by default. annual: Annual-mean metrics. enso_index: ENSO index metrics. enso_coefficient: ENSO regression coefficient metrics. ipo_index: Interdecadal Pacific Oscillation index metrics. monthly_reference_data: Path to monthly reference data to compare against. time_mean_reference_data: Path to reference time means to compare against. """ mean_denorm: MeanMetricConfig = dataclasses.field( default_factory=lambda: MeanMetricConfig(target="denorm") ) mean_norm: MeanMetricConfig = dataclasses.field( default_factory=lambda: MeanMetricConfig(target="norm") ) step_means: list[StepMeanMetricConfig] = dataclasses.field( default_factory=lambda: [ StepMeanMetricConfig(step=20, target="denorm"), StepMeanMetricConfig(step=20, target="norm"), ] ) ensembles: list[EnsembleMetricConfig] = dataclasses.field( default_factory=lambda: [EnsembleMetricConfig(step=20)] ) power_spectrum: PowerSpectrumMetricConfig = dataclasses.field( default_factory=PowerSpectrumMetricConfig ) zonal_mean: ZonalMeanMetricConfig = dataclasses.field( default_factory=ZonalMeanMetricConfig ) time_mean_denorm: TimeMeanMetricConfig = dataclasses.field( default_factory=lambda: TimeMeanMetricConfig(target="denorm") ) time_mean_norm: TimeMeanMetricConfig = dataclasses.field( default_factory=lambda: TimeMeanMetricConfig(target="norm") ) video: VideoMetricConfig = dataclasses.field(default_factory=VideoMetricConfig) histogram: HistogramMetricConfig = dataclasses.field( default_factory=HistogramMetricConfig ) seasonal: SeasonalMetricConfig = dataclasses.field( default_factory=SeasonalMetricConfig ) annual: AnnualMetricConfig = dataclasses.field(default_factory=AnnualMetricConfig) enso_index: EnsoIndexMetricConfig = dataclasses.field( default_factory=EnsoIndexMetricConfig ) enso_coefficient: EnsoCoefficientMetricConfig = dataclasses.field( default_factory=EnsoCoefficientMetricConfig ) ipo_index: IpoIndexMetricConfig = dataclasses.field( default_factory=IpoIndexMetricConfig ) monthly_reference_data: str | None = None time_mean_reference_data: str | None = None def __post_init__(self): if self.mean_denorm.target != "denorm": raise ValueError( f"mean_denorm.target must be 'denorm', got '{self.mean_denorm.target}'" ) if self.mean_norm.target != "norm": raise ValueError( f"mean_norm.target must be 'norm', got '{self.mean_norm.target}'" ) if self.time_mean_denorm.target != "denorm": raise ValueError( "time_mean_denorm.target must be 'denorm', " f"got '{self.time_mean_denorm.target}'" ) if self.time_mean_norm.target != "norm": raise ValueError( f"time_mean_norm.target must be 'norm', " f"got '{self.time_mean_norm.target}'" ) def _get_metrics(self) -> list[MetricConfig]: all_metrics: list[MetricConfig] = [ self.mean_denorm, self.mean_norm, *self.step_means, *self.ensembles, self.power_spectrum, self.zonal_mean, self.time_mean_denorm, self.time_mean_norm, self.video, self.histogram, self.seasonal, self.annual, self.enso_index, self.enso_coefficient, self.ipo_index, ] return [m for m in all_metrics if m.enabled] def build( self, dataset_info: DatasetInfo, n_ic_steps: int, n_forward_steps: int, initial_time: xr.DataArray, normalize: Callable[[TensorMapping], TensorDict], output_dir: str | None = None, channel_mean_names: Sequence[str] | None = None, save_diagnostics: bool = True, n_ensemble_per_ic: int = 1, enable_time_series: bool = True, ) -> "InferenceEvaluatorAggregator": return build_inference_evaluator_aggregator( metrics=self._get_metrics(), dataset_info=dataset_info, n_ic_steps=n_ic_steps, n_forward_steps=n_forward_steps, initial_time=initial_time, normalize=normalize, monthly_reference_data=self.monthly_reference_data, time_mean_reference_data=self.time_mean_reference_data, output_dir=output_dir, channel_mean_names=channel_mean_names, save_diagnostics=save_diagnostics, n_ensemble_per_ic=n_ensemble_per_ic, enable_time_series=enable_time_series, raise_on_unsupported=False, )
[docs]@dataclasses.dataclass class StepMeanEntry: """ Configuration for logging mean metrics at a particular step. Attributes: step: Number of forward steps after which to log mean metrics. For example, step=20 will log mean metrics at the 20th forward step (i.e. time index n_ic_steps + 19). name: Name to use for the logged metrics. If None, will use "mean_step_{step}". """ step: int name: str | None = None def get_name(self): return self.name or f"mean_step_{self.step}" def validate(self, n_forward_steps: int): if self.step > n_forward_steps: raise ValueError( f"Step {self.step} is " f"greater than n_forward_steps {n_forward_steps}. " "Please ensure that all steps in log_step_means are less than or " "equal to " "n_forward_steps. If your run is less than 20 steps, you must pass " "a custom log_step_means configuration to override the default " "(e.g. log_step_means: [])." )
[docs]@dataclasses.dataclass class LegacyFlagInferenceEvaluatorAggregatorConfig: """ Legacy configuration for inference evaluator aggregator using boolean flags. Deprecated: Use InferenceEvaluatorAggregatorConfig instead. """ def __post_init__(self): warnings.warn( "LegacyFlagInferenceEvaluatorAggregatorConfig is deprecated. " "Use InferenceEvaluatorAggregatorConfig instead.", DeprecationWarning, stacklevel=2, ) log_histograms: bool = False log_video: bool = False log_extended_video: bool = False log_zonal_mean_images: bool | int = 4096 log_seasonal_means: bool = False log_global_mean_time_series: bool = True log_global_mean_norm_time_series: bool = True monthly_reference_data: str | None = None time_mean_reference_data: str | None = None log_nino34_index: bool = True log_ipo_index: bool = True log_step_means: list[StepMeanEntry] = dataclasses.field( default_factory=lambda: [StepMeanEntry(step=20)] ) def _get_metrics( self, n_timesteps: int, timestep: datetime.timedelta, horizontal_coordinates: HorizontalCoordinates, ops: GriddedOperations, n_ensemble_per_ic: int = 1, ) -> list[MetricConfig]: metrics: list[MetricConfig] = [] if self.log_global_mean_time_series: metrics.append(MeanMetricConfig(target="denorm")) if self.log_global_mean_norm_time_series: metrics.append(MeanMetricConfig(target="norm")) for entry in self.log_step_means: name = entry.get_name() metrics.append( StepMeanMetricConfig(step=entry.step, name=name, target="denorm") ) metrics.append( StepMeanMetricConfig( step=entry.step, name=name + "_norm", target="norm" ) ) if n_ensemble_per_ic > 1: metrics.append(EnsembleMetricConfig(step=entry.step)) metrics.append(PowerSpectrumMetricConfig()) if self.log_zonal_mean_images: metrics.append( ZonalMeanMetricConfig(zonal_mean_max_size=self.log_zonal_mean_images) ) if self.log_video: metrics.append( VideoMetricConfig(enable_extended_videos=self.log_extended_video) ) metrics.append(TimeMeanMetricConfig(target="denorm")) metrics.append(TimeMeanMetricConfig(target="norm")) if self.log_histograms: metrics.append(HistogramMetricConfig()) if self.log_seasonal_means: metrics.append(SeasonalMetricConfig()) if n_timesteps * timestep > APPROXIMATELY_TWO_YEARS: metrics.append(AnnualMetricConfig()) if ( self.log_nino34_index and isinstance(horizontal_coordinates, LatLonCoordinates) and isinstance(ops, LatLonOperations) ): metrics.append(EnsoIndexMetricConfig()) if n_timesteps * timestep > SLIGHTLY_LESS_THAN_FIVE_YEARS: metrics.append(EnsoCoefficientMetricConfig()) if ( self.log_ipo_index and n_timesteps * timestep > APPROXIMATELY_EIGHTY_YEARS and isinstance(horizontal_coordinates, LatLonCoordinates) ): metrics.append(IpoIndexMetricConfig()) return metrics def build( self, dataset_info: DatasetInfo, n_ic_steps: int, n_forward_steps: int, initial_time: xr.DataArray, normalize: Callable[[TensorMapping], TensorDict], output_dir: str | None = None, channel_mean_names: Sequence[str] | None = None, save_diagnostics: bool = True, n_ensemble_per_ic: int = 1, enable_time_series: bool = True, ) -> "InferenceEvaluatorAggregator": n_timesteps = n_ic_steps + n_forward_steps metrics = self._get_metrics( n_timesteps=n_timesteps, timestep=dataset_info.timestep, horizontal_coordinates=dataset_info.horizontal_coordinates, ops=dataset_info.gridded_operations, n_ensemble_per_ic=n_ensemble_per_ic, ) return build_inference_evaluator_aggregator( metrics=metrics, dataset_info=dataset_info, n_ic_steps=n_ic_steps, n_forward_steps=n_forward_steps, initial_time=initial_time, normalize=normalize, monthly_reference_data=self.monthly_reference_data, time_mean_reference_data=self.time_mean_reference_data, output_dir=output_dir, channel_mean_names=channel_mean_names, save_diagnostics=save_diagnostics, n_ensemble_per_ic=n_ensemble_per_ic, enable_time_series=enable_time_series, )
class InferenceEvaluatorAggregator( InferenceAggregatorABC[PairedData | PrognosticState, PairedData] ): """ Aggregates statistics for inference comparing a generated and target series. To use, call `record_batch` on the results of each batch, then call `get_logs` to get a dictionary of statistics when you're done. """ def __init__( self, aggregators: dict[str, SubAggregator], time_series_aggregators: dict[str, TimeSeriesLogs], coords: Mapping[str, np.ndarray], n_ic_steps: int, normalize: Callable[[TensorMapping], TensorDict], save_diagnostics: bool = True, output_dir: str | None = None, n_ensemble_per_ic: int = 1, ensemble_aggregators: dict[str, SelectStepEnsembleAggregator] | None = None, ): if save_diagnostics and output_dir is None: raise ValueError("Output directory must be set to save diagnostics") self._aggregators = aggregators self._time_series_aggregators = time_series_aggregators self.n_ensemble_per_ic = n_ensemble_per_ic self._ensemble_aggregators = ensemble_aggregators or {} summary_aggregators: dict[str, SubAggregator | SelectStepEnsembleAggregator] = { name: agg for name, agg in aggregators.items() if name not in time_series_aggregators } if n_ensemble_per_ic > 1: summary_aggregators.update(self._ensemble_aggregators) self._summary_aggregators = summary_aggregators self._coords = coords self.n_ic_steps = n_ic_steps self._normalize = normalize self._save_diagnostics = save_diagnostics self._output_dir = output_dir self._log_time_series = len(time_series_aggregators) > 0 self._n_timesteps_seen = 0 @property def log_time_series(self) -> bool: return self._log_time_series @torch.no_grad() def record_batch( self, data: PairedData, ) -> InferenceLogs: if len(data.prediction) == 0: raise ValueError("No prediction values in data") if len(data.target) == 0: raise ValueError("No target values in data") target_data = data.target batch = InferenceBatchData( prediction=data.prediction, prediction_norm=self._normalize(data.prediction), target=target_data, target_norm=self._normalize(target_data), time=data.time, i_time_start=self._n_timesteps_seen, ) for aggregator in self._aggregators.values(): aggregator.record_batch(batch) if self.n_ensemble_per_ic > 1: unfolded_target_data, unfolded_prediction_data = ( data.as_ensemble_tensor_dicts(data.n_ensemble) ) unfolded_target_data_norm = unfold_ensemble_dim( TensorDict(batch.target_norm), data.n_ensemble ) unfolded_prediction_data_norm = unfold_ensemble_dim( TensorDict(batch.prediction_norm), data.n_ensemble ) for ensemble_aggregator in self._ensemble_aggregators.values(): ensemble_aggregator.record_batch( target_data=unfolded_target_data, gen_data=unfolded_prediction_data, target_data_norm=unfolded_target_data_norm, gen_data_norm=unfolded_prediction_data_norm, i_time_start=self._n_timesteps_seen, ) n_times = data.time.shape[1] logs = self._get_inference_logs_slice( step_slice=slice(self._n_timesteps_seen, self._n_timesteps_seen + n_times), ) self._n_timesteps_seen += n_times return logs def record_initial_condition( self, initial_condition: PairedData | PrognosticState, ) -> InferenceLogs: if self._n_timesteps_seen != 0: raise RuntimeError( "record_initial_condition may only be called once, " "before recording any batches" ) if isinstance(initial_condition, PairedData): target_data = initial_condition.target gen_data = initial_condition.prediction time = initial_condition.time else: batch_data = initial_condition.as_batch_data() target_data = batch_data.data gen_data = target_data time = batch_data.time n_times = time.shape[1] if n_times != self.n_ic_steps: raise ValueError( f"Expected {self.n_ic_steps} initial condition steps, but got {n_times}" ) batch = InferenceBatchData( prediction=gen_data, prediction_norm=self._normalize(gen_data), target=target_data, target_norm=self._normalize(target_data), time=time, i_time_start=0, ) for name in self._time_series_aggregators: self._aggregators[name].record_batch(batch) logs = self._get_inference_logs_slice( step_slice=slice(self._n_timesteps_seen, self._n_timesteps_seen + n_times), ) self._n_timesteps_seen = n_times return logs def get_summary_logs(self) -> InferenceLog: logs: InferenceLog = {} for name, aggregator in self._summary_aggregators.items(): logging.info(f"Getting summary logs for {name} aggregator") logs.update(aggregator.get_logs(label=name)) return logs @torch.no_grad() def _get_logs(self): """Returns logs as can be reported to WandB.""" logs: InferenceLog = {} for name, aggregator in self._aggregators.items(): logs.update(aggregator.get_logs(label=name)) if self.n_ensemble_per_ic > 1: for name, ensemble_aggregator in self._ensemble_aggregators.items(): logs.update(ensemble_aggregator.get_logs(label=name)) return logs @torch.no_grad() def _get_inference_logs_slice(self, step_slice: slice): """ Returns a subset of the time series for applicable metrics for a specific slice of as can be reported to WandB. Args: step_slice: Timestep slice to determine the time series subset. Returns: Tuple of start index and list of logs. """ logs = {} for name, aggregator in self._time_series_aggregators.items(): logs.update(aggregator.get_logs(label=name, step_slice=step_slice)) return to_inference_logs(logs) @torch.no_grad() def flush_diagnostics(self, subdir: str | None = None): if self._save_diagnostics: reduced_diagnostics = get_reduced_diagnostics( sub_aggregators=self._aggregators, coords=self._coords, ) if self._output_dir is not None: write_reduced_diagnostics( reduced_diagnostics=reduced_diagnostics, output_dir=self._output_dir, subdir=subdir, ) else: raise ValueError("Output directory not set.") def to_inference_logs( log: Mapping[str, Table | float | int], ) -> list[dict[str, float | int]]: # We have a dictionary which contains WandB tables which we will convert # to a list of dictionaries, one for each row in the tables. # Any scalar values will be reported in the last dictionary. n_rows = 0 for val in log.values(): if isinstance(val, Table): n_rows = max(n_rows, len(val.data)) logs: list[dict[str, float | int]] = [] for i in range(max(1, n_rows)): logs.append({}) for key, val in log.items(): if isinstance(val, Table): for i, row in enumerate(val.data): for j, col in enumerate(val.columns): key_without_table_name = key[: key.rfind("/")] logs[i][f"{key_without_table_name}/{col}"] = row[j] else: logs[-1][key] = val return logs def table_to_logs(table: Table) -> list[dict[str, float | int]]: """Converts a WandB table into a list of dictionaries.""" logs = [] for row in table.data: logs.append({table.columns[i]: row[i] for i in range(len(row))}) return logs
[docs]@dataclasses.dataclass class InferenceAggregatorConfig: """ Configuration for inference aggregator. Parameters: time_mean_reference_data: Path to reference time means to compare against. log_global_mean_time_series: Whether to log global mean time series metrics. """ time_mean_reference_data: str | None = None log_global_mean_time_series: bool = True def build( self, dataset_info: DatasetInfo, n_timesteps: int, output_dir: str | None = None, save_diagnostics: bool = True, ) -> "InferenceAggregator": if self.time_mean_reference_data is not None: time_means = xr.open_dataset( self.time_mean_reference_data, decode_timedelta=False, ) else: time_means = None horizontal_coordinates = dataset_info.horizontal_coordinates gridded_operations = dataset_info.gridded_operations aggregators: dict[str, SubAggregator] = {} time_series_aggregators: dict[str, TimeSeriesLogs] = {} if self.log_global_mean_time_series: mean_agg = SingleTargetMeanAggregator( gridded_operations, n_timesteps=n_timesteps, ) aggregators["mean"] = mean_agg time_series_aggregators["mean"] = mean_agg aggregators["time_mean"] = TimeMeanAggregator( gridded_operations=gridded_operations, variable_metadata=dataset_info.variable_metadata, reference_means=time_means, ) aggregators["annual"] = GlobalMeanAnnualAggregator( gridded_operations, dataset_info.timestep, dataset_info.variable_metadata, ) try: aggregators["power_spectrum"] = SphericalPowerSpectrumAggregator( gridded_operations=gridded_operations, nan_fill_fn=SmoothFloodFill(num_steps=4), report_plot=True, variable_metadata=dataset_info.variable_metadata, ) except NotImplementedError: logging.warning( "Power spectrum aggregator not implemented for this grid type, " "omitting." ) if ( isinstance(horizontal_coordinates, LatLonCoordinates) and isinstance(gridded_operations, LatLonOperations) and n_timesteps * dataset_info.timestep > APPROXIMATELY_TWO_YEARS ): nino34_region = LatLonRegion( lat_bounds=NINO34_LAT, lon_bounds=NINO34_LON, lat=horizontal_coordinates.lat, lon=horizontal_coordinates.lon, ) aggregators["enso_index"] = RegionalIndexAggregator( regional_weights=nino34_region.regional_weights, regional_mean=gridded_operations.regional_area_weighted_mean, ) return InferenceAggregator( aggregators=aggregators, time_series_aggregators=time_series_aggregators, coords=horizontal_coordinates.coords, save_diagnostics=save_diagnostics, output_dir=output_dir, )
class InferenceAggregator( InferenceAggregatorABC[ PrognosticState, PairedData, ] ): """ Aggregates statistics on a single timeseries of data. To use, call `record_batch` on the results of each batch, then call `get_logs` to get a dictionary of statistics when you're done. """ def __init__( self, aggregators: dict[str, SubAggregator], time_series_aggregators: dict[str, TimeSeriesLogs], coords: Mapping[str, np.ndarray], save_diagnostics: bool = True, output_dir: str | None = None, ): if save_diagnostics and output_dir is None: raise ValueError("Output directory must be set to save diagnostics") self._aggregators = aggregators self._time_series_aggregators = time_series_aggregators self._summary_aggregators = { name: agg for name, agg in aggregators.items() if name not in time_series_aggregators } self._coords = coords self._save_diagnostics = save_diagnostics self._output_dir = output_dir self._log_time_series = len(time_series_aggregators) > 0 self._n_timesteps_seen = 0 @property def log_time_series(self) -> bool: return self._log_time_series @torch.no_grad() def record_batch(self, data: PairedData) -> InferenceLogs: """ Record a batch of data. Args: data: Batch of data to record. """ if len(data.prediction) == 0: raise ValueError("data is empty") batch = InferenceBatchData( prediction=data.prediction, time=data.time, i_time_start=self._n_timesteps_seen, ) for aggregator in self._aggregators.values(): aggregator.record_batch(batch) n_times = data.time.shape[1] logs = self._get_inference_logs_slice( step_slice=slice(self._n_timesteps_seen, self._n_timesteps_seen + n_times), ) self._n_timesteps_seen += n_times return logs def record_initial_condition( self, initial_condition: PrognosticState, ) -> InferenceLogs: if self._n_timesteps_seen != 0: raise RuntimeError( "record_initial_condition may only be called once, " "before recording any batches" ) batch_data = initial_condition.as_batch_data() batch = InferenceBatchData( prediction=batch_data.data, time=batch_data.time, i_time_start=0, ) for name in self._time_series_aggregators: self._aggregators[name].record_batch(batch) n_times = batch_data.time.shape[1] logs = self._get_inference_logs_slice( step_slice=slice(self._n_timesteps_seen, self._n_timesteps_seen + n_times), ) self._n_timesteps_seen = n_times return logs def get_summary_logs(self) -> InferenceLog: logs = {} for name, aggregator in self._summary_aggregators.items(): logging.info(f"Getting summary logs for {name} aggregator") logs.update(aggregator.get_logs(label=name)) return logs @torch.no_grad() def _get_logs(self): """Returns logs as can be reported to WandB.""" logs = {} for name, aggregator in self._aggregators.items(): logs.update(aggregator.get_logs(label=name)) return logs @torch.no_grad() def _get_inference_logs(self) -> list[dict[str, float | int]]: """ Returns a list of logs to report to WandB. This is done because in inference, we use the wandb step as the time step, meaning we need to re-organize the logged data from tables into a list of dictionaries. """ return to_inference_logs(self._get_logs()) @torch.no_grad() def _get_inference_logs_slice(self, step_slice: slice): """ Returns a subset of the time series for applicable metrics for a specific slice of as can be reported to WandB. Args: step_slice: Timestep slice to determine the time series subset. """ logs = {} for name, aggregator in self._time_series_aggregators.items(): logs.update(aggregator.get_logs(label=name, step_slice=step_slice)) return to_inference_logs(logs) @torch.no_grad() def flush_diagnostics(self, subdir: str | None = None): if self._save_diagnostics: reduced_diagnostics = get_reduced_diagnostics( sub_aggregators=self._aggregators, coords=self._coords, ) if self._output_dir is not None: write_reduced_diagnostics( reduced_diagnostics=reduced_diagnostics, output_dir=self._output_dir, subdir=subdir, ) else: raise ValueError("Output directory is not set.")