Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

keep read_callable callback for queries #172

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 182 additions & 49 deletions tree_sitter/binding.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ typedef struct {
typedef struct {
PyObject_HEAD
TSTree *tree;
PyObject *source;
PyObject *source_or_callback;
} Tree;

typedef struct {
Expand Down Expand Up @@ -650,46 +650,114 @@ static PyObject *node_get_text(Node *self, void *payload) {
PyErr_SetString(PyExc_ValueError, "No tree");
return NULL;
}
if (tree->source == Py_None || tree->source == NULL) {
if (tree->source_or_callback == Py_None || tree->source_or_callback == NULL) {
Py_RETURN_NONE;
}

PyObject *start_byte = PyLong_FromSize_t((size_t)ts_node_start_byte(self->node));
if (start_byte == NULL) {
PyErr_SetString(PyExc_RuntimeError, "Failed to determine start byte");
return NULL;
}
PyObject *end_byte = PyLong_FromSize_t((size_t)ts_node_end_byte(self->node));
if (end_byte == NULL) {
Py_DECREF(start_byte);
PyErr_SetString(PyExc_RuntimeError, "Failed to determine end byte");
return NULL;
}
PyObject *slice = PySlice_New(start_byte, end_byte, NULL);
Py_DECREF(start_byte);
Py_DECREF(end_byte);
if (slice == NULL) {
PyErr_SetString(PyExc_RuntimeError, "PySlice_New failed");
return NULL;
}
PyObject *node_mv = PyMemoryView_FromObject(tree->source);
if (node_mv == NULL) {
size_t start_offset = (size_t)ts_node_start_byte(self->node);
size_t end_offset = (size_t)ts_node_end_byte(self->node);
PyObject *result = NULL;

// Case 1: source_or_callback is a byte buffer
if (!PyCallable_Check(tree->source_or_callback)) {
PyObject *start_byte = PyLong_FromSize_t(start_offset);
PyObject *end_byte = PyLong_FromSize_t(end_offset);
PyObject *slice = PySlice_New(start_byte, end_byte, NULL);
Py_XDECREF(start_byte);
Py_XDECREF(end_byte);
if (slice == NULL) {
PyErr_SetString(PyExc_RuntimeError, "PySlice_New failed");
return NULL;
}
PyObject *node_mv = PyMemoryView_FromObject(tree->source_or_callback);
if (node_mv == NULL) {
Py_DECREF(slice);
PyErr_SetString(PyExc_RuntimeError, "PyMemoryView_FromObject failed");
return NULL;
}
PyObject *node_slice = PyObject_GetItem(node_mv, slice);
Py_DECREF(slice);
PyErr_SetString(PyExc_RuntimeError, "PyMemoryView_FromObject failed");
return NULL;
}
PyObject *node_slice = PyObject_GetItem(node_mv, slice);
Py_DECREF(slice);
Py_DECREF(node_mv);
if (node_slice == NULL) {
PyErr_SetString(PyExc_RuntimeError, "PyObject_GetItem failed");
return NULL;
Py_DECREF(node_mv);
if (node_slice == NULL) {
PyErr_SetString(PyExc_RuntimeError, "PyObject_GetItem failed");
return NULL;
}
result = PyBytes_FromObject(node_slice);
Py_DECREF(node_slice);
}
// Case 2: source_or_callback is a callable
else {
PyObject *collected_bytes = PyBytes_FromString("");
if (collected_bytes == NULL) {
PyErr_SetString(PyExc_RuntimeError, "Initialization failed");
return NULL;
}

TSPoint start_point = ts_node_start_point(self->node);
TSPoint end_point = ts_node_end_point(self->node);

TSPoint current_point = start_point;
size_t current_offset = start_offset;

while (current_offset < end_offset) {
PyObject *byte_offset_obj = PyLong_FromSize_t(current_offset);
PyObject *point_obj = point_new(current_point);
if (!point_obj) {
Py_XDECREF(collected_bytes);
PyErr_SetString(PyExc_RuntimeError, "Failed to create point object");
return NULL;
}

PyObject *args = PyTuple_Pack(2, byte_offset_obj, point_obj);
Py_XDECREF(byte_offset_obj);
Py_XDECREF(point_obj);

PyObject *rv = PyObject_Call(tree->source_or_callback, args, NULL);
Py_XDECREF(args);

if (rv == NULL || rv == Py_None || !PyBytes_Check(rv)) {
Py_XDECREF(rv);
Py_XDECREF(collected_bytes);
PyErr_SetString(PyExc_RuntimeError, "Callback execution failed or returned invalid type");
return NULL;
}

PyBytes_Concat(&collected_bytes, rv);
Py_XDECREF(rv);

if (collected_bytes == NULL) {
PyErr_SetString(PyExc_RuntimeError, "Byte concatenation failed");
return NULL;
}

// Update current_point and current_offset
size_t bytes_read = PyBytes_Size(rv);
for (size_t i = 0; i < bytes_read; i++) {
if (PyBytes_AsString(rv)[i] == '\n') {
current_point.row++;
current_point.column = 0;
} else {
current_point.column++;
}
}
current_offset += bytes_read;
}

PyObject *start_byte = PyLong_FromSize_t(0);
PyObject *end_byte = PyLong_FromSize_t(end_offset - start_offset);
PyObject *slice = PySlice_New(start_byte, end_byte, NULL);
Py_XDECREF(start_byte);
Py_XDECREF(end_byte);

result = PyObject_GetItem(collected_bytes, slice);
Py_DECREF(slice);
Py_XDECREF(collected_bytes);
}
PyObject *result = PyBytes_FromObject(node_slice);
Py_DECREF(node_slice);

return result;
}


static PyMethodDef node_methods[] = {
{
.ml_name = "walk",
Expand Down Expand Up @@ -871,7 +939,7 @@ static bool node_is_instance(ModuleState *state, PyObject *self) {

static void tree_dealloc(Tree *self) {
ts_tree_delete(self->tree);
Py_XDECREF(self->source);
Py_XDECREF(self->source_or_callback);
Py_TYPE(self)->tp_free((PyObject *)self);
}

Expand All @@ -881,14 +949,82 @@ static PyObject *tree_get_root_node(Tree *self, void *payload) {
}

static PyObject *tree_get_text(Tree *self, void *payload) {
PyObject *source = self->source;
if (source == NULL) {
PyObject *source_or_callback = self->source_or_callback;
if (source_or_callback == NULL || source_or_callback == Py_None) {
Py_RETURN_NONE;
}
Py_INCREF(source);
return source;

// If source_or_callback is a byte buffer, return it directly
if (!PyCallable_Check(source_or_callback)) {
Py_INCREF(source_or_callback);
return source_or_callback;
}
// If source_or_callback is a callable, call it to get the full text
else {
PyObject *collected_bytes = PyBytes_FromString("");
if (collected_bytes == NULL) {
PyErr_SetString(PyExc_RuntimeError, "Initialization failed");
return NULL;
}

size_t current_offset = 0;
TSPoint current_point = {0, 0}; // Initialize to the start of the file

while (true) { // Continue reading until the callable returns None or an empty bytes object
PyObject *byte_offset_obj = PyLong_FromSize_t(current_offset);
PyObject *point_obj = point_new(current_point);
if (!point_obj) {
Py_XDECREF(collected_bytes);
PyErr_SetString(PyExc_RuntimeError, "Failed to create point object");
return NULL;
}

PyObject *args = PyTuple_Pack(2, byte_offset_obj, point_obj);
Py_XDECREF(byte_offset_obj);
Py_XDECREF(point_obj);

PyObject *rv = PyObject_Call(source_or_callback, args, NULL);
Py_XDECREF(args);

if (rv == NULL || rv == Py_None || !PyBytes_Check(rv)) {
Py_XDECREF(rv);
Py_XDECREF(collected_bytes);
PyErr_SetString(PyExc_RuntimeError, "Callback execution failed or returned invalid type");
return NULL;
}

if (PyBytes_Size(rv) == 0) {
Py_XDECREF(rv);
break; // Stop reading if an empty bytes object is returned
}

PyBytes_Concat(&collected_bytes, rv);
Py_XDECREF(rv);

if (collected_bytes == NULL) {
PyErr_SetString(PyExc_RuntimeError, "Byte concatenation failed");
return NULL;
}

// Update current_point based on the returned bytes
size_t bytes_read = PyBytes_Size(rv);
for (size_t i = 0; i < bytes_read; i++) {
if (PyBytes_AsString(rv)[i] == '\n') {
current_point.row++;
current_point.column = 0;
} else {
current_point.column++;
}
}

current_offset += bytes_read;
}

return collected_bytes;
}
}


static PyObject *tree_root_node_with_offset(Tree *self, PyObject *args) {
ModuleState *state = PyType_GetModuleState(Py_TYPE(self));

Expand Down Expand Up @@ -933,9 +1069,9 @@ static PyObject *tree_edit(Tree *self, PyObject *args, PyObject *kwargs) {
.new_end_point = {new_end_row, new_end_column},
};
ts_tree_edit(self->tree, &edit);
Py_XDECREF(self->source);
self->source = Py_None;
Py_INCREF(self->source);
Py_XDECREF(self->source_or_callback);
self->source_or_callback = Py_None;
Py_INCREF(self->source_or_callback);
}
Py_RETURN_NONE;
}
Expand Down Expand Up @@ -1045,19 +1181,19 @@ static PyType_Spec tree_type_spec = {
.slots = tree_type_slots,
};

static PyObject *tree_new_internal(ModuleState *state, TSTree *tree, PyObject *source,
static PyObject *tree_new_internal(ModuleState *state, TSTree *tree, PyObject *source_or_callback,
int keep_text) {
Tree *self = (Tree *)state->tree_type->tp_alloc(state->tree_type, 0);
if (self != NULL) {
self->tree = tree;
}

if (keep_text) {
self->source = source;
self->source_or_callback = source_or_callback;
} else {
self->source = Py_None;
self->source_or_callback = Py_None;
}
Py_INCREF(self->source);
Py_INCREF(self->source_or_callback);
return (PyObject *)self;
}

Expand Down Expand Up @@ -1516,9 +1652,6 @@ static PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) {
new_tree = ts_parser_parse(self->parser, old_tree, input);
Py_XDECREF(payload.previous_return_value);

// don't allow tree_new_internal to keep the source text
source_or_callback = Py_None;
keep_text = 0;
} else {
PyErr_SetString(PyExc_TypeError, "First argument byte buffer type or callable");
return NULL;
Expand Down Expand Up @@ -1852,7 +1985,7 @@ static bool satisfies_text_predicates(Query *query, TSQueryMatch match, Tree *tr
ModuleState *state = PyType_GetModuleState(Py_TYPE(query));
PyObject *pattern_text_predicates = PyList_GetItem(query->text_predicates, match.pattern_index);
// if there is no source, ignore the text predicates
if (tree->source == Py_None || tree->source == NULL) {
if (tree->source_or_callback == Py_None || tree->source_or_callback == NULL) {
return true;
}

Expand Down