Skip to content

Commit

Permalink
Pass in Python callback to native FlightSQLServer, invoke in GetFligh…
Browse files Browse the repository at this point in the history
…tInfoStatement (#492)

- Create in-memory map to store and retrieve query data, using efficient
concurrent hash table from libcuckoo
- Build schema from query result field types

---------

Co-authored-by: Sophie Zhang <[email protected]>
  • Loading branch information
sopzha and Sophie Zhang authored Apr 21, 2024
1 parent a1e6207 commit 56f18b0
Show file tree
Hide file tree
Showing 10 changed files with 227 additions and 73 deletions.
25 changes: 8 additions & 17 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,6 @@ find_package(Boost REQUIRED)

add_subdirectory(third_party)

add_library(brad_server_lib OBJECT
server/brad_server_simple.cc
server/brad_sql_info.cc
server/brad_statement_batch_reader.cc
server/brad_statement.cc
server/brad_tables_schema_batch_reader.cc)

add_library(sqlite_server_lib OBJECT
sqlite_server/sqlite_server.cc
sqlite_server/sqlite_sql_info.cc
Expand All @@ -31,12 +24,18 @@ add_library(sqlite_server_lib OBJECT
sqlite_server/sqlite_tables_schema_batch_reader.cc
sqlite_server/sqlite_type_info.cc)

pybind11_add_module(pybind_brad_server pybind/brad_server.cc)
pybind11_add_module(pybind_brad_server pybind/brad_server.cc
server/brad_server_simple.cc
server/brad_sql_info.cc
server/brad_statement_batch_reader.cc
server/brad_statement.cc
server/brad_tables_schema_batch_reader.cc)

target_link_libraries(pybind_brad_server
PRIVATE Arrow::arrow_shared
PRIVATE ArrowFlight::arrow_flight_shared
PRIVATE ArrowFlightSql::arrow_flight_sql_shared
PRIVATE brad_server_lib)
PUBLIC libcuckoo)

add_executable(flight_sql_example_client flight_sql_example_client.cc)
target_link_libraries(flight_sql_example_client
Expand All @@ -55,14 +54,6 @@ target_link_libraries(flight_sql_example_server
${SQLite3_LIBRARIES}
${Boost_LIBRARIES})

add_executable(flight_sql_brad_server flight_sql_brad_server.cc)
target_link_libraries(flight_sql_brad_server
PRIVATE Arrow::arrow_shared
PRIVATE ArrowFlight::arrow_flight_shared
PRIVATE ArrowFlightSql::arrow_flight_sql_shared
PRIVATE brad_server_lib
gflags)

add_executable(brad_front_end brad_front_end.cc)
target_link_libraries(brad_front_end
PRIVATE Arrow::arrow_shared
Expand Down
2 changes: 2 additions & 0 deletions cpp/pybind/brad_server.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include <pybind11/pybind11.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>

#include <iostream>

Expand Down
86 changes: 68 additions & 18 deletions cpp/server/brad_server_simple.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@ using arrow::internal::checked_cast;
using namespace arrow::flight;
using namespace arrow::flight::sql;

arrow::Result<Ticket> EncodeTransactionQuery(
const std::string &query,
std::string GetQueryTicket(
const std::string &autoincrement_id,
const std::string &transaction_id) {
std::string transaction_query = transaction_id;
transaction_query += ':';
transaction_query += query;
return transaction_id + ':' + autoincrement_id;
}

arrow::Result<Ticket> EncodeTransactionQuery(
const std::string &query_ticket) {
ARROW_ASSIGN_OR_RAISE(auto ticket_string,
CreateStatementQueryTicket(transaction_query));
CreateStatementQueryTicket(query_ticket));
return Ticket{std::move(ticket_string)};
}

Expand All @@ -40,17 +42,35 @@ arrow::Result<std::pair<std::string, std::string>> DecodeTransactionQuery(
return arrow::Status::Invalid("Malformed ticket");
}
std::string transaction_id = ticket.substr(0, divider);
std::string query = ticket.substr(divider + 1);
return std::make_pair(std::move(query), std::move(transaction_id));
std::string autoincrement_id = ticket.substr(divider + 1);
return std::make_pair(std::move(autoincrement_id), std::move(transaction_id));
}

BradFlightSqlServer::BradFlightSqlServer() = default;
std::vector<std::vector<std::any>> TransformQueryResult(
std::vector<py::tuple> query_result) {
std::vector<std::vector<std::any>> transformed_query_result;
for (const auto &row : query_result) {
std::vector<std::any> transformed_row{};
for (const auto &field : row) {
if (py::isinstance<py::int_>(field)) {
transformed_row.push_back(std::make_any<int>(py::cast<int>(field)));
} else if (py::isinstance<py::float_>(field)) {
transformed_row.push_back(std::make_any<float>(py::cast<float>(field)));
} else {
transformed_row.push_back(std::make_any<std::string>(py::cast<std::string>(field)));
}
}
transformed_query_result.push_back(transformed_row);
}
return transformed_query_result;
}

BradFlightSqlServer::BradFlightSqlServer() : autoincrement_id_(0ULL) {}

BradFlightSqlServer::~BradFlightSqlServer() = default;

std::shared_ptr<BradFlightSqlServer>
BradFlightSqlServer::Create() {
// std::shared_ptr<BradFlightSqlServer> result(new BradFlightSqlServer());
std::shared_ptr<BradFlightSqlServer> result =
std::make_shared<BradFlightSqlServer>();
for (const auto &id_to_result : GetSqlInfoResultMap()) {
Expand All @@ -59,9 +79,15 @@ std::shared_ptr<BradFlightSqlServer>
return result;
}

void BradFlightSqlServer::InitWrapper(const std::string &host, int port) {
void BradFlightSqlServer::InitWrapper(
const std::string &host,
int port,
std::function<std::vector<py::tuple>(std::string)> handle_query) {
auto location = arrow::flight::Location::ForGrpcTcp(host, port).ValueOrDie();
arrow::flight::FlightServerOptions options(location);

handle_query_ = handle_query;

this->Init(options);
}

Expand All @@ -79,10 +105,25 @@ arrow::Result<std::unique_ptr<FlightInfo>>
const StatementQuery &command,
const FlightDescriptor &descriptor) {
const std::string &query = command.query;
ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(query));
ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema());

const std::string &autoincrement_id = std::to_string(++autoincrement_id_);
const std::string &query_ticket = GetQueryTicket(autoincrement_id, command.transaction_id);
ARROW_ASSIGN_OR_RAISE(auto ticket,
EncodeTransactionQuery(query, command.transaction_id));
EncodeTransactionQuery(query_ticket));

std::vector<std::vector<std::any>> transformed_query_result;

{
py::gil_scoped_acquire guard;
std::vector<py::tuple> query_result = handle_query_(query);
transformed_query_result = TransformQueryResult(query_result);
}

ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(transformed_query_result));
query_data_.insert(query_ticket, statement);

ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema());

std::vector<FlightEndpoint> endpoints{
FlightEndpoint{std::move(ticket), {}, std::nullopt, ""}};

Expand All @@ -103,14 +144,23 @@ arrow::Result<std::unique_ptr<FlightDataStream>>
const StatementQueryTicket &command) {
ARROW_ASSIGN_OR_RAISE(auto pair,
DecodeTransactionQuery(command.statement_handle));
const std::string &sql = pair.first;
const std::string &autoincrement_id = pair.first;
const std::string transaction_id = pair.second;

std::shared_ptr<BradStatement> statement;
ARROW_ASSIGN_OR_RAISE(statement, BradStatement::Create(sql));
const std::string &query_ticket = transaction_id + ':' + autoincrement_id;

std::shared_ptr<BradStatement> result;
const bool found = query_data_.erase_fn(query_ticket, [&result](auto& qr) {
result = qr;
return true;
});

if (!found) {
return arrow::Status::Invalid("Invalid ticket.");
}

std::shared_ptr<BradStatementBatchReader> reader;
ARROW_ASSIGN_OR_RAISE(reader, BradStatementBatchReader::Create(statement));
ARROW_ASSIGN_OR_RAISE(reader, BradStatementBatchReader::Create(result));

return std::make_unique<RecordBatchStream>(reader);
}
Expand Down
22 changes: 21 additions & 1 deletion cpp/server/brad_server_simple.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
#pragma once

#include <atomic>
#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <vector>

#include <arrow/flight/sql/server.h>
#include "brad_statement.h"
#include <arrow/result.h>

#include "libcuckoo/cuckoohash_map.hh"

#include <pybind11/pybind11.h>

namespace py = pybind11;
using namespace pybind11::literals;

namespace brad {

class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase {
Expand All @@ -17,7 +28,9 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase {

static std::shared_ptr<BradFlightSqlServer> Create();

void InitWrapper(const std::string &host, int port);
void InitWrapper(const std::string &host,
int port,
std::function<std::vector<py::tuple>(std::string)>);

void ServeWrapper();

Expand All @@ -33,6 +46,13 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase {
DoGetStatement(
const arrow::flight::ServerCallContext &context,
const arrow::flight::sql::StatementQueryTicket &command) override;

private:
std::function<std::vector<py::tuple>(std::string)> handle_query_;

libcuckoo::cuckoohash_map<std::string, std::shared_ptr<BradStatement>> query_data_;

std::atomic<uint64_t> autoincrement_id_;
};

} // namespace brad
112 changes: 81 additions & 31 deletions cpp/server/brad_statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,47 +24,97 @@ arrow::Result<std::shared_ptr<BradStatement>> BradStatement::Create(
return result;
}

arrow::Result<std::shared_ptr<BradStatement>> BradStatement::Create(
std::vector<std::vector<std::any>> query_result) {
std::shared_ptr<BradStatement> result(
std::make_shared<BradStatement>(query_result));
return result;
}

BradStatement::BradStatement(std::vector<std::vector<std::any>> query_result) :
query_result_(std::move(query_result)) {}

BradStatement::~BradStatement() {
}

arrow::Result<std::shared_ptr<arrow::Schema>> BradStatement::GetSchema() const {
if (schema_) {
return schema_;
}

std::vector<std::shared_ptr<arrow::Field>> fields;
fields.push_back(arrow::field("Day", arrow::int8()));
fields.push_back(arrow::field("Month", arrow::int8()));
fields.push_back(arrow::field("Year", arrow::int16()));
return arrow::schema(fields);

if (query_result_.size() > 0) {
const std::vector<std::any> &row = query_result_[0];

int counter = 0;
for (const auto &field : row) {
std::string field_type = field.type().name();
if (field_type == "i") {
fields.push_back(arrow::field("INT FIELD " + std::to_string(++counter), arrow::int8()));
} else if (field_type == "f") {
fields.push_back(arrow::field("FLOAT FIELD " + std::to_string(++counter), arrow::float32()));
} else {
fields.push_back(arrow::field("STRING FIELD " + std::to_string(++counter), arrow::utf8()));
}
}
}

schema_ = arrow::schema(fields);
return schema_;
}

arrow::Result<std::shared_ptr<arrow::RecordBatch>> BradStatement::FetchResult() {
arrow::Int8Builder int8builder;
int8_t days_raw[5] = {1, 12, 17, 23, 28};
ARROW_RETURN_NOT_OK(int8builder.AppendValues(days_raw, 5));
std::shared_ptr<arrow::Array> days;
ARROW_ASSIGN_OR_RAISE(days, int8builder.Finish());

int8_t months_raw[5] = {1, 3, 5, 7, 1};
ARROW_RETURN_NOT_OK(int8builder.AppendValues(months_raw, 5));
std::shared_ptr<arrow::Array> months;
ARROW_ASSIGN_OR_RAISE(months, int8builder.Finish());

arrow::Int16Builder int16builder;
int16_t years_raw[5] = {1990, 2000, 1995, 2000, 1995};
ARROW_RETURN_NOT_OK(int16builder.AppendValues(years_raw, 5));
std::shared_ptr<arrow::Array> years;
ARROW_ASSIGN_OR_RAISE(years, int16builder.Finish());

std::shared_ptr<arrow::RecordBatch> record_batch;

arrow::Result<std::shared_ptr<arrow::Schema>> result = GetSchema();
if (result.ok()) {
std::shared_ptr<arrow::Schema> schema = result.ValueOrDie();
record_batch = arrow::RecordBatch::Make(schema,
days->length(),
{days, months, years});
return record_batch;
std::shared_ptr<arrow::Schema> schema = GetSchema().ValueOrDie();

const int num_rows = query_result_.size();

std::vector<std::shared_ptr<arrow::Array>> columns;
columns.reserve(schema->num_fields());

for (int field_ix = 0; field_ix < schema->num_fields(); ++field_ix) {
const auto &field = schema->fields()[field_ix];
if (field->type() == arrow::int8()) {
arrow::Int8Builder int8builder;
int8_t values_raw[num_rows];
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
values_raw[row_ix] = std::any_cast<int>(query_result_[row_ix][field_ix]);
}
ARROW_RETURN_NOT_OK(int8builder.AppendValues(values_raw, num_rows));

std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, int8builder.Finish());

columns.push_back(values);
} else if (field->type() == arrow::float32()) {
arrow::FloatBuilder floatbuilder;
float values_raw[num_rows];
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
values_raw[row_ix] = std::any_cast<float>(query_result_[row_ix][field_ix]);
}
ARROW_RETURN_NOT_OK(floatbuilder.AppendValues(values_raw, num_rows));

std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, floatbuilder.Finish());

columns.push_back(values);
} else if (field->type() == arrow::utf8()) {
arrow::StringBuilder stringbuilder;
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const std::string* str = std::any_cast<const std::string>(&(query_result_[row_ix][field_ix]));
ARROW_RETURN_NOT_OK(stringbuilder.Append(str->data(), str->size()));
}

std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, stringbuilder.Finish());
}
}

return arrow::Status::OK();
std::shared_ptr<arrow::RecordBatch> record_batch =
arrow::RecordBatch::Make(schema,
num_rows,
columns);
return record_batch;
}

std::string* BradStatement::GetBradStmt() const { return stmt_; }
Expand Down
Loading

0 comments on commit 56f18b0

Please sign in to comment.