Skip to content

Commit

Permalink
Distinguish between DML and the generic 'GPU' term. This is needed fo…
Browse files Browse the repository at this point in the history
…r packaging DML EP in the same ORT GPU pkg.
  • Loading branch information
pranavsharma committed Oct 25, 2024
1 parent c5b6be0 commit 7690e26
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
1 change: 1 addition & 0 deletions include/onnxruntime/core/framework/ortdevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ struct OrtDevice {
static const DeviceType GPU = 1; // Nvidia or AMD
static const DeviceType FPGA = 2;
static const DeviceType NPU = 3; // Ascend
static const DeviceType DML = 4;

struct MemType {
// Pre-defined memory types.
Expand Down
8 changes: 3 additions & 5 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,9 @@ const char* GetDeviceName(const OrtDevice& device) {
case OrtDevice::CPU:
return CPU;
case OrtDevice::GPU:
#ifdef USE_DML
return DML;
#else
return CUDA;
#endif
case OrtDevice::DML:
return DML;
case OrtDevice::FPGA:
return "FPGA";
case OrtDevice::NPU:
Expand Down Expand Up @@ -1579,7 +1577,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra
.def_static("cann", []() { return OrtDevice::NPU; })
.def_static("fpga", []() { return OrtDevice::FPGA; })
.def_static("npu", []() { return OrtDevice::NPU; })
.def_static("dml", []() { return OrtDevice::GPU; })
.def_static("dml", []() { return OrtDevice::DML; })
.def_static("default_memory", []() { return OrtDevice::MemType::DEFAULT; });

py::class_<OrtArenaCfg> ort_arena_cfg_binding(m, "OrtArenaCfg");
Expand Down

0 comments on commit 7690e26

Please sign in to comment.