importdataclassesfromcollections.abcimportMappingfromtypingimportNewTypeimporttorchTensorMapping=Mapping[str,torch.Tensor]TensorDict=dict[str,torch.Tensor]EnsembleTensorDict=NewType("EnsembleTensorDict",TensorDict)EnsembleTensorDict.__doc__="""A dictionary of tensors with an explicit ensemble (sample) dimension, whereensemble members represent multiple predictions for the same initial condition.The ensemble dimension is the second dimension of the tensors,while the batch dimension is the first."""
[docs]@dataclasses.dataclassclassSlice:""" Configuration of a python `slice` built-in. Required because `slice` cannot be initialized directly by dacite. Parameters: start: Start index of the slice. stop: Stop index of the slice. step: Step of the slice. """start:int|None=Nonestop:int|None=Nonestep:int|None=None@propertydefslice(self)->slice:returnslice(self.start,self.stop,self.step)defcontains(self,value:int)->bool:start=self.startifself.startisnotNoneelse0stop=self.stopifself.stopisnotNoneelsefloat("inf")step=self.stepifself.stepisnotNoneelse1returnstart<=value<stopand(value-start)%step==0