[docs]@dataclasses.dataclassclassSlabOceanConfig:""" Configuration for a slab ocean model. Parameters: mixed_layer_depth_name: Name of the mixed layer depth field. q_flux_name: Name of the heat flux field. """mixed_layer_depth_name:strq_flux_name:str@propertydefnames(self)->list[str]:return[self.mixed_layer_depth_name,self.q_flux_name]
[docs]@dataclasses.dataclassclassOceanConfig:""" Configuration for determining sea surface temperature from an ocean model. Parameters: surface_temperature_name: Name of the sea surface temperature field. ocean_fraction_name: Name of the ocean fraction field. interpolate: If True, interpolate between ML-predicted surface temperature and ocean-predicted surface temperature according to ocean_fraction. If False, only use ocean-predicted surface temperature where ocean_fraction>=0.5. slab: If provided, use a slab ocean model to predict surface temperature. """surface_temperature_name:strocean_fraction_name:strinterpolate:bool=Falseslab:SlabOceanConfig|None=Nonedefbuild(self,in_names:list[str],out_names:list[str],timestep:datetime.timedelta,)->"Ocean":ifnot(self.surface_temperature_nameinin_namesandself.surface_temperature_nameinout_names):raiseValueError("To use a surface ocean model, the surface temperature must be present"f" in_names and out_names, but {self.surface_temperature_name} is not.")returnOcean(config=self,timestep=timestep)@propertydefforcing_names(self)->list[str]:names=[self.ocean_fraction_name]ifself.slabisNone:names.append(self.surface_temperature_name)else:names.extend(self.slab.names)returnlist(set(names))
classOcean:"""Overwrite sea surface temperature with that predicted from some ocean model."""def__init__(self,config:OceanConfig,timestep:datetime.timedelta):""" Args: config: Configuration for the surface ocean model. timestep: Timestep of the model. """self.surface_temperature_name=config.surface_temperature_nameself.ocean_fraction_name=config.ocean_fraction_nameself.prescriber=Prescriber(prescribed_name=config.surface_temperature_name,mask_name=config.ocean_fraction_name,mask_value=1,interpolate=config.interpolate,)self._forcing_names=config.forcing_namesifconfig.slabisNone:self.type="prescribed"else:self.type="slab"self.mixed_layer_depth_name=config.slab.mixed_layer_depth_nameself.q_flux_name=config.slab.q_flux_nameself.timestep=timestepdef__call__(self,input_data:TensorMapping,gen_data:TensorMapping,target_data:TensorMapping,)->TensorDict:""" Args: input_data: Denormalized input data for current step. gen_data: Denormalized output data for current step. target_data: Denormalized data that includes mask and forcing data. Assumed to correspond to the same time step as gen_data. Returns: gen_data with sea surface temperature overwritten by ocean model. """ifself.type=="prescribed":next_step_temperature=target_data[self.surface_temperature_name]elifself.type=="slab":temperature_tendency=mixed_layer_temperature_tendency(AtmosphereData(gen_data).net_surface_energy_flux_without_frozen_precip,target_data[self.q_flux_name],target_data[self.mixed_layer_depth_name],)next_step_temperature=(input_data[self.surface_temperature_name]+temperature_tendency*self.timestep.total_seconds())else:raiseNotImplementedError(f"Ocean type={self.type} is not implemented")returnself.prescriber(target_data,gen_data,{self.surface_temperature_name:next_step_temperature},)@propertydefforcing_names(self)->list[str]:"""These are the variables required from the forcing data."""returnself._forcing_namesdefmixed_layer_temperature_tendency(f_net:torch.Tensor,q_flux:torch.Tensor,depth:torch.Tensor,density=DENSITY_OF_WATER,specific_heat=SPECIFIC_HEAT_OF_WATER,)->torch.Tensor:""" Args: f_net: Net surface energy flux in W/m^2. q_flux: Convergence of ocean heat transport in W/m^2. depth: Mixed layer depth in m. density (optional): Density of water in kg/m^3. specific_heat (optional): Specific heat of water in J/kg/K. Returns: Temperature tendency of mixed layer in K/s. """return(f_net+q_flux)/(density*depth*specific_heat)