Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add another pitfall to docs that can happen when using rank_zero_only decorator in lightning #2719

Merged
merged 5 commits into from
Sep 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/source/pages/lightning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch.nn import Module
from lightning.pytorch import LightningModule
from lightning.pytorch.utilities import rank_zero_only
from torchmetrics import Metric

#################################
Expand Down Expand Up @@ -193,6 +194,29 @@ The following contains a list of pitfalls to be aware of:
Because the object is logged in the first case, Lightning will reset the metric before calling the second line leading
to errors or nonsense results.

* If you decorate a lightning method with the ``rank_zero_only`` decorator with the goal of only calculating a particular
metric on the main process, you need to disable the default behavior of the metric to synchronize the metric values
across all processes. This can be done by setting the ``sync_on_compute`` flag to ``False`` when initializing the
metric. Not doing so can lead to race conditions and processes hanging.

.. testcode:: python

class MyModule(LightningModule):

def __init__(self, num_classes):
...
self.metric = torchmetrics.image.FrechetInceptionDistance(sync_on_compute=False)

@rank_zero_only
def validation_step(self, batch, batch_idx):
image, target = batch
generated_image = self(x)
...
self.metric(image, real=True)
self.metric(generated_image, real=False)
val = self.metric.compute() # this will only be called on the main process
self.log('val_fid', val)

* Calling ``self.log("val", self.metric(preds, target))`` with the intention of logging the metric object. Because
``self.metric(preds, target)`` corresponds to calling the forward method, this will return a tensor and not the
metric object. Such logging will be wrong in this case. Instead, it is essential to separate into several lines:
Expand Down
Loading