conv_gru
Implementation of Conv GRU and cell module.
-
class mfai.pytorch.models.gan_dgmr.layers.conv_gru.ConvGRU(input_channels, output_channels, kernel_size=3, sn_eps=0.0001)[source]
Bases: Module
ConvGRU Cell wrapper to replace tf.static_rnn in TF implementation.
- Parameters:
input_channels (int)
output_channels (int)
kernel_size (int)
sn_eps (float)
-
forward(x, hidden_state)[source]
Apply the forward function on each cell prior to returning it as a stack.
- Parameters:
x (Tensor) – Tensor from the conditionning stack or from the previous ConvGRU. (timesteps, B, input_channels, H, W)
hidden_state (Tensor) – Tensor from the latent conditionning stack. (B, latent_channels, H, W)
- Return type:
Tensor
-
class mfai.pytorch.models.gan_dgmr.layers.conv_gru.ConvGRUCell(input_channels, output_channels, kernel_size=3, sn_eps=0.0001)[source]
Bases: Module
A ConvGRU implementation.
- Parameters:
input_channels (int)
output_channels (int)
kernel_size (int)
sn_eps (float)
-
forward(x, prev_state)[source]
Conv GRU forward, returning the current state.
The sum of the 2 channels dimensions shoud be equal to input_channels,
ie. C1 + C2 = input_channels.
- Parameters:
x (Tensor) – Input tensor of shape (B, C1, H, W).
prev_state (Tensor) – Previous state, tensor of shape (B, C2, H, W).
- Returns:
Output of the convGRU, tensor of shape (B, output_channels, H, W).
- Return type:
Tensor