diff --git a/examples/mnist_autoencoder.py b/examples/mnist_autoencoder.py index 9c1508b5..cdfbd0e2 100644 --- a/examples/mnist_autoencoder.py +++ b/examples/mnist_autoencoder.py @@ -87,8 +87,7 @@ def get_config(): def _make_ds(training: bool): - Tfds = kd.data.py.Tfds - return Tfds( + return kd.data.py.Tfds( name="mnist", split="train" if training else "test", shuffle=True if training else False, diff --git a/kauldron/data/py/base.py b/kauldron/data/py/base.py index 74135fa3..57907eca 100644 --- a/kauldron/data/py/base.py +++ b/kauldron/data/py/base.py @@ -170,7 +170,7 @@ def ds_for_current_process(self, rng: random.PRNGKey) -> grain.MapDataset: def _get_num_workers(num_workers: int) -> int: """Set the number of workers.""" - if epy.is_notebook(): # in colab worker_count has to be 0 + if epy.is_notebook() or epy.is_test(): # in colab worker_count has to be 0 # TODO(klausg): autodetect if Kernel supports multiprocessing # Could check # from multiprocessing import spawn diff --git a/kauldron/evals/eval_impl_test.py b/kauldron/evals/eval_impl_test.py index 6c9b38b1..36a63805 100644 --- a/kauldron/evals/eval_impl_test.py +++ b/kauldron/evals/eval_impl_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test.""" - from collections.abc import Iterator import os from unittest import mock @@ -28,7 +26,9 @@ def test_eval_impl(tmp_path: epath.Path): # Load config and reduce size cfg = mnist_autoencoder.get_config() - cfg.train_ds.batch_size = 2 + # TODO(klausg): remove this once data mocking works correctly with grain + cfg.train_ds.__qualname__ = 'kauldron.kd:data.Tfds' + cfg.train_ds.batch_size = 1 cfg.evals.eval.ds.batch_size = 1 # pytype: disable=attribute-error cfg.model.encoder.features = 3 cfg.num_train_steps = 1 diff --git a/kauldron/utils/sharding_utils_test.py b/kauldron/utils/sharding_utils_test.py index a8f257d6..3fb35ffa 100644 --- a/kauldron/utils/sharding_utils_test.py +++ b/kauldron/utils/sharding_utils_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Test.""" - import os from etils import epath @@ -28,7 +26,9 @@ def test_sharding(tmp_path: epath.Path): # Load config and reduce size cfg = mnist_autoencoder.get_config() - cfg.train_ds.batch_size = 2 + # TODO(klausg): remove this once data mocking works correctly with grain + cfg.train_ds.__qualname__ = 'kauldron.kd:data.Tfds' + cfg.train_ds.batch_size = 1 cfg.model.encoder.features = 3 cfg.workdir = os.fspath(tmp_path)