diff --git a/train.py b/train.py index e482cfb..1ca634b 100644 --- a/train.py +++ b/train.py @@ -203,6 +203,7 @@ def _setup_devices(self) -> "torch.device": device = torch.device("cpu") self._n_gpu = 0 elif is_torch_tpu_available(): + import torch_xla.core.xla_model as xm device = xm.xla_device() self._n_gpu = 0 elif self.local_rank == -1: