unetrpp

UnetR++ Vision transformer based on: “Shaker et al., Adapted from https://github.com/Amshaker/unetr_plus_plus Added 2d support and Bilinear interpolation for upsampling.

class mfai.pytorch.models.unetrpp.EPA(input_size, hidden_size, num_heads=4, qkv_bias=False, channel_attn_drop=0.1, spatial_attn_drop=0.1, proj_size=64, attention_code='torch')[source]

Bases: Module

Efficient Paired Attention Block, based on: “Shaker et al., UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation” Modifications : - adds compatibility with 2d inputs - adds an option to use torch’s scaled dot product instead of the original implementation This should enable the use of flash attention in the future.

Parameters:
  • input_size (int)

  • hidden_size (int)

  • num_heads (int)

  • qkv_bias (bool)

  • channel_attn_drop (float)

  • spatial_attn_drop (float)

  • proj_size (int)

  • attention_code (str)

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor)

Return type:

Tensor

no_weight_decay()[source]
Return type:

set[str]

class mfai.pytorch.models.unetrpp.LayerNorm(normalized_shape, eps=1e-06, data_format='channels_last')[source]

Bases: Module

Parameters:
forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor)

Return type:

Tensor

class mfai.pytorch.models.unetrpp.TransformerBlock(input_size, hidden_size, num_heads, dropout_rate=0.0, pos_embed=False, spatial_dims=2, proj_size=64, attention_code='torch')[source]

Bases: Module

A transformer block, based on: “Shaker et al., UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation” Modified to work with both 2d and 3d data (spatial_dims).

Parameters:
  • input_size (int)

  • hidden_size (int)

  • num_heads (int)

  • dropout_rate (float)

  • pos_embed (bool)

  • spatial_dims (int)

  • proj_size (int)

  • attention_code (str)

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor)

Return type:

Tensor

class mfai.pytorch.models.unetrpp.UNetRPP(in_channels, out_channels, input_shape, settings=UNetRPPSettings(hidden_size=256, num_heads_encoder=4, num_heads_decoder=4, pos_embed='perceptron', norm_name='instance', dropout_rate=0.0, depths=(3, 3, 3, 3), conv_op='Conv2d', linear_upsampling=False, downsampling_rate=4, decoder_proj_size=64, encoder_proj_sizes=(64, 64, 64, 32), autopad_enabled=False, add_skip_connections=True, attention_code='torch'))[source]

Bases: BaseModel, AutoPaddingModel

UNetR++ based on: “Shaker et al., UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation”.

Parameters:
features_last: bool = False
forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor | list[Tensor]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor)

Return type:

Tensor | list[Tensor]

model_type = 3
property num_spatial_dims: int

Returns the number of spatial dimensions of the instanciated model.

onnx_supported = False
proj_feat(x)[source]
Return type:

Tensor

Parameters:

x (Tensor)

register: bool = True
property settings: UNetRPPSettings

Returns the settings instance used to configure for this model.

settings_kls

alias of UNetRPPSettings

supported_num_spatial_dims = (2, 3)
validate_input_shape(input_shape)[source]
Given an input shape, verifies whether the inputs fit with the

calling model’s specifications.

Parameters:

input_shape (Size) – The shape of the input data, excluding any batch dimension and channel dimension. For example, for a batch of 2D tensors of shape [B,C,W,H], [W,H] should be passed. For 3D data instead of shape [B,C,W,H,D], instead, [W,H,D] should be passed.

Returns:

Returns a tuple where the first element is a boolean signaling whether the given input shape

already fits the model’s requirements. If that value is False, the second element contains the closest shape that fits the model, otherwise it will be None.

Return type:

tuple[bool, Size]

class mfai.pytorch.models.unetrpp.UNetRPPEncoder(input_size=[32768, 4096, 512, 64], dims=[32, 64, 128, 256], depths=[3, 3, 3, 3], num_heads=4, spatial_dims=2, in_channels=4, dropout=0.0, transformer_dropout_rate=0.1, downsampling_rate=4, proj_sizes=(64, 64, 64, 32), attention_code='torch')[source]

Bases: Module

Parameters:
forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: tuple[Tensor, list[Tensor]]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor)

Return type:

tuple[Tensor, list[Tensor]]

forward_features(x)[source]
Return type:

tuple[Tensor, list[Tensor]]

Parameters:

x (Tensor)

class mfai.pytorch.models.unetrpp.UNetRPPSettings(hidden_size=256, num_heads_encoder=4, num_heads_decoder=4, pos_embed='perceptron', norm_name='instance', dropout_rate=0.0, depths=(3, 3, 3, 3), conv_op='Conv2d', linear_upsampling=False, downsampling_rate=4, decoder_proj_size=64, encoder_proj_sizes=(64, 64, 64, 32), autopad_enabled=False, add_skip_connections=True, attention_code='torch')[source]

Bases: object

Settings dataclass for UNetRPP. Contains all the hyperparameters needed to initialize the model.

Parameters:
  • hidden_size (int) – dimensions of the last encoder.

  • num_heads – number of attention heads.

  • pos_embed (str) – position embedding layer type.

  • norm_name (Union[tuple, str]) – feature normalization type and arguments.

  • dropout_rate (float) – faction of the input units to drop.

  • depths (tuple[int, ...]) – number of blocks for each stage.

  • conv_op (str) – type of convolution operation.

  • do_ds – use deep supervision to compute the loss.

  • num_heads_encoder (int)

  • num_heads_decoder (int)

  • linear_upsampling (bool)

  • downsampling_rate (int)

  • decoder_proj_size (int)

  • encoder_proj_sizes (tuple[int, ...])

  • autopad_enabled (bool)

  • add_skip_connections (bool)

  • attention_code (str)

add_skip_connections: bool = True
attention_code: str = 'torch'
autopad_enabled: bool = False
conv_op: str = 'Conv2d'
decoder_proj_size: int = 64
depths: tuple[int, ...] = (3, 3, 3, 3)
do_ds = False
downsampling_rate: int = 4
dropout_rate: float = 0.0
encoder_proj_sizes: tuple[int, ...] = (64, 64, 64, 32)
classmethod from_dict(kvs, *, infer_missing=False)
Return type:

TypeVar(A, bound= DataClassJsonMixin)

Parameters:

kvs (dict | list | str | int | float | bool | None)

classmethod from_json(s, *, parse_float=None, parse_int=None, parse_constant=None, infer_missing=False, **kw)
Return type:

TypeVar(A, bound= DataClassJsonMixin)

Parameters:

s (str | bytes | bytearray)

hidden_size: int = 256
linear_upsampling: bool = False
norm_name: Union[tuple, str] = 'instance'
num_heads_decoder: int = 4
num_heads_encoder: int = 4
pos_embed: str = 'perceptron'
classmethod schema(*, infer_missing=False, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)
Return type:

SchemaF[TypeVar(A, bound= DataClassJsonMixin)]

Parameters:
spatial_dims = 2
to_dict(encode_json=False)
Return type:

Dict[str, Union[dict, list, str, int, float, bool, None]]

to_json(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, indent=None, separators=None, default=None, sort_keys=False, **kw)
Return type:

str

Parameters:
class mfai.pytorch.models.unetrpp.UNetRUpBlock(spatial_dims, in_channels, out_channels, kernel_size, upsample_kernel_size, norm_name, num_heads=4, out_size=0, depth=3, conv_decoder=False, linear_upsampling=False, proj_size=64, attention_code='torch')[source]

Bases: Module

Parameters:
forward(inp, skip=None)[source]

Forward pass: 1. Upsampling using bi/tri-linear OR Conv{2,3}dTranspose. 2. Adds skip connection if available. 3. Conv or Transformer block.

Return type:

Tensor

Parameters:
mfai.pytorch.models.unetrpp.init_(tensor)[source]
Return type:

Tensor

Parameters:

tensor (Tensor)

mfai.pytorch.models.unetrpp.trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0)[source]

Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution \(\mathcal{N}(\text{mean}, \text{std}^2)\) with values outside \([a, b]\) redrawn until they are within the bounds. The method used for generating the random values works best when \(a \leq \text{mean} \leq b\).

NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are applied while sampling the normal with mean/std applied, therefore a, b args should be adjusted to match the range of mean, std args.

Parameters:
  • tensor (Tensor) – an n-dimensional Tensor

  • mean (float) – the mean of the normal distribution

  • std (float) – the standard deviation of the normal distribution

  • a (float) – the minimum cutoff value

  • b (float) – the maximum cutoff value

Return type:

Tensor

Examples

>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)