Source code for fme.core.packer

from typing import List

import torch
import torch.jit

from fme.core.typing_ import TensorDict


class DataShapesNotUniform(ValueError):
    """Indicates that a set of tensors do not all have the same shape."""

    pass


[docs]class Packer: """ Responsible for packing tensors into a single tensor. """ def __init__(self, names: List[str]): self.names = names
[docs] def pack(self, tensors: TensorDict, axis=0) -> torch.Tensor: """ Packs tensors into a single tensor, concatenated along a new axis. Args: tensors: Dict from names to tensors. axis: index for new concatenation axis. Raises: DataShapesNotUniform: when packed tensors do not all have the same shape. """ shape = next(iter(tensors.values())).shape for name in tensors: if tensors[name].shape != shape: raise DataShapesNotUniform( ( f"Cannot pack tensors of different shapes. " 'Expected "{shape}" got "{tensors[name].shape}"' ) ) return _pack(tensors, self.names, axis=axis)
def unpack(self, tensor: torch.Tensor, axis=0) -> TensorDict: return _unpack(tensor, self.names, axis=axis)
@torch.jit.script def _pack(tensors: TensorDict, names: List[str], axis: int = 0) -> torch.Tensor: return torch.cat([tensors[n].unsqueeze(axis) for n in names], dim=axis) @torch.jit.script def _unpack(tensor: torch.Tensor, names: List[str], axis: int = 0) -> TensorDict: return {n: tensor.select(axis, index=i) for i, n in enumerate(names)}