Skip to content

Commit

Permalink
ENH: geography constructor from geoarrow
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisvandenbossche committed Oct 3, 2024
1 parent e41dbfd commit 577821e
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 0 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ endif()
add_library(spherely MODULE
src/geography.cpp
src/accessors-geog.cpp
src/geoarrow.cpp
src/predicates.cpp
src/spherely.cpp
)
Expand Down
1 change: 1 addition & 0 deletions ci/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ dependencies:
- ninja
- pytest
- pip
- geoarrow-pyarrow
117 changes: 117 additions & 0 deletions src/geoarrow.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#include <s2geography.h>

#include "geography.hpp"
#include "pybind11.hpp"

namespace py = pybind11;
namespace s2geog = s2geography;
using namespace spherely;

// PyObjectGeography from_wkt(std::string a) {
// s2geog::WKTReader reader;
// std::unique_ptr<s2geog::Geography> s2geog = reader.read_feature(a);
// auto geog_ptr = std::make_unique<spherely::Geography>(std::move(s2geog));
// return PyObjectGeography::from_geog(std::move(geog_ptr));
// }

// void init_geoarrow(py::module& m) {
// m.def("from_wkt",
// py::vectorize(&from_wkt),
// py::arg("a"),
// R"pbdoc(
// Creates a geography object from a WKT string.

// Parameters
// ----------
// a : str
// WKT string

// )pbdoc");
// }

#ifdef __cplusplus
extern "C" {
#endif

// Extra guard for versions of Arrow without the canonical guard
#ifndef ARROW_FLAG_DICTIONARY_ORDERED

#ifndef ARROW_C_DATA_INTERFACE
#define ARROW_C_DATA_INTERFACE

#define ARROW_FLAG_DICTIONARY_ORDERED 1
#define ARROW_FLAG_NULLABLE 2
#define ARROW_FLAG_MAP_KEYS_SORTED 4

struct ArrowSchema {
// Array type description
const char* format;
const char* name;
const char* metadata;
int64_t flags;
int64_t n_children;
struct ArrowSchema** children;
struct ArrowSchema* dictionary;

// Release callback
void (*release)(struct ArrowSchema*);
// Opaque producer-specific data
void* private_data;
};

struct ArrowArray {
// Array data description
int64_t length;
int64_t null_count;
int64_t offset;
int64_t n_buffers;
int64_t n_children;
const void** buffers;
struct ArrowArray** children;
struct ArrowArray* dictionary;

// Release callback
void (*release)(struct ArrowArray*);
// Opaque producer-specific data
void* private_data;
};

#endif // ARROW_C_DATA_INTERFACE
#endif // ARROW_FLAG_DICTIONARY_ORDERED

#ifdef __cplusplus
}
#endif

py::array_t<PyObjectGeography> from_geoarrow(py::object input) {
py::tuple capsules = input.attr("__arrow_c_array__")();
py::capsule schema_capsule = capsules[0];
py::capsule array_capsule = capsules[1];

const ArrowSchema* schema = static_cast<const ArrowSchema*>(schema_capsule);
const ArrowArray* array = static_cast<const ArrowArray*>(array_capsule);

s2geog::geoarrow::Reader reader;
std::vector<std::unique_ptr<s2geog::Geography>> s2geog_vec;

reader.Init(schema, s2geog::geoarrow::ImportOptions());
reader.ReadGeography(array, 0, array->length, &s2geog_vec);

// Convert resulting vector to array of python objects
auto result = py::array_t<PyObjectGeography>(array->length);
py::buffer_info rbuf = result.request();
py::object* rptr = static_cast<py::object*>(rbuf.ptr);

py::ssize_t i = 0;
for (auto& s2geog_ptr : s2geog_vec) {
auto geog_ptr = std::make_unique<spherely::Geography>(std::move(s2geog_ptr));
// return PyObjectGeography::from_geog(std::move(geog_ptr));
rptr[i] = py::cast(std::move(geog_ptr));
i++;
}
return result;
}

void init_geoarrow(py::module& m) {
m.def("from_geoarrow", &from_geoarrow);
}
14 changes: 14 additions & 0 deletions src/pybind11.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,20 @@ struct npy_format_descriptor<spherely::PyObjectGeography> {
}
};

// // Register PyObjectGeography as a valid numpy dtype (numpy.object alias)
// // from: https://github.com/pybind/pybind11/pull/1152
// template <>
// struct npy_format_descriptor<std::string> {
// static constexpr auto name = _("object");
// enum { value = npy_api::NPY_OBJECT_ };
// static pybind11::dtype dtype() {
// if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) {
// return reinterpret_borrow<pybind11::dtype>(ptr);
// }
// pybind11_fail("Unsupported buffer format!");
// }
// };

// Override signature type hint for vectorized Geography arguments
template <int Flags>
struct handle_type_name<array_t<spherely::PyObjectGeography, Flags>> {
Expand Down
2 changes: 2 additions & 0 deletions src/spherely.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace py = pybind11;
void init_geography(py::module&);
void init_predicates(py::module&);
void init_accessors(py::module&);
void init_geoarrow(py::module&);

PYBIND11_MODULE(spherely, m) {
m.doc() = R"pbdoc(
Expand All @@ -21,6 +22,7 @@ PYBIND11_MODULE(spherely, m) {
init_geography(m);
init_predicates(m);
init_accessors(m);
init_geoarrow(m);

#ifdef VERSION_INFO
m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);
Expand Down
38 changes: 38 additions & 0 deletions tests/test_geoarrow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import numpy as np
import pyarrow as pa
import geoarrow.pyarrow as ga

import pytest

import spherely


def test_from_geoarrow_wkt():

arr = ga.as_wkt(["POINT (1 1)", "POINT(2 2)", "POINT(3 3)"])

result = spherely.from_geoarrow(arr)
expected = spherely.create([1, 2, 3], [1, 2, 3])
# object equality does not yet work
# np.testing.assert_array_equal(result, expected)
assert spherely.equals(result, expected).all()


def test_from_geoarrow_wkb():

arr = ga.as_wkt(["POINT (1 1)", "POINT(2 2)", "POINT(3 3)"])
arr_wkb = ga.as_wkb(arr)

result = spherely.from_geoarrow(arr_wkb)
expected = spherely.create([1, 2, 3], [1, 2, 3])
assert spherely.equals(result, expected).all()


def test_from_geoarrow_native():

arr = ga.as_wkt(["POINT (1 1)", "POINT(2 2)", "POINT(3 3)"])
arr_point = ga.as_geoarrow(arr)

result = spherely.from_geoarrow(arr_point)
expected = spherely.create([1, 2, 3], [1, 2, 3])
assert spherely.equals(result, expected).all()

0 comments on commit 577821e

Please sign in to comment.