importabcimportdataclassesfromtypingimportAny,Callable,ClassVar,Mapping,Tuple,TypeVar,Unionimportdacitefromtorchimportnnfrom.registryimportRegistry@dataclasses.dataclassclassModuleConfig(abc.ABC):""" Builds a nn.Module given information about the input and output channels and the image shape. This is a "Config" as in practice it is a dataclass loaded directly from yaml, allowing us to specify details of the network architecture in a config file. """@abc.abstractmethoddefbuild(self,n_in_channels:int,n_out_channels:int,img_shape:Tuple[int,int],)->nn.Module:""" Build a nn.Module given information about the input and output channels and the image shape. Args: n_in_channels: number of input channels n_out_channels: number of output channels img_shape: shape of last two dimensions of data, e.g. latitude and longitude. Returns: a nn.Module """...@classmethoddeffrom_state(cls,state:Mapping[str,Any])->"ModuleConfig":""" Create a ModuleSelector from a dictionary containing all the information needed to build a ModuleConfig. """returndacite.from_dict(data_class=cls,data=state,config=dacite.Config(strict=True))MT=TypeVar("MT",bound=nn.Module)
[docs]@dataclasses.dataclassclassModuleSelector:""" A dataclass containing all the information needed to build a ModuleConfig, including the type of the ModuleConfig and the data needed to build it. This is helpful as ModuleSelector can be serialized and deserialized without any additional information, whereas to load a ModuleConfig you would need to know the type of the ModuleConfig being loaded. It is also convenient because ModuleSelector is a single class that can be used to represent any ModuleConfig, whereas ModuleConfig is a protocol that can be implemented by many different classes. Parameters: type: the type of the ModuleConfig config: data for a ModuleConfig instance of the indicated type """type:strconfig:Mapping[str,Any]registry:ClassVar[Registry]=Registry()def__post__init(self):ifself.registryisnotRegistry():raiseValueError("ModuleSelector.registry should not be set manually")
[docs]defbuild(self,n_in_channels:int,n_out_channels:int,img_shape:Tuple[int,int],)->nn.Module:""" Build a nn.Module given information about the input and output channels and the image shape. Args: n_in_channels: number of input channels n_out_channels: number of output channels img_shape: last two dimensions of data, corresponding to lat and lon when using FourCastNet conventions Returns: a nn.Module """instance=self.registry.from_dict(self.get_state())returninstance.build(n_in_channels=n_in_channels,n_out_channels=n_out_channels,img_shape=img_shape,)
[docs]defget_state(self)->Union[Mapping[str,Any],dict]:""" Get a dictionary containing all the information needed to build a ModuleConfig. """return{"type":self.type,"config":self.config}
[docs]@classmethoddeffrom_state(cls,state:Mapping[str,Any])->"ModuleSelector":""" Create a ModuleSelector from a dictionary containing all the information needed to build a ModuleConfig. """returndacite.from_dict(data_class=cls,data=state,config=dacite.Config(strict=True))
[docs]@classmethoddefget_available_types(cls):"""This class method is used to expose all available types of Modules."""returncls(type="",config={}).registry._types.keys()