discriminators

Discriminators.

class mfai.pytorch.models.gan_dgmr.discriminators.Discriminator(input_channels=12, num_spatial_frames=8, conv_type='standard', temporal_num_layers=3, spatial_num_layers=4)[source]

Bases: Module

Discriminators class.

Parameters:
  • input_channels (int)

  • num_spatial_frames (int)

  • conv_type (Literal['standard', 'coord', '3d'])

  • temporal_num_layers (int)

  • spatial_num_layers (int)

forward(x)[source]

Mixes the spatial loss and temporal loss of the tensor prior to returning it.

Parameters:

x (Tensor) – a tensor with a complete observation (b, t, c, h, w).

Return type:

Tensor

class mfai.pytorch.models.gan_dgmr.discriminators.SpatialDiscriminator(input_channels=12, num_timesteps=8, num_layers=4, conv_type='standard')[source]

Bases: Module

Spatial Discriminator class.

Parameters:
  • input_channels (int)

  • num_timesteps (int)

  • num_layers (int)

  • conv_type (Literal['standard', 'coord', '3d'])

forward(x)[source]
Parameters:

x (Tensor) – tensor of predictions or observations of shape (b, t, c, h, w).

Return type:

Tensor

class mfai.pytorch.models.gan_dgmr.discriminators.TemporalDiscriminator(input_channels=12, num_layers=3, conv_type='standard')[source]

Bases: Module

Temporal Discriminator class.

Parameters:
  • input_channels (int)

  • num_layers (int)

  • conv_type (Literal['standard', 'coord', '3d'])

forward(x)[source]
Parameters:

x (Tensor) – tensor of predictions or observations of shape (b, t, c, h, w).

Return type:

Tensor