gan_dgmr¶
This module defines the DGMRLightningModule class, which is a PyTorch Lightning module for training a Deep Generative Model of Radar (DGMR). The model is designed for forecasting future radar images using a Generative Adversarial Network (GAN) architecture.
The DGMRLightningModule includes: - Initialization of model parameters and components, including the generator and discriminator. - Forward pass method to generate predictions from input radar data. - Discriminator and generator training steps, including loss calculations. - Configuration of optimizers for training the generator and discriminator.
The implementation is inspired by the Skillful Nowcasting GAN from OpenClimateFix and is modified for multiple satellite channels.
- class mfai.pytorch.lightning_modules.gan_dgmr.DGMRLightningModule(forecast_steps=18, input_channels=1, gen_lr=5e-05, disc_lr=0.0002, conv_type='standard', grid_lambda=20.0, beta1=0.0, beta2=0.999, latent_channels=768, context_channels=384, generation_steps=6, precip_weight_cap=24.0, use_attention=True, temporal_num_layers=3, spatial_num_layers=4, **kwargs)[source]¶
Bases:
LightningModulePytorch Lightning Module to train the GAN.
- Parameters:
forecast_steps (int)
input_channels (int)
gen_lr (float)
disc_lr (float)
conv_type (Literal['standard', 'coord', '3d'])
grid_lambda (float)
beta1 (float)
beta2 (float)
latent_channels (int)
context_channels (int)
generation_steps (int)
precip_weight_cap (float)
use_attention (bool)
temporal_num_layers (int)
spatial_num_layers (int)
kwargs (Any)
- forward(x, mask=None)[source]¶
Apply the generator to the tensor.
- Parameters:
x (NamedTensor) – The past observations. NamedTensor of shape (B T H W C) with a feature ‘rain’.
mask (Tensor) – Whether to put nan on the prediction, outside of the radar range. Boolean Tensor of shape (B T H W). If False, replace the predicted value by nan. Default value is None.
- Return type:
- Returns:
- the NamedTensor that contains the prediction made by the DGMR model.
NamedTensor of shape (B T H W C) with a feature ‘rain’.
- training_step(batch)[source]¶
Performs the training step for the batch.
- Return type:
- Parameters:
batch (tuple[NamedTensor, NamedTensor])