gan_dgmr¶
Module for various loss functions used with DGMR GAN.
- class mfai.pytorch.losses.gan_dgmr.GridCellLoss(precip_weight_cap=24.0)[source]¶
Bases:
ModuleGrid Cell Regularizer loss from Skillful Nowcasting, see equation (3) of the original paper (https://arxiv.org/pdf/2104.00954.pdf).
- Parameters:
precip_weight_cap (float)
- forward(generated_images, targets)[source]¶
Forward function.
Calculates the grid cell regularizer value, assumes generated images are the mean predictions from 6 calls to the generator (Monte Carlo estimation of the expectations for the latent variable).
\[\begin{split}L_R(\\Theta) = \frac{1}{HWN} \\| (\\mathbb{E}_Z [G_|theta(Z; X_{1:M})] - X_{M+1:M+N}) \\circ w(X_{M+1:M+N}) \\|_1\end{split}\]where H, W and N represent height, width and leadtimes.