Source code for fme.coupled.inference.inference

import dataclasses
import logging
from collections.abc import Sequence
from typing import Literal

import dacite
import torch
import xarray as xr

import fme
from fme.ace.data_loading.inference import (
    ExplicitIndices,
    InferenceInitialConditionIndices,
    TimestampList,
)
from fme.ace.inference.inference import InitialConditionConfig, get_initial_condition
from fme.core.cli import prepare_config, prepare_directory
from fme.core.cloud import makedirs
from fme.core.derived_variables import get_derived_variable_metadata
from fme.core.generics.inference import get_record_to_wandb, run_inference
from fme.core.logging_utils import LoggingConfig
from fme.core.timing import GlobalTimer
from fme.coupled.aggregator import InferenceAggregatorConfig
from fme.coupled.data_loading.batch_data import CoupledPrognosticState
from fme.coupled.data_loading.getters import get_forcing_data
from fme.coupled.data_loading.gridded_data import InferenceGriddedData
from fme.coupled.data_loading.inference import CoupledForcingDataLoaderConfig
from fme.coupled.inference.data_writer import (
    CoupledDataWriterConfig,
    CoupledPairedDataWriter,
    DatasetMetadata,
)
from fme.coupled.stepper import CoupledStepper, CoupledStepperConfig

from .evaluator import (
    StandaloneComponentCheckpointsConfig,
    _validate_coupled_steps_config,
    load_stepper,
    load_stepper_config,
)

StartIndices = InferenceInitialConditionIndices | ExplicitIndices | TimestampList


[docs]@dataclasses.dataclass class ComponentInitialConditionConfig: """ Parameters: path: Path to the component initial condition dataset. engine: Backend used in xarray.open_dataset call. """ path: str engine: Literal["netcdf4", "h5netcdf", "zarr"] = "netcdf4" def get_dataset(self, start_indices: StartIndices | None = None) -> xr.Dataset: ic_config = InitialConditionConfig( path=self.path, engine=self.engine, start_indices=start_indices, ) return ic_config.get_dataset()
[docs]@dataclasses.dataclass class CoupledInitialConditionConfig: """ Configuration for initial conditions in coupled inference. Parameters: ocean: Configuration for the ocean initial conditions. atmosphere: Configuration for the atmosphere initial conditions. start_indices: Indices to use for selecting initial conditions, should correspond to the ocean initial condition dataset. """ ocean: ComponentInitialConditionConfig atmosphere: ComponentInitialConditionConfig start_indices: StartIndices | None = None def get_initial_condition( self, ocean_prognostic_names: Sequence[str], atmosphere_prognostic_names: Sequence[str], n_ensemble_per_ic: int, ) -> CoupledPrognosticState: ocean = self.ocean.get_dataset(self.start_indices) # time is a required variable but not necessarily a dimension sample_dim_name = ocean.time.dims[0] atmos = self.atmosphere.get_dataset().sel( {sample_dim_name: ocean[sample_dim_name]} ) return CoupledPrognosticState( ocean_data=get_initial_condition( ds=ocean, prognostic_names=ocean_prognostic_names, n_ensemble=n_ensemble_per_ic, ), atmosphere_data=get_initial_condition( ds=atmos, prognostic_names=atmosphere_prognostic_names, n_ensemble=n_ensemble_per_ic, ), )
[docs]@dataclasses.dataclass class InferenceConfig: """ Configuration for running inference. Parameters: experiment_dir: Directory to save results to. n_coupled_steps: Number of steps to run the model forward for. checkpoint_path: Path to a CoupledStepper training checkpoint to load, or a mapping to two separate Stepper training checkpoints. logging: configuration for logging. initial_condition: Configuration for initial condition data. forcing_loader: Configuration for forcing data. coupled_steps_in_memory: Number of coupled steps to complete in memory at a time, will load one more step for initial condition. data_writer: Configuration for data writers. aggregator: Configuration for inference aggregator. n_ensemble_per_ic: Number of ensemble members per initial condition """ experiment_dir: str n_coupled_steps: int checkpoint_path: str | StandaloneComponentCheckpointsConfig logging: LoggingConfig initial_condition: CoupledInitialConditionConfig forcing_loader: CoupledForcingDataLoaderConfig coupled_steps_in_memory: int = 1 data_writer: CoupledDataWriterConfig = dataclasses.field( default_factory=lambda: CoupledDataWriterConfig() ) aggregator: InferenceAggregatorConfig = dataclasses.field( default_factory=lambda: InferenceAggregatorConfig() ) n_ensemble_per_ic: int = 1 def __post_init__(self): _validate_coupled_steps_config( self.n_coupled_steps, self.coupled_steps_in_memory ) def configure_logging(self, log_filename: str): config = dataclasses.asdict(self) self.logging.configure_logging( self.experiment_dir, log_filename, config=config, resumable=False ) def load_stepper(self) -> CoupledStepper: return load_stepper(self.checkpoint_path) def load_stepper_config(self) -> CoupledStepperConfig: return load_stepper_config(self.checkpoint_path) def get_data_writer( self, data: InferenceGriddedData, ) -> CoupledPairedDataWriter: if self.data_writer.ocean.time_coarsen is not None: try: self.data_writer.ocean.time_coarsen.validate( self.coupled_steps_in_memory, self.n_coupled_steps, ) except ValueError as err: raise ValueError( f"Ocean time_coarsen config invalid with error: {str(err)}" ) if self.data_writer.atmosphere.time_coarsen is not None: try: self.data_writer.atmosphere.time_coarsen.validate( self.coupled_steps_in_memory * data.n_inner_steps, self.n_coupled_steps * data.n_inner_steps, ) except ValueError as err: raise ValueError( f"Atmosphere time_coarsen config invalid with error: {str(err)}" ) variable_metadata = get_derived_variable_metadata() | data.variable_metadata dataset_metadata = DatasetMetadata.from_env() coupled_dataset_metadata = { "ocean": dataset_metadata, "atmosphere": dataset_metadata, } return self.data_writer.build_paired( experiment_dir=self.experiment_dir, initial_condition_times=data.initial_time.to_numpy(), n_timesteps_ocean=self.n_coupled_steps, n_timesteps_atmosphere=self.n_coupled_steps * data.n_inner_steps, ocean_timestep=data.ocean_timestep, atmosphere_timestep=data.atmosphere_timestep, variable_metadata=variable_metadata, coords=data.coords, dataset_metadata=coupled_dataset_metadata, )
def main( yaml_config: str, override_dotlist: Sequence[str] | None = None, ): config_data = prepare_config(yaml_config, override=override_dotlist) config = dacite.from_dict( data_class=InferenceConfig, data=config_data, config=dacite.Config(strict=True), ) prepare_directory(config.experiment_dir, config_data) with torch.no_grad(): with GlobalTimer(): return run_inference_from_config(config) def run_inference_from_config(config: InferenceConfig): timer = GlobalTimer.get_instance() timer.start_outer("inference") timer.start("initialization") makedirs(config.experiment_dir, exist_ok=True) config.configure_logging(log_filename="inference_out.log") if fme.using_gpu(): torch.backends.cudnn.benchmark = True stepper_config = config.load_stepper_config() data_requirements = stepper_config.get_forcing_window_data_requirements( n_coupled_steps=config.coupled_steps_in_memory ) logging.info("Loading initial condition data") initial_condition = config.initial_condition.get_initial_condition( ocean_prognostic_names=stepper_config.ocean.stepper.prognostic_names, atmosphere_prognostic_names=stepper_config.atmosphere.stepper.prognostic_names, n_ensemble_per_ic=config.n_ensemble_per_ic, ) stepper = config.load_stepper() stepper.set_eval() logging.info("Initializing forcing data loader") data = get_forcing_data( config=config.forcing_loader, total_coupled_steps=config.n_coupled_steps, window_requirements=data_requirements, initial_condition=initial_condition, dataset_info=stepper.training_dataset_info, ) aggregator_config: InferenceAggregatorConfig = config.aggregator variable_metadata = get_derived_variable_metadata() | data.variable_metadata dataset_info = data.dataset_info.update_variable_metadata(variable_metadata) n_timesteps_ocean = config.n_coupled_steps + stepper.ocean.n_ic_timesteps n_timesteps_atmosphere = ( config.n_coupled_steps * stepper.n_inner_steps + stepper.atmosphere.n_ic_timesteps ) aggregator = aggregator_config.build( dataset_info=dataset_info, n_timesteps_ocean=n_timesteps_ocean, n_timesteps_atmosphere=n_timesteps_atmosphere, output_dir=config.experiment_dir, ) writer = config.get_data_writer(data=data) timer.stop("initialization") logging.info("Starting inference") logger = get_record_to_wandb(label="inference") run_inference( predict=stepper.predict_paired, data=data, aggregator=aggregator, writer=writer, record_logs=logger.log, ) timer.start("final_writer_flush") logging.info("Starting final flush of data writer") writer.finalize() logging.info("Writing reduced metrics to disk in netcdf format.") aggregator.flush_diagnostics() timer.stop("final_writer_flush") timer.stop_outer("inference") total_steps = ( config.n_coupled_steps * stepper.n_inner_steps ) * data.n_initial_conditions inference_duration = timer.get_duration("inference") wandb_logging_duration = timer.get_duration("inference/wandb_logging") total_steps_per_second = total_steps / (inference_duration - wandb_logging_duration) timer.log_durations() logging.info( "Total steps per second (ignoring wandb logging): " f"{total_steps_per_second:.2f} steps/second" ) summary_logs = { "total_steps_per_second": total_steps_per_second, **aggregator.get_summary_logs(), } logger.log_to_current_step(summary_logs) logger.log_to_current_step(timer.get_durations())