Skip to content

Commit

Permalink
Address PR comments, add in Arrow conversion for string types
Browse files Browse the repository at this point in the history
  • Loading branch information
Sophie Zhang committed Apr 20, 2024
1 parent dcc081a commit b54818f
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 23 deletions.
4 changes: 0 additions & 4 deletions cpp/server/brad_server_simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
#include <cstdint>
#include <memory>
#include <string>
#include <functional>
#include <any>
#include <atomic>
#include <mutex>

#include <arrow/flight/sql/server.h>
#include "brad_statement.h"
Expand Down
37 changes: 25 additions & 12 deletions cpp/server/brad_statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,26 @@ BradStatement::BradStatement(std::vector<std::vector<std::any>> query_result) :
BradStatement::~BradStatement() {
}

arrow::Result<std::shared_ptr<arrow::Schema>> BradStatement::GetSchema() {
arrow::Result<std::shared_ptr<arrow::Schema>> BradStatement::GetSchema() const {
if (schema_) {
return schema_;
}

std::vector<std::shared_ptr<arrow::Field>> fields;
const std::vector<std::any> &row = query_result_[0];

int counter = 0;
for (const auto &field : row) {
std::string field_type = field.type().name();
if (field_type == "i") {
fields.push_back(arrow::field("INT FIELD " + std::to_string(++counter), arrow::int8()));
} else if (field_type == "f") {
fields.push_back(arrow::field("FLOAT FIELD " + std::to_string(++counter), arrow::float32()));
} else {
fields.push_back(arrow::field("STRING FIELD " + std::to_string(++counter), arrow::utf8()));

if (query_result_.size() > 0) {
const std::vector<std::any> &row = query_result_[0];

int counter = 0;
for (const auto &field : row) {
std::string field_type = field.type().name();
if (field_type == "i") {
fields.push_back(arrow::field("INT FIELD " + std::to_string(++counter), arrow::int8()));
} else if (field_type == "f") {
fields.push_back(arrow::field("FLOAT FIELD " + std::to_string(++counter), arrow::float32()));
} else {
fields.push_back(arrow::field("STRING FIELD " + std::to_string(++counter), arrow::utf8()));
}
}
}

Expand All @@ -67,6 +70,8 @@ arrow::Result<std::shared_ptr<arrow::RecordBatch>> BradStatement::FetchResult()
const int num_rows = query_result_.size();

std::vector<std::shared_ptr<arrow::Array>> columns;
columns.reserve(schema->num_fields());

for (int field_ix = 0; field_ix < schema->num_fields(); ++field_ix) {
const auto &field = schema->fields()[field_ix];
if (field->type() == arrow::int8()) {
Expand Down Expand Up @@ -94,6 +99,14 @@ arrow::Result<std::shared_ptr<arrow::RecordBatch>> BradStatement::FetchResult()

columns.push_back(values);
} else if (field->type() == arrow::utf8()) {
arrow::StringBuilder stringbuilder;
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const std::string* str = std::any_cast<const std::string>(&(query_result_[row_ix][field_ix]));
ARROW_RETURN_NOT_OK(stringbuilder.Append(str->data(), str->size()));
}

std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, stringbuilder.Finish());
}
}

Expand Down
9 changes: 2 additions & 7 deletions cpp/server/brad_statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@
#include <arrow/flight/sql/column_metadata.h>
#include <arrow/type_fwd.h>

#include <pybind11/pybind11.h>

namespace py = pybind11;
using namespace pybind11::literals;

namespace brad {

/// \brief Create an object ColumnMetadata using the column type and
Expand All @@ -38,7 +33,7 @@ class BradStatement {

/// \brief Creates an Arrow Schema based on the results of this statement.
/// \return The resulting Schema.
arrow::Result<std::shared_ptr<arrow::Schema>> GetSchema();
arrow::Result<std::shared_ptr<arrow::Schema>> GetSchema() const;

arrow::Result<std::shared_ptr<arrow::RecordBatch>> FetchResult();

Expand All @@ -47,7 +42,7 @@ class BradStatement {
private:
std::vector<std::vector<std::any>> query_result_;

std::shared_ptr<arrow::Schema> schema_;
mutable std::shared_ptr<arrow::Schema> schema_;

std::string* stmt_;

Expand Down

0 comments on commit b54818f

Please sign in to comment.