Source code for mfai.pytorch.namedtensor

"""
A class based NamedTensor implementation for PyTorch, inspired from the unstable PyTorch namedtensors.
"""

from collections.abc import Iterable
from copy import deepcopy
from dataclasses import dataclass
from itertools import chain
from typing import Any, Sequence, Union

import einops
import torch
from tabulate import tabulate
from torch import Tensor


[docs] @dataclass(slots=True) class TensorWrapper: """ Wrapper around a torch tensor. We do this separated dataclass to allow lightning's introspection to see our batch size and move our tensors to the right device, otherwise we have this error/warning: "Trying to infer the `batch_size` from an ambiguous collection ...". """ tensor: Tensor
[docs] class NamedTensor(TensorWrapper): """ NamedTensor is a wrapper around a torch tensor adding several attributes : * a 'names' attribute with the names of the tensor's dimensions (like https://pytorch.org/docs/stable/named_tensor.html). Torch's named tensors are still experimental and subject to change. * a 'feature_names' attribute containing the names of the features along the last dimension of the tensor. NamedTensor can be concatenated along the last dimension using the | operator. nt3 = nt1 | nt2 """ SPATIAL_DIM_NAMES = ("lat", "lon", "ngrid") def __init__( self, tensor: Tensor, names: Sequence[str], feature_names: Sequence[str], feature_dim_name: str = "features", ): if len(tensor.shape) != len(names): raise ValueError( f"Number of names ({len(names)}) must match number of dimensions ({len(tensor.shape)})" ) if tensor.shape[names.index(feature_dim_name)] != len(feature_names): raise ValueError( f"Number of feature names ({len(feature_names)}:{feature_names}) must match " f"number of features ({tensor.shape[names.index(feature_dim_name)]}) in the supplied tensor" ) super().__init__(tensor) self.names: list[str] = list(names) # build lookup table for fast indexing self.feature_names_to_idx = { feature_name: idx for idx, feature_name in enumerate(feature_names) } self.feature_names: list[str] = list(feature_names) self.feature_dim_name = feature_dim_name @property def ndims(self) -> int: """ Number of dimensions of the tensor. """ return len(self.names) @property def num_spatial_dims(self) -> int: """ Number of spatial dimensions of the tensor. """ return len([x for x in self.names if x in self.SPATIAL_DIM_NAMES]) @property def feature_dim_idx(self) -> int: """ Index of the features dimension. """ return self.names.index(self.feature_dim_name) def __str__(self) -> str: head = "--- NamedTensor ---\n" head += f"Names: {self.names}\nTensor Shape: {self.tensor.shape})\nFeatures:\n" table = [ [feature, self[feature].min(), self[feature].max()] for feature in self.feature_names ] headers = ["Feature name", "Min", "Max"] table_string = str(tabulate(table, headers=headers, tablefmt="simple_outline")) return head + table_string def __or__(self, other: Union["NamedTensor", None]) -> "NamedTensor": """ Concatenate two NamedTensors along the features dimension. """ if other is None: return self if not isinstance(other, NamedTensor): raise ValueError("Can only concatenate NamedTensor with NamedTensor") # check features names are distinct between the two tensors if set(self.feature_names) & set(other.feature_names): raise ValueError( f"Feature names must be distinct between the two tensors for" f"unambiguous concat, self:{self.feature_names} other:{other.feature_names}" ) if self.names != other.names: raise ValueError( f"NamedTensors must have the same dimension names to concatenate, self:{self.names} other:{other.names}" ) try: return NamedTensor( torch.cat([self.tensor, other.tensor], dim=self.feature_dim_idx), self.names.copy(), self.feature_names + other.feature_names, ) except Exception as e: raise ValueError(f"Error while concatenating {self} and {other}") from e def __ror__(self, other: Union["NamedTensor", None]) -> "NamedTensor": return self.__or__(other)
[docs] @staticmethod def stack( nts: Sequence["NamedTensor"], dim_name: str, dim: int = 0 ) -> "NamedTensor": """ Stack a list of NamedTensors along a new dimension. """ if len(nts) == 0: raise ValueError("Cannot stack an empty list of NamedTensors") if len(nts) == 1: return nts[0].clone() else: # Check features names are identical between the n named tensors if not all(nt.feature_names == nts[0].feature_names for nt in nts): raise ValueError( "NamedTensors must have the same feature names to stack" ) # Check that all named tensors have the same dim names if not all(nt.names == nts[0].names for nt in nts): raise ValueError( "NamedTensors must have the same dimension names to stack" ) # define new list of dim names, with new dim name inserted at dim names = nts[0].names.copy() names.insert(dim, dim_name) new_tensor = torch.stack([nt.tensor for nt in nts], dim=dim) return NamedTensor(new_tensor, names, nts[0].feature_names.copy())
[docs] @staticmethod def concat(nts: Sequence["NamedTensor"]) -> "NamedTensor": """ Safely concat a list of NamedTensors along the last dimension in one shot. """ if len(nts) == 0: raise ValueError("Cannot concatenate an empty list of NamedTensors") if len(nts) == 1: return nts[0].clone() else: # Check features names are distinct between the n named tensors feature_names: set[str] = set() for nt in nts: if feature_names & set(nt.feature_names): raise ValueError( f"Feature names must be distinct between the named tensors to concat\n" f"Found duplicates: {feature_names & set(nt.feature_names)}" ) feature_names |= set(nt.feature_names) # Check that all named tensors have the same names if not all(nt.names == nts[0].names for nt in nts[1:]): raise ValueError( "NamedTensors must have the same dimension names to concatenate" ) # Check that all named tensors have the same feature dimension name if not all( nt.feature_dim_name == nts[0].feature_dim_name for nt in nts[1:] ): raise ValueError( "NamedTensors must have the same feature dimension name to concatenate" ) # Concat in one shot return NamedTensor( torch.cat([nt.tensor for nt in nts], dim=nts[0].feature_dim_idx), nts[0].names.copy(), list(chain.from_iterable(nt.feature_names for nt in nts)), feature_dim_name=nts[0].feature_dim_name, )
[docs] def dim_index(self, dim_name: str) -> int: """ Return the index of a dimension given its name. """ return self.names.index(dim_name)
[docs] def clone(self) -> "NamedTensor": """Clone (with a deepcopy) the NamedTensor.""" return NamedTensor( tensor=deepcopy(self.tensor).to(self.tensor.device), names=self.names.copy(), feature_names=self.feature_names.copy(), )
def __getitem__(self, feature_name: str) -> Tensor: """ Get one feature from the features dimension of the tensor by name. The returned tensor has the same number of dimensions as the original tensor. """ try: return self.tensor.select( self.names.index(self.feature_dim_name), self.feature_names_to_idx[feature_name], ).unsqueeze(self.names.index(self.feature_dim_name)) except KeyError: raise ValueError( f"Feature {feature_name} not found in {self.feature_names}" )
[docs] def type_(self, new_type: str | torch.dtype) -> None: """ Modify the type of the underlying torch tensor by calling torch's .type method. in_place operation for this class, the internal tensor is replaced by the new one. """ self.tensor = self.tensor.type(new_type)
[docs] def flatten_( self, flatten_dim_name: str, start_dim: int = 0, end_dim: int = -1 ) -> None: """ Flatten the underlying tensor from start_dim to end_dim. Deletes flattened dimension names and insert the new one. """ self.tensor = torch.flatten(self.tensor, start_dim, end_dim) # Remove the flattened dimensions from the names # and insert the replacing one end_dim = len(self.names) if end_dim == -1 else end_dim self.names = ( self.names[:start_dim] + [flatten_dim_name] + self.names[end_dim + 1 :] )
[docs] def unflatten_( self, dim: int, unflattened_size: torch.Size, unflatten_dim_name: Sequence[str] ) -> None: """ Unflatten the dimension dim of the underlying tensor. Insert unflattened_size dimension instead. """ self.tensor = self.tensor.unflatten(dim, unflattened_size) self.names = self.names[:dim] + [*unflatten_dim_name] + self.names[dim + 1 :]
[docs] def squeeze_(self, dim_name: Union[Sequence[str], str]) -> None: """ Squeeze the underlying tensor along the dimension(s) given its/their name(s). """ if isinstance(dim_name, str): dim_name = [dim_name] dim_indices = [self.names.index(name) for name in dim_name] self.tensor = torch.squeeze(self.tensor, dim=dim_indices) for name in dim_name: self.names.remove(name)
[docs] def unsqueeze_(self, dim_name: str, dim_index: int) -> None: """ Insert a new dimension dim_name of size 1 at dim_index. """ self.tensor = torch.unsqueeze(self.tensor, dim_index) self.names.insert(dim_index, dim_name)
[docs] def select_dim(self, dim_name: str, index: int) -> "NamedTensor": """ Return the tensor indexed along the dimension dim_name with the index index. The given dimension is removed from the tensor. See https://pytorch.org/docs/stable/generated/torch.select.html. """ if dim_name == self.feature_dim_name: raise ValueError( "Impossible to select the feature dimension of a NamedTensor." ) return NamedTensor( self.tensor.select(self.names.index(dim_name), index), self.names[: self.names.index(dim_name)] + self.names[self.names.index(dim_name) + 1 :], self.feature_names, feature_dim_name=self.feature_dim_name, )
[docs] def select_tensor_dim(self, dim_name: str, index: int) -> Tensor: """ Same as select_dim but returns a Tensor. Allows the selection of the feature dimension. """ return self.tensor.select(self.names.index(dim_name), index)
[docs] def index_select_dim(self, dim_name: str, indices: Sequence[int]) -> "NamedTensor": """ Return the tensor indexed along the dimension dim_name with the indices tensor. The returned tensor has the same number of dimensions as the original tensor (input). The dimth dimension has the same size as the length of index; other dimensions have the same size as in the original tensor. See https://pytorch.org/docs/stable/generated/torch.index_select.html. """ return NamedTensor( self.tensor.index_select( self.names.index(dim_name), Tensor(indices).type(torch.int64).to(self.device), ), self.names, ( self.feature_names if dim_name != self.feature_dim_name else [self.feature_names[i] for i in indices] ), feature_dim_name=self.feature_dim_name, )
[docs] def index_select_tensor_dim( self, dim_name: str, indices: Sequence[int] ) -> torch.Tensor: """ Same as index_select_dim but returns a torch.tensor, but returns a Tensor. """ return self.tensor.index_select( self.names.index(dim_name), Tensor(indices).type(torch.int64).to(self.device), )
[docs] def dim_size(self, dim_name: str) -> int: """ Return the size of a dimension given its name. """ try: return self.tensor.size(self.names.index(dim_name)) except ValueError as ve: raise ValueError(f"Dimension {dim_name} not found in {self.names}") from ve
@property def spatial_dim_idx(self) -> list[int]: """ Return the indices of the spatial dimensions in the tensor. """ return sorted( self.names.index(name) for name in set(self.SPATIAL_DIM_NAMES).intersection(set(self.names)) )
[docs] def unsqueeze_and_expand_from_(self, other: "NamedTensor") -> None: """ Unsqueeze and expand the tensor to have the same number of spatial dimensions as another NamedTensor. Injects new dimensions where the missing names are. """ missing_names = set(other.names) - set(self.names) missing_names &= set(self.SPATIAL_DIM_NAMES) if missing_names: index_to_unsqueeze = [ (name, other.names.index(name)) for name in missing_names ] for name, idx in sorted(index_to_unsqueeze, key=lambda x: x[1]): self.tensor = torch.unsqueeze(self.tensor, idx) self.names.insert(idx, name) expander = [] for _, name in enumerate(self.names): expander.append(other.dim_size(name) if name in missing_names else -1) self.tensor = self.tensor.expand(*expander)
[docs] def iter_dim(self, dim_name: str) -> Iterable["NamedTensor"]: """ Iterate over the tensor along a given dimension. """ for i in range(self.dim_size(dim_name)): yield self.select_dim(dim_name, i)
[docs] def iter_tensor_dim(self, dim_name: str) -> Iterable[Tensor]: """ Iterate over the tensor along a given dimension. """ for i in range(self.dim_size(dim_name)): yield self.select_tensor_dim(dim_name, i)
[docs] def rearrange_(self, einops_str: str) -> None: """ Rearrange in place the underlying tensor dimensions using einops syntax. For now only supports re-ordering of dimensions. """ old_dims_str, new_dims_str = einops_str.split("->") old_dims = old_dims_str.split(" ")[:-1] new_dims = new_dims_str.split(" ")[1:] # check that the number of dims and dim names match if not set(self.names) == set(old_dims) == set(new_dims): raise ValueError( f"Dimensions in rearrange_ {old_dims} do not match tensor dimensions {self.names}" ) self.tensor = einops.rearrange(self.tensor, einops_str) self.names = new_dims
[docs] @staticmethod def new_like(tensor: Tensor, other: "NamedTensor") -> "NamedTensor": """ Create a new NamedTensor with the same names and feature names as another NamedTensor. """ return NamedTensor(tensor, other.names.copy(), other.feature_names.copy())
[docs] @staticmethod def expand_to_batch_like(tensor: Tensor, other: "NamedTensor") -> "NamedTensor": """ Create a new NamedTensor with the same names and feature names as another NamedTensor with an extra first dimension called 'batch' using the supplied tensor. Supplied new 'batched' tensor must have one more dimension than other. """ names = ["batch"] + other.names if tensor.dim() != len(names): raise ValueError( f"Tensor dim {tensor.dim()} must match number of names {len(names)} with extra batch dim" ) return NamedTensor(tensor, ["batch"] + other.names, other.feature_names.copy())
@property def device(self) -> torch.device: return self.tensor.device
[docs] def pin_memory_(self) -> None: """ 'In place' operation to pin the underlying tensor to memory. """ self.tensor = self.tensor.pin_memory()
[docs] def to_(self, *args: Any, **kwargs: Any) -> None: """ 'In place' operation to call torch's 'to' method on the underlying tensor. """ self.tensor = self.tensor.to(*args, **kwargs)
[docs] @staticmethod def collate_fn( batch: Sequence["NamedTensor"], pad_dims: tuple[str, ...] | tuple[()] = (), pad_value: int | float = 0, ) -> "NamedTensor": """ Collate a list of NamedTensors into a batched single NamedTensor. Optionnally pads the dimensions specified in pad_dims with pad_value. """ if len(batch) == 0: raise ValueError("Cannot collate an empty list of NamedTensors") if len(batch) == 1: # add batch dim to the single namedtensor (in place operation) batch[0].unsqueeze_(dim_name="batch", dim_index=0) return batch[0] else: if len(pad_dims): # Find the maximum size for each dimension in pad_dims looking # at each tensor of each NamedTensor in the batch max_sizes: dict[str, int] = { dim: max(nt.dim_size(dim) for nt in batch) for dim in pad_dims } # Pad each tensor in the batch to the maximum size, for each dim we pad after the data padded_batch = [] for nt in batch: padded_shape = list(nt.tensor.shape) do_pad: bool = False for dim in pad_dims: dim_idx = nt.dim_index(dim) if nt.dim_size(dim) < max_sizes[dim]: padding_size = max_sizes[dim] - nt.dim_size(dim) padded_shape[dim_idx] += padding_size do_pad = True if do_pad: padded_tensor = nt.tensor.new_full( padded_shape, pad_value, ) slicer = tuple(slice(0, nt.dim_size(dim)) for dim in nt.names) padded_tensor[slicer] = nt.tensor else: padded_tensor = nt.tensor.clone() padded_batch.append(NamedTensor.new_like(padded_tensor, nt)) return NamedTensor.stack(padded_batch, dim_name="batch", dim=0) else: return NamedTensor.stack(batch, dim_name="batch", dim=0)