Source code for fme.core.corrector.ocean
import dataclasses
import datetime
from types import MappingProxyType
from typing import Any, List, Mapping, Optional
import dacite
from fme.core.coordinates import HybridSigmaPressureCoordinate
from fme.core.corrector.corrector import force_positive
from fme.core.corrector.registry import CorrectorABC, CorrectorConfigProtocol
from fme.core.gridded_ops import GriddedOperations
from fme.core.masking import MaskingConfig
from fme.core.registry.corrector import CorrectorSelector
from fme.core.stacker import Stacker
from fme.core.typing_ import TensorMapping
OCEAN_FIELD_NAME_PREFIXES = MappingProxyType(
{
"surface_height": ["zos"],
"salinity": ["so_"],
"potential_temperature": ["thetao_"],
"zonal_velocity": ["uo_"],
"meridional_velocity": ["vo_"],
}
)
[docs]@CorrectorSelector.register("ocean_corrector")
@dataclasses.dataclass
class OceanCorrectorConfig(CorrectorConfigProtocol):
masking: Optional[MaskingConfig] = None
force_positive_names: List[str] = dataclasses.field(default_factory=list)
def build(
self,
gridded_operations: GriddedOperations,
vertical_coordinate: HybridSigmaPressureCoordinate,
timestep: datetime.timedelta,
):
return OceanCorrector(
config=self,
gridded_operations=gridded_operations,
vertical_coordinate=vertical_coordinate,
timestep=timestep,
)
[docs] @classmethod
def from_state(cls, state: Mapping[str, Any]) -> "OceanCorrectorConfig":
return dacite.from_dict(
data_class=cls, data=state, config=dacite.Config(strict=True)
)
class OceanCorrector(CorrectorABC):
def __init__(
self,
config: OceanCorrectorConfig,
gridded_operations: GriddedOperations,
vertical_coordinate: HybridSigmaPressureCoordinate,
timestep: datetime.timedelta,
):
self._config = config
self._gridded_operations = gridded_operations
self._vertical_coordinates = vertical_coordinate
self._timestep = timestep
if config.masking is not None:
self._masking = config.masking.build()
else:
self._masking = None
self._stacker = Stacker(OCEAN_FIELD_NAME_PREFIXES)
def __call__(
self,
input_data: TensorMapping,
gen_data: TensorMapping,
forcing_data: TensorMapping,
) -> TensorMapping:
if self._masking is not None:
gen_data = self._masking(self._stacker, gen_data, input_data)
if len(self._config.force_positive_names) > 0:
gen_data = force_positive(gen_data, self._config.force_positive_names)
return gen_data