From c8154313aca1685f5b2a0545e8a21afefd3d6709 Mon Sep 17 00:00:00 2001 From: Klaus Greff Date: Wed, 23 Oct 2024 05:21:42 -0700 Subject: [PATCH] deuglify mnist example PiperOrigin-RevId: 688922829 --- examples/mnist_autoencoder.py | 3 +-- kauldron/data/py/base.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/mnist_autoencoder.py b/examples/mnist_autoencoder.py index 9c1508b5..cdfbd0e2 100644 --- a/examples/mnist_autoencoder.py +++ b/examples/mnist_autoencoder.py @@ -87,8 +87,7 @@ def get_config(): def _make_ds(training: bool): - Tfds = kd.data.py.Tfds - return Tfds( + return kd.data.py.Tfds( name="mnist", split="train" if training else "test", shuffle=True if training else False, diff --git a/kauldron/data/py/base.py b/kauldron/data/py/base.py index 74135fa3..57907eca 100644 --- a/kauldron/data/py/base.py +++ b/kauldron/data/py/base.py @@ -170,7 +170,7 @@ def ds_for_current_process(self, rng: random.PRNGKey) -> grain.MapDataset: def _get_num_workers(num_workers: int) -> int: """Set the number of workers.""" - if epy.is_notebook(): # in colab worker_count has to be 0 + if epy.is_notebook() or epy.is_test(): # in colab worker_count has to be 0 # TODO(klausg): autodetect if Kernel supports multiprocessing # Could check # from multiprocessing import spawn