Source code for fme.core.corrector.ocean

import dataclasses
import datetime
from collections.abc import Mapping
from typing import Any, Literal, Protocol

import torch

from fme.core.atmosphere_data import AtmosphereData
from fme.core.constants import (
    FREEZING_TEMPERATURE_KELVIN,
    LATENT_HEAT_OF_VAPORIZATION,
    SPECIFIC_HEAT_OF_SEA_WATER_CM4,
)
from fme.core.corrector.registry import CorrectorABC, CorrectorConfigABC
from fme.core.corrector.utils import force_positive
from fme.core.dataset_info import DatasetInfo
from fme.core.gridded_ops import GriddedOperations
from fme.core.ocean_data import HasOceanDepthIntegral, OceanData
from fme.core.registry.corrector import CorrectorSelector
from fme.core.typing_ import TensorDict, TensorMapping


@dataclasses.dataclass
class SeaIceFractionConfig:
    """Correct predicted sea_ice_fraction to ensure it is always in 0-1, and
    land_fraction + sea_ice_fraction + ocean_fraction = 1. After
    sea_ice_fraction is corrected, all variables listed in
    zero_where_ice_free_names will be set to 0 everywhere
    sea_ice_fraction is 0.

    Parameters:
        sea_ice_fraction_name: Name of the sea ice fraction variable.
        land_fraction_name: Name of the land fraction variable.
        zero_where_ice_free_names: List of variable names to set to 0
            wherever sea_ice_fraction is 0.
        remove_negative_ocean_fraction: If True, reduce sea_ice_fraction
            to prevent ocean_fraction (1 - sea_ice_fraction - land_fraction)
            from being negative.
    """

    sea_ice_fraction_name: str
    land_fraction_name: str
    zero_where_ice_free_names: list[str] = dataclasses.field(default_factory=list)
    remove_negative_ocean_fraction: bool = True

    def __call__(
        self, gen_data: TensorMapping, input_data: TensorMapping
    ) -> TensorDict:
        out = {**gen_data}
        out[self.sea_ice_fraction_name] = torch.clamp(
            out[self.sea_ice_fraction_name], min=0.0, max=1.0
        )
        if self.remove_negative_ocean_fraction:
            negative_ocean_fraction = (
                1
                - out[self.sea_ice_fraction_name]
                - input_data[self.land_fraction_name]
            )
            negative_ocean_fraction = negative_ocean_fraction.clip(max=0)
            out[self.sea_ice_fraction_name] += negative_ocean_fraction
        for name in self.zero_where_ice_free_names:
            out[name] = gen_data[name] * (out[self.sea_ice_fraction_name] > 0.0)
        return out


@dataclasses.dataclass
class OceanHeatContentBudgetConfig:
    """Configuration for ocean heat content budget correction.

    Parameters:
        method: Method to use for OHC budget correction. The available option is
            "scaled_temperature", which enforces conservation of heat content
            by scaling the predicted potential temperature by a vertically and
            horizontally uniform correction factor.
        constant_unaccounted_heating: Area-weighted global mean
            column-integrated heating in W/m**2 to be added to the energy flux
            into the ocean when conserving the heat content. This can be useful
            for correcting errors in heat budget in target data. The same
            additional heating is imposed at all time steps and grid cells.

    """

    method: Literal["scaled_temperature"]
    constant_unaccounted_heating: float = 0.0


@dataclasses.dataclass
class SurfaceEnergyFluxCorrectionConfig:
    """Configuration for correcting the generated hfds using
    atmosphere-derived surface energy fluxes and ocean_fraction.

    The net_flux is the net surface energy flux computed from atmospheric
    forcing variables and generated SST. The ocean_fraction naturally zeroes
    out the correction on land and reduces it under sea ice.

    Available options are:
      - "residual_prediction": corrected_hfds = gen_hfds + ocean_fraction * net_flux.
        The network predicts a residual that is added to the forcing-derived flux.
      - "prescribed": corrected_hfds = net_flux * ocean_fraction + gen_hfds *
        (1 - ocean_fraction). Open-ocean hfds is prescribed from forcings; the
        network prediction is retained under sea ice and on land.

    Parameters:
        method: Method to use for the correction.

    """

    method: Literal["residual_prediction", "prescribed"]


[docs]@CorrectorSelector.register("ocean_corrector") @dataclasses.dataclass class OceanCorrectorConfig(CorrectorConfigABC): force_positive_names: list[str] = dataclasses.field(default_factory=list) sea_ice_fraction_correction: SeaIceFractionConfig | None = None surface_energy_flux_correction: SurfaceEnergyFluxCorrectionConfig | None = None ocean_heat_content_correction: OceanHeatContentBudgetConfig | None = None
[docs] @classmethod def remove_deprecated_keys(cls, state: Mapping[str, Any]) -> dict[str, Any]: state_copy = dict(state) if "masking" in state_copy: del state_copy["masking"] if "ocean_heat_content_correction" in state_copy and isinstance( state_copy["ocean_heat_content_correction"], bool ): if state_copy["ocean_heat_content_correction"]: state_copy["ocean_heat_content_correction"] = ( OceanHeatContentBudgetConfig(method="scaled_temperature") ) else: state_copy["ocean_heat_content_correction"] = None if "sea_ice_fraction_correction" in state_copy: sif = state_copy["sea_ice_fraction_correction"] if isinstance(sif, dict) and "sea_ice_thickness_name" in sif: thickness_name = sif.pop("sea_ice_thickness_name") if thickness_name is not None: sif.setdefault("zero_where_ice_free_names", []).append( thickness_name ) return state_copy
def get_corrector( self, dataset_info: DatasetInfo, ) -> "OceanCorrector": return OceanCorrector( self, dataset_info.gridded_operations, dataset_info.ocean_vertical_coordinate, dataset_info.timestep, )
class OceanCorrector(CorrectorABC): def __init__( self, config: OceanCorrectorConfig, gridded_operations: GriddedOperations, vertical_coordinate: HasOceanDepthIntegral | None, timestep: datetime.timedelta, ): self._config = config self._gridded_operations = gridded_operations self._vertical_coordinate = vertical_coordinate self._timestep = timestep def __call__( self, input_data: TensorMapping, gen_data: TensorMapping, forcing_data: TensorMapping, ) -> TensorDict: if len(self._config.force_positive_names) > 0: gen_data = force_positive(gen_data, self._config.force_positive_names) if self._config.sea_ice_fraction_correction is not None: gen_data = self._config.sea_ice_fraction_correction(gen_data, input_data) if self._config.surface_energy_flux_correction is not None: gen_data = _correct_hfds( input_data, gen_data, forcing_data, method=self._config.surface_energy_flux_correction.method, ) if self._config.ocean_heat_content_correction is not None: if self._vertical_coordinate is None: raise ValueError( "Ocean heat content correction is turned on, but no vertical " "coordinate is available." ) gen_data = _force_conserve_ocean_heat_content( input_data, gen_data, forcing_data, self._gridded_operations.area_weighted_mean, self._vertical_coordinate, self._timestep.total_seconds(), self._config.ocean_heat_content_correction.method, self._config.ocean_heat_content_correction.constant_unaccounted_heating, ) return dict(gen_data) def _compute_ocean_net_surface_energy_flux( forcing_data: TensorMapping, sst: torch.Tensor, ) -> torch.Tensor: """Compute the net surface energy flux into the ocean from atmospheric forcing variables and the sea surface temperature. This extends the atmosphere net surface energy flux with SST-dependent heat transport by precipitation and evaporation. """ atmos = AtmosphereData(forcing_data) base_flux = ( atmos.net_surface_energy_flux ) # missing: - calving * LATENT_HEAT_OF_FREEZING mass_heat_flux = ( SPECIFIC_HEAT_OF_SEA_WATER_CM4 * ( atmos.precipitation_rate + atmos.frozen_precipitation_rate - (atmos.latent_heat_flux / LATENT_HEAT_OF_VAPORIZATION) ) # missing: + river runoff + calving * (sst - FREEZING_TEMPERATURE_KELVIN) ) return base_flux + mass_heat_flux def _correct_hfds( input_data: TensorMapping, gen_data: TensorMapping, forcing_data: TensorMapping, method: Literal["residual_prediction", "prescribed"], ) -> TensorDict: """Apply surface energy flux correction to the generated hfds. The ocean_fraction naturally zeroes the correction on land and reduces it under sea ice. Methods: residual_prediction: gen_hfds + ocean_fraction * net_flux prescribed: net_flux * ocean_fraction + gen_hfds * (1 - ocean_fraction) """ input = OceanData(input_data) forcing = OceanData(forcing_data) ocean_fraction = input.ocean_fraction net_flux = _compute_ocean_net_surface_energy_flux( forcing_data, input.sea_surface_temperature ) out = dict(gen_data) if "hfds" in gen_data: hfds_name = "hfds" else: hfds_name = "hfds_total_area" net_flux = net_flux * forcing.sea_surface_fraction gen_hfds = gen_data[hfds_name] if method == "residual_prediction": out[hfds_name] = net_flux * ocean_fraction + gen_hfds elif method == "prescribed": out[hfds_name] = net_flux * ocean_fraction + gen_hfds * (1 - ocean_fraction) else: raise NotImplementedError( f"Method {method!r} not implemented for surface energy flux correction" ) return out class AreaWeightedMean(Protocol): def __call__( self, data: torch.Tensor, keepdim: bool, name: str | None = None ) -> torch.Tensor: ... def _force_conserve_ocean_heat_content( input_data: TensorMapping, gen_data: TensorMapping, forcing_data: TensorMapping, area_weighted_mean: AreaWeightedMean, vertical_coordinate: HasOceanDepthIntegral, timestep_seconds: float, method: Literal["scaled_temperature"] = "scaled_temperature", unaccounted_heating: float = 0.0, ) -> TensorDict: if method != "scaled_temperature": raise NotImplementedError( f"Method {method!r} not implemented for ocean heat content conservation" ) if "hfds" in gen_data and "hfds" in forcing_data: raise ValueError( "Net downward surface heat flux cannot be present in both gen_data and " "forcing_data." ) input = OceanData(input_data, vertical_coordinate) if input.ocean_heat_content is None: raise ValueError( "ocean_heat_content is required to force ocean heat content conservation" ) gen = OceanData(gen_data, vertical_coordinate) forcing = OceanData(forcing_data) global_gen_ocean_heat_content = area_weighted_mean( gen.ocean_heat_content, keepdim=True, name="ocean_heat_content", ) global_input_ocean_heat_content = area_weighted_mean( input.ocean_heat_content, keepdim=True, name="ocean_heat_content", ) try: # First priority: pre-weighted heat flux in gen_data net_energy_flux_into_ocean = ( gen.net_downward_surface_heat_flux_total_area + forcing.geothermal_heat_flux * forcing.sea_surface_fraction ) except KeyError: try: # Second priority: standard heat flux in gen_data net_energy_flux_into_ocean = ( gen.net_downward_surface_heat_flux + forcing.geothermal_heat_flux ) * forcing.sea_surface_fraction except KeyError: # Third priority: standard heat flux in input_data net_energy_flux_into_ocean = ( input.net_downward_surface_heat_flux + forcing.geothermal_heat_flux ) * forcing.sea_surface_fraction energy_flux_global_mean = area_weighted_mean( net_energy_flux_into_ocean, keepdim=True, name="ocean_heat_content", ) expected_change_ocean_heat_content = ( energy_flux_global_mean + unaccounted_heating ) * timestep_seconds heat_content_correction_ratio = ( global_input_ocean_heat_content + expected_change_ocean_heat_content ) / global_gen_ocean_heat_content # apply same temperature correction to all vertical layers n_levels = gen.sea_water_potential_temperature.shape[-1] for k in range(n_levels): name = f"thetao_{k}" gen.data[name] = gen.data[name] * heat_content_correction_ratio if "sst" in gen.data: gen.data["sst"] = ( # assuming sst in Kelvin gen.data["sst"] - FREEZING_TEMPERATURE_KELVIN ) * heat_content_correction_ratio + FREEZING_TEMPERATURE_KELVIN return gen.data