Source code for fme.ace.registry.m2lines

import dataclasses
from collections.abc import Mapping
from typing import Any

from fme.ace.models.ocean.m2lines.samudra import Samudra
from fme.ace.registry.registry import ModuleConfig, ModuleSelector


[docs]@ModuleSelector.register("Samudra") @dataclasses.dataclass class SamudraBuilder(ModuleConfig): """ Configuration for the M2Lines Samudra architecture. """ ch_width: list[int] = dataclasses.field( default_factory=lambda: [200, 250, 300, 400] ) n_layers: list[int] = dataclasses.field(default_factory=lambda: [1, 1, 1, 1]) dilation: list[int] = dataclasses.field(default_factory=lambda: [1, 2, 4, 8]) pad: str = "circular" norm: str = "instance" norm_kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict) upscale_factor: int = 4 def __post_init__(self): if "num_features" in self.norm_kwargs: raise ValueError("norm_kwargs should not have num_features") if "normalized_shape" in self.norm_kwargs: raise ValueError("norm_kwargs should not have normalized_shape")
[docs] def build( self, n_in_channels: int, n_out_channels: int, img_shape: tuple[int, int], ): return Samudra( input_channels=n_in_channels, output_channels=n_out_channels, ch_width=self.ch_width, dilation=self.dilation, n_layers=self.n_layers, pad=self.pad, norm=self.norm, norm_kwargs=self.norm_kwargs, upscale_factor=self.upscale_factor, )