From b54818f6d629f0393e2a316dff1457d3b7f742f8 Mon Sep 17 00:00:00 2001 From: Sophie Zhang Date: Sat, 20 Apr 2024 18:30:17 -0400 Subject: [PATCH] Address PR comments, add in Arrow conversion for string types --- cpp/server/brad_server_simple.h | 4 ---- cpp/server/brad_statement.cc | 37 ++++++++++++++++++++++----------- cpp/server/brad_statement.h | 9 ++------ 3 files changed, 27 insertions(+), 23 deletions(-) diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index cbffe9e2..48056a04 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -3,10 +3,6 @@ #include #include #include -#include -#include -#include -#include #include #include "brad_statement.h" diff --git a/cpp/server/brad_statement.cc b/cpp/server/brad_statement.cc index 3d66c6ef..e9ce1588 100644 --- a/cpp/server/brad_statement.cc +++ b/cpp/server/brad_statement.cc @@ -37,23 +37,26 @@ BradStatement::BradStatement(std::vector> query_result) : BradStatement::~BradStatement() { } -arrow::Result> BradStatement::GetSchema() { +arrow::Result> BradStatement::GetSchema() const { if (schema_) { return schema_; } std::vector> fields; - const std::vector &row = query_result_[0]; - - int counter = 0; - for (const auto &field : row) { - std::string field_type = field.type().name(); - if (field_type == "i") { - fields.push_back(arrow::field("INT FIELD " + std::to_string(++counter), arrow::int8())); - } else if (field_type == "f") { - fields.push_back(arrow::field("FLOAT FIELD " + std::to_string(++counter), arrow::float32())); - } else { - fields.push_back(arrow::field("STRING FIELD " + std::to_string(++counter), arrow::utf8())); + + if (query_result_.size() > 0) { + const std::vector &row = query_result_[0]; + + int counter = 0; + for (const auto &field : row) { + std::string field_type = field.type().name(); + if (field_type == "i") { + fields.push_back(arrow::field("INT FIELD " + std::to_string(++counter), arrow::int8())); + } else if (field_type == "f") { + fields.push_back(arrow::field("FLOAT FIELD " + std::to_string(++counter), arrow::float32())); + } else { + fields.push_back(arrow::field("STRING FIELD " + std::to_string(++counter), arrow::utf8())); + } } } @@ -67,6 +70,8 @@ arrow::Result> BradStatement::FetchResult() const int num_rows = query_result_.size(); std::vector> columns; + columns.reserve(schema->num_fields()); + for (int field_ix = 0; field_ix < schema->num_fields(); ++field_ix) { const auto &field = schema->fields()[field_ix]; if (field->type() == arrow::int8()) { @@ -94,6 +99,14 @@ arrow::Result> BradStatement::FetchResult() columns.push_back(values); } else if (field->type() == arrow::utf8()) { + arrow::StringBuilder stringbuilder; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + const std::string* str = std::any_cast(&(query_result_[row_ix][field_ix])); + ARROW_RETURN_NOT_OK(stringbuilder.Append(str->data(), str->size())); + } + + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, stringbuilder.Finish()); } } diff --git a/cpp/server/brad_statement.h b/cpp/server/brad_statement.h index 5c62dfea..6f13bc70 100644 --- a/cpp/server/brad_statement.h +++ b/cpp/server/brad_statement.h @@ -7,11 +7,6 @@ #include #include -#include - -namespace py = pybind11; -using namespace pybind11::literals; - namespace brad { /// \brief Create an object ColumnMetadata using the column type and @@ -38,7 +33,7 @@ class BradStatement { /// \brief Creates an Arrow Schema based on the results of this statement. /// \return The resulting Schema. - arrow::Result> GetSchema(); + arrow::Result> GetSchema() const; arrow::Result> FetchResult(); @@ -47,7 +42,7 @@ class BradStatement { private: std::vector> query_result_; - std::shared_ptr schema_; + mutable std::shared_ptr schema_; std::string* stmt_;