Source code for fme.ace.inference.data_writer.main

import dataclasses
import datetime
import warnings
from collections.abc import Mapping, Sequence
from pathlib import Path

import numpy as np
import torch
import xarray as xr

from fme.ace.data_loading.batch_data import BatchData, PairedData, PrognosticState
from fme.core.dataset.data_typing import VariableMetadata
from fme.core.generics.writer import WriterABC

from .dataset_metadata import DatasetMetadata
from .histograms import PairedHistogramDataWriter
from .monthly import MonthlyDataWriter, PairedMonthlyDataWriter, months_for_timesteps
from .raw import PairedRawDataWriter, RawDataWriter
from .time_coarsen import PairedTimeCoarsen, TimeCoarsen, TimeCoarsenConfig
from .video import PairedVideoDataWriter

PairedSubwriter = (
    PairedRawDataWriter
    | PairedVideoDataWriter
    | PairedHistogramDataWriter
    | PairedTimeCoarsen
    | PairedMonthlyDataWriter
)

Subwriter = MonthlyDataWriter | RawDataWriter | TimeCoarsen


[docs]@dataclasses.dataclass class DataWriterConfig: """ Configuration for inference data writers. Parameters: log_extended_video_netcdfs: Whether to enable writing of netCDF files containing video metrics. save_prediction_files: Whether to enable writing of netCDF files containing the predictions and target values. save_monthly_files: Whether to enable writing of netCDF files containing the monthly predictions and target values. names: Names of variables to save in the prediction, histogram, and monthly netCDF files. save_histogram_files: Enable writing of netCDF files containing histograms. time_coarsen: Configuration for time coarsening of written outputs. """ log_extended_video_netcdfs: bool = False save_prediction_files: bool = True save_monthly_files: bool = True names: Sequence[str] | None = None save_histogram_files: bool = False time_coarsen: TimeCoarsenConfig | None = None def __post_init__(self): if ( not any( [ self.save_prediction_files, self.save_monthly_files, self.save_histogram_files, ] ) and self.names is not None ): warnings.warn( "names provided but all options to " "save subsettable output files are False." ) def build_paired( self, experiment_dir: str, n_initial_conditions: int, n_timesteps: int, timestep: datetime.timedelta, variable_metadata: Mapping[str, VariableMetadata], coords: Mapping[str, np.ndarray], dataset_metadata: DatasetMetadata, ) -> "PairedDataWriter": return PairedDataWriter( path=experiment_dir, n_initial_conditions=n_initial_conditions, n_timesteps=n_timesteps, timestep=timestep, variable_metadata=variable_metadata, coords=coords, enable_prediction_netcdfs=self.save_prediction_files, enable_monthly_netcdfs=self.save_monthly_files, enable_video_netcdfs=self.log_extended_video_netcdfs, save_names=self.names, enable_histogram_netcdfs=self.save_histogram_files, time_coarsen=self.time_coarsen, dataset_metadata=dataset_metadata, ) def build( self, experiment_dir: str, n_initial_conditions: int, n_timesteps: int, timestep: datetime.timedelta, variable_metadata: Mapping[str, VariableMetadata], coords: Mapping[str, np.ndarray], dataset_metadata: DatasetMetadata, ) -> "DataWriter": if self.save_histogram_files: raise NotImplementedError( "Saving histograms is not supported for prediction-only data writers. " "Make sure to set `save_histogram_files=False`." ) if self.log_extended_video_netcdfs: raise NotImplementedError( "Saving 'extended video' netCDFs is not supported for prediction-only " "data writers. Make sure to set `log_extended_video_netcdfs=False`." ) return DataWriter( path=experiment_dir, n_initial_conditions=n_initial_conditions, n_timesteps=n_timesteps, variable_metadata=variable_metadata, coords=coords, timestep=timestep, enable_prediction_netcdfs=self.save_prediction_files, enable_monthly_netcdfs=self.save_monthly_files, save_names=self.names, time_coarsen=self.time_coarsen, dataset_metadata=dataset_metadata, )
class PairedDataWriter(WriterABC[PrognosticState, PairedData]): def __init__( self, path: str, n_initial_conditions: int, n_timesteps: int, variable_metadata: Mapping[str, VariableMetadata], coords: Mapping[str, np.ndarray], timestep: datetime.timedelta, enable_prediction_netcdfs: bool, enable_monthly_netcdfs: bool, enable_video_netcdfs: bool, save_names: Sequence[str] | None, enable_histogram_netcdfs: bool, dataset_metadata: DatasetMetadata, time_coarsen: TimeCoarsenConfig | None = None, ): """ Args: path: Path to write netCDF file(s). n_initial_conditions: Number of ICs/ensemble members to write to the file. n_timesteps: Number of timesteps to write to the file. variable_metadata: Metadata for each variable to be written to the file. coords: Coordinate data to be written to the file. timestep: Timestep of the model. enable_prediction_netcdfs: Whether to enable writing of netCDF files containing the predictions and target values. enable_monthly_netcdfs: Whether to enable writing of netCDF files containing the monthly predictions and target values. enable_video_netcdfs: Whether to enable writing of netCDF files containing video metrics. save_names: Names of variables to save in the prediction, histogram, and monthly netCDF files. enable_histogram_netcdfs: Whether to write netCDFs with histogram data. dataset_metadata: Metadata for the dataset. time_coarsen: Configuration for time coarsening of written outputs. """ self._writers: list[PairedSubwriter] = [] self.path = path self.coords = coords self.variable_metadata = variable_metadata self.dataset_metadata = dataset_metadata if time_coarsen is not None: n_coarsened_timesteps = time_coarsen.n_coarsened_timesteps(n_timesteps) else: n_coarsened_timesteps = n_timesteps def _time_coarsen_builder(data_writer: PairedSubwriter) -> PairedSubwriter: if time_coarsen is not None: return time_coarsen.build_paired(data_writer) return data_writer if enable_prediction_netcdfs: self._writers.append( _time_coarsen_builder( PairedRawDataWriter( path=path, n_initial_conditions=n_initial_conditions, save_names=save_names, variable_metadata=variable_metadata, coords=coords, dataset_metadata=dataset_metadata, ) ) ) if enable_monthly_netcdfs: self._writers.append( PairedMonthlyDataWriter( path=path, n_samples=n_initial_conditions, n_timesteps=n_timesteps, timestep=timestep, save_names=save_names, variable_metadata=variable_metadata, coords=coords, dataset_metadata=dataset_metadata, ) ) if enable_video_netcdfs: self._writers.append( _time_coarsen_builder( PairedVideoDataWriter( path=path, n_timesteps=n_coarsened_timesteps, variable_metadata=variable_metadata, coords=coords, dataset_metadata=dataset_metadata, ) ) ) if enable_histogram_netcdfs: self._writers.append( _time_coarsen_builder( PairedHistogramDataWriter( path=path, n_timesteps=n_coarsened_timesteps, variable_metadata=variable_metadata, save_names=save_names, dataset_metadata=dataset_metadata, ) ) ) self._n_timesteps_seen = 0 def write(self, data: PrognosticState, filename: str): """Eagerly write data to a single netCDF file. Args: data: the data to be written. filename: the filename to use for the netCDF file. """ _write( data=data.as_batch_data(), path=self.path, filename=filename, variable_metadata=self.variable_metadata, coords=self.coords, dataset_metadata=self.dataset_metadata, ) def append_batch( self, batch: PairedData, ): """ Append a batch of data to the file. Args: batch: Prediction and target data. """ for writer in self._writers: writer.append_batch( target=dict(batch.reference), prediction=dict(batch.prediction), start_timestep=self._n_timesteps_seen, batch_time=batch.time, ) self._n_timesteps_seen += batch.time.shape[1] def flush(self): """ Flush the data to disk. """ for writer in self._writers: writer.flush() def _write( data: BatchData, path: str, filename: str, variable_metadata: Mapping[str, VariableMetadata], coords: Mapping[str, np.ndarray], dataset_metadata: DatasetMetadata, ): """Write provided data to a single netCDF at specified path/filename. If the data has only one timestep, the data is squeezed to remove the time dimension. Args: data: Batch data to written. path: Directory to write the netCDF file in. filename: filename to use for netCDF. variable_metadata: Metadata for each variable to be written to the file. coords: Coordinate data to be written to the file. dataset_metadata: Metadata for the dataset. """ if data.time.sizes["time"] == 1: time_dim = data.dims.index("time") dims_to_write = data.dims[:time_dim] + data.dims[time_dim + 1 :] def maybe_squeeze(x: torch.Tensor) -> torch.Tensor: return x.squeeze(dim=time_dim) time_array = data.time.isel(time=0) else: dims_to_write = data.dims def maybe_squeeze(x): return x time_array = data.time data_arrays = {} for name in data.data: array = maybe_squeeze(data.data[name]).cpu().numpy() data_arrays[name] = xr.DataArray(array, dims=dims_to_write) if name in variable_metadata: data_arrays[name].attrs = { "long_name": variable_metadata[name].long_name, "units": variable_metadata[name].units, } data_arrays["time"] = time_array ds = xr.Dataset(data_arrays, coords=coords) ds.attrs.update(dataset_metadata.as_flat_str_dict()) ds.to_netcdf(str(Path(path) / filename)) class DataWriter(WriterABC[PrognosticState, PairedData]): def __init__( self, path: str, n_initial_conditions: int, n_timesteps: int, variable_metadata: Mapping[str, VariableMetadata], coords: Mapping[str, np.ndarray], timestep: datetime.timedelta, enable_prediction_netcdfs: bool, enable_monthly_netcdfs: bool, save_names: Sequence[str] | None, dataset_metadata: DatasetMetadata, time_coarsen: TimeCoarsenConfig | None = None, ): """ Args: path: Directory within which to write netCDF file(s). n_initial_conditions: Number of initial conditions / timeseries to write to the file. n_timesteps: Number of timesteps to write to the file. variable_metadata: Metadata for each variable to be written to the file. coords: Coordinate data to be written to the file. timestep: Timestep of the model. enable_prediction_netcdfs: Whether to enable writing of netCDF files containing the predictions and target values. enable_monthly_netcdfs: Whether to enable writing of netCDF files save_names: Names of variables to save in the prediction, histogram, and monthly netCDF files. dataset_metadata: Metadata for the dataset. time_coarsen: Configuration for time coarsening of raw outputs. """ self._writers: list[Subwriter] = [] if "face" in coords: # TODO: handle writing HEALPix data # https://github.com/ai2cm/full-model/issues/1089 return def _time_coarsen_builder(data_writer: Subwriter) -> Subwriter: if time_coarsen is not None: return time_coarsen.build(data_writer) return data_writer if enable_prediction_netcdfs: self._writers.append( _time_coarsen_builder( RawDataWriter( path=path, label="autoregressive_predictions.nc", n_initial_conditions=n_initial_conditions, save_names=save_names, variable_metadata=variable_metadata, coords=coords, dataset_metadata=dataset_metadata, ) ) ) if enable_monthly_netcdfs: self._writers.append( MonthlyDataWriter( path=path, label="predictions", n_samples=n_initial_conditions, n_months=months_for_timesteps(n_timesteps, timestep), save_names=save_names, variable_metadata=variable_metadata, coords=coords, dataset_metadata=dataset_metadata, ) ) self.path = path self.variable_metadata = variable_metadata self.dataset_metadata = dataset_metadata self.coords = coords self._n_timesteps_seen = 0 def append_batch(self, batch: PairedData): """ Append prediction data to the file. The prognostic data and forcing data are merged before writing. Args: batch: Paired data to be written. """ merged = {**batch.prediction, **batch.forcing} unpaired_batch = BatchData.new_on_device( data=merged, time=batch.time, ) self._append_batch(unpaired_batch) def _append_batch(self, batch: BatchData): for writer in self._writers: writer.append_batch( data=dict(batch.data), start_timestep=self._n_timesteps_seen, batch_time=batch.time, ) self._n_timesteps_seen += batch.time.shape[1] def flush(self): """ Flush the data to disk. """ for writer in self._writers: writer.flush() def write(self, data: PrognosticState, filename: str): _write( data=data.as_batch_data(), path=self.path, filename=filename, variable_metadata=self.variable_metadata, coords=self.coords, dataset_metadata=self.dataset_metadata, )