Skip to content

Commit

Permalink
add cuda.mem_get_info (#82) (#10398)
Browse files Browse the repository at this point in the history
* register mem_get_info

* add unittest

* refine unittest

* Update env.cpp

* format

* format

* format

* format

* rename GetMemInfo to CudaGetMemoInfo
  • Loading branch information
marigoold authored Jan 5, 2024
1 parent 891c710 commit af0807a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
7 changes: 7 additions & 0 deletions oneflow/api/python/env/env.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("CudaSynchronize", &CudaSynchronize);
m.def("GetCUDAMemoryUsed", &GetCUDAMemoryUsed);
m.def("GetCPUMemoryUsed", &GetCPUMemoryUsed);
m.def("CudaMemGetInfo", [](int device) -> std::pair<size_t, size_t> {
CudaCurrentDeviceGuard guard(device);
size_t device_free = 0;
size_t device_total = 0;
OF_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total));
return {device_free, device_total};
});
m.def(
"_get_device_properties",
[](int device) -> cudaDeviceProp* { return GetDeviceProperties(device); },
Expand Down
18 changes: 18 additions & 0 deletions python/oneflow/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,24 @@ def empty_cache() -> None:
return flow._oneflow_internal.EmptyCache()


def mem_get_info(device: Any = None) -> Tuple[int, int]:
r"""Returns the global free and total GPU memory for a given
device using cudaMemGetInfo.
The documentation is referenced from:
https://pytorch.org/docs/stable/generated/torch.cuda.mem_get_info.html
Args:
device (flow.device or int, optional): selected device. Returns
statistic for the current device, given by :func:`~flow.cuda.current_device`,
if :attr:`device` is ``None`` (default).
"""
if device is None:
device = current_device()
device = _get_device_index(device)
return flow._oneflow_internal.CudaMemGetInfo(device)


from .random import * # noqa: F403


Expand Down
8 changes: 6 additions & 2 deletions python/oneflow/test/misc/test_env_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import os
import unittest
import oneflow as flow
from oneflow.test_utils.automated_test_util.generators import nothing, oneof
from oneflow.test_utils.automated_test_util.generators import nothing, oneof, random
from oneflow.test_utils.automated_test_util import torch
import oneflow.unittest
import torch


@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
Expand Down Expand Up @@ -55,6 +55,10 @@ def test_cuda_get_device_name(test_case):
def test_cuda_get_device_capability(test_case):
return torch.cuda.get_device_capability(oneof(0, nothing()))

def test_cuda_mem_get_info(test_case):
device_idx = random(0, flow.cuda.device_count()).to(int).value()
return torch.cuda.mem_get_info(device_idx)


if __name__ == "__main__":
unittest.main()

0 comments on commit af0807a

Please sign in to comment.