MLFlowSystemMonitorCallback¶
- class mfai.pytorch.callbacks.MLFlowSystemMonitorCallback[source]¶
Bases:
CallbackA Lightning callback to log system metrics (GPU usage etc.) in MLFlow. We use this callback because the default MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING option from mlflow doesn’t work with lightning. See this issue: https://github.com/Lightning-AI/pytorch-lightning/issues/20563.
- load_state_dict(state_dict)¶
Called when loading a checkpoint, implement to reload callback state given callback’s
state_dict.
- on_after_backward(trainer, pl_module)¶
Called after
loss.backward()and before optimizers are stepped.- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_before_backward(trainer, pl_module, loss)¶
Called before
loss.backward().- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
loss (Tensor)
- on_before_optimizer_step(trainer, pl_module, optimizer)¶
Called before
optimizer.step().- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
optimizer (Optimizer)
- on_before_zero_grad(trainer, pl_module, optimizer)¶
Called before
optimizer.zero_grad().- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
optimizer (Optimizer)
- on_exception(trainer, pl_module, exception)¶
Called when any trainer execution is interrupted by an exception.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
exception (BaseException)
- on_fit_end(trainer, pl_module)[source]¶
Called when fit ends.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_fit_start(trainer, pl_module)[source]¶
Called when fit begins.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_load_checkpoint(trainer, pl_module, checkpoint)¶
Called when loading a model checkpoint, use to reload state.
- Parameters:
pl_module (
LightningModule) – the currentLightningModuleinstance.checkpoint (
dict[str,Any]) – the full checkpoint dictionary that got loaded by the Trainer.
- Return type:
- on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)¶
Called when the predict batch ends.
- on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)¶
Called when the predict batch begins.
- on_predict_end(trainer, pl_module)¶
Called when predict ends.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_predict_epoch_end(trainer, pl_module)¶
Called when the predict epoch ends.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_predict_epoch_start(trainer, pl_module)¶
Called when the predict epoch begins.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_predict_start(trainer, pl_module)¶
Called when the predict begins.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_sanity_check_end(trainer, pl_module)¶
Called when the validation sanity check ends.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_sanity_check_start(trainer, pl_module)¶
Called when the validation sanity check starts.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_save_checkpoint(trainer, pl_module, checkpoint)¶
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
pl_module (
LightningModule) – the currentLightningModuleinstance.checkpoint (
dict[str,Any]) – the checkpoint dictionary that will be saved.
- Return type:
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)¶
Called when the test batch ends.
- on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)¶
Called when the test batch begins.
- on_test_end(trainer, pl_module)¶
Called when the test ends.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_test_epoch_end(trainer, pl_module)¶
Called when the test epoch ends.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_test_epoch_start(trainer, pl_module)¶
Called when the test epoch begins.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_test_start(trainer, pl_module)¶
Called when the test begins.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)¶
Called when the train batch ends. :rtype:
NoneNote
The value
outputs["loss"]here will be the normalized value w.r.taccumulate_grad_batchesof the loss returned fromtraining_step.
- on_train_batch_start(trainer, pl_module, batch, batch_idx)¶
Called when the train batch begins.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
batch (Any)
batch_idx (int)
- on_train_end(trainer, pl_module)¶
Called when the train ends.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_train_epoch_end(trainer, pl_module)¶
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
lightning.pytorch.core.LightningModuleand access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_train_epoch_start(trainer, pl_module)¶
Called when the train epoch begins.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_train_start(trainer, pl_module)¶
Called when the train begins.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)¶
Called when the validation batch ends.
- on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)¶
Called when the validation batch begins.
- on_validation_end(trainer, pl_module)¶
Called when the validation loop ends.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_validation_epoch_end(trainer, pl_module)¶
Called when the val epoch ends.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_validation_epoch_start(trainer, pl_module)¶
Called when the val epoch begins.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- on_validation_start(trainer, pl_module)¶
Called when the validation loop begins.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
- setup(trainer, pl_module, stage)¶
Called when fit, validate, test, predict, or tune begins.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
stage (str)
- state_dict()¶
Called when saving a checkpoint, implement to generate callback’s
state_dict.
- property state_key: str¶
Identifier for the state of the callback.
Used to store and retrieve a callback’s state from the checkpoint dictionary by
checkpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.
- teardown(trainer, pl_module, stage)¶
Called when fit, validate, test, predict, or tune ends.
- Return type:
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
stage (str)