conv_gru

Implementation of Conv GRU and cell module.

class mfai.pytorch.models.gan_dgmr.layers.conv_gru.ConvGRU(input_channels, output_channels, kernel_size=3, sn_eps=0.0001)[source]

Bases: Module

ConvGRU Cell wrapper to replace tf.static_rnn in TF implementation.

Parameters:
  • input_channels (int)

  • output_channels (int)

  • kernel_size (int)

  • sn_eps (float)

forward(x, hidden_state)[source]

Apply the forward function on each cell prior to returning it as a stack.

Parameters:
  • x (Tensor) – Tensor from the conditionning stack or from the previous ConvGRU. (timesteps, B, input_channels, H, W)

  • hidden_state (Tensor) – Tensor from the latent conditionning stack. (B, latent_channels, H, W)

Return type:

Tensor

class mfai.pytorch.models.gan_dgmr.layers.conv_gru.ConvGRUCell(input_channels, output_channels, kernel_size=3, sn_eps=0.0001)[source]

Bases: Module

A ConvGRU implementation.

Parameters:
  • input_channels (int)

  • output_channels (int)

  • kernel_size (int)

  • sn_eps (float)

forward(x, prev_state)[source]

Conv GRU forward, returning the current state. The sum of the 2 channels dimensions shoud be equal to input_channels, ie. C1 + C2 = input_channels.

Parameters:
  • x (Tensor) – Input tensor of shape (B, C1, H, W).

  • prev_state (Tensor) – Previous state, tensor of shape (B, C2, H, W).

Returns:

Output of the convGRU, tensor of shape (B, output_channels, H, W).

Return type:

Tensor