From 3e06a266437ae27017a10e70b272ce6aee15e9eb Mon Sep 17 00:00:00 2001 From: The kauldron Authors Date: Thu, 24 Oct 2024 01:17:21 -0700 Subject: [PATCH] Mark training complete after last checkpoint saving is completed. PiperOrigin-RevId: 689279915 --- kauldron/train/train_lib.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/kauldron/train/train_lib.py b/kauldron/train/train_lib.py index 4333d062..4a30d1da 100644 --- a/kauldron/train/train_lib.py +++ b/kauldron/train/train_lib.py @@ -141,15 +141,16 @@ def train_impl( log_summaries=log_summaries, ) + # Ensure all hosts exit together. See section in dm/jax-faqs. + _sync() + ckpt.wait_until_finished() + # Notify the eval job training is complete if trainer.workdir.exists(): # `TrainEvaluator` do not have a workdir epath.Path(trainer.workdir).joinpath( eval_impl.TRAIN_COMPLETE_FILENAME ).touch() - # Ensure all hosts exit together. See section in dm/jax-faqs. - _sync() - ckpt.wait_until_finished() # Returning the final state is convenient for interactive training in colab return state, aux