blocks
Modules for generator blocks.
-
class mfai.pytorch.models.gan_dgmr.blocks.ContextConditioningStack(input_channels=1, output_channels=768, num_context_steps=4, conv_type='standard')[source]
Bases: Module
Context conditioning stack.
- Parameters:
-
-
forward(x)[source]
Generate the condition representation.
- Return type:
tuple[Tensor, Tensor, Tensor, Tensor]
- Parameters:
x (Tensor)
-
class mfai.pytorch.models.gan_dgmr.blocks.DBlock(input_channels=12, output_channels=12, conv_type='standard', first_relu=True, keep_same_output=False)[source]
Bases: Module
D block class.
- Parameters:
-
-
forward(x)[source]
Apply the D residual block.
- Parameters:
x (Tensor) – a tensor with a complete observation (b, c, h, w)
- Return type:
Tensor
- Returns:
A tensor with discriminator loss scalars (b,)
-
class mfai.pytorch.models.gan_dgmr.blocks.GBlock(input_channels=12, output_channels=12, conv_type='standard', spectral_normalized_eps=0.0001)[source]
Bases: Module
Residual generator block without upsampling.
- Parameters:
input_channels (int)
output_channels (int)
conv_type (Literal['standard', 'coord', '3d'])
spectral_normalized_eps (float)
-
forward(x)[source]
Apply the forward function.
- Return type:
Tensor
- Parameters:
x (Tensor)
-
class mfai.pytorch.models.gan_dgmr.blocks.LBlock(input_channels=12, output_channels=12, kernel_size=3, conv_type='standard')[source]
Bases: Module
Residual block for the Latent Stack.
- Parameters:
-
-
forward(x)[source]
Apply the L residual block to this tensor.
- Return type:
Tensor
- Parameters:
x (Tensor)
-
class mfai.pytorch.models.gan_dgmr.blocks.LatentConditioningStack(input_channels=8, output_channels=768, use_attention=True)[source]
Bases: Module
Latent conditioning stack class.
- Parameters:
input_channels (int)
output_channels (int)
use_attention (bool)
-
forward(x)[source]
Apply convolution, l blocks and spatial attention module to the tensor.
- Parameters:
x (Tensor) – tensor on the correct device, to move over the latent distribution
- Return type:
Tensor
- Returns:
tensor
-
class mfai.pytorch.models.gan_dgmr.blocks.UpsampleGBlock(input_channels=12, output_channels=12, conv_type='standard', spectral_normalized_eps=0.0001)[source]
Bases: Module
Residual generator block with upsampling.
- Parameters:
input_channels (int)
output_channels (int)
conv_type (Literal['standard', 'coord', '3d'])
spectral_normalized_eps (float)
-
forward(x)[source]
Apply the forward function.
- Return type:
Tensor
- Parameters:
x (Tensor)