pangu

class mfai.pytorch.models.pangu.CustomPad2d(data_size, patch_size, value=0.0)[source]

Bases: ConstantPad2d

Custom 2d padding based on token embedding patch size. Padding direction is center.

Parameters:
  • data_size (torch.Size) – data size

  • patch_size (Tuple[int, int]) – patch size for the token embedding operation

  • value (float, optional) – padding value. Defaults to 0.

class mfai.pytorch.models.pangu.CustomPad3d(data_size, patch_size, value=0.0)[source]

Bases: ConstantPad3d

Custom 3d padding based on token embedding patch size. Padding direction is center.

Parameters:
  • data_size (torch.Size) – data size

  • patch_size (Tuple[int, int, int]) – patch size for the token embedding operation

  • value (float, optional) – padding value. Defaults to 0.

class mfai.pytorch.models.pangu.DownSample(data_size, dim)[source]

Bases: Module

Down-sampling operation. The number of tokens is divided by 4 while their size in multiplied by 2. E. g., from (8x360x181) tokens of size 192 to (8x180x91) tokens of size 384.

Parameters:
  • data_size (torch.Size) – data size in terms of embeded plevel, latitude, longitude

  • dim (int) – initial size of the tokens

forward(x, embedding_shape)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tuple[Tensor, Size]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tuple[Tensor, Size]

class mfai.pytorch.models.pangu.EarthAttention3D(data_size, dim, num_heads, dropout_rate, window_size)[source]

Bases: Module

3D sliding window attention with the Earth-Specific bias,

see https://github.com/microsoft/Swin-Transformer for the official implementation of 2D sliding window attention.

Parameters:
  • data_size (torch.Size) – data size in terms of plevel, latitude, longitude

  • dim (int) – token size

  • num_heads (int) – number of heads

  • dropout_rate (float) – dropout rate

  • window_size (Tuple[int, int, int]) – window size (z, h ,w)

forward(x, mask, batch_size)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

class mfai.pytorch.models.pangu.EarthSpecificBlock(data_size, dim, drop_path_ratio, num_heads, window_size=(2, 6, 12), dropout_rate=0.0, checkpoint_activation=False, lam=False)[source]

Bases: Module

3D transformer block with Earth-Specific bias and window attention,

see https://github.com/microsoft/Swin-Transformer for the official implementation of 2D window attention. The major difference is that we expand the dimensions to 3 and replace the relative position bias with Earth-Specific bias.

Parameters:
  • data_size (torch.Size) – data size in terms of plevel, latitude, longitude

  • dim (int) – token size

  • drop_path_ratio (float) – ratio to apply to drop path

  • num_heads (int) – number of attention heads

  • window_size (Tuple[int, int, int], optional) – window size for the sliding window attention. Defaults to (2, 6, 12).

  • dropout_rate (float, optional) – dropout rate in the MLP. Defaults to 0..

  • checkpoint_activation (bool, optional) – whether to use checkpoint activation. Defaults to False.

  • lam (bool, optional) – whether to use the limited area attention mask. Defaults to False.

forward(x, embedding_shape, roll)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

class mfai.pytorch.models.pangu.EarthSpecificLayer(depth, data_size, dim, drop_path_ratio_list, num_heads, window_size, dropout_rate, checkpoint_activation, lam)[source]

Bases: Module

Basic layer of our network, contains 2 or 6 blocks.

Parameters:
  • depth (int) – number of blocks

  • data_size (torch.Size) – see EarthSpecificBlock

  • dim (int) – see EarthSpecificBlock

  • drop_path_ratio_list (Tensor]) – see EarthSpecificBlock

  • num_heads (int) – see EarthSpecificBlock

  • window_size (Tuple[int, int, int], optional) – see EarthSpecificBlock

  • dropout_rate (float, optional) – see EarthSpecificBlock

  • checkpoint_activation (bool, optional) – see EarthSpecificBlock

  • lam (bool, optional) – see EarthSpecificBlock

forward(x, embedding_shape)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

class mfai.pytorch.models.pangu.MLP(dim, dropout_rate)[source]

Bases: Module

MLP layers, same as most vision transformer architectures.

Parameters:
  • dim (int) – input and output token size

  • dropout_rate (float) – dropout rate applied after each linear layer

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor)

Return type:

Tensor

class mfai.pytorch.models.pangu.PanguWeather(in_channels, out_channels, input_shape, settings=PanguWeatherSettings(plevel_patch_size=(2, 4, 4), token_size=192, layer_depth=(2, 6), num_heads=(6, 12), spatial_dims=2, surface_variables=4, plevel_variables=5, plevels=13, static_length=3, window_size=(2, 6, 12), dropout_rate=0.0, checkpoint_activation=False, lam=False))[source]

Bases: BaseModel

PanguWeather network as described in http://arxiv.org/abs/2211.02556 and https://www.nature.com/articles/s41586-023-06185-3. This implementation follows the official pseudo code here: https://github.com/198808xc/Pangu-Weather.

Parameters:
features_last: bool = False
forward(input_plevel, input_surface, static_data=None)[source]

Forward pass of the PanguWeather model.

Parameters:
  • input_plevel (Tensor) – Input tensor of shape (N, C, Z, H, W) for pressure level data.

  • input_surface (Tensor) – Input tensor of shape (N, C, H, W) for surface data.

  • static_data (Tensor, optional) – Static data tensor, e.g., land sea mask, of shape (N, C, H, W). Defaults to None.

Return type:

Tuple[Tensor, Tensor]

model_type = 6
property num_spatial_dims: int

Returns the number of spatial dimensions of the instanciated model.

onnx_supported: bool = False
register: bool = True
property settings: PanguWeatherSettings

Returns the settings instance used to configure for this model.

settings_kls

alias of PanguWeatherSettings

supported_num_spatial_dims: Tuple = (2,)
class mfai.pytorch.models.pangu.PanguWeatherSettings(plevel_patch_size=(2, 4, 4), token_size=192, layer_depth=(2, 6), num_heads=(6, 12), spatial_dims=2, surface_variables=4, plevel_variables=5, plevels=13, static_length=3, window_size=(2, 6, 12), dropout_rate=0.0, checkpoint_activation=False, lam=False)[source]

Bases: object

PanguWeather configuration class, containing the hyperparameters for the model.

Parameters:
  • plevel_patch_size (Tuple[int, int, int]) – Patch size for the pressure level data. Default is (2, 4, 4). Setting (2, 8, 8) leads to Pangu Lite.

  • token_size (int) – Size of the tokens (equivalent to channel size) of the first layer. Default is 192.

  • layer_depth (Tuple[int, int]) –

    Number of blocks in layers. Default is (2, 6), meaning that the first and fourth layers contain 2 blocks, and the second

    and third contain 6.

  • num_heads (Tuple[int, int]) – Number of heads in attention layers. Default is (6, 12), corresponding to respectively first and fourth layers, and second and third.

  • spatial_dims (int) – number of spatial dimensions (2 or 3).

  • surface_variables (int) – number of surface variables.

  • plevel_variables (int) – number of pressure level variables.

  • plevels (int) – number of pressure levels.

  • static_length (int) – number of static variables (e.g., land sea mask).

  • window_size (Tuple[int, int, int]) – size of the sliding window.

  • dropout_rate (float) – faction of the input units to drop.

  • checkpoint_activation (bool) – whether to use checkpoint activation.

  • lam (bool) – whether to use the limited area attention mask.

checkpoint_activation: bool = False
dropout_rate: float = 0.0
classmethod from_dict(kvs, *, infer_missing=False)
Return type:

TypeVar(A, bound= DataClassJsonMixin)

Parameters:

kvs (dict | list | str | int | float | bool | None)

classmethod from_json(s, *, parse_float=None, parse_int=None, parse_constant=None, infer_missing=False, **kw)
Return type:

TypeVar(A, bound= DataClassJsonMixin)

Parameters:

s (str | bytes | bytearray)

lam: bool = False
layer_depth: Tuple[int, int] = (2, 6)
num_heads: Tuple[int, int] = (6, 12)
plevel_patch_size: Tuple[int, int, int] = (2, 4, 4)
plevel_variables: int = 5
plevels: int = 13
classmethod schema(*, infer_missing=False, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)
Return type:

SchemaF[TypeVar(A, bound= DataClassJsonMixin)]

Parameters:
spatial_dims: int = 2
static_length: int = 3
surface_variables: int = 4
to_dict(encode_json=False)
Return type:

Dict[str, Union[dict, list, str, int, float, bool, None]]

to_json(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, indent=None, separators=None, default=None, sort_keys=False, **kw)
Return type:

str

Parameters:
token_size: int = 192
window_size: Tuple[int, int, int] = (2, 6, 12)
class mfai.pytorch.models.pangu.PatchEmbedding(c_dim, patch_size, plevel_size, surface_size)[source]

Bases: Module

Patch embedding operation. Apply a linear projection for patch_size[0]*patch_size[1]*patch_size[2] patches,

patch_size = (2, 4, 4) in the original paper.

Parameters:
  • c_dim (_type_) – embeeding channel size

  • patch_size (Tuple[int, int, int]) – patch size for pressure level data

  • plevel_size (torch.Size) – pressure level data size

  • surface_size (torch.Size) – surface data size

forward(input_plevel, input_surface)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tuple[Tensor, Size]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tuple[Tensor, Size]

class mfai.pytorch.models.pangu.PatchRecovery(dim, patch_size, plevel_channels=5, surface_channels=4)[source]

Bases: Module

Patch recovery operation. The inverse operation of the patch embedding operation.

Parameters:
  • dim (int) – number of channels

  • patch_size (Tuple[int, int, int]) – pressure level patch size, e. g., (2, 4, 4) as in the original paper

  • plevel_channels (int, optional) – pressure level data channel size

  • surface_channels (int, optional) – surface data channel size

forward(x, embedding_shape)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tuple[Tensor, Tensor]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tuple[Tensor, Tensor]

class mfai.pytorch.models.pangu.UpSample(input_dim, output_dim)[source]

Bases: Module

Parameters:
  • input_dim (int)

  • output_dim (int)

forward(x, embedding_shape)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

mfai.pytorch.models.pangu.define_3d_earth_position_index(window_size)[source]

Build the index for the Earth specific positional bias of sliding attention windows from PanguWeather. See http://arxiv.org/abs/2211.02556.

Parameters:

window_size (Tuple[int, int, int]) – size of the sliding window

Returns:

index

Return type:

Tensor

mfai.pytorch.models.pangu.generate_3d_attention_mask(x, window_size, shift_size, lam=False)[source]

Method to generate attention mask for sliding window attention in the context of 3D data. Based on https://pytorch.org/vision/main/_modules/torchvision/models/swin_transformer.html#swin_s.

Parameters:
  • x (Tensor) – input data, used to generate the mask on the same device.

  • window_size (Tuple[int, int, int]) – size of the sliding window.

  • shift_size (Tuple[int, ...]) – size of the shift for the sliding window.

  • lam (bool) – whether to use the LAM attention mechanism.

Returns:

attention mask.

Return type:

Tensor