importtorchimporttorch.jitfromfme.core.typing_importTensorDictclassDataShapesNotUniform(ValueError):"""Indicates that a set of tensors do not all have the same shape."""pass
[docs]classPacker:""" Responsible for packing tensors into a single tensor. """def__init__(self,names:list[str]):self.names=names
[docs]defpack(self,tensors:TensorDict,axis=0)->torch.Tensor:""" Packs tensors into a single tensor, concatenated along a new axis. Args: tensors: Dict from names to tensors. axis: index for new concatenation axis. Raises: DataShapesNotUniform: when packed tensors do not all have the same shape. """shape=next(iter(tensors.values())).shapefornameintensors:iftensors[name].shape!=shape:raiseDataShapesNotUniform(f"Cannot pack tensors of different shapes. "f'Expected "{shape}" got "{tensors[name].shape}"')return_pack(tensors,self.names,axis=axis)