Source code for fme.core.corrector.atmosphere

import dataclasses
import datetime
from collections.abc import Callable
from typing import Literal, Protocol

import torch

import fme
from fme.core.atmosphere_data import (
    AtmosphereData,
    HasAtmosphereVerticalIntegral,
    compute_layer_thickness,
)
from fme.core.constants import GRAVITY, SPECIFIC_HEAT_OF_DRY_AIR_CONST_VOLUME
from fme.core.corrector.registry import CorrectorABC, CorrectorConfigABC
from fme.core.corrector.utils import force_positive
from fme.core.dataset_info import DatasetInfo
from fme.core.gridded_ops import GriddedOperations
from fme.core.registry.corrector import CorrectorSelector
from fme.core.typing_ import TensorDict, TensorMapping


@dataclasses.dataclass
class EnergyBudgetConfig:
    """Configuration for total energy budget correction.

    Parameters:
        method: Method to use for total energy budget correction. The available option
            is "constant_temperature", which enforces conservation of total energy by
            imposing a vertically and horizontally uniform air temperature correction.
        constant_unaccounted_heating: Column-integrated heating in W/m**2 to be added
            to the energy flux into the atmosphere when conserving total energy.
            This can be useful for correcting errors in energy budget in target data.
            The same additional heating is imposed at all time steps and grid cells.
    """

    method: Literal["constant_temperature"]
    constant_unaccounted_heating: float = 0.0


[docs]@CorrectorSelector.register("atmosphere_corrector") @dataclasses.dataclass class AtmosphereCorrectorConfig(CorrectorConfigABC): r""" Configuration for the post-step state corrector. ``conserve_dry_air`` enforces the constraint that: .. math:: global\_dry\_air = global\_mean(ps - sum_k((ak\_diff + bk\_diff \* ps) \* wat_k)) in the generated data is equal to its value in the input data. This is done by adding a globally-constant correction to the surface pressure in each column. As per-mass values such as mixing ratios of water are unchanged, this can cause changes in total water or energy. Note all global means here are area-weighted. ``zero_global_mean_moisture_advection`` enforces the constraint that: .. math:: global\_mean(tendency\_of\_total\_water\_path\_due\_to\_advection) = 0 in the generated data. This is done by adding a globally-constant correction to the moisture advection tendency in each column. ``moisture_budget_correction`` enforces closure of the moisture budget equation: .. math:: tendency\_of\_total\_water\_path = (evaporation\_rate - precipitation\_rate \\\\ + tendency\_of\_total\_water\_path\_due\_to\_advection) in the generated data, where ``tendency_of_total_water_path`` is the difference between the total water path at the current timestep and the previous timestep divided by the time difference. This is done by modifying the precipitation, evaporation, and/or moisture advection tendency fields as described in the ``moisture_budget_correction`` attribute. When advection tendency is modified, this budget equation is enforced in each column, while when only precipitation or evaporation are modified, only the global mean of the budget equation is enforced. When enforcing moisture budget closure, we assume the global mean moisture advection is zero. Therefore ``zero_global_mean_moisture_advection`` must be True if using a ``moisture_budget_correction`` option other than ``None``. Parameters: conserve_dry_air: If True, force the generated data to conserve dry air by subtracting a constant offset from the surface pressure of each column. This can cause changes in per-mass values such as total water or energy. zero_global_mean_moisture_advection: If True, force the generated data to have zero global mean moisture advection by subtracting a constant offset from the moisture advection tendency of each column. moisture_budget_correction: If not "None", force the generated data to conserve global or column-local moisture by modifying budget fields. Options are: - ``precipitation``: multiply precipitation by a scale factor to close the global moisture budget. - ``evaporation``: multiply evaporation by a scale factor to close the global moisture budget. - ``advection_and_precipitation``: after applying the "precipitation" global-mean correction above, recompute the column-integrated advective tendency as the budget residual, ensuring column budget closure. - ``advection_and_evaporation``: after applying the "evaporation" global-mean correction above, recompute the column-integrated advective tendency as the budget residual, ensuring column budget closure. force_positive_names: Names of fields that should be forced to be greater than or equal to zero. This is useful for fields like precipitation. total_energy_budget_correction: If not None, force the generated data to conserve an idealized version of total energy using the provided configuration. """ conserve_dry_air: bool = False zero_global_mean_moisture_advection: bool = False moisture_budget_correction: ( Literal[ "precipitation", "evaporation", "advection_and_precipitation", "advection_and_evaporation", ] | None ) = None force_positive_names: list[str] = dataclasses.field(default_factory=list) total_energy_budget_correction: EnergyBudgetConfig | None = None def get_corrector( self, dataset_info: DatasetInfo, ) -> "AtmosphereCorrector": return AtmosphereCorrector( self, dataset_info.gridded_operations, dataset_info.atmosphere_vertical_coordinate, dataset_info.timestep, )
class AtmosphereCorrector(CorrectorABC): def __init__( self, config: AtmosphereCorrectorConfig, gridded_operations: GriddedOperations, vertical_coordinate: HasAtmosphereVerticalIntegral | None, timestep: datetime.timedelta, ): self._config = config self._gridded_operations = gridded_operations self._vertical_coordinate = vertical_coordinate self._timestep_seconds = timestep.total_seconds() if fme.get_device() == torch.device("mps", 0): self._dry_air_precision = torch.float32 else: self._dry_air_precision = torch.float64 def __call__( self, input_data: TensorMapping, gen_data: TensorMapping, forcing_data: TensorMapping, ) -> TensorDict: """Apply corrections to the generated data. Args: input_data: The input time step data. gen_data: The data generated by the model, to be corrected. forcing_data: The forcing data for the same time step as gen_data. Returns: The corrected data. """ gen_data = dict(gen_data) if len(self._config.force_positive_names) > 0: # do this step before imposing other conservation correctors, since # otherwise it could end up creating violations of those constraints. gen_data = force_positive(gen_data, self._config.force_positive_names) if self._config.conserve_dry_air: if self._vertical_coordinate is None: raise ValueError( "conserve_dry_air is set to True, but no vertical coordinate is " "available." ) gen_data = _force_conserve_dry_air( input_data=input_data, gen_data=gen_data, area_weighted_mean=self._gridded_operations.area_weighted_mean, vertical_coordinate=self._vertical_coordinate, precision=self._dry_air_precision, ) if self._config.zero_global_mean_moisture_advection: gen_data = _force_zero_global_mean_moisture_advection( gen_data=gen_data, area_weighted_mean=self._gridded_operations.area_weighted_mean, ) if self._config.moisture_budget_correction is not None: if self._vertical_coordinate is None: raise ValueError( "Moisture budget correction is turned on, but no vertical " "coordinate is available." ) gen_data = _force_conserve_moisture( input_data=input_data, gen_data=gen_data, area_weighted_mean=self._gridded_operations.area_weighted_mean, vertical_coordinate=self._vertical_coordinate, timestep_seconds=self._timestep_seconds, terms_to_modify=self._config.moisture_budget_correction, ) if self._config.total_energy_budget_correction is not None: if self._vertical_coordinate is None: raise ValueError( "Energy budget correction is turned on, but no vertical coordinate" " is available." ) gen_data = _force_conserve_total_energy( input_data=input_data, gen_data=gen_data, forcing_data=forcing_data, area_weighted_mean=self._gridded_operations.area_weighted_mean, vertical_coordinate=self._vertical_coordinate, timestep_seconds=self._timestep_seconds, method=self._config.total_energy_budget_correction.method, unaccounted_heating=self._config.total_energy_budget_correction.constant_unaccounted_heating, ) return gen_data class AreaWeightedMean(Protocol): def __call__( self, data: torch.Tensor, keepdim: bool, name: str | None = None ) -> torch.Tensor: ... def _force_conserve_dry_air( input_data: TensorMapping, gen_data: TensorMapping, area_weighted_mean: AreaWeightedMean, vertical_coordinate: HasAtmosphereVerticalIntegral, precision: torch.dtype = torch.float64, ) -> TensorDict: """ Update the generated data to conserve dry air. This is done by adding a constant correction to the dry air pressure of each column, and may result in changes in per-mass values such as total water or energy. We first compute the target dry air pressure by computing the globally averaged difference in dry air pressure between the input_data and gen_data, and then add this offset to the fully-resolved gen_data dry air pressure. We can then solve for the surface pressure corresponding to this new dry air pressure. We start from the expression for dry air pressure: dry_air = ps - sum_k((ak_diff + bk_diff * ps) * wat_k) To update the dry air, we compute and update the surface pressure: ps = ( dry_air + sum_k(ak_diff * wat_k) ) / ( 1 - sum_k(bk_diff * wat_k) ) """ input = AtmosphereData(input_data, vertical_coordinate) if input.surface_pressure is None: raise ValueError("surface_pressure is required to force dry air conservation") gen = AtmosphereData(gen_data, vertical_coordinate) gen_dry_air = gen.surface_pressure_due_to_dry_air global_gen_dry_air = area_weighted_mean(gen_dry_air.to(precision), keepdim=True) global_target_gen_dry_air = area_weighted_mean( input.surface_pressure_due_to_dry_air.to(precision), keepdim=True, ) error = global_gen_dry_air - global_target_gen_dry_air new_gen_dry_air = gen_dry_air.to(precision) - error try: wat = gen.specific_total_water.to(precision) except KeyError: raise ValueError("specific_total_water is required for conservation") ak_diff = vertical_coordinate.get_ak().diff().to(precision) bk_diff = vertical_coordinate.get_bk().diff().to(precision) new_pressure = (new_gen_dry_air + (ak_diff * wat).sum(-1)) / ( 1 - (bk_diff * wat).sum(-1) ) gen.set_surface_pressure(new_pressure.to(dtype=input.surface_pressure.dtype)) return gen.data def _force_zero_global_mean_moisture_advection( gen_data: TensorMapping, area_weighted_mean: Callable[[torch.Tensor], torch.Tensor], ) -> TensorDict: """ Update the generated data so advection conserves moisture. Does so by adding a constant offset to the moisture advective tendency. Args: gen_data: The generated data. area_weighted_mean: Computes an area-weighted mean, removing horizontal dimensions. """ gen = AtmosphereData(gen_data) mean_moisture_advection = area_weighted_mean( gen.tendency_of_total_water_path_due_to_advection, ) gen.set_tendency_of_total_water_path_due_to_advection( gen.tendency_of_total_water_path_due_to_advection - mean_moisture_advection[..., None, None] ) return gen.data def _force_conserve_moisture( input_data: TensorMapping, gen_data: TensorMapping, area_weighted_mean: AreaWeightedMean, vertical_coordinate: HasAtmosphereVerticalIntegral, timestep_seconds: float, terms_to_modify: Literal[ "precipitation", "evaporation", "advection_and_precipitation", "advection_and_evaporation", ], ) -> TensorDict: """ Update the generated data to conserve moisture. Does so while conserving total dry air in each column. Assumes the global mean advective tendency of moisture is zero. This assumption means any existing global mean advective tendency will be set to zero if the advective tendency is re-computed. Args: input_data: The input data. gen_data: The generated data one timestep after the input data. area_weighted_mean: Computes an area-weighted mean, removing horizontal dimensions. vertical_coordinate: The sigma coordinates. timestep_seconds: Timestep of the model in seconds. terms_to_modify: Which terms to modify, in addition to modifying surface pressure to conserve dry air mass. One of: - "precipitation": modify precipitation only - "evaporation": modify evaporation only - "advection_and_precipitation": modify advection and precipitation - "advection_and_evaporation": modify advection and evaporation """ input = AtmosphereData(input_data, vertical_coordinate) gen = AtmosphereData(gen_data, vertical_coordinate) gen_total_water_path = gen.total_water_path twp_total_tendency = ( gen_total_water_path - input.total_water_path ) / timestep_seconds twp_tendency_global_mean = area_weighted_mean(twp_total_tendency, keepdim=True) evaporation_global_mean = area_weighted_mean(gen.evaporation_rate, keepdim=True) precipitation_global_mean = area_weighted_mean(gen.precipitation_rate, keepdim=True) if terms_to_modify.endswith("precipitation"): # We want to achieve # global_mean(twp_total_tendency) = ( # global_mean(evaporation_rate) # - global_mean(precipitation_rate) # ) # so we modify precipitation_rate to achieve this. Note we have # assumed the global mean advection tendency is zero. # First, we find the required global-mean precipitation rate # new_global_precip_rate = ( # global_mean(evaporation_rate) # - global_mean(twp_total_tendency) # ) new_precipitation_global_mean = ( evaporation_global_mean - twp_tendency_global_mean ) # Because scalar multiplication commutes with summation, we can # achieve this by multiplying each gridcell's precipitation rate # by the ratio of the new global mean to the current global mean. # new_precip_rate = ( # new_global_precip_rate / current_global_precip_rate # ) * current_precip_rate gen.set_precipitation_rate( gen.precipitation_rate * (new_precipitation_global_mean / precipitation_global_mean) ) elif terms_to_modify.endswith("evaporation"): # Derived similarly as for "precipitation" case. new_evaporation_global_mean = ( twp_tendency_global_mean + precipitation_global_mean ) gen.set_evaporation_rate( gen.evaporation_rate * (new_evaporation_global_mean / evaporation_global_mean) ) if terms_to_modify.startswith("advection"): # Having already corrected the global-mean budget, we recompute # advection based on assumption that the columnwise # moisture budget closes. Correcting the global mean budget first # is important to ensure the resulting advection has zero global mean. new_advection = twp_total_tendency - ( gen.evaporation_rate - gen.precipitation_rate ) gen.set_tendency_of_total_water_path_due_to_advection(new_advection) return gen.data def _force_conserve_total_energy( input_data: TensorMapping, gen_data: TensorMapping, forcing_data: TensorMapping, area_weighted_mean: AreaWeightedMean, vertical_coordinate: HasAtmosphereVerticalIntegral, timestep_seconds: float, method: Literal["constant_temperature"] = "constant_temperature", unaccounted_heating: float = 0.0, ) -> TensorDict: """Apply a correction to the generated data to conserve total energy. This function also inserts the unaccounted heating into the generated data. """ if method != "constant_temperature": raise NotImplementedError( f"Method {method} not implemented for total energy conservation" ) input = AtmosphereData(input_data, vertical_coordinate) forcing = AtmosphereData(forcing_data) required_forcing = { "DSWRFtoa": forcing.toa_down_sw_radiative_flux, "HGTsfc": forcing.surface_height, } atmosphere_data = dict(gen_data) for name, tensor in required_forcing.items(): atmosphere_data[name] = tensor gen = AtmosphereData(atmosphere_data, vertical_coordinate) gen_energy_path = gen.total_energy_ace2_path input_energy_path = input.total_energy_ace2_path predicted_energy_flux_into_atmosphere = gen.net_energy_flux_into_atmosphere gen_energy_path_global_mean = area_weighted_mean(gen_energy_path, keepdim=True) input_energy_path_global_mean = area_weighted_mean(input_energy_path, keepdim=True) energy_flux_global_mean = area_weighted_mean( predicted_energy_flux_into_atmosphere, keepdim=True ) desired_energy_path_global_mean = ( input_energy_path_global_mean + (energy_flux_global_mean + unaccounted_heating) * timestep_seconds ) energy_correction = desired_energy_path_global_mean - gen_energy_path_global_mean energy_to_temperature_factor = _energy_correction_factor(gen, vertical_coordinate) # take global mean to impose a spatially uniform temperature correction energy_to_temp_factor_gm = area_weighted_mean(energy_to_temperature_factor, True) temperature_correction = energy_correction / energy_to_temp_factor_gm # apply same temperature correction to all vertical layers air_temperature_names = gen.get_all_vertical_level_names("air_temperature") for name in air_temperature_names: gen.data[name] = gen.data[name] + temperature_correction # filter required here because we merged forcing data into gen above return {k: v for k, v in gen.data.items() if k in gen_data} def _energy_correction_factor( gen: AtmosphereData, vertical_coordinate: HasAtmosphereVerticalIntegral ) -> torch.Tensor: """ Compute the factor to get a vertically-uniform temperature correction that will lead to a desired change in the globally-averaged total energy. See https://www.overleaf.com/read/dqjjcvzxnfvn#d525aa. """ interface_pressure = vertical_coordinate.interface_pressure(gen.surface_pressure) q_times_dlogp = ( compute_layer_thickness( interface_pressure, gen.air_temperature, gen.specific_total_water ) * GRAVITY / gen.air_temperature ) cumulative = torch.cumsum(q_times_dlogp.flip(dims=(-1,)), dim=-1).flip(dims=(-1,)) total_integrand = ( SPECIFIC_HEAT_OF_DRY_AIR_CONST_VOLUME - 0.5 * q_times_dlogp + cumulative ) correction_factor = vertical_coordinate.vertical_integral( total_integrand, gen.surface_pressure ) return correction_factor