pangu¶
- class mfai.pytorch.models.pangu.CustomPad2d(data_size, patch_size, value=0.0)[source]¶
Bases:
ConstantPad2dCustom 2d padding based on token embedding patch size. Padding direction is center.
- Parameters:
data_size (torch.Size) – data size
patch_size (Tuple[int, int]) – patch size for the token embedding operation
value (float, optional) – padding value. Defaults to 0.
- class mfai.pytorch.models.pangu.CustomPad3d(data_size, patch_size, value=0.0)[source]¶
Bases:
ConstantPad3dCustom 3d padding based on token embedding patch size. Padding direction is center.
- Parameters:
data_size (torch.Size) – data size
patch_size (Tuple[int, int, int]) – patch size for the token embedding operation
value (float, optional) – padding value. Defaults to 0.
- class mfai.pytorch.models.pangu.DownSample(data_size, dim)[source]¶
Bases:
ModuleDown-sampling operation. The number of tokens is divided by 4 while their size in multiplied by 2. E. g., from (8x360x181) tokens of size 192 to (8x180x91) tokens of size 384.
- Parameters:
data_size (torch.Size) – data size in terms of embeded plevel, latitude, longitude
dim (int) – initial size of the tokens
- forward(x, embedding_shape)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tuple[Tensor,Size]Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mfai.pytorch.models.pangu.EarthAttention3D(data_size, dim, num_heads, dropout_rate, window_size)[source]¶
Bases:
Module- 3D sliding window attention with the Earth-Specific bias,
see https://github.com/microsoft/Swin-Transformer for the official implementation of 2D sliding window attention.
- Parameters:
- forward(x, mask, batch_size)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mfai.pytorch.models.pangu.EarthSpecificBlock(data_size, dim, drop_path_ratio, num_heads, window_size=(2, 6, 12), dropout_rate=0.0, checkpoint_activation=False, lam=False)[source]¶
Bases:
Module- 3D transformer block with Earth-Specific bias and window attention,
see https://github.com/microsoft/Swin-Transformer for the official implementation of 2D window attention. The major difference is that we expand the dimensions to 3 and replace the relative position bias with Earth-Specific bias.
- Parameters:
data_size (torch.Size) – data size in terms of plevel, latitude, longitude
dim (int) – token size
drop_path_ratio (float) – ratio to apply to drop path
num_heads (int) – number of attention heads
window_size (Tuple[int, int, int], optional) – window size for the sliding window attention. Defaults to (2, 6, 12).
dropout_rate (float, optional) – dropout rate in the MLP. Defaults to 0..
checkpoint_activation (bool, optional) – whether to use checkpoint activation. Defaults to False.
lam (bool, optional) – whether to use the limited area attention mask. Defaults to False.
- forward(x, embedding_shape, roll)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mfai.pytorch.models.pangu.EarthSpecificLayer(depth, data_size, dim, drop_path_ratio_list, num_heads, window_size, dropout_rate, checkpoint_activation, lam)[source]¶
Bases:
ModuleBasic layer of our network, contains 2 or 6 blocks.
- Parameters:
depth (int) – number of blocks
data_size (torch.Size) – see EarthSpecificBlock
dim (int) – see EarthSpecificBlock
drop_path_ratio_list (Tensor]) – see EarthSpecificBlock
num_heads (int) – see EarthSpecificBlock
window_size (Tuple[int, int, int], optional) – see EarthSpecificBlock
dropout_rate (float, optional) – see EarthSpecificBlock
checkpoint_activation (bool, optional) – see EarthSpecificBlock
lam (bool, optional) – see EarthSpecificBlock
- forward(x, embedding_shape)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mfai.pytorch.models.pangu.MLP(dim, dropout_rate)[source]¶
Bases:
ModuleMLP layers, same as most vision transformer architectures.
- Parameters:
- forward(x)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mfai.pytorch.models.pangu.PanguWeather(in_channels, out_channels, input_shape, settings=PanguWeatherSettings(plevel_patch_size=(2, 4, 4), token_size=192, layer_depth=(2, 6), num_heads=(6, 12), spatial_dims=2, surface_variables=4, plevel_variables=5, plevels=13, static_length=3, window_size=(2, 6, 12), dropout_rate=0.0, checkpoint_activation=False, lam=False))[source]¶
Bases:
BaseModelPanguWeather network as described in http://arxiv.org/abs/2211.02556 and https://www.nature.com/articles/s41586-023-06185-3. This implementation follows the official pseudo code here: https://github.com/198808xc/Pangu-Weather.
- Parameters:
in_channels (int)
out_channels (int)
settings (PanguWeatherSettings)
- forward(input_plevel, input_surface, static_data=None)[source]¶
Forward pass of the PanguWeather model.
- Parameters:
input_plevel (Tensor) – Input tensor of shape (N, C, Z, H, W) for pressure level data.
input_surface (Tensor) – Input tensor of shape (N, C, H, W) for surface data.
static_data (Tensor, optional) – Static data tensor, e.g., land sea mask, of shape (N, C, H, W). Defaults to None.
- Return type:
- model_type = 6¶
- property settings: PanguWeatherSettings¶
Returns the settings instance used to configure for this model.
- settings_kls¶
alias of
PanguWeatherSettings
- class mfai.pytorch.models.pangu.PanguWeatherSettings(plevel_patch_size=(2, 4, 4), token_size=192, layer_depth=(2, 6), num_heads=(6, 12), spatial_dims=2, surface_variables=4, plevel_variables=5, plevels=13, static_length=3, window_size=(2, 6, 12), dropout_rate=0.0, checkpoint_activation=False, lam=False)[source]¶
Bases:
objectPanguWeather configuration class, containing the hyperparameters for the model.
- Parameters:
plevel_patch_size (
Tuple[int,int,int]) – Patch size for the pressure level data. Default is (2, 4, 4). Setting (2, 8, 8) leads to Pangu Lite.token_size (
int) – Size of the tokens (equivalent to channel size) of the first layer. Default is 192.layer_depth (
Tuple[int,int]) –Number of blocks in layers. Default is (2, 6), meaning that the first and fourth layers contain 2 blocks, and the second
and third contain 6.
num_heads (
Tuple[int,int]) – Number of heads in attention layers. Default is (6, 12), corresponding to respectively first and fourth layers, and second and third.spatial_dims (
int) – number of spatial dimensions (2 or 3).surface_variables (
int) – number of surface variables.plevel_variables (
int) – number of pressure level variables.plevels (
int) – number of pressure levels.static_length (
int) – number of static variables (e.g., land sea mask).window_size (
Tuple[int,int,int]) – size of the sliding window.dropout_rate (
float) – faction of the input units to drop.checkpoint_activation (
bool) – whether to use checkpoint activation.lam (
bool) – whether to use the limited area attention mask.
- classmethod from_dict(kvs, *, infer_missing=False)¶
- classmethod from_json(s, *, parse_float=None, parse_int=None, parse_constant=None, infer_missing=False, **kw)¶
- classmethod schema(*, infer_missing=False, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)¶
- to_json(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, indent=None, separators=None, default=None, sort_keys=False, **kw)¶
- class mfai.pytorch.models.pangu.PatchEmbedding(c_dim, patch_size, plevel_size, surface_size)[source]¶
Bases:
Module- Patch embedding operation. Apply a linear projection for patch_size[0]*patch_size[1]*patch_size[2] patches,
patch_size = (2, 4, 4) in the original paper.
- Parameters:
c_dim (_type_) – embeeding channel size
patch_size (Tuple[int, int, int]) – patch size for pressure level data
plevel_size (torch.Size) – pressure level data size
surface_size (torch.Size) – surface data size
- forward(input_plevel, input_surface)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tuple[Tensor,Size]Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mfai.pytorch.models.pangu.PatchRecovery(dim, patch_size, plevel_channels=5, surface_channels=4)[source]¶
Bases:
ModulePatch recovery operation. The inverse operation of the patch embedding operation.
- Parameters:
- forward(x, embedding_shape)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
Tuple[Tensor,Tensor]Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mfai.pytorch.models.pangu.UpSample(input_dim, output_dim)[source]¶
Bases:
Module- forward(x, embedding_shape)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- mfai.pytorch.models.pangu.define_3d_earth_position_index(window_size)[source]¶
Build the index for the Earth specific positional bias of sliding attention windows from PanguWeather. See http://arxiv.org/abs/2211.02556.
- mfai.pytorch.models.pangu.generate_3d_attention_mask(x, window_size, shift_size, lam=False)[source]¶
Method to generate attention mask for sliding window attention in the context of 3D data. Based on https://pytorch.org/vision/main/_modules/torchvision/models/swin_transformer.html#swin_s.
- Parameters:
- Returns:
attention mask.
- Return type:
Tensor