Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training CLMBR - AssertionError: Can only have one batch when collating #227

Open
EricBorland opened this issue Aug 27, 2024 · 0 comments
Open

Comments

@EricBorland
Copy link

EricBorland commented Aug 27, 2024

Describe the bug

When following the tutorials to train a CLMBR model, exactly at cell [5] of https://github.com/som-shahlab/femr/blob/main/tutorials/4_Train%20CLMBR.ipynb
I get the following error since the trainer sends multiple batches to the collate function:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[9], line 48
     38     return processor.collate(data)
     41 trainer = transformers.Trainer(
     42     model=model,
     43     data_collator=processor.collate,
   (...)
     46     args=trainer_config
     47 )
---> 48 trainer.train()
     49 model.save_pretrained(MODEL_DIR)

File ~/anaconda3/envs/python3/lib/python3.10/site-packages/transformers/trainer.py:1948, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   1946         hf_hub_utils.enable_progress_bars()
   1947 else:
-> 1948     return inner_training_loop(
   1949         args=args,
   1950         resume_from_checkpoint=resume_from_checkpoint,
   1951         trial=trial,
   1952         ignore_keys_for_eval=ignore_keys_for_eval,
   1953     )

File ~/anaconda3/envs/python3/lib/python3.10/site-packages/transformers/trainer.py:2246, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2243     rng_to_sync = True
   2245 step = -1
-> 2246 for step, inputs in enumerate(epoch_iterator):
   2247     total_batched_samples += 1
   2249     if self.args.include_num_input_tokens_seen:

File ~/anaconda3/envs/python3/lib/python3.10/site-packages/accelerate/data_loader.py:454, in DataLoaderShard.__iter__(self)
    452 # We iterate one batch ahead to check when we are at the end
    453 try:
--> 454     current_batch = next(dataloader_iter)
    455 except StopIteration:
    456     yield

File ~/anaconda3/envs/python3/lib/python3.10/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
    627 if self._sampler_iter is None:
    628     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    629     self._reset()  # type: ignore[call-arg]
--> 630 data = self._next_data()
    631 self._num_yielded += 1
    632 if self._dataset_kind == _DatasetKind.Iterable and \
    633         self._IterableDataset_len_called is not None and \
    634         self._num_yielded > self._IterableDataset_len_called:

File ~/anaconda3/envs/python3/lib/python3.10/site-packages/torch/utils/data/dataloader.py:673, in _SingleProcessDataLoaderIter._next_data(self)
    671 def _next_data(self):
    672     index = self._next_index()  # may raise StopIteration
--> 673     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    674     if self._pin_memory:
    675         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File ~/anaconda3/envs/python3/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:55, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     53 else:
     54     data = self.dataset[possibly_batched_index]
---> 55 return self.collate_fn(data)

File ~/anaconda3/envs/python3/lib/python3.10/site-packages/femr/models/processor.py:400, in FEMRBatchProcessor.collate(self, batches)
    398 def collate(self, batches: List[Mapping[str, Any]]) -> Mapping[str, Any]:
    399     """A collate function that prepares batches for being fed into a dataloader."""
--> 400     assert len(batches) == 1, "Can only have one batch when collating"
    401     return {"batch": _add_dimension(self.creator.cleanup_batch(batches[0]))}

AssertionError: Can only have one batch when collating

Steps to reproduce the bug

Follow the https://github.com/som-shahlab/femr/blob/main/tutorials/4_Train%20CLMBR.ipynb tutorial

Expected results

To be able to train the model

Actual results

The indicated error

Environment info

  • datasets version: 2.20.0
  • Platform: Linux-5.10.217-205.860.amzn2.x86_64-x86_64-with-glibc2.26
  • Python version: 3.10.14
  • huggingface_hub version: 0.24.5
  • PyArrow version: 16.1.0
  • Pandas version: 2.2.2
  • fsspec version: 2024.5.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant