Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Arrow PyCapsule Interface #5070

Merged
merged 28 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b1c43a4
arrow ffi array copy
kylebarron Nov 13, 2023
c61d270
remove copy_ffi_array
kylebarron Nov 13, 2023
6521cba
docstring
kylebarron Nov 13, 2023
92b070a
wip: pycapsule support
kylebarron Nov 13, 2023
6460701
return
kylebarron Nov 13, 2023
dfdcfae
Update arrow/src/pyarrow.rs
kylebarron Nov 13, 2023
5a4f738
remove sync impl
kylebarron Nov 13, 2023
8a1a05e
Update arrow/src/pyarrow.rs
kylebarron Nov 13, 2023
e109c1a
Remove copy()
kylebarron Nov 13, 2023
3b95ebc
Merge branch 'kyle/arrow-ffi-copy' of github.com:kylebarron/arrow-rs …
kylebarron Nov 13, 2023
05ea67d
Need &mut FFI_ArrowArray for std::mem::replace
kylebarron Nov 13, 2023
e7ed58d
Use std::ptr::replace
kylebarron Nov 13, 2023
dc04b13
update comments
kylebarron Nov 13, 2023
86918fa
Minimize unsafe block
kylebarron Nov 13, 2023
0e273a3
revert pub release functions
kylebarron Nov 13, 2023
252e746
Add RecordBatch and Stream conversion
kylebarron Nov 14, 2023
60bee4a
fix returns
kylebarron Nov 14, 2023
46612ce
Fix return type
kylebarron Nov 14, 2023
becda12
Fix name
kylebarron Nov 14, 2023
2f7767b
fix ci
kylebarron Nov 14, 2023
1e7bcd3
Add tests
kylebarron Nov 15, 2023
ae909fb
Add table test
kylebarron Nov 15, 2023
6f01c91
skip if pre pyarrow 14
kylebarron Nov 15, 2023
f183057
bump python version in CI to use pyarrow 14
kylebarron Nov 15, 2023
107acef
Add record batch test
kylebarron Nov 15, 2023
6c44e01
Update arrow/src/pyarrow.rs
kylebarron Nov 15, 2023
6020247
run on pyarrow 13 and 14
kylebarron Nov 15, 2023
2e42926
Update .github/workflows/integration.yml
kylebarron Nov 15, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ jobs:
key: ${{ runner.os }}-${{ matrix.arch }}-target-maturin-cache-${{ matrix.rust }}-
- uses: actions/setup-python@v4
with:
python-version: '3.7'
python-version: '3.8'
- name: Upgrade pip and setuptools
run: pip install --upgrade pip setuptools wheel virtualenv
- name: Create virtualenv and install dependencies
Expand Down
2 changes: 2 additions & 0 deletions arrow-pyarrow-integration-testing/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Note that this crate uses two languages and an external ABI:
* `Rust`
* `Python`
* C ABI privately exposed by `Pyarrow`.
* PyCapsule ABI publicly exposed by `pyarrow`

## Basic idea

Expand All @@ -36,6 +37,7 @@ we can use pyarrow's interface to move pointers from and to Rust.
## Relevant literature

* [Arrow's CDataInterface](https://arrow.apache.org/docs/format/CDataInterface.html)
* [Arrow PyCapsule Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html)
* [Rust's FFI](https://doc.rust-lang.org/nomicon/ffi.html)
* [Pyarrow private binds](https://github.com/apache/arrow/blob/ae1d24efcc3f1ac2a876d8d9f544a34eb04ae874/python/pyarrow/array.pxi#L1226)
* [PyO3](https://docs.rs/pyo3/0.12.1/pyo3/index.html)
Expand Down
138 changes: 134 additions & 4 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

import arrow_pyarrow_integration_testing as rust

PYARROW_PRE_14 = int(pa.__version__.split('.')[0]) < 14


@contextlib.contextmanager
def no_pyarrow_leak():
Expand Down Expand Up @@ -113,13 +115,49 @@ def assert_pyarrow_leak():
_unsupported_pyarrow_types = [
]

# As of pyarrow 14, pyarrow implements the Arrow PyCapsule interface
# (https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html).
# This defines that Arrow consumers should allow any object that has specific "dunder"
# methods, `__arrow_c_*_`. These wrapper classes ensure that arrow-rs is able to handle
# _any_ class, without pyarrow-specific handling.
class SchemaWrapper:
def __init__(self, schema):
self.schema = schema

def __arrow_c_schema__(self):
return self.schema.__arrow_c_schema__()


class ArrayWrapper:
def __init__(self, array):
self.array = array

def __arrow_c_array__(self):
return self.array.__arrow_c_array__()


class StreamWrapper:
def __init__(self, stream):
self.stream = stream

def __arrow_c_stream__(self):
return self.stream.__arrow_c_stream__()


@pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str)
def test_type_roundtrip(pyarrow_type):
restored = rust.round_trip_type(pyarrow_type)
assert restored == pyarrow_type
assert restored is not pyarrow_type

@pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14")
@pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str)
def test_type_roundtrip_pycapsule(pyarrow_type):
wrapped = SchemaWrapper(pyarrow_type)
restored = rust.round_trip_type(wrapped)
assert restored == pyarrow_type
assert restored is not pyarrow_type


@pytest.mark.parametrize("pyarrow_type", _unsupported_pyarrow_types, ids=str)
def test_type_roundtrip_raises(pyarrow_type):
Expand All @@ -138,6 +176,20 @@ def test_field_roundtrip(pyarrow_type):
field = rust.round_trip_field(pyarrow_field)
assert field == pyarrow_field

@pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14")
@pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str)
def test_field_roundtrip_pycapsule(pyarrow_type):
pyarrow_field = pa.field("test", pyarrow_type, nullable=True)
wrapped = SchemaWrapper(pyarrow_field)
field = rust.round_trip_field(wrapped)
assert field == wrapped.schema

if pyarrow_type != pa.null():
# A null type field may not be non-nullable
pyarrow_field = pa.field("test", pyarrow_type, nullable=False)
field = rust.round_trip_field(wrapped)
assert field == wrapped.schema

def test_field_metadata_roundtrip():
metadata = {"hello": "World! 😊", "x": "2"}
pyarrow_field = pa.field("test", pa.int32(), metadata=metadata)
Expand All @@ -163,6 +215,17 @@ def test_primitive_python():
del b


@pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14")
def test_primitive_python_pycapsule():
"""
Python -> Rust -> Python
"""
a = pa.array([1, 2, 3])
wrapped = ArrayWrapper(a)
b = rust.double(wrapped)
assert b == pa.array([2, 4, 6])


def test_primitive_rust():
"""
Rust -> Python -> Rust
Expand Down Expand Up @@ -433,6 +496,33 @@ def test_record_batch_reader():
got_batches = list(b)
assert got_batches == batches

@pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14")
def test_record_batch_reader_pycapsule():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
a = pa.RecordBatchReader.from_batches(schema, batches)
wrapped = StreamWrapper(a)
b = rust.round_trip_record_batch_reader(wrapped)

assert b.schema == schema
got_batches = list(b)
assert got_batches == batches

# Also try the boxed reader variant
a = pa.RecordBatchReader.from_batches(schema, batches)
wrapped = StreamWrapper(a)
b = rust.boxed_reader_roundtrip(wrapped)
assert b.schema == schema
got_batches = list(b)
assert got_batches == batches


def test_record_batch_reader_error():
schema = pa.schema([('ints', pa.list_(pa.int32()))])

Expand All @@ -453,24 +543,64 @@ def iter_batches():
with pytest.raises(ValueError, match="invalid utf-8"):
rust.round_trip_record_batch_reader(reader)


@pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14")
def test_record_batch_pycapsule():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
batch = pa.record_batch([[[1], [2, 42]]], schema)
wrapped = StreamWrapper(batch)
b = rust.round_trip_record_batch_reader(wrapped)
new_table = b.read_all()
new_batches = new_table.to_batches()

assert len(new_batches) == 1
new_batch = new_batches[0]

assert batch == new_batch
assert batch.schema == new_batch.schema


@pytest.mark.skipif(PYARROW_PRE_14, reason="requires pyarrow 14")
def test_table_pycapsule():
"""
Python -> Rust -> Python
"""
schema = pa.schema([('ints', pa.list_(pa.int32()))], metadata={b'key1': b'value1'})
batches = [
pa.record_batch([[[1], [2, 42]]], schema),
pa.record_batch([[None, [], [5, 6]]], schema),
]
table = pa.Table.from_batches(batches)
wrapped = StreamWrapper(table)
b = rust.round_trip_record_batch_reader(wrapped)
new_table = b.read_all()

assert table.schema == new_table.schema
assert table == new_table
assert len(table.to_batches()) == len(new_table.to_batches())


def test_reject_other_classes():
# Arbitrary type that is not a PyArrow type
not_pyarrow = ["hello"]

with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.Array, got builtins.list"):
rust.round_trip_array(not_pyarrow)

with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.Schema, got builtins.list"):
rust.round_trip_schema(not_pyarrow)

with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.Field, got builtins.list"):
rust.round_trip_field(not_pyarrow)

with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.DataType, got builtins.list"):
rust.round_trip_type(not_pyarrow)

with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.RecordBatch, got builtins.list"):
rust.round_trip_record_batch(not_pyarrow)

with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.RecordBatchReader, got builtins.list"):
rust.round_trip_record_batch_reader(not_pyarrow)
2 changes: 2 additions & 0 deletions arrow-schema/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ impl Drop for FFI_ArrowSchema {
}
}

unsafe impl Send for FFI_ArrowSchema {}

impl TryFrom<&FFI_ArrowSchema> for DataType {
type Error = ArrowError;

Expand Down
Loading
Loading