nlam

Graph Neural Network architectures adapted from https://github.com/mllam/neural-lam.

class mfai.pytorch.models.nlam.BaseGraphModel(in_channels, out_channels, input_shape, settings, *args, **kwargs)[source]

Bases: BaseModel

Base (abstract) class for graph-based models building on the encode-process-decode idea.

Parameters:
embedd_mesh_nodes()[source]

Embedd static mesh features Returns tensor of shape (N_mesh, d_h).

Return type:

Tensor

features_last: bool = True
finalize_graph_model()[source]

Method to be overridden by subclasses for finalizing the graph model.

Return type:

None

forward(x)[source]

Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 prev_state: (B, N_grid, feature_dim), X_t prev_prev_state: (B, N_grid, feature_dim), X_{t-1} forcing: (B, N_grid, forcing_dim).

Return type:

Tensor

Parameters:

x (Tensor)

get_num_mesh()[source]

Compute number of mesh nodes from loaded features, and number of mesh nodes that should be ignored in encoding/decoding.

Return type:

tuple[int, int]

hierarchical = False
load_graph(device='cpu')[source]

Loads a graph from its disk serialised format and set them as attributes.

Return type:

None

Parameters:

device (device | Literal['cpu', 'cuda'])

model_type = 1
num_spatial_dims: int = 1
onnx_supported = False
process_step(mesh_rep)[source]

Process step of embedd-process-decode framework Processes the representation on the mesh, possible in multiple steps.

mesh_rep: has shape (B, N_mesh, d_h) Returns mesh_rep: (B, N_mesh, d_h)

Return type:

Tensor

Parameters:

mesh_rep (Tensor)

classmethod rank_zero_setup(settings, meshgrid)[source]

This is a static method to allow multi-GPU trainig frameworks to call this method once on rank zero before instantiating the model.

Return type:

None

Parameters:
property settings: GraphLamSettings

Returns the settings instance used to configure for this model.

settings_kls

alias of GraphLamSettings

supported_num_spatial_dims = (1,)
class mfai.pytorch.models.nlam.BaseHiGraphModel(in_channels, out_channels, input_shape, settings, *args, **kwargs)[source]

Bases: BaseGraphModel

Base class for hierarchical graph models.

Parameters:
embedd_mesh_nodes()[source]

Embedd static mesh features This embedds only bottom level, rest is done at beginning of processing step Returns tensor of shape (N_mesh[0], d_h).

Return type:

Tensor

finalize_graph_model()[source]

Method to be overridden by subclasses for finalizing the graph model.

Return type:

None

get_num_mesh()[source]

Compute number of mesh nodes from loaded features, and number of mesh nodes that should be ignored in encoding/decoding.

Return type:

tuple[int, int]

hi_processor_step(mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep)[source]

Internal processor step of hierarchical graph models. Between mesh init and read out.

Each input is list with representations, each with shape

mesh_rep_levels: (B, N_mesh[l], d_h) mesh_same_rep: (B, M_same[l], d_h) mesh_up_rep: (B, M_up[l -> l+1], d_h) mesh_down_rep: (B, M_down[l <- l+1], d_h)

Returns same lists

Return type:

tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]]

Parameters:
hierarchical = True
process_step(mesh_rep)[source]

Process step of embedd-process-decode framework Processes the representation on the mesh, possible in multiple steps.

mesh_rep: has shape (B, N_mesh, d_h) Returns mesh_rep: (B, N_mesh, d_h)

Return type:

Tensor

Parameters:

mesh_rep (Tensor)

class mfai.pytorch.models.nlam.GraphLAM(in_channels, out_channels, input_shape, settings, *args, **kwargs)[source]

Bases: BaseGraphModel

Full graph-based LAM model that can be used with different (non-hierarchical )graphs. Mainly based on GraphCast, but the model from Keisler (2022) almost identical. Used for GC-LAM and L1-LAM in Oskarsson et al. (2023).

Parameters:
embedd_mesh_nodes()[source]

Embedd static mesh features Returns tensor of shape (N_mesh, d_h).

Return type:

Tensor

finalize_graph_model()[source]

Method to be overridden by subclasses for finalizing the graph model.

Return type:

None

get_num_mesh()[source]

Compute number of mesh nodes from loaded features, and number of mesh nodes that should be ignored in encoding/decoding.

Return type:

tuple[int, int]

process_step(mesh_rep)[source]

Process step of embedd-process-decode framework Processes the representation on the mesh, possible in multiple steps.

mesh_rep: has shape (B, N_mesh, d_h) Returns mesh_rep: (B, N_mesh, d_h)

Return type:

Tensor

Parameters:

mesh_rep (Tensor)

register: bool = True
class mfai.pytorch.models.nlam.GraphLamSettings(tmp_dir=PosixPath('/tmp'), hidden_dims=64, hidden_layers=1, use_checkpointing=False, offload_to_cpu=False, mesh_aggr='sum', processor_layers=4)[source]

Bases: object

Settings for graph-based models.

Parameters:
  • tmp_dir (Path)

  • hidden_dims (int)

  • hidden_layers (int)

  • use_checkpointing (bool)

  • offload_to_cpu (bool)

  • mesh_aggr (Literal['sum', 'mean'])

  • processor_layers (int)

classmethod from_dict(kvs, *, infer_missing=False)
Return type:

TypeVar(A, bound= DataClassJsonMixin)

Parameters:

kvs (dict | list | str | int | float | bool | None)

classmethod from_json(s, *, parse_float=None, parse_int=None, parse_constant=None, infer_missing=False, **kw)
Return type:

TypeVar(A, bound= DataClassJsonMixin)

Parameters:

s (str | bytes | bytearray)

hidden_dims: int
hidden_layers: int
mesh_aggr: Literal['sum', 'mean']
offload_to_cpu: bool
processor_layers: int
classmethod schema(*, infer_missing=False, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)
Return type:

SchemaF[TypeVar(A, bound= DataClassJsonMixin)]

Parameters:
tmp_dir: Path
to_dict(encode_json=False)
Return type:

Dict[str, Union[dict, list, str, int, float, bool, None]]

to_json(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, indent=None, separators=None, default=None, sort_keys=False, **kw)
Return type:

str

Parameters:
use_checkpointing: bool
class mfai.pytorch.models.nlam.HiLAM(in_channels, out_channels, input_shape, settings, *args, **kwargs)[source]

Bases: BaseHiGraphModel

Hierarchical graph model with message passing that goes sequentially down and up the hierarchy during processing. The Hi-LAM model from Oskarsson et al. (2023).

Parameters:
finalize_graph_model()[source]

Method to be overridden by subclasses for finalizing the graph model.

Return type:

None

hi_processor_step(mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep)[source]

Internal processor step of hierarchical graph models. Between mesh init and read out.

Each input is list with representations, each with shape

mesh_rep_levels: (B, N_mesh[l], d_h) mesh_same_rep: (B, M_same[l], d_h) mesh_up_rep: (B, M_up[l -> l+1], d_h) mesh_down_rep: (B, M_down[l <- l+1], d_h)

Returns same lists

Return type:

tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]]

Parameters:
make_down_gnns()[source]

Make GNNs for processing steps down through the hierarchy.

Return type:

ModuleList

make_same_gnns()[source]

Make intra-level GNNs.

Return type:

ModuleList

make_up_gnns()[source]

Make GNNs for processing steps up through the hierarchy.

Return type:

ModuleList

mesh_down_step(mesh_rep_levels, mesh_same_rep, mesh_down_rep, down_gnns, same_gnns)[source]

Run down-part of vertical processing, sequentially alternating between processing using down edges and same-level edges.

Return type:

tuple[list[Tensor], list[Tensor], list[Tensor]]

Parameters:
mesh_up_step(mesh_rep_levels, mesh_same_rep, mesh_up_rep, up_gnns, same_gnns)[source]

Run up-part of vertical processing, sequentially alternating between processing using up edges and same-level edges.

Return type:

tuple[list[Tensor], list[Tensor], list[Tensor]]

Parameters:
register: bool = True
class mfai.pytorch.models.nlam.HiLAMParallel(in_channels, out_channels, input_shape, settings, *args, **kwargs)[source]

Bases: BaseHiGraphModel

Version of HiLAM where all message passing in the hierarchical mesh (up, down, inter-level) is ran in paralell.

This is a somewhat simpler alternative to the sequential message passing of Hi-LAM.

Parameters:
finalize_graph_model()[source]

Method to be overridden by subclasses for finalizing the graph model.

Return type:

None

hi_processor_step(mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep)[source]

Internal processor step of hierarchical graph models. Between mesh init and read out.

Each input is list with representations, each with shape

mesh_rep_levels: (B, N_mesh[l], d_h) mesh_same_rep: (B, M_same[l], d_h) mesh_up_rep: (B, M_up[l -> l+1], d_h) mesh_down_rep: (B, M_down[l <- l+1], d_h)

Returns same lists

Return type:

tuple[list[Tensor], list[Tensor], list[Tensor], list[Tensor]]

Parameters:
register: bool = True
mfai.pytorch.models.nlam.offload_to_cpu(model)[source]
Return type:

ModuleList

Parameters:

model (ModuleList)