Skip to content

Commit

Permalink
Expose group sorting mechanism in Cython
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins committed Sep 30, 2024
1 parent 09fa980 commit 9ffbcd7
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 35 deletions.
80 changes: 56 additions & 24 deletions cpp/arcae/group_sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <memory>
#include <numeric>
#include <queue>
#include <string>
#include <vector>

#include "arcae/type_traits.h"
Expand All @@ -19,10 +20,12 @@ using ::arrow::AllocateBuffer;
using ::arrow::Array;
using ::arrow::Buffer;
using ::arrow::DoubleArray;
using ::arrow::Field;
using ::arrow::Int32Array;
using ::arrow::Int64Array;
using ::arrow::Result;
using ::arrow::Status;
using ::arrow::Table;

using ::arcae::detail::AggregateAdapter;

Expand Down Expand Up @@ -142,32 +145,62 @@ Result<std::shared_ptr<GroupSortData>> GroupSortData::Sort() const {
std::make_shared<Int64Array>(nrow, std::move(rows_buffer)));
}

std::shared_ptr<Table> GroupSortData::ToTable() const {
std::vector<std::shared_ptr<Array>> arrays;
std::vector<std::shared_ptr<Field>> fields;

// Groups + TIME, ANTENNA1. ANTENNA2, ROW
auto narrays = groups_.size() + 4;
arrays.reserve(narrays);
fields.reserve(narrays);

for (std::size_t g = 0; g < groups_.size(); ++g) {
fields.push_back(field("GROUP_" + std::to_string(g), arrow::int32()));
arrays.push_back(groups_[g]);
}
fields.push_back(field("TIME", arrow::float64()));
fields.push_back(field("ANTENNA1", arrow::int32()));
fields.push_back(field("ANTENNA2", arrow::int32()));
fields.push_back(field("ROW", arrow::int64()));

arrays.push_back(time_);
arrays.push_back(ant1_);
arrays.push_back(ant2_);
arrays.push_back(rows_);

return Table::Make(schema(std::move(fields)), std::move(arrays));
}

Result<std::shared_ptr<GroupSortData>> MergeGroups(
const std::vector<std::shared_ptr<GroupSortData>>& group_data) {
if (group_data.empty())
return std::make_shared<AggregateAdapter<GroupSortData>>(
GroupSortData::GroupsType{}, nullptr, nullptr, nullptr, nullptr);

struct MergeData {
std::size_t gd;
GroupSortData* group;
GroupSortData* group_;
std::int64_t r;

double time(std::int64_t r) const { return group->time()[r]; }
std::int32_t ant1(std::int64_t r) const { return group->ant1()[r]; }
std::int32_t ant2(std::int64_t r) const { return group->ant2()[r]; }

bool operator<(const MergeData& rhs) const {
// To obtain a descending sort, we reverse the comparison
for (std::size_t g = 0; g < group->nGroups(); ++g) {
auto lhs_group = group->group(g)[r];
auto rhs_group = rhs.group->group(g)[rhs.r];
if (lhs_group != rhs_group) return lhs_group > rhs_group;
inline std::int32_t group(std::size_t g, std::int64_t r) const {
return group_->group(g, r);
}
inline double time(std::int64_t r) const { return group_->time(r); }
inline std::int32_t ant1(std::int64_t r) const { return group_->ant1(r); }
inline std::int32_t ant2(std::int64_t r) const { return group_->ant2(r); }

bool compare(const MergeData& rhs) const {
for (std::size_t g = 0; g < group_->nGroups(); ++g) {
auto lhs_group = group(g, r);
auto rhs_group = rhs.group(g, rhs.r);
if (lhs_group != rhs_group) return lhs_group < rhs_group;
}
if (time(r) != rhs.time(rhs.r)) return time(r) > rhs.time(rhs.r);
if (ant1(r) != rhs.ant1(rhs.r)) return ant1(r) > rhs.ant1(rhs.r);
return ant2(r) > rhs.ant2(rhs.r);
if (time(r) != rhs.time(rhs.r)) return time(r) < rhs.time(rhs.r);
if (ant1(r) != rhs.ant1(rhs.r)) return ant1(r) < rhs.ant1(rhs.r);
return ant2(r) < rhs.ant2(rhs.r);
}

// To obtain a descending sort, we reverse the comparison
inline bool operator<(const MergeData& rhs) const { return !compare(rhs); }
};

std::int64_t nrows = 0;
Expand Down Expand Up @@ -204,27 +237,26 @@ Result<std::shared_ptr<GroupSortData>> MergeGroups(

for (std::size_t gd = 0; gd < group_data.size(); ++gd) {
if (group_data[gd]->nRows() > 0) {
queue.emplace(MergeData{gd, group_data[gd].get(), 0});
queue.emplace(MergeData{group_data[gd].get(), 0});
}
}

while (!queue.empty()) {
auto [gd, dummy, gr] = queue.top();
const auto& top_group = group_data[gd];
auto [top_group, gr] = queue.top();
queue.pop();

for (std::size_t g = 0; g < ngroups; ++g) {
group_spans[g][row] = top_group->group(g)[gr];
group_spans[g][row] = top_group->group(g, gr);
}

time_span[row] = top_group->time()[gr];
ant1_span[row] = top_group->ant1()[gr];
ant2_span[row] = top_group->ant2()[gr];
rows_span[row] = top_group->rows()[gr];
time_span[row] = top_group->time(gr);
ant1_span[row] = top_group->ant1(gr);
ant2_span[row] = top_group->ant2(gr);
rows_span[row] = top_group->rows(gr);
++row;

if (gr + 1 < top_group->nRows()) {
queue.emplace(MergeData{gd, top_group.get(), gr + 1});
queue.emplace(MergeData{top_group, gr + 1});
}
}

Expand Down
15 changes: 10 additions & 5 deletions cpp/arcae/group_sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ struct GroupSortData {
std::shared_ptr<arrow::Int32Array> ant2_;
std::shared_ptr<arrow::Int64Array> rows_;

const std::int32_t* group(int g) const { return groups_[g]->raw_values(); }
const double* time() const { return time_->raw_values(); }
const std::int32_t* ant1() const { return ant1_->raw_values(); }
const std::int32_t* ant2() const { return ant2_->raw_values(); }
const std::int64_t* rows() const { return rows_->raw_values(); }
inline std::int32_t group(std::size_t group, std::size_t row) const {
return groups_[group]->raw_values()[row];
}
inline double time(std::size_t row) const { return time_->raw_values()[row]; }
inline std::int32_t ant1(std::size_t row) const { return ant1_->raw_values()[row]; }
inline std::int32_t ant2(std::size_t row) const { return ant2_->raw_values()[row]; }
inline std::int64_t rows(std::size_t row) const { return rows_->raw_values()[row]; }

// Create the GroupSortData from grouping and sorting arrays
static arrow::Result<std::shared_ptr<GroupSortData>> Make(
Expand All @@ -33,6 +35,9 @@ struct GroupSortData {
const std::shared_ptr<arrow::Array>& ant2,
const std::shared_ptr<arrow::Array>& rows);

// Convert to an Arrow Table
std::shared_ptr<arrow::Table> ToTable() const;

// Number of group columns
std::size_t nGroups() const { return groups_.size(); }

Expand Down
7 changes: 1 addition & 6 deletions cpp/tests/group_sort_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,8 @@ TEST(GroupSortTest, TestSort) {
ArrayFromJSON(arrow::int32(), "[1, 1, 1, 2, 2, 2, 3, 3, 3]"));
ASSERT_OK_AND_ASSIGN(auto group2,
ArrayFromJSON(arrow::int32(), "[1, 1, 1, 2, 2, 2, 3, 3, 3]"));
ASSERT_OK_AND_ASSIGN(auto group3,
ArrayFromJSON(arrow::int32(), "[1, 1, 1, 2, 2, 2, 3, 3, 3]"));

groups.push_back(group1);
groups.push_back(group2);
groups.push_back(group3);

ASSERT_OK_AND_ASSIGN(auto time,
ArrayFromJSON(arrow::float64(), "[1, 2, 3, 4, 5, 6, 7, 8, 9]"));
ASSERT_OK_AND_ASSIGN(auto ant1,
Expand All @@ -42,7 +37,7 @@ TEST(GroupSortTest, TestSort) {

ASSERT_OK_AND_ASSIGN(auto base, GroupSortData::Make(groups, time, ant1, ant2, rows));
ASSERT_OK_AND_ASSIGN(auto sorted, base->Sort());
ASSERT_OK_AND_ASSIGN(auto merged, MergeGroups({sorted, sorted}));
ASSERT_OK_AND_ASSIGN(auto merged, MergeGroups({sorted, sorted, sorted, sorted}));
}

} // namespace
19 changes: 19 additions & 0 deletions src/arcae/lib/arrow_tables.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,25 @@ cdef extern from "arcae/descriptor.h" namespace "arcae" nogil:
cdef CResult[string] CMSDescriptor" arcae::MSDescriptor"(
const string & table, bool complete)

cdef extern from "arcae/group_sort.h" namespace "arcae" nogil:
cdef cppclass CGroupSortData" arcae::GroupSortData":
@staticmethod
CResult[shared_ptr[CGroupSortData]] Make" GroupSortData::Make"(
const vector[shared_ptr[CArray]] & columns,
const shared_ptr[CArray] & time,
const shared_ptr[CArray] & ant1,
const shared_ptr[CArray] & ant2,
const shared_ptr[CArray] & rows)

size_t nGroups" GroupSortData::nGroups"()
size_t nRows" GroupSortData::nRows"()
CResult[shared_ptr[CGroupSortData]] Sort" GroupSortData::Sort"()
shared_ptr[CTable] ToTable" GroupSortData::ToTable"()

CResult[shared_ptr[CGroupSortData]] MergeGroups" MergeGroups"(
const vector[shared_ptr[CGroupSortData]] & group_data)


cdef extern from "arcae/new_table_proxy.h" namespace "arcae" nogil:
cdef cppclass CCasaTable" arcae::NewTableProxy":
@staticmethod
Expand Down
64 changes: 64 additions & 0 deletions src/arcae/lib/arrow_tables.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ from pyarrow.lib import (tobytes, frombytes)
from arcae.lib.arrow_tables cimport (
CCasaTable,
CConfiguration,
CGroupSortData,
CMSDescriptor,
CServiceLocator,
COpenTable,
CDefaultMS,
CSelection,
CSelectionBuilder,
CTaql,
MergeGroups,
IndexType)


Expand Down Expand Up @@ -431,3 +433,65 @@ class Configuration(MutableMapping):
config: cython.pointer(CConfiguration) = &CServiceLocator.configuration()

return config.Size()


cdef class GroupSortData:
cdef shared_ptr[CGroupSortData] c_data

def __init__(
self,
groups: Sequence[pa.array],
time: pa.array,
ant1: pa.array,
ant2: pa.array,
rows: pa.array
):
cdef:
vector[shared_ptr[CArray]] c_groups
shared_ptr[CArray] c_time = pyarrow_unwrap_array(time)
shared_ptr[CArray] c_ant1 = pyarrow_unwrap_array(ant1)
shared_ptr[CArray] c_ant2 = pyarrow_unwrap_array(ant2)
shared_ptr[CArray] c_rows = pyarrow_unwrap_array(rows)

for g in groups:
c_groups.push_back(pyarrow_unwrap_array(g))

with nogil:
self.c_data = GetResultValue(
CGroupSortData.Make(c_groups, c_time, c_ant1, c_ant2, c_rows)
)

def sort(self) -> GroupSortData:
cdef shared_ptr[CGroupSortData] c_gsd
cdef GroupSortData gsd

with nogil:
c_gsd = GetResultValue(self.c_data.get().Sort())

gsd = GroupSortData.__new__(GroupSortData)
gsd.c_data = c_gsd
return gsd

def to_arrow(self) -> pa.Table:
cdef shared_ptr[CTable] table

with nogil:
table = self.c_data.get().ToTable()

return pyarrow_wrap_table(table)


def merge_groups(groups: Sequence[GroupSortData]) -> GroupSortData:
cdef vector[shared_ptr[CGroupSortData]] c_groups
cdef shared_ptr[CGroupSortData] c_merged
cdef GroupSortData gsd

for g in groups:
c_groups.push_back((<GroupSortData?> g).c_data)

with nogil:
c_merged = GetResultValue(MergeGroups(c_groups))

gsd = GroupSortData.__new__(GroupSortData)
gsd.c_data = c_merged
return gsd
38 changes: 38 additions & 0 deletions src/arcae/tests/test_group_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import pyarrow as pa

from arcae.lib.arrow_tables import GroupSortData


def test_sorting():
data = pa.Table.from_pydict(
{
"GROUP_0": pa.array([0, 1, 0, 1], pa.int32()),
"GROUP_1": pa.array([1, 0, 1, 0], pa.int32()),
"TIME": pa.array([3, 2, 1, 0], pa.float64()),
"ANTENNA1": pa.array([0, 0, 0, 0], pa.int32()),
"ANTENNA2": pa.array([1, 1, 1, 1], pa.int32()),
"ROW": pa.array([0, 1, 2, 3], pa.int64()),
}
)

gsd = GroupSortData(
[
data["GROUP_0"].combine_chunks(),
data["GROUP_1"].combine_chunks(),
],
data["TIME"].combine_chunks(),
data["ANTENNA1"].combine_chunks(),
data["ANTENNA2"].combine_chunks(),
data["ROW"].combine_chunks(),
)

keys = [
("GROUP_0", "ascending"),
("GROUP_1", "ascending"),
("TIME", "ascending"),
("ANTENNA1", "ascending"),
("ANTENNA2", "ascending"),
("ROW", "ascending"),
]

assert gsd.sort().to_arrow().equals(data.sort_by(keys))

0 comments on commit 9ffbcd7

Please sign in to comment.