forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Device.cpp
281 lines (255 loc) · 8.66 KB
/
Device.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
#include <torch/csrc/Device.h>
#include <torch/csrc/Exceptions.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_strings.h>
#include <ATen/Device.h>
#include <c10/util/Exception.h>
#include <structmember.h>
#include <limits>
#include <sstream>
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
PyObject* THPUpperModuleOfDevice = nullptr;
PyObject* THPDevice_New(const at::Device& device) {
auto type = (PyTypeObject*)&THPDeviceType;
auto self = THPObjectPtr{type->tp_alloc(type, 0)};
if (!self)
throw python_error();
auto self_ = reinterpret_cast<THPDevice*>(self.get());
self_->device = device;
return self.release();
}
PyObject* THPDevice_repr(THPDevice* self) {
std::ostringstream oss;
oss << "device(type=\'" << self->device.type() << "\'";
if (self->device.has_index()) {
// `self->device.index()` returns uint8_t which is treated as ascii while
// printing, hence casting it to uint16_t.
// https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout
oss << ", index=" << static_cast<uint16_t>(self->device.index());
}
oss << ")";
return THPUtils_packString(oss.str().c_str());
}
PyObject* THPDevice_str(THPDevice* self) {
std::ostringstream oss;
oss << self->device;
return THPUtils_packString(oss.str().c_str());
}
PyObject* THPDevice_pynew(
PyTypeObject* type,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static torch::PythonArgParser parser(
{"device(Device device)",
"device(c10::string_view type, int64_t? index=-1)"});
torch::ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
if (r.has_torch_function()) {
return handle_torch_function(
r, nullptr, args, kwargs, THPUpperModuleOfDevice, "torch");
}
if (r.idx == 0) {
auto device = r.device(0);
return THPDevice_New(device);
} else if (r.idx == 1) {
auto as_device = r.device(0); // this works, because device can take strings
if (as_device.has_index()) {
auto device_type = r.string(0);
throw std::runtime_error(
"type (string) must not include an index because index "
"was passed explicitly: " +
device_type);
}
int64_t device_index = -1;
if (!r.isNone(1)) {
device_index = r.toInt64(1);
// -1 is allowed in ATen/C++, to mean the default device, but not in
// Python.
TORCH_CHECK(device_index >= 0, "Device index must not be negative");
}
at::Device device(
as_device.type(), static_cast<c10::DeviceIndex>(device_index));
return THPDevice_New(device);
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THPDevice_type(THPDevice* self, PyObject* noargs) {
HANDLE_TH_ERRORS
std::ostringstream oss;
oss << self->device.type();
return THPUtils_packString(oss.str().c_str());
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THPDevice_index(THPDevice* self, PyObject* noargs) {
HANDLE_TH_ERRORS
if (self->device.has_index()) {
return THPUtils_packInt64(self->device.index());
} else {
Py_RETURN_NONE;
}
END_HANDLE_TH_ERRORS
}
static Py_ssize_t THPDevice_hash(THPDevice* self) {
HANDLE_TH_ERRORS
return static_cast<Py_ssize_t>(
std::hash<at::Device>{}(self->device) %
std::numeric_limits<Py_ssize_t>::max());
END_HANDLE_TH_ERRORS_RET(-1)
}
PyObject* THPDevice_rc(PyObject* a, PyObject* b, int op) {
HANDLE_TH_ERRORS
if (!THPDevice_Check(a) || !THPDevice_Check(b)) {
// Py_RETURN_NOTIMPLEMENTED not in python 2.
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
THPDevice* da = reinterpret_cast<THPDevice*>(a);
THPDevice* db = reinterpret_cast<THPDevice*>(b);
switch (op) {
case Py_EQ:
if (da->device == db->device) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
case Py_NE:
if (da->device == db->device) {
Py_RETURN_FALSE;
} else {
Py_RETURN_TRUE;
}
case Py_LT:
case Py_LE:
case Py_GT:
case Py_GE:
throw torch::TypeError("comparison not implemented");
default:
throw torch::TypeError("unexpected comparison op");
}
END_HANDLE_TH_ERRORS
}
PyObject* THPDevice_reduce(PyObject* _self, PyObject* noargs) {
HANDLE_TH_ERRORS
auto self = (THPDevice*)_self;
auto ret = THPObjectPtr{PyTuple_New(2)};
if (!ret)
throw python_error();
py::object torch_module = py::module::import("torch");
py::object torch_device = torch_module.attr("device");
PyTuple_SET_ITEM(ret.get(), 0, torch_device.release().ptr());
THPObjectPtr args;
std::ostringstream oss;
oss << self->device.type();
if (self->device.has_index()) {
args = THPObjectPtr{Py_BuildValue(
"(si)", oss.str().c_str(), static_cast<int>(self->device.index()))};
} else {
args = THPObjectPtr{Py_BuildValue("(s)", oss.str().c_str())};
}
if (!args)
throw python_error();
PyTuple_SET_ITEM(ret.get(), 1, args.release());
return ret.release();
END_HANDLE_TH_ERRORS
}
PyObject* THPDevice_enter(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
py::object mode = py::module::import("torch.utils._device")
.attr("DeviceContext")(py::handle(self));
at::impl::PythonTorchFunctionTLS::push_onto_stack(
std::make_shared<c10::SafePyObject>(
mode.release().ptr(), getPyInterpreter()));
// So that with torch.device('cuda') as dev: works
Py_INCREF(self);
return self;
END_HANDLE_TH_ERRORS
}
PyObject* THPDevice_exit(PyObject* self, PyObject* unused) {
HANDLE_TH_ERRORS
at::impl::PythonTorchFunctionTLS::pop_stack();
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
PyObject* THPDevice_call(PyObject* self, PyObject* args, PyObject* kwargs) {
HANDLE_TH_ERRORS
py::object deco =
py::module::import("torch.utils._device").attr("device_decorator");
return deco(py::handle(self), *py::handle(args), **py::handle(kwargs))
.release()
.ptr();
END_HANDLE_TH_ERRORS
}
typedef PyObject* (*getter)(PyObject*, void*);
// NB: If you edit these properties/methods, update torch/_C/__init__.pyi.in
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
static struct PyGetSetDef THPDevice_properties[] = {
{"type", (getter)THPDevice_type, nullptr, nullptr, nullptr},
{"index", (getter)THPDevice_index, nullptr, nullptr, nullptr},
{nullptr}};
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
static PyMethodDef THPDevice_methods[] = {
{"__reduce__", THPDevice_reduce, METH_NOARGS, nullptr},
{"__enter__", THPDevice_enter, METH_NOARGS, nullptr},
{"__exit__", THPDevice_exit, METH_VARARGS, nullptr},
{nullptr} /* Sentinel */
};
PyTypeObject THPDeviceType = {
PyVarObject_HEAD_INIT(nullptr, 0) "torch.device", /* tp_name */
sizeof(THPDevice), /* tp_basicsize */
0, /* tp_itemsize */
nullptr, /* tp_dealloc */
0, /* tp_vectorcall_offset */
nullptr, /* tp_getattr */
nullptr, /* tp_setattr */
nullptr, /* tp_reserved */
(reprfunc)THPDevice_repr, /* tp_repr */
nullptr, /* tp_as_number */
nullptr, /* tp_as_sequence */
nullptr, /* tp_as_mapping */
(hashfunc)THPDevice_hash, /* tp_hash */
// TODO: We're not sure if this is a good idea or not, because making
// torch.device callable means that it will start returning true
// for callable() queries, and that is unexpected. We can always add
// this later, so for now, don't actually implement this
// THPDevice_call, /* tp_call */
nullptr, /* tp_call */
(reprfunc)THPDevice_str, /* tp_str */
nullptr, /* tp_getattro */
nullptr, /* tp_setattro */
nullptr, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
nullptr, /* tp_doc */
nullptr, /* tp_traverse */
nullptr, /* tp_clear */
(richcmpfunc)THPDevice_rc, /* tp_richcompare */
0, /* tp_weaklistoffset */
nullptr, /* tp_iter */
nullptr, /* tp_iternext */
THPDevice_methods, /* tp_methods */
nullptr, /* tp_members */
THPDevice_properties, /* tp_getset */
nullptr, /* tp_base */
nullptr, /* tp_dict */
nullptr, /* tp_descr_get */
nullptr, /* tp_descr_set */
0, /* tp_dictoffset */
nullptr, /* tp_init */
nullptr, /* tp_alloc */
THPDevice_pynew, /* tp_new */
};
void THPDevice_init(PyObject* module) {
if (PyType_Ready(&THPDeviceType) < 0) {
throw python_error();
}
Py_INCREF(&THPDeviceType);
THPUpperModuleOfDevice = module;
if (PyModule_AddObject(module, "device", (PyObject*)&THPDeviceType) != 0) {
throw python_error();
}
}