Skip to content

Commit

Permalink
Update datasets.py
Browse files Browse the repository at this point in the history
Workaround to allow `torch.Tensor` instead of `torch.FloatTensor` since `torch.FloatTensor` is only there for compatibility.
  • Loading branch information
icarosadero authored Oct 20, 2024
1 parent 9e14790 commit cac869b
Showing 1 changed file with 14 additions and 11 deletions.
25 changes: 14 additions & 11 deletions cebra/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@
from cebra.data.datatypes import Batch
from cebra.data.datatypes import BatchIndex

import warnings

class TensorDataset(cebra_data.SingleSessionDataset):
class TensorDataset(SingleSessionDataset):
"""Discrete and/or continuously indexed dataset based on torch/numpy arrays.
If dealing with datasets sufficiently small to fit :py:func:`numpy.array` or :py:class:`torch.Tensor`, this
Expand Down Expand Up @@ -74,23 +75,25 @@ def __init__(self,
offset: int = 1,
device: str = "cpu"):
super().__init__(device=device)
self.neural = self._to_tensor(neural, torch.FloatTensor).float()
self.continuous = self._to_tensor(continuous, torch.FloatTensor)
self.discrete = self._to_tensor(discrete, torch.LongTensor)
self.neural = self._to_tensor(neural, 'float').float()
self.continuous = self._to_tensor(continuous, 'float')
self.discrete = self._to_tensor(discrete, 'long')
if self.continuous is None and self.discrete is None:
raise ValueError(
"You have to pass at least one of the arguments 'continuous' or 'discrete'."
warnings.warn(
"You should pass at least one of the arguments 'continuous' or 'discrete'."
)
self.offset = offset

def _to_tensor(self, array, check_dtype=None):
def _to_tensor(self, array, dtype=None):
if array is None:
return None
if isinstance(array, np.ndarray):
array = torch.from_numpy(array)
if check_dtype is not None:
if not isinstance(array, check_dtype):
raise TypeError(f"{type(array)} instead of {check_dtype}.")
if dtype == 'float':
array = torch.from_numpy(array).float()
elif dtype == 'long':
array = torch.from_numpy(array).long()
else:
array = torch.from_numpy(array)
return array

@property
Expand Down

0 comments on commit cac869b

Please sign in to comment.