diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index 0b7f191..3a04f2e 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -37,8 +37,9 @@ from cebra.data.datatypes import Batch from cebra.data.datatypes import BatchIndex +import warnings -class TensorDataset(cebra_data.SingleSessionDataset): +class TensorDataset(SingleSessionDataset): """Discrete and/or continuously indexed dataset based on torch/numpy arrays. If dealing with datasets sufficiently small to fit :py:func:`numpy.array` or :py:class:`torch.Tensor`, this @@ -74,23 +75,25 @@ def __init__(self, offset: int = 1, device: str = "cpu"): super().__init__(device=device) - self.neural = self._to_tensor(neural, torch.FloatTensor).float() - self.continuous = self._to_tensor(continuous, torch.FloatTensor) - self.discrete = self._to_tensor(discrete, torch.LongTensor) + self.neural = self._to_tensor(neural, 'float').float() + self.continuous = self._to_tensor(continuous, 'float') + self.discrete = self._to_tensor(discrete, 'long') if self.continuous is None and self.discrete is None: - raise ValueError( - "You have to pass at least one of the arguments 'continuous' or 'discrete'." + warnings.warn( + "You should pass at least one of the arguments 'continuous' or 'discrete'." ) self.offset = offset - def _to_tensor(self, array, check_dtype=None): + def _to_tensor(self, array, dtype=None): if array is None: return None if isinstance(array, np.ndarray): - array = torch.from_numpy(array) - if check_dtype is not None: - if not isinstance(array, check_dtype): - raise TypeError(f"{type(array)} instead of {check_dtype}.") + if dtype == 'float': + array = torch.from_numpy(array).float() + elif dtype == 'long': + array = torch.from_numpy(array).long() + else: + array = torch.from_numpy(array) return array @property