Source code for fme.ace.aggregator.inference.main

import dataclasses
import datetime
import logging
from collections.abc import Callable, Mapping, Sequence
from typing import Protocol

import torch
import xarray as xr

from fme.ace.data_loading.batch_data import PairedData, PrognosticState
from fme.core.coordinates import LatLonCoordinates
from fme.core.dataset_info import DatasetInfo
from fme.core.diagnostics import get_reduced_diagnostics, write_reduced_diagnostics
from fme.core.generics.aggregator import (
    InferenceAggregatorABC,
    InferenceLog,
    InferenceLogs,
)
from fme.core.gridded_ops import LatLonOperations
from fme.core.typing_ import TensorDict, TensorMapping
from fme.core.wandb import Table, WandB

from ..one_step.reduced import MeanAggregator as OneStepMeanAggregator
from .annual import GlobalMeanAnnualAggregator, PairedGlobalMeanAnnualAggregator
from .enso import (
    EnsoCoefficientEvaluatorAggregator,
    LatLonRegion,
    PairedRegionalIndexAggregator,
    RegionalIndexAggregator,
)
from .histogram import HistogramAggregator
from .reduced import MeanAggregator, SingleTargetMeanAggregator
from .seasonal import SeasonalAggregator
from .spectrum import PairedSphericalPowerSpectrumAggregator
from .time_mean import TimeMeanAggregator, TimeMeanEvaluatorAggregator
from .video import VideoAggregator
from .zonal_mean import ZonalMeanAggregator

wandb = WandB.get_instance()
APPROXIMATELY_TWO_YEARS = datetime.timedelta(days=730)
SLIGHTLY_LESS_THAN_FIVE_YEARS = datetime.timedelta(days=1800)
NINO34_LAT = (-5, 5)
NINO34_LON = (190, 240)


class _Aggregator(Protocol):
    @torch.no_grad()
    def record_batch(
        self,
        data: TensorMapping,
    ): ...

    @torch.no_grad()
    def get_logs(self, label: str): ...

    @torch.no_grad()
    def get_dataset(self) -> xr.Dataset: ...


class _EvaluatorAggregator(Protocol):
    @torch.no_grad()
    def record_batch(
        self,
        target_data: TensorMapping,
        gen_data: TensorMapping,
        target_data_norm: TensorMapping,
        gen_data_norm: TensorMapping,
        i_time_start: int = 0,
    ): ...

    @torch.no_grad()
    def get_logs(self, label: str): ...

    @torch.no_grad()
    def get_dataset(self) -> xr.Dataset: ...


class _TimeDependentAggregator(Protocol):
    @torch.no_grad()
    def record_batch(
        self,
        time: xr.DataArray,
        data: TensorMapping,
    ): ...

    @torch.no_grad()
    def get_logs(self, label: str): ...

    @torch.no_grad()
    def get_dataset(self) -> xr.Dataset: ...


class _TimeDependentEvaluatorAggregator(Protocol):
    @torch.no_grad()
    def record_batch(
        self,
        time: xr.DataArray,
        target_data: TensorMapping,
        gen_data: TensorMapping,
    ): ...

    @torch.no_grad()
    def get_logs(self, label: str): ...

    @torch.no_grad()
    def get_dataset(self) -> xr.Dataset: ...


[docs]@dataclasses.dataclass class InferenceEvaluatorAggregatorConfig: """ Configuration for inference evaluator aggregator. Parameters: log_histograms: Whether to log histograms of the targets and predictions. log_video: Whether to log videos of the state evolution. log_extended_video: Whether to log wandb videos of the predictions with statistical metrics, only done if log_video is True. log_zonal_mean_images: Whether to log zonal-mean images (hovmollers) with a time dimension. If greater than 0 zonal-mean images will be logged. The value of log_zonal_mean_images is default to 4096 (2**12) and can be set with a maximum of 32768 (2**15) (limited by matplotlib). log_seasonal_means: Whether to log seasonal mean metrics and images. log_global_mean_time_series: Whether to log global mean time series metrics. log_global_mean_norm_time_series: Whether to log the normalized global mean time series metrics. monthly_reference_data: Path to monthly reference data to compare against. time_mean_reference_data: Path to reference time means to compare against. """ log_histograms: bool = False log_video: bool = False log_extended_video: bool = False log_zonal_mean_images: bool | int = 4096 log_seasonal_means: bool = False log_global_mean_time_series: bool = True log_global_mean_norm_time_series: bool = True monthly_reference_data: str | None = None time_mean_reference_data: str | None = None log_nino34_index: bool = True def build( self, dataset_info: DatasetInfo, n_timesteps: int, initial_time: xr.DataArray, normalize: Callable[[TensorMapping], TensorDict], output_dir: str | None = None, record_step_20: bool = False, channel_mean_names: Sequence[str] | None = None, save_diagnostics: bool = True, ) -> "InferenceEvaluatorAggregator": if save_diagnostics and output_dir is None: raise ValueError("Output directory must be set to save diagnostics.") if self.monthly_reference_data is None: monthly_reference_data = None else: monthly_reference_data = xr.open_dataset( self.monthly_reference_data, decode_timedelta=False ) if self.time_mean_reference_data is None: time_mean = None else: time_mean = xr.open_dataset( self.time_mean_reference_data, decode_timedelta=False ) return InferenceEvaluatorAggregator( dataset_info=dataset_info, n_timesteps=n_timesteps, initial_time=initial_time, output_dir=output_dir, log_histograms=self.log_histograms, log_video=self.log_video, enable_extended_videos=self.log_extended_video, log_zonal_mean_images=self.log_zonal_mean_images, log_seasonal_means=self.log_seasonal_means, log_global_mean_time_series=self.log_global_mean_time_series, log_global_mean_norm_time_series=self.log_global_mean_norm_time_series, monthly_reference_data=monthly_reference_data, time_mean_reference_data=time_mean, record_step_20=record_step_20, channel_mean_names=channel_mean_names, log_nino34_index=self.log_nino34_index, normalize=normalize, save_diagnostics=save_diagnostics, )
class InferenceEvaluatorAggregator( InferenceAggregatorABC[PairedData | PrognosticState, PairedData] ): """ Aggregates statistics for inference comparing a generated and target series. To use, call `record_batch` on the results of each batch, then call `get_logs` to get a dictionary of statistics when you're done. """ def __init__( self, dataset_info: DatasetInfo, n_timesteps: int, initial_time: xr.DataArray, normalize: Callable[[TensorMapping], TensorDict], log_zonal_mean_images: bool | int, output_dir: str | None = None, record_step_20: bool = False, log_video: bool = False, enable_extended_videos: bool = False, log_seasonal_means: bool = False, log_global_mean_time_series: bool = True, log_global_mean_norm_time_series: bool = True, monthly_reference_data: xr.Dataset | None = None, log_histograms: bool = False, time_mean_reference_data: xr.Dataset | None = None, channel_mean_names: Sequence[str] | None = None, log_nino34_index: bool = True, save_diagnostics: bool = True, ): """ Args: dataset_info: Dataset coordinates and metadata. n_timesteps: Number of timesteps of inference that will be run. initial_time: Initial time for each sample. output_dir: Directory to save diagnostic output. normalize: Normalization function to use. log_zonal_mean_images: Whether to log zonal-mean images (hovmollers) with a time dimension. record_step_20: Whether to record the mean of the 20th steps. log_video: Whether to log videos of the state evolution. enable_extended_videos: Whether to log videos of statistical metrics of state evolution log_seasonal_means: Whether to log seasonal means metrics and images. log_global_mean_time_series: Whether to log global mean time series metrics. log_global_mean_norm_time_series: Whether to log the normalized global mean time series metrics. monthly_reference_data: Reference monthly data for computing target stats. log_histograms: Whether to aggregate histograms. data_grid: The grid type of the data, used for spherical power spectrum. time_mean_reference_data: Reference time means for computing bias stats. channel_mean_names: Names over which to compute channel means. If not provided, all available variables will be used. log_nino34_index: Whether to log the Nino34 index. save_diagnostics: Whether to save reduced diagnostics to disk. """ if save_diagnostics and output_dir is None: raise ValueError("Output directory must be set to save diagnostics") self._channel_mean_names = channel_mean_names self._aggregators: dict[str, _EvaluatorAggregator] = {} self._time_dependent_aggregators: dict[ str, _TimeDependentEvaluatorAggregator ] = {} self._save_diagnostics = save_diagnostics self._output_dir = output_dir timestep = dataset_info.timestep horizontal_coordinates = dataset_info.horizontal_coordinates self._coords = horizontal_coordinates.coords ops = dataset_info.gridded_operations self._log_time_series = ( log_global_mean_time_series or log_global_mean_norm_time_series ) if log_global_mean_time_series: self._aggregators["mean"] = MeanAggregator( ops, target="denorm", n_timesteps=n_timesteps, variable_metadata=dataset_info.variable_metadata, ) if log_global_mean_norm_time_series: self._aggregators["mean_norm"] = MeanAggregator( ops, target="norm", n_timesteps=n_timesteps, variable_metadata=dataset_info.variable_metadata, ) self._record_step_20 = record_step_20 if record_step_20: self._aggregators["mean_step_20"] = OneStepMeanAggregator( ops, target_time=20, target="denorm", ) self._aggregators["mean_step_20_norm"] = OneStepMeanAggregator( ops, target_time=20, target="norm", include_bias=False, include_grad_mag_percent_diff=False, channel_mean_names=self._channel_mean_names, ) try: self._aggregators["power_spectrum"] = ( PairedSphericalPowerSpectrumAggregator( gridded_operations=ops, report_plot=True, ) ) except NotImplementedError: logging.warning( "Power spectrum aggregator not implemented for this grid type, " "omitting." ) if log_zonal_mean_images: if ops.zonal_mean is None: logging.warning( "Zonal mean aggregator not implemented for this grid type, " "omitting." ) else: self._aggregators["zonal_mean"] = ZonalMeanAggregator( zonal_mean=ops.zonal_mean, n_timesteps=n_timesteps, variable_metadata=dataset_info.variable_metadata, zonal_mean_max_size=log_zonal_mean_images, ) if isinstance(horizontal_coordinates, LatLonCoordinates): if log_video: self._aggregators["video"] = VideoAggregator( n_timesteps=n_timesteps, enable_extended_videos=enable_extended_videos, variable_metadata=dataset_info.variable_metadata, ) self._aggregators["time_mean"] = TimeMeanEvaluatorAggregator( ops, horizontal_dims=horizontal_coordinates.dims, variable_metadata=dataset_info.variable_metadata, reference_means=time_mean_reference_data, ) self._aggregators["time_mean_norm"] = TimeMeanEvaluatorAggregator( ops, horizontal_dims=horizontal_coordinates.dims, target="norm", variable_metadata=dataset_info.variable_metadata, channel_mean_names=self._channel_mean_names, ) if log_histograms: self._aggregators["histogram"] = HistogramAggregator() if log_seasonal_means: self._time_dependent_aggregators["seasonal"] = SeasonalAggregator( ops=ops, variable_metadata=dataset_info.variable_metadata, ) if n_timesteps * timestep > APPROXIMATELY_TWO_YEARS: self._time_dependent_aggregators["annual"] = ( PairedGlobalMeanAnnualAggregator( ops=ops, timestep=timestep, variable_metadata=dataset_info.variable_metadata, monthly_reference_data=monthly_reference_data, ) ) if ( isinstance(horizontal_coordinates, LatLonCoordinates) and isinstance(ops, LatLonOperations) and log_nino34_index ): nino34_region = LatLonRegion( lat_bounds=NINO34_LAT, lon_bounds=NINO34_LON, lat=horizontal_coordinates.lat, lon=horizontal_coordinates.lon, ) self._time_dependent_aggregators["enso_index"] = ( PairedRegionalIndexAggregator( target_aggregator=RegionalIndexAggregator( regional_weights=nino34_region.regional_weights, regional_mean=ops.regional_area_weighted_mean, ), prediction_aggregator=RegionalIndexAggregator( regional_weights=nino34_region.regional_weights, regional_mean=ops.regional_area_weighted_mean, ), ) ) if n_timesteps * timestep > SLIGHTLY_LESS_THAN_FIVE_YEARS: self._time_dependent_aggregators["enso_coefficient"] = ( EnsoCoefficientEvaluatorAggregator( initial_time, n_timesteps - 1, timestep, gridded_operations=ops, variable_metadata=dataset_info.variable_metadata, ) ) self._summary_aggregators = { name: agg for name, agg in list(self._aggregators.items()) + list(self._time_dependent_aggregators.items()) if name not in ["mean", "mean_norm"] } self._n_timesteps_seen = 0 self._normalize = normalize @property def log_time_series(self) -> bool: return self._log_time_series @torch.no_grad() def record_batch( self, data: PairedData, ) -> InferenceLogs: if len(data.prediction) == 0: raise ValueError("No prediction values in data") if len(data.target) == 0: raise ValueError("No target values in data") target_data = data.target target_data_norm = self._normalize(target_data) gen_data_norm = self._normalize(data.prediction) for aggregator in self._aggregators.values(): aggregator.record_batch( target_data=target_data, gen_data=data.prediction, target_data_norm=target_data_norm, gen_data_norm=gen_data_norm, i_time_start=self._n_timesteps_seen, ) for time_dependent_aggregator in self._time_dependent_aggregators.values(): time_dependent_aggregator.record_batch( time=data.time, target_data=target_data, gen_data=data.prediction, ) n_times = data.time.shape[1] logs = self._get_inference_logs_slice( step_slice=slice(self._n_timesteps_seen, self._n_timesteps_seen + n_times), ) self._n_timesteps_seen += n_times return logs def record_initial_condition( self, initial_condition: PairedData | PrognosticState, ) -> InferenceLogs: if self._n_timesteps_seen != 0: raise RuntimeError( "record_initial_condition may only be called once, " "before recording any batches" ) if isinstance(initial_condition, PairedData): target_data = initial_condition.target target_data_norm = self._normalize(target_data) gen_data = initial_condition.prediction gen_data_norm = self._normalize(gen_data) n_times = initial_condition.time.shape[1] else: batch_data = initial_condition.as_batch_data() target_data = batch_data.data target_data_norm = self._normalize(target_data) gen_data = target_data gen_data_norm = target_data_norm n_times = batch_data.time.shape[1] for aggregator_name in ["mean", "mean_norm"]: aggregator = self._aggregators.get(aggregator_name) if aggregator is not None: aggregator.record_batch( target_data=target_data, gen_data=gen_data, target_data_norm=target_data_norm, gen_data_norm=gen_data_norm, i_time_start=0, ) logs = self._get_inference_logs_slice( step_slice=slice(self._n_timesteps_seen, self._n_timesteps_seen + n_times), ) self._n_timesteps_seen = n_times return logs def get_summary_logs(self) -> InferenceLog: logs = {} for name, aggregator in self._summary_aggregators.items(): logging.info(f"Getting summary logs for {name} aggregator") logs.update(aggregator.get_logs(label=name)) if self._record_step_20: # we don't provide it so these are NaN always logs.pop("mean_step_20/loss") logs.pop("mean_step_20_norm/loss") return logs @torch.no_grad() def _get_logs(self): """ Returns logs as can be reported to WandB. """ logs = {} for name, aggregator in self._aggregators.items(): logs.update(aggregator.get_logs(label=name)) for name, time_dependent_aggregator in self._time_dependent_aggregators.items(): logs.update(time_dependent_aggregator.get_logs(label=name)) return logs @torch.no_grad() def _get_inference_logs_slice(self, step_slice: slice): """ Returns a subset of the time series for applicable metrics for a specific slice of as can be reported to WandB. Args: step_slice: Timestep slice to determine the time series subset. Returns: Tuple of start index and list of logs. """ logs = {} for name, aggregator in self._aggregators.items(): if isinstance(aggregator, MeanAggregator): logs.update(aggregator.get_logs(label=name, step_slice=step_slice)) return to_inference_logs(logs) @torch.no_grad() def flush_diagnostics(self, subdir: str | None = None): if self._save_diagnostics: reduced_diagnostics = get_reduced_diagnostics( sub_aggregators=(self._aggregators | self._time_dependent_aggregators), coords=self._coords, ) if self._output_dir is not None: write_reduced_diagnostics( reduced_diagnostics=reduced_diagnostics, output_dir=self._output_dir, subdir=subdir, ) else: raise ValueError("Output directory not set.") def to_inference_logs( log: Mapping[str, Table | float | int], ) -> list[dict[str, float | int]]: # we have a dictionary which contains WandB tables # which we will convert to a list of dictionaries, one for each # row in the tables. Any scalar values will be reported in the last # dictionary. n_rows = 0 for val in log.values(): if isinstance(val, Table): n_rows = max(n_rows, len(val.data)) logs: list[dict[str, float | int]] = [] for i in range(max(1, n_rows)): # need at least one for non-series values logs.append({}) for key, val in log.items(): if isinstance(val, Table): for i, row in enumerate(val.data): for j, col in enumerate(val.columns): key_without_table_name = key[: key.rfind("/")] logs[i][f"{key_without_table_name}/{col}"] = row[j] else: logs[-1][key] = val return logs def table_to_logs(table: Table) -> list[dict[str, float | int]]: """ Converts a WandB table into a list of dictionaries. """ logs = [] for row in table.data: logs.append({table.columns[i]: row[i] for i in range(len(row))}) return logs
[docs]@dataclasses.dataclass class InferenceAggregatorConfig: """ Configuration for inference aggregator. Parameters: time_mean_reference_data: Path to reference time means to compare against. log_global_mean_time_series: Whether to log global mean time series metrics. """ time_mean_reference_data: str | None = None log_global_mean_time_series: bool = True def build( self, dataset_info: DatasetInfo, n_timesteps: int, output_dir: str, ) -> "InferenceAggregator": if self.time_mean_reference_data is not None: time_means = xr.open_dataset( self.time_mean_reference_data, decode_timedelta=False, ) else: time_means = None return InferenceAggregator( dataset_info=dataset_info, n_timesteps=n_timesteps, output_dir=output_dir, time_mean_reference_data=time_means, log_global_mean_time_series=self.log_global_mean_time_series, )
class InferenceAggregator( InferenceAggregatorABC[ PrognosticState, PairedData, ] ): """ Aggregates statistics on a single timeseries of data. To use, call `record_batch` on the results of each batch, then call `get_logs` to get a dictionary of statistics when you're done. """ def __init__( self, dataset_info: DatasetInfo, n_timesteps: int, save_diagnostics: bool = True, output_dir: str | None = None, time_mean_reference_data: xr.Dataset | None = None, log_global_mean_time_series: bool = True, ): """ Args: dataset_info: The coordinates of the dataset. n_timesteps: Number of timesteps in the model. save_diagnostics: Whether to save diagnostics. output_dir: Directory to save diagnostic output. time_mean_reference_data: Reference time means for computing bias stats. log_global_mean_time_series: Whether to log global mean time series metrics. """ if save_diagnostics and output_dir is None: raise ValueError("Output directory must be set to save diagnostics") self._log_time_series = log_global_mean_time_series horizontal_coordinates = dataset_info.horizontal_coordinates self._coords = horizontal_coordinates.coords self._save_diagnostics = save_diagnostics self._output_dir = output_dir aggregators: dict[str, _Aggregator] = {} gridded_operations = dataset_info.gridded_operations if log_global_mean_time_series: aggregators["mean"] = SingleTargetMeanAggregator( gridded_operations, n_timesteps=n_timesteps, ) aggregators["time_mean"] = TimeMeanAggregator( gridded_operations=gridded_operations, variable_metadata=dataset_info.variable_metadata, reference_means=time_mean_reference_data, ) aggregators["annual"] = GlobalMeanAnnualAggregator( gridded_operations, dataset_info.timestep, dataset_info.variable_metadata, ) if ( isinstance(horizontal_coordinates, LatLonCoordinates) and isinstance(gridded_operations, LatLonOperations) and n_timesteps * dataset_info.timestep > APPROXIMATELY_TWO_YEARS ): nino34_region = LatLonRegion( lat_bounds=NINO34_LAT, lon_bounds=NINO34_LON, lat=horizontal_coordinates.lat, lon=horizontal_coordinates.lon, ) aggregators["enso_index"] = RegionalIndexAggregator( regional_weights=nino34_region.regional_weights, regional_mean=gridded_operations.regional_area_weighted_mean, ) self._aggregators = aggregators self._summary_aggregators = { name: aggregators[name] for name in ["time_mean", "annual", "enso_index"] if name in aggregators } self._time_dependent_aggregator_names = ["annual", "enso_index"] self._n_timesteps_seen = 0 @property def log_time_series(self) -> bool: return self._log_time_series @torch.no_grad() def record_batch(self, data: PairedData) -> InferenceLogs: """ Record a batch of data. Args: data: Batch of data to record. """ if len(data.prediction) == 0: raise ValueError("data is empty") for name in self._aggregators: if name in self._time_dependent_aggregator_names: self._aggregators[name].record_batch( time=data.time, data=data.prediction ) else: self._aggregators[name].record_batch( data=data.prediction, i_time_start=self._n_timesteps_seen, ) n_times = data.time.shape[1] logs = self._get_inference_logs_slice( step_slice=slice(self._n_timesteps_seen, self._n_timesteps_seen + n_times), ) self._n_timesteps_seen += n_times return logs def record_initial_condition( self, initial_condition: PrognosticState, ) -> InferenceLogs: if self._n_timesteps_seen != 0: raise RuntimeError( "record_initial_condition may only be called once, " "before recording any batches" ) batch_data = initial_condition.as_batch_data() if "mean" in self._aggregators: self._aggregators["mean"].record_batch( data=batch_data.data, i_time_start=0, ) n_times = batch_data.time.shape[1] logs = self._get_inference_logs_slice( step_slice=slice(self._n_timesteps_seen, self._n_timesteps_seen + n_times), ) self._n_timesteps_seen = n_times return logs def get_summary_logs(self) -> InferenceLog: logs = {} for name, aggregator in self._summary_aggregators.items(): logging.info(f"Getting summary logs for {name} aggregator") logs.update(aggregator.get_logs(label=name)) return logs @torch.no_grad() def _get_logs(self): """ Returns logs as can be reported to WandB. """ logs = {} for name, aggregator in self._aggregators.items(): logs.update(aggregator.get_logs(label=name)) return logs @torch.no_grad() def _get_inference_logs(self) -> list[dict[str, float | int]]: """ Returns a list of logs to report to WandB. This is done because in inference, we use the wandb step as the time step, meaning we need to re-organize the logged data from tables into a list of dictionaries. """ return to_inference_logs(self._get_logs()) @torch.no_grad() def _get_inference_logs_slice(self, step_slice: slice): """ Returns a subset of the time series for applicable metrics for a specific slice of as can be reported to WandB. Args: step_slice: Timestep slice to determine the time series subset. """ logs = {} for name, aggregator in self._aggregators.items(): if isinstance(aggregator, SingleTargetMeanAggregator): logs.update(aggregator.get_logs(label=name, step_slice=step_slice)) return to_inference_logs(logs) @torch.no_grad() def flush_diagnostics(self, subdir: str | None = None): if self._save_diagnostics: reduced_diagnostics = get_reduced_diagnostics( sub_aggregators=self._aggregators, coords=self._coords, ) if self._output_dir is not None: write_reduced_diagnostics( reduced_diagnostics=reduced_diagnostics, output_dir=self._output_dir, subdir=subdir, ) else: raise ValueError("Output directory is not set.")