archesweather

class mfai.pytorch.models.archesweather.ArchesWeather(in_channels, out_channels, input_shape, settings=ArchesWeatherSettings(plevel_patch_size=(2, 2, 2), token_size=192, layer_depth=(2, 6), num_heads=(6, 12, 12, 6), spatial_dims=2, surface_variables=4, plevel_variables=5, plevels=13, static_length=3, window_size=(1, 6, 10), dropout_rate=0.0, checkpoint_activation=False, lam=False, cond_dim=32, droppath_coeff=0.2, depth_multiplier=1, position_embs_dim=0, use_prev=False, use_skip=False, conv_head=False, first_interaction_layer=False, axial_attn=False, axial_attn_heads=8))[source]

Bases: BaseModel

ArchesWeather model as described in http://arxiv.org/abs/2405.14527.

Parameters:
features_last: bool = False
forward(input_level, input_surface, static_data=None, cond_emb=None)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: tuple[Tensor, Tensor]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

tuple[Tensor, Tensor]

model_type = 6
property num_spatial_dims: int

Returns the number of spatial dimensions of the instanciated model.

onnx_supported: bool = False
register: bool = True
property settings: ArchesWeatherSettings

Returns the settings instance used to configure for this model.

settings_kls

alias of ArchesWeatherSettings

supported_num_spatial_dims: tuple[int, ...] = (2,)
class mfai.pytorch.models.archesweather.ArchesWeatherSettings(plevel_patch_size=(2, 2, 2), token_size=192, layer_depth=(2, 6), num_heads=(6, 12, 12, 6), spatial_dims=2, surface_variables=4, plevel_variables=5, plevels=13, static_length=3, window_size=(1, 6, 10), dropout_rate=0.0, checkpoint_activation=False, lam=False, cond_dim=32, droppath_coeff=0.2, depth_multiplier=1, position_embs_dim=0, use_prev=False, use_skip=False, conv_head=False, first_interaction_layer=False, axial_attn=False, axial_attn_heads=8)[source]

Bases: PanguWeatherSettings

ArchesWeather configuration class. Inherits from PanguWeatherSettings, with additional hyperparameters for Earth-Specific bias and window attention.

Parameters:
  • token_size (int) – embedding size

  • cond_dim (int) – conditioning embedding size

  • num_heads (tuple) – number of heads per EarthSpecificLayer

  • droppath_coeff (float) – drop path coefficient

  • plevel_patch_size (tuple) – patch size for input data embedding

  • window_size (tuple) – window size for shifted-window attention of EarthSpecificBlock

  • depth_multiplier (int) – depth multiplier for the number of blocks in EarthSpecificLayer

  • position_embs_dim (int) – dimension of positional embeddings

  • use_prev (bool) – whether to use previous state

  • use_skip (bool) – whether to use skip connections

  • conv_head (bool) – whether to use a convolutional head for patch recovery

  • dropout_rate (float) – dropout rate

  • first_interaction_layer (bool) – whether to use a linear interaction layer before the first EarthSpecificLayer

  • checkpoint_activation (bool) – whether to use gradient checkpointing

  • axial_attn (bool) – whether to use axial attention

  • axial_attn_head – number of heads for axial attention

  • lam (bool) – whether to use limited area setting in the attention mask

  • lon_resolution – longitude resolution

  • lat_resolution – latitude resolution

  • surface_variables (int) – number of variables in the surface data

  • static_length (int) – number of variables in the mask data

  • plevel_variables (int) – number of variables in the level data

  • plevels (int) – number of atmospheric levels in the level data

  • spatial_dims (int) – number of spatial dimensions (2).

  • layer_depth (Tuple[int, int])

  • axial_attn_heads (int)

axial_attn: bool = False
axial_attn_heads: int = 8
cond_dim: int = 32
conv_head: bool = False
depth_multiplier: int = 1
droppath_coeff: float = 0.2
first_interaction_layer: bool = False
classmethod from_dict(kvs, *, infer_missing=False)
Return type:

TypeVar(A, bound= DataClassJsonMixin)

Parameters:

kvs (dict | list | str | int | float | bool | None)

classmethod from_json(s, *, parse_float=None, parse_int=None, parse_constant=None, infer_missing=False, **kw)
Return type:

TypeVar(A, bound= DataClassJsonMixin)

Parameters:

s (str | bytes | bytearray)

num_heads: tuple = (6, 12, 12, 6)
plevel_patch_size: tuple = (2, 2, 2)
position_embs_dim: int = 0
classmethod schema(*, infer_missing=False, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)
Return type:

SchemaF[TypeVar(A, bound= DataClassJsonMixin)]

Parameters:
to_dict(encode_json=False)
Return type:

Dict[str, Union[dict, list, str, int, float, bool, None]]

to_json(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, indent=None, separators=None, default=None, sort_keys=False, **kw)
Return type:

str

Parameters:
use_prev: bool = False
use_skip: bool = False
window_size: tuple = (1, 6, 10)
class mfai.pytorch.models.archesweather.CondBasicLayer(depth, data_size, dim, cond_dim, drop_path_ratio_list, num_heads, window_size, dropout_rate, lam, axial_attn, axial_attn_heads, checkpoint_activation)[source]

Bases: EarthSpecificLayer

Wrapper for EarthSpecificLayer with conditional embeddings :type dim: int :param dim: token size. :type dim: int :type cond_dim: int :param cond_dim: size of the conditional embedding. :type cond_dim: int

Parameters:
forward(x, embedding_shape, cond_emb=None)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

class mfai.pytorch.models.archesweather.EarthSpecificBlock(data_size, dim, drop_path_ratio, heads, window_size=(2, 6, 12), dropout_rate=0.0, axial_attn=False, axial_attn_heads=8, checkpoint_activation=False, lam=False)[source]

Bases: Module

3D transformer block with Earth-Specific bias and window attention, see https://github.com/microsoft/Swin-Transformer for the official implementation of 2D window attention. The major difference is that we expand the dimensions to 3 and replace the relative position bias with Earth-Specific bias.

Parameters:
  • data_size (torch.Size) – data size in terms of plevel, latitude, longitude

  • dim (int) – token size

  • drop_path_ratio (float) – ratio to apply to drop path

  • heads (int) – number of attention heads

  • window_size (tuple[int], optional) – window size for the sliding window attention.

  • to (Defaults)

  • dropout_rate (float, optional) – dropout rate in the MLP. Defaults to 0..

  • axial_attn (bool, optional) – whether to use axial attention. Defaults to False.

  • axial_attn_heads (int, optional) – number of heads for axial attention. Defaults to 8.

  • checkpoint_activation (bool, optional) – whether to use checkpoint for activation.

  • False. (Defaults to)

  • lam (bool, optional) – whether to use limited area setting for shifted-window attention.

  • False.

forward(x, embedding_shape, cond_embed=None, roll=False)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

class mfai.pytorch.models.archesweather.EarthSpecificLayer(depth, data_size, dim, drop_path_ratio_list, num_heads, window_size, dropout_rate, axial_attn, axial_attn_heads, checkpoint_activation, lam)[source]

Bases: Module

Basic layer of our network, contains 2 or 6 blocks.

Parameters:
  • depth (int) – number of blocks

  • data_size (torch.Size) – see EarthSpecificBlock

  • dim (int) – see EarthSpecificBlock

  • drop_path_ratio_list (list[float]) – see EarthSpecificBlock

  • num_heads (int) – see EarthSpecificBlock

  • window_size (tuple[int, int, int])

  • dropout_rate (float)

  • axial_attn (bool)

  • axial_attn_heads (int)

  • checkpoint_activation (bool)

  • lam (bool)

forward(x, embedding_shape, cond_embed=None)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:
Return type:

Tensor

class mfai.pytorch.models.archesweather.Interpolate(scale_factor, mode, align_corners=False)[source]

Bases: Module

Interpolation module.

Parameters:
  • scale_factor (float) – scaling

  • mode (str) – interpolation mode

  • align_corners (bool) – align corners

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor)

Return type:

Tensor

class mfai.pytorch.models.archesweather.LinVert(in_features, embedding_size)[source]

Bases: Module

Linear layer for the vertical dimension :type in_features: int :param in_features: input feature size :type in_features: int :type embedding_size: tuple[int, ...] :param embedding_size: embedding size. :type embedding_size: tuple[int, …]

Parameters:
forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor)

Return type:

Tensor

class mfai.pytorch.models.archesweather.PatchRecoveryConv(input_dim, downfactor=4, hidden_dim=96, plevel_variables=5, surface_variables=4, plevels=13)[source]

Bases: Module

Upsampling with interpolation + conv to avoid chessboard effect :type input_dim: int :param input_dim: input feature size :type input_dim: int :type downfactor: int :param downfactor: downsampling factor (patch size in latitude and longitude) :type downfactor: int :type hidden_dim: int :param hidden_dim: hidden feature size :type hidden_dim: int :type plevel_variables: int :param plevel_variables: number of level variables :type plevel_variables: int :type surface_variables: int :param surface_variables: number of surface variables :type surface_variables: int :type plevels: int :param plevels: number of levels. :type plevels: int

Parameters:
  • input_dim (int)

  • downfactor (int)

  • hidden_dim (int)

  • plevel_variables (int)

  • surface_variables (int)

  • plevels (int)

forward(x)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: tuple[Tensor, Tensor]

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Parameters:

x (Tensor)

Return type:

tuple[Tensor, Tensor]