Skip to content

Commit

Permalink
address review
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt711 committed Oct 25, 2024
1 parent f6e8d4d commit ba909d6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
14 changes: 9 additions & 5 deletions python/pylibcudf/pylibcudf/interop.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -342,16 +342,18 @@ cpdef Table from_dlpack(object managed_tensor):
Table with a copy of the tensor data.
"""
if not PyCapsule_IsValid(managed_tensor, "dltensor"):
raise ValueError("Invalid capsule object")
raise ValueError("Invalid PyCapsule object")
cdef unique_ptr[table] c_result
cdef DLManagedTensor* dlpack_tensor = <DLManagedTensor*>PyCapsule_GetPointer(
managed_tensor, "dltensor"
)
if dlpack_tensor is NULL:
raise ValueError("PyCapsule object contained a NULL pointer")
PyCapsule_SetName(managed_tensor, "used_dltensor")

# Note: A copy is always performed when converting the dlpack
# data to a libcudf table. We also delete the dlpack_tensor pointer
# as the poionter is not deleted by libcudf's from_dlpack function.
# as the pointer is not deleted by libcudf's from_dlpack function.
# TODO: https://github.com/rapidsai/cudf/issues/10874
# TODO: https://github.com/rapidsai/cudf/issues/10849
with nogil:
Expand Down Expand Up @@ -400,6 +402,8 @@ cdef void dlmanaged_tensor_pycapsule_deleter(object pycap_obj) noexcept:
if PyCapsule_IsValid(pycap_obj, "used_dltensor"):
# we do not call a used capsule's deleter
return
cdef DLManagedTensor* dlpack_tensor
dlpack_tensor = <DLManagedTensor*>PyCapsule_GetPointer(pycap_obj, "dltensor")
dlpack_tensor.deleter(dlpack_tensor)
cdef DLManagedTensor* dlpack_tensor = <DLManagedTensor*>PyCapsule_GetPointer(
pycap_obj, "dltensor"
)
if dlpack_tensor is not NULL:
dlpack_tensor.deleter(dlpack_tensor)
2 changes: 1 addition & 1 deletion python/pylibcudf/pylibcudf/tests/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,5 @@ def test_to_dlpack_error():


def test_from_dlpack_error():
with pytest.raises(ValueError, match="Invalid capsule object"):
with pytest.raises(ValueError, match="Invalid PyCapsule object"):
plc.interop.from_dlpack(1)

0 comments on commit ba909d6

Please sign in to comment.