Source code for mfai.pytorch.models.utils
import math
from typing import Tuple
import einops
import torch
import torch.nn as nn
from torch import Tensor
[docs]
def patch_first_conv(
model: torch.nn.Module,
new_in_channels: int,
default_in_channels: int = 3,
pretrained: bool = True,
) -> None:
"""Change first convolution layer input channels.
In case:
in_channels == 1 or in_channels == 2 -> reuse original weights.
in_channels > 3 -> make random kaiming normal initialization.
"""
# get first conv
for module in model.modules():
if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels:
first_conv: nn.Conv2d = module
break
weight = first_conv.weight.detach()
first_conv.in_channels = new_in_channels
if not pretrained:
first_conv.weight = nn.parameter.Parameter(
Tensor(
first_conv.out_channels,
new_in_channels // first_conv.groups,
*first_conv.kernel_size,
)
)
first_conv.reset_parameters()
elif new_in_channels == 1:
new_weight = weight.sum(1, keepdim=True)
first_conv.weight = nn.parameter.Parameter(new_weight)
else:
new_weight = Tensor(
first_conv.out_channels,
new_in_channels // first_conv.groups,
*first_conv.kernel_size,
)
for i in range(new_in_channels):
new_weight[:, i] = weight[:, i % default_in_channels]
new_weight = new_weight * (default_in_channels / new_in_channels)
first_conv.weight = nn.parameter.Parameter(new_weight)
[docs]
def replace_strides_with_dilation(module: torch.nn.Module, dilation: int) -> None:
"""Patch Conv2d modules replacing strides with dilation."""
for mod in module.modules():
if isinstance(mod, nn.Conv2d):
mod.stride = (1, 1)
mod.dilation = (dilation, dilation)
kh, kw = mod.kernel_size
mod.padding = ((kh // 2) * dilation, (kh // 2) * dilation)
# Kostyl for EfficientNet
if hasattr(mod, "static_padding"):
mod.static_padding = nn.Identity()
[docs]
class AbsolutePosEmdebding(nn.Module):
"""
Absolute pos embedding.
Learns a position dependent bias for each pixel/node of each feature map.
"""
def __init__(
self,
input_shape: Tuple[int, ...],
num_features: int,
feature_last: bool = False,
):
super().__init__()
if feature_last:
self.pos_embedding = nn.Parameter(
init_(torch.zeros(1, *input_shape, num_features)), requires_grad=True
)
else:
self.pos_embedding = nn.Parameter(
init_(torch.zeros(1, num_features, *input_shape), dim_idx=1),
requires_grad=True,
)
[docs]
def forward(self, x: Tensor) -> Tensor:
return x + self.pos_embedding
[docs]
def init_(tensor: Tensor, dim_idx: int = -1) -> Tensor:
dim: int = tensor.shape[dim_idx]
std: float = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
[docs]
def features_last_to_second(x: Tensor) -> Tensor:
"""
Moves features from the last dimension to the second dimension.
"""
return einops.rearrange(x, "b x y n -> b n x y").contiguous()
[docs]
def features_second_to_last(y: Tensor) -> Tensor:
"""
Moves features from the second dimension to the last dimension.
"""
return einops.rearrange(y, "b n x y -> b x y n").contiguous()
[docs]
def expand_to_batch(x: Tensor, batch_size: int) -> Tensor:
"""
Expand tensor with initial batch dimension.
"""
# In order to be generic (for 1D or 2D grid)
sizes: list[int] = [batch_size] + [-1 for _ in x.shape]
return x.unsqueeze(0).expand(*sizes)