Source code for fme.ace.aggregator.one_step.main

import dataclasses
from collections.abc import Sequence

import torch

from fme.ace.aggregator.one_step.deterministic import (
    DeterministicTrainOutput,
    OneStepDeterministicAggregator,
)
from fme.ace.aggregator.one_step.ensemble import get_one_step_ensemble_aggregator
from fme.ace.stepper import TrainOutput
from fme.core.dataset_info import DatasetInfo
from fme.core.generics.aggregator import AggregatorABC
from fme.core.tensors import fold_ensemble_dim, fold_sized_ensemble_dim
from fme.core.typing_ import TensorMapping


class OneStepAggregator(AggregatorABC[TrainOutput]):
    """
    Aggregates statistics for the first timestep.

    To use, call `record_batch` on the results of each batch, then call
    `get_logs` to get a dictionary of statistics when you're done.
    """

    def __init__(
        self,
        dataset_info: DatasetInfo,
        save_diagnostics: bool = True,
        output_dir: str | None = None,
        loss_scaling: TensorMapping | None = None,
        log_snapshots: bool = True,
        log_mean_maps: bool = True,
        channel_mean_names: Sequence[str] | None = None,
    ):
        """
        Args:
            dataset_info: Dataset coordinates and metadata.
            save_diagnostics: Whether to save diagnostics.
            output_dir: Directory to write diagnostics to.
            loss_scaling: Dictionary of variables and their scaling factors
                used in loss computation.
            log_snapshots: Whether to include snapshots in diagnostics.
            log_mean_maps: Whether to include mean maps in diagnostics.
            channel_mean_names: Names to include in channel-mean metrics.
        """
        self._deterministic_aggregator = OneStepDeterministicAggregator(
            dataset_info=dataset_info,
            save_diagnostics=save_diagnostics,
            output_dir=output_dir,
            loss_scaling=loss_scaling,
            log_snapshots=log_snapshots,
            log_mean_maps=log_mean_maps,
            channel_mean_names=channel_mean_names,
        )
        self._ensemble_aggregator = get_one_step_ensemble_aggregator(
            gridded_operations=dataset_info.gridded_operations,
            log_mean_maps=log_mean_maps,
            target_time=1,
            metadata=dataset_info.variable_metadata,
        )
        self._ensemble_recorded = False

    @torch.no_grad()
    def record_batch(
        self,
        batch: TrainOutput,
    ):
        folded_gen_data, n_ensemble = fold_ensemble_dim(batch.gen_data)
        folded_target_data = fold_sized_ensemble_dim(batch.target_data, n_ensemble)
        self._deterministic_aggregator.record_batch(
            DeterministicTrainOutput(
                metrics=batch.metrics,
                gen_data=folded_gen_data,
                target_data=folded_target_data,
                normalize=batch.normalize,
            )
        )
        if n_ensemble > 1:
            self._ensemble_aggregator.record_batch(
                target_data=batch.target_data,
                gen_data=batch.gen_data,
                i_time_start=0,
            )
            self._ensemble_recorded = True

    @torch.no_grad()
    def get_logs(self, label: str):
        """
        Returns logs as can be reported to WandB.

        Args:
            label: Label to prepend to all log keys.
        """
        deterministic_logs = self._deterministic_aggregator.get_logs(label)
        if self._ensemble_recorded:
            stochastic_logs = self._ensemble_aggregator.get_logs(label)
            if len(set(deterministic_logs.keys()) & set(stochastic_logs.keys())) > 0:
                raise ValueError(
                    "Stochastic and deterministic logs have overlapping keys, "
                    f"stochastic logs: {stochastic_logs}, "
                    f"deterministic logs: {deterministic_logs}"
                )
            return {**deterministic_logs, **stochastic_logs}
        else:
            return deterministic_logs

    @torch.no_grad()
    def flush_diagnostics(self, subdir: str | None = None):
        self._deterministic_aggregator.flush_diagnostics(subdir)


[docs]@dataclasses.dataclass class OneStepAggregatorConfig: """ Configuration for the validation OneStepAggregator. Arguments: log_snapshots: Whether to log snapshot images. log_mean_maps: Whether to log mean map images. """ log_snapshots: bool = True log_mean_maps: bool = True def build( self, dataset_info: DatasetInfo, save_diagnostics: bool = True, output_dir: str | None = None, loss_scaling: TensorMapping | None = None, channel_mean_names: Sequence[str] | None = None, ): return OneStepAggregator( dataset_info=dataset_info, save_diagnostics=save_diagnostics, output_dir=output_dir, loss_scaling=loss_scaling, log_snapshots=self.log_snapshots, log_mean_maps=self.log_mean_maps, channel_mean_names=channel_mean_names, )