fuyu¶
- class mfai.pytorch.models.llms.fuyu.Fuyu(settings=FuyuSettings(backend='gpt2', emb_dim=768, context_length=1024, n_heads=12, n_layers=12, drop_rate=0.1, qkv_bias=False, hidden_dim=768, num_kv_groups=2, rope_base=500000.0, model_size='124M', attn_tf_compat=False, vision_input_shape=(3, 256, 256), inject_vision_each_stage=False, vision_encoder='linear', resnet_num_tokens=32, resnet_pos_embedding=False, resnet_mlp_output=False, layer_norm_viz_txt=False, layer_norm_viz=True, patch_size=None), vocab_size=50257)[source]¶
Bases:
FreezeMLMMixin,ModuleA multimodal LLM : vision/weather and txt inspired by Fuyu. Can use GPT2 or Llama2 as its LLM backend.
- Parameters:
settings (FuyuSettings)
vocab_size (int)
- forward(txt_token_ids, vision_inputs)[source]¶
Forward function of the Fuyu Multimodal language model.
- Parameters:
txt_token_ids (Tensor) – tensor of shape (B, n_tok).
vision_inputs (Tensor | list[Tensor]) – tensor or list of tensor of shape (B, channels, lat, lon).
- Returns:
tensor of shape (B, n_tok, vocab_size)
- Return type:
Tensor
- model_type = 5¶
- settings_kls¶
alias of
FuyuSettings
- class mfai.pytorch.models.llms.fuyu.FuyuSettings(backend='gpt2', emb_dim=768, context_length=1024, n_heads=12, n_layers=12, drop_rate=0.1, qkv_bias=False, hidden_dim=768, num_kv_groups=2, rope_base=500000.0, model_size='124M', attn_tf_compat=False, vision_input_shape=(3, 256, 256), inject_vision_each_stage=False, vision_encoder='linear', resnet_num_tokens=32, resnet_pos_embedding=False, resnet_mlp_output=False, layer_norm_viz_txt=False, layer_norm_viz=True, patch_size=None)[source]¶
Bases:
objectSettings for a multimodal language model.
- Parameters:
backend (Literal['gpt2', 'llama2', 'llama3'])
emb_dim (int)
context_length (int)
n_heads (int)
n_layers (int)
drop_rate (float)
qkv_bias (bool)
hidden_dim (int)
num_kv_groups (int)
rope_base (float)
model_size (Literal['124M', '355M', '774M', '1558M'])
attn_tf_compat (bool)
inject_vision_each_stage (bool)
vision_encoder (Literal['resnet50', 'linear', 'vit'])
resnet_num_tokens (int)
resnet_pos_embedding (bool)
resnet_mlp_output (bool)
layer_norm_viz_txt (bool)
layer_norm_viz (bool)
- classmethod from_dict(kvs, *, infer_missing=False)¶
- classmethod from_json(s, *, parse_float=None, parse_int=None, parse_constant=None, infer_missing=False, **kw)¶
- classmethod schema(*, infer_missing=False, only=None, exclude=(), many=False, context=None, load_only=(), dump_only=(), partial=False, unknown=None)¶
- to_json(*, skipkeys=False, ensure_ascii=True, check_circular=True, allow_nan=True, indent=None, separators=None, default=None, sort_keys=False, **kw)¶