importdataclassesfromcollections.abcimportMappingfromtypingimportNewTypeimporttorchTensorMapping=Mapping[str,torch.Tensor]TensorDict=dict[str,torch.Tensor]EnsembleTensorDict=NewType("EnsembleTensorDict",TensorDict)EnsembleTensorDict.__doc__="""A dictionary of tensors with an explicit ensemble (sample) dimension, whereensemble members represent multiple predictions for the same initial condition.The ensemble dimension is the second dimension of the tensors,while the batch dimension is the first."""def_shift_bound(value:int|None,shift:int,out_of_bounds_value:int|None)->int|None:""" Shifts a bounding value of a slice relative to a starting index where positive shift will shift the bound left (decrease the index), and negative shift will shift the bound right (increase the index). When shifting left, if the shifted value is less than 0, it is considered out of bounds and replaced with `out_of_bounds_value`. If the value is None, it remains None. Negative initial bound values are not supported. """ifvalueisNone:returnNoneelifvalue<0:raiseValueError("Negative slice bounds as an initial value are not supported")shifted=value-shiftifshifted<0:returnout_of_bounds_valuereturnshifted
[docs]@dataclasses.dataclassclassSlice:""" Configuration of a python `slice` built-in. Required because `slice` cannot be initialized directly by dacite. Parameters: start: Start index of the slice. stop: Stop index of the slice. step: Step of the slice. """start:int|None=Nonestop:int|None=Nonestep:int|None=None@propertydefslice(self)->slice:returnslice(self.start,self.stop,self.step)defcontains(self,value:int)->bool:start=self.startifself.startisnotNoneelse0stop=self.stopifself.stopisnotNoneelsefloat("inf")step=self.stepifself.stepisnotNoneelse1returnstart<=value<stopand(value-start)%step==0
[docs]@classmethoddefshift_left(cls,original:"Slice",start_index:int)->"Slice":""" Shift the slice relative to the start index of a group of data to capture requested correct quantities while still respecting batches. E.g., If slice is (0, 10, 1) and start_index is 5, the new slice would be (None, 5, 1). Raises: ValueError: If trying to shift negative valued slice object, since that is not defined without knowing the total sequence length. """new_start=_shift_bound(original.start,start_index,None)new_stop=_shift_bound(original.stop,start_index,0)returncls(start=new_start,stop=new_stop,step=original.step)