archesweather¶
- class mfai.pytorch.models.archesweather.ArchesWeather(in_channels, out_channels, input_shape, settings=ArchesWeatherSettings(plevel_patch_size=(2, 2, 2), token_size=192, layer_depth=(2, 6), num_heads=(6, 12, 12, 6), spatial_dims=2, surface_variables=4, plevel_variables=5, plevels=13, static_length=3, window_size=(1, 6, 10), dropout_rate=0.0, checkpoint_activation=False, lam=False, cond_dim=32, droppath_coeff=0.2, depth_multiplier=1, position_embs_dim=0, use_prev=False, use_skip=False, conv_head=False, first_interaction_layer=False, axial_attn=False, axial_attn_heads=8))[source]¶
Bases:
BaseModelArchesWeather model as described in http://arxiv.org/abs/2405.14527.
- Parameters:
in_channels (int)
out_channels (int)
settings (ArchesWeatherSettings)
- forward(input_level, input_surface, static_data=None, cond_emb=None)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
tuple[Tensor,Tensor]Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- model_type = 6¶
- property settings: ArchesWeatherSettings¶
Returns the settings instance used to configure for this model.
- settings_kls¶
alias of
ArchesWeatherSettings
- class mfai.pytorch.models.archesweather.ArchesWeatherSettings(plevel_patch_size=(2, 2, 2), token_size=192, layer_depth=(2, 6), num_heads=(6, 12, 12, 6), spatial_dims=2, surface_variables=4, plevel_variables=5, plevels=13, static_length=3, window_size=(1, 6, 10), dropout_rate=0.0, checkpoint_activation=False, lam=False, cond_dim=32, droppath_coeff=0.2, depth_multiplier=1, position_embs_dim=0, use_prev=False, use_skip=False, conv_head=False, first_interaction_layer=False, axial_attn=False, axial_attn_heads=8)[source]¶
Bases:
PanguWeatherSettingsArchesWeather configuration class. Inherits from PanguWeatherSettings, with additional hyperparameters for Earth-Specific bias and window attention.
- Parameters:
token_size (
int) – embedding sizecond_dim (
int) – conditioning embedding sizenum_heads (
tuple) – number of heads per EarthSpecificLayerdroppath_coeff (
float) – drop path coefficientplevel_patch_size (
tuple) – patch size for input data embeddingwindow_size (
tuple) – window size for shifted-window attention of EarthSpecificBlockdepth_multiplier (
int) – depth multiplier for the number of blocks in EarthSpecificLayerposition_embs_dim (
int) – dimension of positional embeddingsuse_prev (
bool) – whether to use previous stateuse_skip (
bool) – whether to use skip connectionsconv_head (
bool) – whether to use a convolutional head for patch recoverydropout_rate (
float) – dropout ratefirst_interaction_layer (
bool) – whether to use a linear interaction layer before the first EarthSpecificLayercheckpoint_activation (
bool) – whether to use gradient checkpointingaxial_attn (
bool) – whether to use axial attentionaxial_attn_head – number of heads for axial attention
lam (
bool) – whether to use limited area setting in the attention masklon_resolution – longitude resolution
lat_resolution – latitude resolution
surface_variables (
int) – number of variables in the surface datastatic_length (
int) – number of variables in the mask dataplevel_variables (
int) – number of variables in the level dataplevels (
int) – number of atmospheric levels in the level dataspatial_dims (
int) – number of spatial dimensions (2).axial_attn_heads (int)
- classmethod from_dict(kvs, *, infer_missing=False)¶
- classmethod from_json(s, *, parse_float=None, parse_int=None, parse_constant=None, infer_missing=False, **kw)¶
- classmethod schema(*, infer_missing=False, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)¶
- to_json(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, indent=None, separators=None, default=None, sort_keys=False, **kw)¶
- class mfai.pytorch.models.archesweather.CondBasicLayer(depth, data_size, dim, cond_dim, drop_path_ratio_list, num_heads, window_size, dropout_rate, lam, axial_attn, axial_attn_heads, checkpoint_activation)[source]¶
Bases:
EarthSpecificLayerWrapper for EarthSpecificLayer with conditional embeddings :type dim:
int:param dim: token size. :type dim: int :type cond_dim:int:param cond_dim: size of the conditional embedding. :type cond_dim: int- Parameters:
- forward(x, embedding_shape, cond_emb=None)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mfai.pytorch.models.archesweather.EarthSpecificBlock(data_size, dim, drop_path_ratio, heads, window_size=(2, 6, 12), dropout_rate=0.0, axial_attn=False, axial_attn_heads=8, checkpoint_activation=False, lam=False)[source]¶
Bases:
Module3D transformer block with Earth-Specific bias and window attention, see https://github.com/microsoft/Swin-Transformer for the official implementation of 2D window attention. The major difference is that we expand the dimensions to 3 and replace the relative position bias with Earth-Specific bias.
- Parameters:
data_size (torch.Size) – data size in terms of plevel, latitude, longitude
dim (int) – token size
drop_path_ratio (float) – ratio to apply to drop path
heads (int) – number of attention heads
window_size (tuple[int], optional) – window size for the sliding window attention.
to (Defaults)
dropout_rate (float, optional) – dropout rate in the MLP. Defaults to 0..
axial_attn (bool, optional) – whether to use axial attention. Defaults to False.
axial_attn_heads (int, optional) – number of heads for axial attention. Defaults to 8.
checkpoint_activation (bool, optional) – whether to use checkpoint for activation.
False. (Defaults to)
lam (bool, optional) – whether to use limited area setting for shifted-window attention.
False.
- forward(x, embedding_shape, cond_embed=None, roll=False)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mfai.pytorch.models.archesweather.EarthSpecificLayer(depth, data_size, dim, drop_path_ratio_list, num_heads, window_size, dropout_rate, axial_attn, axial_attn_heads, checkpoint_activation, lam)[source]¶
Bases:
ModuleBasic layer of our network, contains 2 or 6 blocks.
- Parameters:
- forward(x, embedding_shape, cond_embed=None)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mfai.pytorch.models.archesweather.Interpolate(scale_factor, mode, align_corners=False)[source]¶
Bases:
ModuleInterpolation module.
- Parameters:
- forward(x)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mfai.pytorch.models.archesweather.LinVert(in_features, embedding_size)[source]¶
Bases:
ModuleLinear layer for the vertical dimension :type in_features:
int:param in_features: input feature size :type in_features: int :type embedding_size:tuple[int,...] :param embedding_size: embedding size. :type embedding_size: tuple[int, …]- forward(x)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class mfai.pytorch.models.archesweather.PatchRecoveryConv(input_dim, downfactor=4, hidden_dim=96, plevel_variables=5, surface_variables=4, plevels=13)[source]¶
Bases:
ModuleUpsampling with interpolation + conv to avoid chessboard effect :type input_dim:
int:param input_dim: input feature size :type input_dim: int :type downfactor:int:param downfactor: downsampling factor (patch size in latitude and longitude) :type downfactor: int :type hidden_dim:int:param hidden_dim: hidden feature size :type hidden_dim: int :type plevel_variables:int:param plevel_variables: number of level variables :type plevel_variables: int :type surface_variables:int:param surface_variables: number of surface variables :type surface_variables: int :type plevels:int:param plevels: number of levels. :type plevels: int- Parameters:
- forward(x)[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses. :rtype:
tuple[Tensor,Tensor]Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.