nlam¶
Graph Neural Network architectures adapted from https://github.com/mllam/neural-lam.
- class mfai.pytorch.models.nlam.BaseGraphModel(in_channels, out_channels, input_shape, settings, *args, **kwargs)[source]¶
Bases:
BaseModelBase (abstract) class for graph-based models building on the encode-process-decode idea.
- Parameters:
- embedd_mesh_nodes()[source]¶
Embedd static mesh features Returns tensor of shape (N_mesh, d_h).
- Return type:
- finalize_graph_model()[source]¶
Method to be overridden by subclasses for finalizing the graph model.
- Return type:
- forward(x)[source]¶
Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 prev_state: (B, N_grid, feature_dim), X_t prev_prev_state: (B, N_grid, feature_dim), X_{t-1} forcing: (B, N_grid, forcing_dim).
- get_num_mesh()[source]¶
Compute number of mesh nodes from loaded features, and number of mesh nodes that should be ignored in encoding/decoding.
- hierarchical = False¶
- load_graph(device='cpu')[source]¶
Loads a graph from its disk serialised format and set them as attributes.
- model_type = 1¶
- onnx_supported = False¶
- process_step(mesh_rep)[source]¶
Process step of embedd-process-decode framework Processes the representation on the mesh, possible in multiple steps.
mesh_rep: has shape (B, N_mesh, d_h) Returns mesh_rep: (B, N_mesh, d_h)
- classmethod rank_zero_setup(settings, meshgrid)[source]¶
This is a static method to allow multi-GPU trainig frameworks to call this method once on rank zero before instantiating the model.
- Return type:
- Parameters:
settings (GraphLamSettings)
meshgrid (Tensor)
- property settings: GraphLamSettings¶
Returns the settings instance used to configure for this model.
- settings_kls¶
alias of
GraphLamSettings
- supported_num_spatial_dims = (1,)¶
- class mfai.pytorch.models.nlam.BaseHiGraphModel(in_channels, out_channels, input_shape, settings, *args, **kwargs)[source]¶
Bases:
BaseGraphModelBase class for hierarchical graph models.
- Parameters:
- embedd_mesh_nodes()[source]¶
Embedd static mesh features This embedds only bottom level, rest is done at beginning of processing step Returns tensor of shape (N_mesh[0], d_h).
- Return type:
- finalize_graph_model()[source]¶
Method to be overridden by subclasses for finalizing the graph model.
- Return type:
- get_num_mesh()[source]¶
Compute number of mesh nodes from loaded features, and number of mesh nodes that should be ignored in encoding/decoding.
- hi_processor_step(mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep)[source]¶
Internal processor step of hierarchical graph models. Between mesh init and read out.
Each input is list with representations, each with shape
mesh_rep_levels: (B, N_mesh[l], d_h) mesh_same_rep: (B, M_same[l], d_h) mesh_up_rep: (B, M_up[l -> l+1], d_h) mesh_down_rep: (B, M_down[l <- l+1], d_h)
Returns same lists
- hierarchical = True¶
- class mfai.pytorch.models.nlam.GraphLAM(in_channels, out_channels, input_shape, settings, *args, **kwargs)[source]¶
Bases:
BaseGraphModelFull graph-based LAM model that can be used with different (non-hierarchical )graphs. Mainly based on GraphCast, but the model from Keisler (2022) almost identical. Used for GC-LAM and L1-LAM in Oskarsson et al. (2023).
- Parameters:
- embedd_mesh_nodes()[source]¶
Embedd static mesh features Returns tensor of shape (N_mesh, d_h).
- Return type:
- finalize_graph_model()[source]¶
Method to be overridden by subclasses for finalizing the graph model.
- Return type:
- get_num_mesh()[source]¶
Compute number of mesh nodes from loaded features, and number of mesh nodes that should be ignored in encoding/decoding.
- class mfai.pytorch.models.nlam.GraphLamSettings(tmp_dir=PosixPath('/tmp'), hidden_dims=64, hidden_layers=1, use_checkpointing=False, offload_to_cpu=False, mesh_aggr='sum', processor_layers=4)[source]¶
Bases:
objectSettings for graph-based models.
- Parameters:
- classmethod from_dict(kvs, *, infer_missing=False)¶
- classmethod from_json(s, *, parse_float=None, parse_int=None, parse_constant=None, infer_missing=False, **kw)¶
- classmethod schema(*, infer_missing=False, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)¶
- to_json(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, indent=None, separators=None, default=None, sort_keys=False, **kw)¶
- class mfai.pytorch.models.nlam.HiLAM(in_channels, out_channels, input_shape, settings, *args, **kwargs)[source]¶
Bases:
BaseHiGraphModelHierarchical graph model with message passing that goes sequentially down and up the hierarchy during processing. The Hi-LAM model from Oskarsson et al. (2023).
- Parameters:
- finalize_graph_model()[source]¶
Method to be overridden by subclasses for finalizing the graph model.
- Return type:
- hi_processor_step(mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep)[source]¶
Internal processor step of hierarchical graph models. Between mesh init and read out.
Each input is list with representations, each with shape
mesh_rep_levels: (B, N_mesh[l], d_h) mesh_same_rep: (B, M_same[l], d_h) mesh_up_rep: (B, M_up[l -> l+1], d_h) mesh_down_rep: (B, M_down[l <- l+1], d_h)
Returns same lists
- mesh_down_step(mesh_rep_levels, mesh_same_rep, mesh_down_rep, down_gnns, same_gnns)[source]¶
Run down-part of vertical processing, sequentially alternating between processing using down edges and same-level edges.
- class mfai.pytorch.models.nlam.HiLAMParallel(in_channels, out_channels, input_shape, settings, *args, **kwargs)[source]¶
Bases:
BaseHiGraphModelVersion of HiLAM where all message passing in the hierarchical mesh (up, down, inter-level) is ran in paralell.
This is a somewhat simpler alternative to the sequential message passing of Hi-LAM.
- Parameters:
- finalize_graph_model()[source]¶
Method to be overridden by subclasses for finalizing the graph model.
- Return type:
- hi_processor_step(mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep)[source]¶
Internal processor step of hierarchical graph models. Between mesh init and read out.
Each input is list with representations, each with shape
mesh_rep_levels: (B, N_mesh[l], d_h) mesh_same_rep: (B, M_same[l], d_h) mesh_up_rep: (B, M_up[l -> l+1], d_h) mesh_down_rep: (B, M_down[l <- l+1], d_h)
Returns same lists
- mfai.pytorch.models.nlam.offload_to_cpu(model)[source]¶
- Return type:
- Parameters:
model (ModuleList)