-
Notifications
You must be signed in to change notification settings - Fork 0
/
device_mgmt.py
34 lines (26 loc) · 943 Bytes
/
device_mgmt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
def get_default_device():
""" Pick GPU if available, else CPU """
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
def to_device(data, device):
""" Move tensor(s) to chosen device """
if isinstance(data, (list,tuple)):
return [to_device(x, device) for x in data]
if isinstance(data, dict):
return {k: to_device(t, device) for k, t in data.items()}
return data.to(device, non_blocking=True)
class DeviceDataLoader():
""" Wrap a dataloader to move data to a device """
def __init__(self, dl, device):
self.dl = dl
self.device = device
def __iter__(self):
""" Yield a batch of data after moving it to device """
for b in self.dl:
yield to_device(b, self.device)
def __len__(self):
""" Number of batches """
return len(self.dl)