[docs]@ModuleSelector.register("Samudra")@dataclasses.dataclassclassSamudraBuilder(ModuleConfig):""" Configuration for the M2Lines Samudra architecture. """ch_width:list[int]=dataclasses.field(default_factory=lambda:[200,250,300,400])n_layers:list[int]=dataclasses.field(default_factory=lambda:[1,1,1,1])dilation:list[int]=dataclasses.field(default_factory=lambda:[1,2,4,8])pad:str="circular"norm:str="instance"norm_kwargs:Mapping[str,Any]=dataclasses.field(default_factory=dict)upscale_factor:int=4checkpoint_strategy:Literal["all","simple"]|None=Nonedef__post_init__(self):if"num_features"inself.norm_kwargs:raiseValueError("norm_kwargs should not have num_features")if"normalized_shape"inself.norm_kwargs:raiseValueError("norm_kwargs should not have normalized_shape")
[docs]defbuild(self,n_in_channels:int,n_out_channels:int,dataset_info:DatasetInfo,):iflen(dataset_info.all_labels)>0:raiseValueError("Samudra does not support labels")returnSamudra(input_channels=n_in_channels,output_channels=n_out_channels,ch_width=self.ch_width,dilation=self.dilation,n_layers=self.n_layers,pad=self.pad,norm=self.norm,norm_kwargs=self.norm_kwargs,upscale_factor=self.upscale_factor,checkpoint_strategy=self.checkpoint_strategy,)
[docs]@ModuleSelector.register("FloeNet")@dataclasses.dataclassclassFloeNetBuilder(ModuleConfig):""" Configuration for the M2Lines FloeNet architecture. """latent_dimension:int=256activation:str="SiLU"meshes:int=6M0:int=4bias:bool=Trueradius_fraction:float=1.0layernorm:bool=Trueprocessor_steps:int=4residual:bool=Trueis_ocean:bool=True
[docs]defbuild(self,n_in_channels:int,n_out_channels:int,dataset_info:DatasetInfo,):ifnotGRAPHCAST_AVAIL:raiseImportError("GraphCast dependencies (trimesh, rtree) not available.")returnGraphCast(input_channels=n_in_channels,output_channels=n_out_channels,dataset_info=dataset_info,latent_dimension=self.latent_dimension,activation=self.activation,meshes=self.meshes,M0=self.M0,bias=self.bias,radius_fraction=self.radius_fraction,layernorm=self.layernorm,processor_steps=self.processor_steps,residual=self.residual,is_ocean=self.is_ocean,)