diff --git a/common/checkpointing/snapshot.py b/common/checkpointing/snapshot.py index 2703efd..f20b25d 100644 --- a/common/checkpointing/snapshot.py +++ b/common/checkpointing/snapshot.py @@ -73,8 +73,8 @@ def restore(self, checkpoint: str) -> None: snapshot.restore(self.state) # we still need to ensure that extra_state has walltime in it self.state["extra_state"] = torchsnapshot.StateDict(step=self.step, walltime=0.0) - - logging.info(f"Restored snapshot from {snapshot.path}. ({time.time() - start_time:.05}s") + else: + logging.info(f"Restored snapshot from {snapshot.path}. ({time.time() - start_time:.05}s") @classmethod def get_torch_snapshot(