Skip to content

Commit

Permalink
Change the order of batching to avoid None
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689075797
  • Loading branch information
The kauldron Authors committed Oct 23, 2024
1 parent 87f7474 commit 0cdb80e
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions kauldron/data/py/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ def ds_with_transforms(self, rng: random.PRNGKey) -> grain.MapDataset:

ds = transform_utils.apply_transforms(ds, self.transforms)

if self.batch_size:
ds = ds.batch(self.batch_size, drop_remainder=self.batch_drop_remainder)

return ds

@functools.cached_property
Expand All @@ -111,6 +108,10 @@ def _root_ds(self) -> grain.IterDataset:
# `_root_map_ds` because `_root_map_ds` does not propagate `len`
ds = self._root_map_ds
ds = ds.to_iter_dataset(read_options=self.read_options)
# We do batching after conversion to `IterDataset` to avoid None during
# batching.
if self.batch_size:
ds = ds.batch(self.batch_size, drop_remainder=self.batch_drop_remainder)

# Distribute the execution across multiple worker processes.
num_workers = _get_num_workers(self.num_workers)
Expand Down

0 comments on commit 0cdb80e

Please sign in to comment.