Source code for mfai.pytorch.models.unetrpp

"""
UnetR++ Vision transformer based on: "Shaker et al.,
Adapted from https://github.com/Amshaker/unetr_plus_plus
Added 2d support and Bilinear interpolation for upsampling.
"""

import warnings
from dataclasses import dataclass
from math import ceil, erf, sqrt
from typing import Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses_json import dataclass_json
from monai.networks.blocks.dynunet_block import (
    UnetOutBlock,
    UnetResBlock,
    get_conv_layer,
    get_output_padding,
    get_padding,
)
from monai.networks.layers.utils import get_norm_layer
from monai.utils import optional_import
from torch import Tensor
from torch.nn.functional import scaled_dot_product_attention

from .base import AutoPaddingModel, BaseModel, ModelType


def _trunc_normal_(
    tensor: Tensor, mean: float, std: float, a: float, b: float
) -> Tensor:
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x: float) -> float:
        # Computes standard normal cumulative distribution function
        return (1.0 + erf(x / sqrt(2.0))) / 2.0

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn(
            "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
            "The distribution of values may be incorrect.",
            stacklevel=2,
        )

    # Values are generated by using a truncated uniform distribution and
    # then using the inverse CDF for the normal distribution.
    # Get upper and lower cdf values
    lower = norm_cdf((a - mean) / std)
    upper = norm_cdf((b - mean) / std)

    # Uniformly fill tensor with values from [l, u], then translate to
    # [2l-1, 2u-1].
    tensor.uniform_(2 * lower - 1, 2 * upper - 1)

    # Use inverse cdf transform for normal distribution to get truncated
    # standard normal
    tensor.erfinv_()

    # Transform to proper mean, std
    tensor.mul_(std * sqrt(2.0))
    tensor.add_(mean)

    # Clamp to ensure it's in the proper range
    tensor.clamp_(min=a, max=b)
    return tensor


[docs] def trunc_normal_( tensor: Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0, ) -> Tensor: r"""Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` with values outside :math:`[a, b]` redrawn until they are within the bounds. The method used for generating the random values works best when :math:`a \leq \text{mean} \leq b`. NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are applied while sampling the normal with mean/std applied, therefore a, b args should be adjusted to match the range of mean, std args. Args: tensor: an n-dimensional `Tensor` mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value Examples: >>> w = torch.empty(3, 5) >>> nn.init.trunc_normal_(w) """ with torch.no_grad(): return _trunc_normal_(tensor, mean, std, a, b)
[docs] class LayerNorm(nn.Module): def __init__( self, normalized_shape: int, eps: float = 1e-6, data_format: str = "channels_last", ) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(normalized_shape)) self.bias = nn.Parameter(torch.zeros(normalized_shape)) self.eps = eps self.data_format = data_format if self.data_format not in ["channels_last", "channels_first"]: raise NotImplementedError self.normalized_shape = (normalized_shape,)
[docs] def forward(self, x: Tensor) -> Tensor: if self.data_format == "channels_last": return F.layer_norm( x, self.normalized_shape, self.weight, self.bias, self.eps ) elif self.data_format == "channels_first": u = x.mean(1, keepdim=True) s = (x - u).pow(2).mean(1, keepdim=True) x = (x - u) / torch.sqrt(s + self.eps) x = self.weight[:, None, None] * x + self.bias[:, None, None] return x else: raise NotImplementedError( f"LayerNorm with data_format {self.data_format} is not supported." )
[docs] class TransformerBlock(nn.Module): """ A transformer block, based on: "Shaker et al., UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" Modified to work with both 2d and 3d data (spatial_dims). """ def __init__( self, input_size: int, hidden_size: int, num_heads: int, dropout_rate: float = 0.0, pos_embed: bool = False, spatial_dims: int = 2, proj_size: int = 64, attention_code: str = "torch", ) -> None: """ Args: input_size: the size of the input for each stage. hidden_size: dimension of hidden layer. num_heads: number of attention heads. dropout_rate: faction of the input units to drop. pos_embed: bool argument to determine if positional embedding is used. spatial_dims: number of spatial dimensions (2 or 3). proj_size: size of the projection space for Spatial Attention. attention_code: type of attention implementation to use. See EPA for more details. """ super().__init__() if not (0 <= dropout_rate <= 1): raise ValueError("dropout_rate should be between 0 and 1.") if hidden_size % num_heads != 0: print("Hidden size is ", hidden_size) print("Num heads is ", num_heads) raise ValueError("hidden_size should be divisible by num_heads.") self.norm = nn.LayerNorm(hidden_size) self.gamma = nn.Parameter(1e-6 * torch.ones(hidden_size), requires_grad=True) self.epa_block = EPA( input_size=input_size, hidden_size=hidden_size, num_heads=num_heads, channel_attn_drop=dropout_rate, spatial_attn_drop=dropout_rate, proj_size=proj_size, attention_code=attention_code, ) self.conv51 = UnetResBlock( spatial_dims, hidden_size, hidden_size, kernel_size=3, stride=1, norm_name="batch", ) if spatial_dims == 2: self.conv8 = nn.Sequential( nn.Dropout2d(0.1, False), nn.Conv2d(hidden_size, hidden_size, 1) ) else: self.conv8 = nn.Sequential( nn.Dropout3d(0.1, False), nn.Conv3d(hidden_size, hidden_size, 1) ) self.pos_embed = None self.spatial_dims = spatial_dims if pos_embed: self.pos_embed = nn.Parameter(torch.zeros(1, input_size, hidden_size))
[docs] def forward(self, x: Tensor) -> Tensor: if self.spatial_dims == 2: B, C, H, W = x.shape x = x.reshape(B, C, H * W).permute(0, 2, 1) else: B, C, H, W, D = x.shape x = x.reshape(B, C, H * W * D).permute(0, 2, 1) if self.pos_embed is not None: x = x + self.pos_embed attn = x + self.gamma * self.epa_block(self.norm(x)) if self.spatial_dims == 2: attn_skip = attn.reshape(B, H, W, C).permute(0, 3, 1, 2) else: attn_skip = attn.reshape(B, H, W, D, C).permute(0, 4, 1, 2, 3) attn = self.conv51(attn_skip) x = attn_skip + self.conv8(attn) return x
[docs] def init_(tensor: Tensor) -> Tensor: dim = tensor.shape[-1] std = 1 / sqrt(dim) tensor.uniform_(-std, std) return tensor
[docs] class EPA(nn.Module): """ Efficient Paired Attention Block, based on: "Shaker et al., UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation" Modifications : - adds compatibility with 2d inputs - adds an option to use torch's scaled dot product instead of the original implementation This should enable the use of flash attention in the future. """ def __init__( self, input_size: int, hidden_size: int, num_heads: int = 4, qkv_bias: bool = False, channel_attn_drop: float = 0.1, spatial_attn_drop: float = 0.1, proj_size: int = 64, attention_code: str = "torch", ): super().__init__() self.num_heads = num_heads if attention_code not in ["torch", "flash", "manual"]: raise NotImplementedError( "Attention code should be one of 'torch', 'flash' or 'manual'" ) self.attention_code = attention_code if attention_code == "flash": from flash_attn import flash_attn_func self.attn_func = flash_attn_func self.use_scaled_dot_product_CA = True elif attention_code == "torch": self.attn_func = scaled_dot_product_attention self.use_scaled_dot_product_CA = True else: self.use_scaled_dot_product_CA = False # qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel) self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias) # E and F are projection matrices with shared weights used in spatial attention module to project # keys and values from HWD-dimension to P-dimension self.EF = nn.Parameter(init_(torch.zeros(input_size, proj_size))) self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1)) if not self.use_scaled_dot_product_CA: self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) self.attn_drop = nn.Dropout(channel_attn_drop) self.attn_drop_2 = nn.Dropout(spatial_attn_drop)
[docs] def forward(self, x: Tensor) -> Tensor: # TODO: fully optimize this function for each attention code B, N, C = x.shape qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads) # Matrix index, Batch, Head, Dimensions, Features qkvv = qkvv.permute(2, 0, 3, 1, 4) # Batch, Head, Dimensions, Features q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3] # Batch, Head, Features, Dimensions q_shared = q_shared.transpose(-2, -1) k_shared = k_shared.transpose(-2, -1) v_CA = v_CA.transpose(-2, -1) v_SA = v_SA.transpose(-2, -1) k_shared_projected, v_SA_projected = map( lambda args: torch.einsum("bhdn,nk->bhdk", *args), zip((k_shared, v_SA), (self.EF, self.EF)), ) q_shared = torch.nn.functional.normalize(q_shared, dim=-1).type_as(q_shared) k_shared = torch.nn.functional.normalize(k_shared, dim=-1).type_as(k_shared) if self.use_scaled_dot_product_CA: if self.attention_code == "torch": x_CA = self.attn_func( q_shared, k_shared, v_CA, dropout_p=self.attn_drop.p ) elif self.attention_code == "flash": # flash attention expects inputs of shape (batch_size, seqlen, nheads, headdim) # so we need to permute the dimensions from (batch, head, channels, spatial_dim) # to (batch, channels, head, spatial_dim) q_shared = q_shared.permute(0, 2, 1, 3) k_shared = k_shared.permute(0, 2, 1, 3) v_CA = v_CA.permute(0, 2, 1, 3) x_CA = self.attn_func( q_shared, k_shared, v_CA, dropout_p=self.attn_drop.p ) # flash attention returns the output in the same shape as the input # so we need to permute it back x_CA = x_CA.permute(0, 2, 1, 3) # we permute back the inputs q_shared = q_shared.permute(0, 2, 1, 3) k_shared = k_shared.permute(0, 2, 1, 3) else: attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature attn_CA = attn_CA.softmax(dim=-1) attn_CA = self.attn_drop(attn_CA) x_CA = attn_CA @ v_CA x_CA = x_CA.permute(0, 3, 1, 2).reshape(B, N, C) attn_SA = ( q_shared.permute(0, 1, 3, 2) @ k_shared_projected ) * self.temperature2 attn_SA = attn_SA.softmax(dim=-1) attn_SA = self.attn_drop_2(attn_SA) x_SA = attn_SA @ v_SA_projected.transpose(-2, -1) x_SA = x_SA.permute(0, 3, 1, 2).reshape(B, N, C) return x_CA + x_SA
[docs] @torch.jit.ignore def no_weight_decay(self) -> set[str]: return {"temperature", "temperature2"}
einops, _ = optional_import("einops")
[docs] class UNetRPPEncoder(nn.Module): def __init__( self, input_size: list[int] = [32 * 32 * 32, 16 * 16 * 16, 8 * 8 * 8, 4 * 4 * 4], dims: list[int] = [32, 64, 128, 256], depths: list[int] = [3, 3, 3, 3], num_heads: int = 4, spatial_dims: int = 2, in_channels: int = 4, dropout: float = 0.0, transformer_dropout_rate: float = 0.1, downsampling_rate: int = 4, proj_sizes: tuple[int, ...] = (64, 64, 64, 32), attention_code: str = "torch", ): super().__init__() self.downsample_layers = ( nn.ModuleList() ) # stem and 3 intermediate downsampling conv layers stem_layer = nn.Sequential( get_conv_layer( spatial_dims, in_channels, dims[0], kernel_size=downsampling_rate, stride=downsampling_rate, dropout=dropout, conv_only=True, ), get_norm_layer(name=("group", {"num_groups": 4}), channels=dims[0]), ) self.downsample_layers.append(stem_layer) for i in range(3): downsample_layer = nn.Sequential( get_conv_layer( spatial_dims, dims[i], dims[i + 1], kernel_size=2, stride=2, dropout=dropout, conv_only=True, ), get_norm_layer( name=("group", {"num_groups": dims[i]}), channels=dims[i + 1] ), ) self.downsample_layers.append(downsample_layer) self.stages = ( nn.ModuleList() ) # 4 feature resolution stages, each consisting of multiple Transformer blocks for i in range(4): stage_blocks = [] for _ in range(depths[i]): stage_blocks.append( TransformerBlock( input_size=input_size[i], hidden_size=dims[i], num_heads=num_heads, dropout_rate=transformer_dropout_rate, pos_embed=True, proj_size=proj_sizes[i], attention_code=attention_code, spatial_dims=spatial_dims, ) ) self.stages.append(nn.Sequential(*stage_blocks)) self.hidden_states: list[Tensor] = [] self.apply(self._init_weights) self.spatial_dims = spatial_dims def _init_weights(self, m: nn.Module) -> None: if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (LayerNorm, nn.LayerNorm)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)
[docs] def forward_features(self, x: Tensor) -> tuple[Tensor, list[Tensor]]: hidden_states = [] x = self.downsample_layers[0](x) x = self.stages[0](x) hidden_states.append(x) for i in range(1, 4): x = self.downsample_layers[i](x) x = self.stages[i](x) if i == 3: # Reshape the output of the last stage if self.spatial_dims == 2: x = einops.rearrange(x, "b c h w -> b (h w) c") else: x = einops.rearrange(x, "b c h w d -> b (h w d) c") hidden_states.append(x) return x, hidden_states
[docs] def forward(self, x: Tensor) -> tuple[Tensor, list[Tensor]]: x, hidden_states = self.forward_features(x) return x, hidden_states
[docs] class UNetRUpBlock(nn.Module): def __init__( self, spatial_dims: int, in_channels: int, out_channels: int, kernel_size: tuple[int, int] | tuple[int, int, int] | int, upsample_kernel_size: tuple[int, int, int] | tuple[int, int] | int, norm_name: tuple | str, num_heads: int = 4, out_size: int = 0, depth: int = 3, conv_decoder: bool = False, linear_upsampling: bool = False, proj_size: int = 64, attention_code: str = "torch", ) -> None: """ Args: spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. kernel_size: convolution kernel size. upsample_kernel_size: convolution kernel size for transposed convolution layers. norm_name: feature normalization type and arguments. num_heads: number of heads inside each EPA module. out_size: spatial size for each decoder. depth: number of blocks for the current decoder stage. conv_decoder: whether to use convolutional blocks instead of transformer blocks in the decoder. linear_upsampling: whether to use linear upsampling instead of transposed convolution for upsampling. proj_size: projection size for the spatial attention module in he EPA block. attention_code: type of attention implementation to use. See EPA for more details. """ super().__init__() padding = get_padding(upsample_kernel_size, upsample_kernel_size) self.transp_conv: nn.Module if spatial_dims == 2: if linear_upsampling: if isinstance(upsample_kernel_size, tuple): scale_factor: tuple[float, float] | float = ( float(upsample_kernel_size[0]), float(upsample_kernel_size[1]), ) upsample_kernel_size = upsample_kernel_size[:2] else: scale_factor = float(upsample_kernel_size) if isinstance(kernel_size, tuple): kernel_size = kernel_size[:2] self.transp_conv = nn.Sequential( nn.UpsamplingBilinear2d(scale_factor=scale_factor), nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, padding=1, ), ) else: if isinstance(padding, tuple): if len(padding) != 2: raise ValueError( "padding should be a tuple of 2 integers for 2D data." ) output_padding = get_output_padding( upsample_kernel_size, upsample_kernel_size, padding ) if isinstance(output_padding, tuple): if len(output_padding) != 2: raise ValueError( "output_padding should be a tuple of 2 integers for 2D data." ) if isinstance(upsample_kernel_size, tuple): if len(upsample_kernel_size) != 2: raise ValueError( "upsample_kernel_size should be a tuple of 2 integers for 2D data." ) self.transp_conv = nn.ConvTranspose2d( in_channels, out_channels, kernel_size=upsample_kernel_size, stride=upsample_kernel_size, padding=padding, output_padding=output_padding, dilation=1, ) else: if isinstance(kernel_size, tuple): if len(kernel_size) != 3: raise ValueError( "kernel_size should be a tuple of 3 integers for 3D data." ) if isinstance(upsample_kernel_size, tuple): if len(upsample_kernel_size) != 3: raise ValueError( "upsample_kernel_size should be a tuple of 3 integers for 3D data." ) if linear_upsampling: self.transp_conv = nn.Sequential( nn.Upsample(scale_factor=upsample_kernel_size, mode="trilinear"), nn.Conv3d( in_channels, out_channels, kernel_size=kernel_size, padding=1 ), ) else: if isinstance(padding, tuple): if len(padding) != 3: raise ValueError( "padding should be a tuple of 3 integers for 3D data." ) output_padding = get_output_padding( upsample_kernel_size, upsample_kernel_size, padding ) if isinstance(output_padding, tuple): if len(output_padding) != 3: raise ValueError( "output_padding should be a tuple of 3 integers for 3D data." ) self.transp_conv = nn.ConvTranspose3d( in_channels, out_channels, kernel_size=upsample_kernel_size, stride=upsample_kernel_size, padding=padding, output_padding=output_padding, dilation=1, ) # 4 feature resolution stages, each consisting of multiple residual blocks self.decoder_block = nn.ModuleList() # If this is the last decoder, use ConvBlock(UnetResBlock) instead of EPA_Block # (see suppl. material in the paper) if conv_decoder: self.decoder_block.append( UnetResBlock( spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, norm_name=norm_name, ) ) else: stage_blocks = [] for _ in range(depth): stage_blocks.append( TransformerBlock( input_size=out_size, hidden_size=out_channels, num_heads=num_heads, dropout_rate=0.1, pos_embed=True, proj_size=proj_size, attention_code=attention_code, spatial_dims=spatial_dims, ) ) self.decoder_block.append(nn.Sequential(*stage_blocks)) def _init_weights(self, m: nn.Module) -> None: if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=0.02) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0)
[docs] def forward(self, inp: Tensor, skip: Tensor | None = None) -> Tensor: """ Forward pass: 1. Upsampling using bi/tri-linear OR Conv{2,3}dTranspose. 2. Adds skip connection if available. 3. Conv or Transformer block. """ out = self.transp_conv(inp) out = out + skip if skip is not None else out out = self.decoder_block[0](out) return out
[docs] @dataclass_json @dataclass class UNetRPPSettings: """Settings dataclass for UNetRPP. Contains all the hyperparameters needed to initialize the model. Args: hidden_size: dimensions of the last encoder. num_heads: number of attention heads. pos_embed: position embedding layer type. norm_name: feature normalization type and arguments. dropout_rate: faction of the input units to drop. depths: number of blocks for each stage. conv_op: type of convolution operation. do_ds: use deep supervision to compute the loss. """ hidden_size: int = 256 num_heads_encoder: int = 4 num_heads_decoder: int = 4 pos_embed: str = "perceptron" norm_name: Union[tuple, str] = "instance" dropout_rate: float = 0.0 depths: tuple[int, ...] = (3, 3, 3, 3) conv_op: str = "Conv2d" do_ds = False spatial_dims = 2 linear_upsampling: bool = False downsampling_rate: int = 4 decoder_proj_size: int = 64 encoder_proj_sizes: tuple[int, ...] = (64, 64, 64, 32) autopad_enabled: bool = False # Adds skip connection between encoder layers outputs # and corresponding decoder layers inputs add_skip_connections: bool = True # Specify the attention implementation to use # Options: "torch" : scaled_dot_product_attention from torch.nn.functional # "flash" : flash_attention from flash_attn (loose dependency imported only if needed) # "manual" : manual implementation from the original paper attention_code: str = "torch"
[docs] class UNetRPP(BaseModel, AutoPaddingModel): """ UNetR++ based on: "Shaker et al., UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation". """ onnx_supported = False supported_num_spatial_dims = (2, 3) settings_kls = UNetRPPSettings model_type = ModelType.VISION_TRANSFORMER features_last: bool = False register: bool = True def __init__( self, in_channels: int, out_channels: int, input_shape: tuple[int, ...], settings: UNetRPPSettings = UNetRPPSettings(), ) -> None: """ Args: in_channels: dimension of input channels. out_channels: dimension of output channels. input_shape: dimension of input image. settings: UNetRPPSettings dataclass containing the model hyperparameters. """ super().__init__() self.in_channels = in_channels self.out_channels = out_channels self._settings = settings # we have first a stem layer with stride=subsampling_rate and k_size=subsampling_rate # followed by 3 successive downsampling layer (k=2, stride=2) self.dim_divider = (2**3) * settings.downsampling_rate if self._settings.autopad_enabled: _, self.input_shape = self.validate_input_shape(torch.Size(input_shape)) else: self.input_shape = input_shape self.do_ds = settings.do_ds self.add_skip_connections = settings.add_skip_connections self.conv_op = getattr(nn, settings.conv_op) self.num_classes = out_channels if not (0 <= settings.dropout_rate <= 1): raise AssertionError("dropout_rate should be between 0 and 1.") if settings.pos_embed not in ["conv", "perceptron"]: raise KeyError( f"Position embedding layer of type {settings.pos_embed} is not supported." ) self.feat_size: tuple[int, ...] if settings.spatial_dims == 2: self.feat_size = ( self.input_shape[0] // self.dim_divider, self.input_shape[1] // self.dim_divider, ) else: self.feat_size = ( self.input_shape[0] // self.dim_divider, self.input_shape[1] // self.dim_divider, self.input_shape[2] // self.dim_divider, ) self.hidden_size = settings.hidden_size self.spatial_dims = settings.spatial_dims # Number of pixels after stem layer if settings.spatial_dims == 2: no_pixels = (self.input_shape[0] * self.input_shape[1]) // ( settings.downsampling_rate**2 ) else: no_pixels = ( self.input_shape[0] * self.input_shape[1] * self.input_shape[2] ) // (settings.downsampling_rate**3) # after the stem layer, the input is spatially downsampled # 3 times by a factor of 2 along each spatial dimension subsampling_ratio = 2**settings.spatial_dims encoder_input_size = [ no_pixels, no_pixels // subsampling_ratio, no_pixels // subsampling_ratio**2, no_pixels // subsampling_ratio**3, ] h_size = settings.hidden_size self.unetr_pp_encoder = UNetRPPEncoder( input_size=encoder_input_size, dims=[ h_size // 8, h_size // 4, h_size // 2, h_size, ], depths=list(settings.depths), num_heads=settings.num_heads_encoder, spatial_dims=settings.spatial_dims, in_channels=in_channels, downsampling_rate=settings.downsampling_rate, proj_sizes=settings.encoder_proj_sizes, attention_code=settings.attention_code, ) self.encoder1 = UnetResBlock( spatial_dims=settings.spatial_dims, in_channels=in_channels, out_channels=settings.hidden_size // 16, kernel_size=3, stride=1, norm_name=settings.norm_name, ) self.decoder5 = UNetRUpBlock( spatial_dims=settings.spatial_dims, in_channels=settings.hidden_size, out_channels=settings.hidden_size // 2, kernel_size=3, upsample_kernel_size=2, norm_name=settings.norm_name, out_size=no_pixels // subsampling_ratio**2, linear_upsampling=settings.linear_upsampling, proj_size=settings.decoder_proj_size, attention_code=settings.attention_code, num_heads=settings.num_heads_decoder, ) self.decoder4 = UNetRUpBlock( spatial_dims=settings.spatial_dims, in_channels=settings.hidden_size // 2, out_channels=settings.hidden_size // 4, kernel_size=3, upsample_kernel_size=2, norm_name=settings.norm_name, out_size=no_pixels // subsampling_ratio, linear_upsampling=settings.linear_upsampling, proj_size=settings.decoder_proj_size, attention_code=settings.attention_code, num_heads=settings.num_heads_decoder, ) self.decoder3 = UNetRUpBlock( spatial_dims=settings.spatial_dims, in_channels=settings.hidden_size // 4, out_channels=settings.hidden_size // 8, kernel_size=3, upsample_kernel_size=2, norm_name=settings.norm_name, out_size=no_pixels, linear_upsampling=settings.linear_upsampling, proj_size=settings.decoder_proj_size, attention_code=settings.attention_code, num_heads=settings.num_heads_decoder, ) self.decoder2 = UNetRUpBlock( spatial_dims=settings.spatial_dims, in_channels=settings.hidden_size // 8, out_channels=settings.hidden_size // 16, kernel_size=3, upsample_kernel_size=settings.downsampling_rate, norm_name=settings.norm_name, out_size=no_pixels * (settings.downsampling_rate**2), conv_decoder=True, linear_upsampling=settings.linear_upsampling, proj_size=settings.decoder_proj_size, attention_code=settings.attention_code, num_heads=settings.num_heads_decoder, ) self.out1 = UnetOutBlock( spatial_dims=settings.spatial_dims, in_channels=settings.hidden_size // 16, out_channels=out_channels, ) if self.do_ds: self.out2 = UnetOutBlock( spatial_dims=settings.spatial_dims, in_channels=settings.hidden_size // 8, out_channels=out_channels, ) self.out3 = UnetOutBlock( spatial_dims=settings.spatial_dims, in_channels=settings.hidden_size // 4, out_channels=out_channels, ) self.check_required_attributes() @property def settings(self) -> UNetRPPSettings: return self._settings @property def num_spatial_dims(self) -> int: return self.settings.spatial_dims
[docs] def proj_feat(self, x: Tensor) -> Tensor: if self.spatial_dims == 2: x = x.view( x.size(0), self.feat_size[0], self.feat_size[1], self.hidden_size ) else: x = x.view( x.size(0), self.feat_size[0], self.feat_size[1], self.feat_size[2], self.hidden_size, ) if self.spatial_dims == 2: x = x.permute(0, 3, 1, 2).contiguous() else: x = x.permute(0, 4, 1, 2, 3).contiguous() return x
[docs] def forward(self, x: Tensor) -> Tensor | list[Tensor]: x, old_shape = self._maybe_padding(data_tensor=x) _, hidden_states = self.unetr_pp_encoder(x) convBlock = self.encoder1(x) # Four encoders enc1 = hidden_states[0] enc2 = hidden_states[1] enc3 = hidden_states[2] enc4 = hidden_states[3] # Four decoders dec4 = self.proj_feat(enc4) dec3 = ( self.decoder5(dec4, enc3) if self.add_skip_connections else self.decoder5(dec4) ) dec2 = ( self.decoder4(dec3, enc2) if self.add_skip_connections else self.decoder4(dec3) ) dec1 = ( self.decoder3(dec2, enc1) if self.add_skip_connections else self.decoder3(dec2) ) out = ( self.decoder2(dec1, convBlock) if self.add_skip_connections else self.decoder2(dec1) ) if self.do_ds: list_logits: list[Tensor] = [ self._maybe_unpadding(self.out1(out), old_shape=old_shape), self._maybe_unpadding(self.out2(dec1), old_shape=old_shape), self._maybe_unpadding(self.out3(dec2), old_shape=old_shape), ] return list_logits else: logits: Tensor = self.out1(out) logits = self._maybe_unpadding(logits, old_shape=old_shape) return logits
[docs] def validate_input_shape(self, input_shape: torch.Size) -> tuple[bool, torch.Size]: d = self.dim_divider new_shape = torch.Size( [d * ceil(input_shape[i] / d) for i in range(len(input_shape))] ) return new_shape == input_shape, new_shape