Skip to content

Commit

Permalink
Add a C++ merge of numpy partitions
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Oct 3, 2024
1 parent 3c57c60 commit 9437b8e
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 0 deletions.
1 change: 1 addition & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History

X.Y.Z (YYYY-MM-DD)
------------------
* Add partition sorting and merging utilities (:pr:`127`)
* Deprecate Python 3.9 support (:pr:`125`)
* Upgrade to casacore 3.6.1 (:pr:`124`)
* Build against NumPy 2 (:pr:`122`)
Expand Down
1 change: 1 addition & 0 deletions src/arcae/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ target_link_directories(arrow_tables PUBLIC ${PYARROW_LIBDIRS})
target_include_directories(arrow_tables PUBLIC
PkgConfig::casacore
absl::span
${CMAKE_CURRENT_SOURCE_DIR}
${PYARROW_INCLUDE}
${NUMPY_INCLUDE}
${CMAKE_SOURCE_DIR}/cpp)
Expand Down
8 changes: 8 additions & 0 deletions src/arcae/lib/arrow_tables.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ from libcpp.memory import shared_ptr
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *

cimport numpy as cnp

cdef extern from "<climits>" nogil:
cdef unsigned int UINT_MAX

Expand Down Expand Up @@ -51,6 +53,7 @@ cdef extern from "arcae/descriptor.h" namespace "arcae" nogil:
cdef CResult[string] CMSDescriptor" arcae::MSDescriptor"(
const string & table, bool complete)


cdef extern from "arcae/new_table_proxy.h" namespace "arcae" nogil:
cdef cppclass CCasaTable" arcae::NewTableProxy":
@staticmethod
Expand Down Expand Up @@ -96,3 +99,8 @@ cdef extern from "arcae/table_factory.h" namespace "arcae" nogil:
cdef CResult[shared_ptr[CCasaTable]] CTaql" arcae::Taql"(
const string & taql,
const vector[shared_ptr[CCasaTable]] & tables)

cdef extern from "merge_sort.cc" namespace "arcae" nogil:
int PartitionMerge(
const vector[vector[cnp.PyArrayObject *]] & array_partitions,
vector[cnp.PyArrayObject*] * merged_arrays) except *
27 changes: 27 additions & 0 deletions src/arcae/lib/arrow_tables.pyx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# distutils: language = c++
# cython: language_level = 3


from collections.abc import MutableMapping, Sequence
import cython
import json
Expand All @@ -15,6 +16,8 @@ from libcpp.vector cimport vector
import numpy as np
import pyarrow as pa

cimport numpy as cnp

from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *

Expand All @@ -36,13 +39,17 @@ from arcae.lib.arrow_tables cimport (
CSelection,
CSelectionBuilder,
CTaql,
PartitionMerge,
IndexType)


DimIndex = Union[slice, list, np.ndarray]
FullIndex = Union[list[DimIndex], tuple[DimIndex]]


cnp.import_array()


def ms_descriptor(table: str, complete: bool = False) -> Dict:
cdef string ctable = tobytes(table)

Expand Down Expand Up @@ -431,3 +438,23 @@ class Configuration(MutableMapping):
config: cython.pointer(CConfiguration) = &CServiceLocator.configuration()

return config.Size()


def merge_np_partitions(
partitions: List[Dict[str, np.ndarray]]
) -> Dict[str, np.ndarray]:
cdef vector[vector[cnp.PyArrayObject*]] partition_arrays
cdef vector[cnp.PyArrayObject*] arrays

if len(partitions) == 0:
return {}

for partition in partitions:
for _, array in partition.items():
arrays.push_back(<cnp.PyArrayObject*> array)
partition_arrays.push_back(move(arrays))

# Rely on the C++ implementation to drop the GIL
PartitionMerge(partition_arrays, &arrays)
values = [<cnp.ndarray> array for array in arrays]
return dict(zip(partitions[0].keys(), values))
87 changes: 87 additions & 0 deletions src/arcae/tests/test_partition_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import numpy as np
import pytest
from numpy.testing import assert_equal

from arcae.lib.arrow_tables import merge_np_partitions


@pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("n, chunk", [(100, 33), (125, 42)])
def test_merge_np_partitions(seed, n, chunk):
rng = np.random.default_rng(seed=seed)
ddid = rng.integers(0, 10, n)
field_id = rng.integers(0, 10, n)
time = rng.random(n)
interval = rng.random(n)
ant1 = rng.integers(0, 10, n)
ant2 = rng.integers(0, 10, n)
row = rng.integers(0, 10, n)

partitions = [
{
"DATA_DESC_ID": ddid[start : start + chunk],
"FIELD_ID": field_id[start : start + chunk],
"TIME": time[start : start + chunk],
"ANTENNA1": ant1[start : start + chunk],
"ANTENNA2": ant2[start : start + chunk],
"INTERVAL": interval[start : start + chunk],
"ROW": row[start : start + chunk],
}
for start in range(0, n, chunk)
]

sorts = [np.lexsort(tuple(reversed(p.values()))) for p in partitions]
partitions = [{k: v[s] for k, v in p.items()} for p, s in zip(partitions, sorts)]

expected = {
"DATA_DESC_ID": ddid,
"FIELD_ID": field_id,
"TIME": time,
"ANTENNA1": ant1,
"ANTENNA2": ant2,
"INTERVAL": interval,
"ROW": row,
}

sort = np.lexsort(tuple(reversed(expected.values())))
expected = {k: v[sort] for k, v in expected.items()}
merged = merge_np_partitions(partitions)
assert_equal(merged, expected)


@pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("n, chunk", [(100, 33)])
def test_merge_fail_1d(seed, n, chunk):
rng = np.random.default_rng(seed=seed)
ddid = rng.integers(0, 10, 100)
field_id = rng.integers(0, 10, (100, 4))

partitions = [
{
"DATA_DESC_ID": ddid[start : start + chunk],
"FIELD_ID": field_id[start : start + chunk],
}
for start in range(0, n, chunk)
]

with pytest.raises(ValueError, match="Array must be 1-dimensional"):
merge_np_partitions(partitions)


@pytest.mark.parametrize("seed", [42])
@pytest.mark.parametrize("n, chunk", [(100, 33)])
def test_merge_fail_array_length(seed, n, chunk):
rng = np.random.default_rng(seed=seed)
ddid = rng.integers(0, 10, 100)
field_id = rng.integers(0, 10, 50)

partitions = [
{
"DATA_DESC_ID": ddid[start : start + chunk],
"FIELD_ID": field_id[start : start + chunk],
}
for start in range(0, n, chunk)
]

with pytest.raises(ValueError, match="Array lengths do not match"):
merge_np_partitions(partitions)

0 comments on commit 9437b8e

Please sign in to comment.