cross_attention¶
- class mfai.pytorch.models.llms.cross_attention.XAttMultiModalLM(settings=XAttMultiModalLMSettings(emb_dim=768, context_length=1024, n_heads=12, n_layers=12, drop_rate=0.1, qkv_bias=False, vision_input_shape=(3, 256, 256), x_att_ratio=4, resnet_num_tokens=32, resnet_pos_embedding=False, resnet_mlp_output=False, vision_encoder='linear', layer_norm_vis=True, patch_size=None), vocab_size=50257)[source]¶
Bases:
FreezeMLMMixin,ModuleA multimodal LLM with cross attention. Can use GPT2 or Llama2 as its LLM backend.
- Parameters:
settings (XAttMultiModalLMSettings)
vocab_size (int)
- forward(txt_token_ids, vision_inputs)[source]¶
Forward function of the Cross-Attention Multimodal language model.
- Parameters:
txt_token_ids (Tensor) – tensor of shape (B, n_tok)
vision_inputs (Tensor | list[Tensor]) – tensor or list of tensor of shape (B, channels, lat, lon)
- Returns:
tensor of shape (B, n_tok, vocab_size)
- Return type:
Tensor
- model_type = 5¶
- settings_kls¶
alias of
XAttMultiModalLMSettings
- class mfai.pytorch.models.llms.cross_attention.XAttMultiModalLMSettings(emb_dim=768, context_length=1024, n_heads=12, n_layers=12, drop_rate=0.1, qkv_bias=False, vision_input_shape=(3, 256, 256), x_att_ratio=4, resnet_num_tokens=32, resnet_pos_embedding=False, resnet_mlp_output=False, vision_encoder='linear', layer_norm_vis=True, patch_size=None)[source]¶
Bases:
objectSettings for our cross attention multimodal language model.
- Parameters:
- classmethod from_dict(kvs, *, infer_missing=False)¶
- classmethod from_json(s, *, parse_float=None, parse_int=None, parse_constant=None, infer_missing=False, **kw)¶
- classmethod schema(*, infer_missing=False, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)¶
- to_json(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, indent=None, separators=None, default=None, sort_keys=False, **kw)¶