blocks

Modules for generator blocks.

class mfai.pytorch.models.gan_dgmr.blocks.ContextConditioningStack(input_channels=1, output_channels=768, num_context_steps=4, conv_type='standard')[source]

Bases: Module

Context conditioning stack.

Parameters:
  • input_channels (int)

  • output_channels (int)

  • num_context_steps (int)

  • conv_type (Literal['standard', 'coord', '3d'])

forward(x)[source]

Generate the condition representation.

Return type:

tuple[Tensor, Tensor, Tensor, Tensor]

Parameters:

x (Tensor)

class mfai.pytorch.models.gan_dgmr.blocks.DBlock(input_channels=12, output_channels=12, conv_type='standard', first_relu=True, keep_same_output=False)[source]

Bases: Module

D block class.

Parameters:
  • input_channels (int)

  • output_channels (int)

  • conv_type (Literal['standard', 'coord', '3d'])

  • first_relu (bool)

  • keep_same_output (bool)

forward(x)[source]

Apply the D residual block.

Parameters:

x (Tensor) – a tensor with a complete observation (b, c, h, w)

Return type:

Tensor

Returns:

A tensor with discriminator loss scalars (b,)

class mfai.pytorch.models.gan_dgmr.blocks.GBlock(input_channels=12, output_channels=12, conv_type='standard', spectral_normalized_eps=0.0001)[source]

Bases: Module

Residual generator block without upsampling.

Parameters:
  • input_channels (int)

  • output_channels (int)

  • conv_type (Literal['standard', 'coord', '3d'])

  • spectral_normalized_eps (float)

forward(x)[source]

Apply the forward function.

Return type:

Tensor

Parameters:

x (Tensor)

class mfai.pytorch.models.gan_dgmr.blocks.LBlock(input_channels=12, output_channels=12, kernel_size=3, conv_type='standard')[source]

Bases: Module

Residual block for the Latent Stack.

Parameters:
  • input_channels (int)

  • output_channels (int)

  • kernel_size (int)

  • conv_type (Literal['standard', 'coord', '3d'])

forward(x)[source]

Apply the L residual block to this tensor.

Return type:

Tensor

Parameters:

x (Tensor)

class mfai.pytorch.models.gan_dgmr.blocks.LatentConditioningStack(input_channels=8, output_channels=768, use_attention=True)[source]

Bases: Module

Latent conditioning stack class.

Parameters:
  • input_channels (int)

  • output_channels (int)

  • use_attention (bool)

forward(x)[source]

Apply convolution, l blocks and spatial attention module to the tensor.

Parameters:

x (Tensor) – tensor on the correct device, to move over the latent distribution

Return type:

Tensor

Returns:

tensor

class mfai.pytorch.models.gan_dgmr.blocks.UpsampleGBlock(input_channels=12, output_channels=12, conv_type='standard', spectral_normalized_eps=0.0001)[source]

Bases: Module

Residual generator block with upsampling.

Parameters:
  • input_channels (int)

  • output_channels (int)

  • conv_type (Literal['standard', 'coord', '3d'])

  • spectral_normalized_eps (float)

forward(x)[source]

Apply the forward function.

Return type:

Tensor

Parameters:

x (Tensor)