dummy_dataset¶
- class mfai.pytorch.dummy_dataset.DummyDataModule(task='binary', batch_size=2, dim_x=64, dim_y=64, nb_input_channels=2, nb_output_channels=1)[source]¶
Bases:
LightningDataModuleA Lightning DataModule wrapping our dummy dataset. It defines the train/valid/test/predict datasets and their dataloaders.
- Parameters:
- predict_dataloader()[source]¶
An iterable or collection of iterables specifying prediction samples.
For more information about multiple dataloaders, see this section.
It’s recommended that all data downloads and preparation happen in
prepare_data().Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
- Return type:
- Returns:
A
torch.utils.data.DataLoaderor a sequence of them specifying prediction samples.
- setup(stage='')[source]¶
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- test_dataloader()[source]¶
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern: :rtype:
DataLoaderdownload in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note
If you don’t need a test dataset and a
test_step(), you don’t need to implement this method.- Return type:
- train_dataloader()[source]¶
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :rtype:
DataLoader:paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Return type:
- val_dataloader()[source]¶
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :rtype:
DataLoader:paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data().prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.- Return type:
- class mfai.pytorch.dummy_dataset.DummyDataset(split, task='binary', dim_x=64, dim_y=64, nb_input_channels=2, nb_output_channels=1)[source]¶
Bases:
DatasetA dummy segmentation dataset to test our training modules. X is a random float tensor of chosen size. Y is a random binary tensor of chosen size. X and Y share the same height and width.
- class mfai.pytorch.dummy_dataset.DummyMultiModalDataModule(batch_size=2, dim_x=64, dim_y=64, nb_input_channels=2, context_length=8)[source]¶
Bases:
LightningDataModuleA Lightning DataModule wrapping our dummy dataset. It defines the train/valid/test/predict datasets and their dataloaders.
- collate_fn_fit(batch)[source]¶
Collate a batch of multimodal data.
- Return type:
- Parameters:
batch (List[Tuple[NamedTensor, Tensor, Tensor]])
- setup(stage='')[source]¶
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- test_dataloader()[source]¶
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern: :rtype:
DataLoaderdownload in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note
If you don’t need a test dataset and a
test_step(), you don’t need to implement this method.- Return type:
- train_dataloader()[source]¶
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :rtype:
DataLoader:paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Return type:
- val_dataloader()[source]¶
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :rtype:
DataLoader:paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data().prepare_data()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.- Return type:
- class mfai.pytorch.dummy_dataset.DummyMultiModalDataset(split, dim_x=64, dim_y=64, nb_input_channels=2, context_length=8)[source]¶
Bases:
Dataset- A dummy multimodal dataset to test our training modules.
image is a random float tensor of size (nb_input_channels, dim_x, dim_y).
input_text is a random integer tensor of size (n_tokens, emb_dim).
output_text is a random integer tensor of size (n_tokens+1, emb_dim).