cross_attention

class mfai.pytorch.models.llms.cross_attention.XAttMultiModalLM(settings=XAttMultiModalLMSettings(emb_dim=768, context_length=1024, n_heads=12, n_layers=12, drop_rate=0.1, qkv_bias=False, vision_input_shape=(3, 256, 256), x_att_ratio=4, resnet_num_tokens=32, resnet_pos_embedding=False, resnet_mlp_output=False, vision_encoder='linear', layer_norm_vis=True, patch_size=None), vocab_size=50257)[source]

Bases: FreezeMLMMixin, Module

A multimodal LLM with cross attention. Can use GPT2 or Llama2 as its LLM backend.

Parameters:
property context_length: int
forward(txt_token_ids, vision_inputs)[source]

Forward function of the Cross-Attention Multimodal language model.

Parameters:
  • txt_token_ids (Tensor) – tensor of shape (B, n_tok)

  • vision_inputs (Tensor | list[Tensor]) – tensor or list of tensor of shape (B, channels, lat, lon)

Returns:

tensor of shape (B, n_tok, vocab_size)

Return type:

Tensor

model_type = 5
settings_kls

alias of XAttMultiModalLMSettings

class mfai.pytorch.models.llms.cross_attention.XAttMultiModalLMSettings(emb_dim=768, context_length=1024, n_heads=12, n_layers=12, drop_rate=0.1, qkv_bias=False, vision_input_shape=(3, 256, 256), x_att_ratio=4, resnet_num_tokens=32, resnet_pos_embedding=False, resnet_mlp_output=False, vision_encoder='linear', layer_norm_vis=True, patch_size=None)[source]

Bases: object

Settings for our cross attention multimodal language model.

Parameters:
context_length: int = 1024
drop_rate: float = 0.1
emb_dim: int = 768
classmethod from_dict(kvs, *, infer_missing=False)
Return type:

TypeVar(A, bound= DataClassJsonMixin)

Parameters:

kvs (dict | list | str | int | float | bool | None)

classmethod from_json(s, *, parse_float=None, parse_int=None, parse_constant=None, infer_missing=False, **kw)
Return type:

TypeVar(A, bound= DataClassJsonMixin)

Parameters:

s (str | bytes | bytearray)

layer_norm_vis: bool = True
n_heads: int = 12
n_layers: int = 12
patch_size: None | int | tuple[int, int] = None
qkv_bias: bool = False
resnet_mlp_output: bool = False
resnet_num_tokens: int = 32
resnet_pos_embedding: bool = False
classmethod schema(*, infer_missing=False, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)
Return type:

SchemaF[TypeVar(A, bound= DataClassJsonMixin)]

Parameters:
to_dict(encode_json=False)
Return type:

Dict[str, Union[dict, list, str, int, float, bool, None]]

to_json(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, indent=None, separators=None, default=None, sort_keys=False, **kw)
Return type:

str

Parameters:
vision_encoder: Literal['resnet50', 'linear', 'vit'] = 'linear'
vision_input_shape: tuple[int, int, int] = (3, 256, 256)
x_att_ratio: int = 4