diff --git a/tree_sitter/binding.c b/tree_sitter/binding.c index 8a0ebd4..ff9cd03 100644 --- a/tree_sitter/binding.c +++ b/tree_sitter/binding.c @@ -15,7 +15,7 @@ typedef struct { typedef struct { PyObject_HEAD TSTree *tree; - PyObject *source; + PyObject *source_or_callback; } Tree; typedef struct { @@ -650,46 +650,114 @@ static PyObject *node_get_text(Node *self, void *payload) { PyErr_SetString(PyExc_ValueError, "No tree"); return NULL; } - if (tree->source == Py_None || tree->source == NULL) { + if (tree->source_or_callback == Py_None || tree->source_or_callback == NULL) { Py_RETURN_NONE; } - PyObject *start_byte = PyLong_FromSize_t((size_t)ts_node_start_byte(self->node)); - if (start_byte == NULL) { - PyErr_SetString(PyExc_RuntimeError, "Failed to determine start byte"); - return NULL; - } - PyObject *end_byte = PyLong_FromSize_t((size_t)ts_node_end_byte(self->node)); - if (end_byte == NULL) { - Py_DECREF(start_byte); - PyErr_SetString(PyExc_RuntimeError, "Failed to determine end byte"); - return NULL; - } - PyObject *slice = PySlice_New(start_byte, end_byte, NULL); - Py_DECREF(start_byte); - Py_DECREF(end_byte); - if (slice == NULL) { - PyErr_SetString(PyExc_RuntimeError, "PySlice_New failed"); - return NULL; - } - PyObject *node_mv = PyMemoryView_FromObject(tree->source); - if (node_mv == NULL) { + size_t start_offset = (size_t)ts_node_start_byte(self->node); + size_t end_offset = (size_t)ts_node_end_byte(self->node); + PyObject *result = NULL; + + // Case 1: source_or_callback is a byte buffer + if (!PyCallable_Check(tree->source_or_callback)) { + PyObject *start_byte = PyLong_FromSize_t(start_offset); + PyObject *end_byte = PyLong_FromSize_t(end_offset); + PyObject *slice = PySlice_New(start_byte, end_byte, NULL); + Py_XDECREF(start_byte); + Py_XDECREF(end_byte); + if (slice == NULL) { + PyErr_SetString(PyExc_RuntimeError, "PySlice_New failed"); + return NULL; + } + PyObject *node_mv = PyMemoryView_FromObject(tree->source_or_callback); + if (node_mv == NULL) { + Py_DECREF(slice); + PyErr_SetString(PyExc_RuntimeError, "PyMemoryView_FromObject failed"); + return NULL; + } + PyObject *node_slice = PyObject_GetItem(node_mv, slice); Py_DECREF(slice); - PyErr_SetString(PyExc_RuntimeError, "PyMemoryView_FromObject failed"); - return NULL; - } - PyObject *node_slice = PyObject_GetItem(node_mv, slice); - Py_DECREF(slice); - Py_DECREF(node_mv); - if (node_slice == NULL) { - PyErr_SetString(PyExc_RuntimeError, "PyObject_GetItem failed"); - return NULL; + Py_DECREF(node_mv); + if (node_slice == NULL) { + PyErr_SetString(PyExc_RuntimeError, "PyObject_GetItem failed"); + return NULL; + } + result = PyBytes_FromObject(node_slice); + Py_DECREF(node_slice); + } + // Case 2: source_or_callback is a callable + else { + PyObject *collected_bytes = PyBytes_FromString(""); + if (collected_bytes == NULL) { + PyErr_SetString(PyExc_RuntimeError, "Initialization failed"); + return NULL; + } + + TSPoint start_point = ts_node_start_point(self->node); + TSPoint end_point = ts_node_end_point(self->node); + + TSPoint current_point = start_point; + size_t current_offset = start_offset; + + while (current_offset < end_offset) { + PyObject *byte_offset_obj = PyLong_FromSize_t(current_offset); + PyObject *point_obj = point_new(current_point); + if (!point_obj) { + Py_XDECREF(collected_bytes); + PyErr_SetString(PyExc_RuntimeError, "Failed to create point object"); + return NULL; + } + + PyObject *args = PyTuple_Pack(2, byte_offset_obj, point_obj); + Py_XDECREF(byte_offset_obj); + Py_XDECREF(point_obj); + + PyObject *rv = PyObject_Call(tree->source_or_callback, args, NULL); + Py_XDECREF(args); + + if (rv == NULL || rv == Py_None || !PyBytes_Check(rv)) { + Py_XDECREF(rv); + Py_XDECREF(collected_bytes); + PyErr_SetString(PyExc_RuntimeError, "Callback execution failed or returned invalid type"); + return NULL; + } + + PyBytes_Concat(&collected_bytes, rv); + Py_XDECREF(rv); + + if (collected_bytes == NULL) { + PyErr_SetString(PyExc_RuntimeError, "Byte concatenation failed"); + return NULL; + } + + // Update current_point and current_offset + size_t bytes_read = PyBytes_Size(rv); + for (size_t i = 0; i < bytes_read; i++) { + if (PyBytes_AsString(rv)[i] == '\n') { + current_point.row++; + current_point.column = 0; + } else { + current_point.column++; + } + } + current_offset += bytes_read; + } + + PyObject *start_byte = PyLong_FromSize_t(0); + PyObject *end_byte = PyLong_FromSize_t(end_offset - start_offset); + PyObject *slice = PySlice_New(start_byte, end_byte, NULL); + Py_XDECREF(start_byte); + Py_XDECREF(end_byte); + + result = PyObject_GetItem(collected_bytes, slice); + Py_DECREF(slice); + Py_XDECREF(collected_bytes); } - PyObject *result = PyBytes_FromObject(node_slice); - Py_DECREF(node_slice); + return result; } + static PyMethodDef node_methods[] = { { .ml_name = "walk", @@ -871,7 +939,7 @@ static bool node_is_instance(ModuleState *state, PyObject *self) { static void tree_dealloc(Tree *self) { ts_tree_delete(self->tree); - Py_XDECREF(self->source); + Py_XDECREF(self->source_or_callback); Py_TYPE(self)->tp_free((PyObject *)self); } @@ -881,14 +949,82 @@ static PyObject *tree_get_root_node(Tree *self, void *payload) { } static PyObject *tree_get_text(Tree *self, void *payload) { - PyObject *source = self->source; - if (source == NULL) { + PyObject *source_or_callback = self->source_or_callback; + if (source_or_callback == NULL || source_or_callback == Py_None) { Py_RETURN_NONE; } - Py_INCREF(source); - return source; + + // If source_or_callback is a byte buffer, return it directly + if (!PyCallable_Check(source_or_callback)) { + Py_INCREF(source_or_callback); + return source_or_callback; + } + // If source_or_callback is a callable, call it to get the full text + else { + PyObject *collected_bytes = PyBytes_FromString(""); + if (collected_bytes == NULL) { + PyErr_SetString(PyExc_RuntimeError, "Initialization failed"); + return NULL; + } + + size_t current_offset = 0; + TSPoint current_point = {0, 0}; // Initialize to the start of the file + + while (true) { // Continue reading until the callable returns None or an empty bytes object + PyObject *byte_offset_obj = PyLong_FromSize_t(current_offset); + PyObject *point_obj = point_new(current_point); + if (!point_obj) { + Py_XDECREF(collected_bytes); + PyErr_SetString(PyExc_RuntimeError, "Failed to create point object"); + return NULL; + } + + PyObject *args = PyTuple_Pack(2, byte_offset_obj, point_obj); + Py_XDECREF(byte_offset_obj); + Py_XDECREF(point_obj); + + PyObject *rv = PyObject_Call(source_or_callback, args, NULL); + Py_XDECREF(args); + + if (rv == NULL || rv == Py_None || !PyBytes_Check(rv)) { + Py_XDECREF(rv); + Py_XDECREF(collected_bytes); + PyErr_SetString(PyExc_RuntimeError, "Callback execution failed or returned invalid type"); + return NULL; + } + + if (PyBytes_Size(rv) == 0) { + Py_XDECREF(rv); + break; // Stop reading if an empty bytes object is returned + } + + PyBytes_Concat(&collected_bytes, rv); + Py_XDECREF(rv); + + if (collected_bytes == NULL) { + PyErr_SetString(PyExc_RuntimeError, "Byte concatenation failed"); + return NULL; + } + + // Update current_point based on the returned bytes + size_t bytes_read = PyBytes_Size(rv); + for (size_t i = 0; i < bytes_read; i++) { + if (PyBytes_AsString(rv)[i] == '\n') { + current_point.row++; + current_point.column = 0; + } else { + current_point.column++; + } + } + + current_offset += bytes_read; + } + + return collected_bytes; + } } + static PyObject *tree_root_node_with_offset(Tree *self, PyObject *args) { ModuleState *state = PyType_GetModuleState(Py_TYPE(self)); @@ -933,9 +1069,9 @@ static PyObject *tree_edit(Tree *self, PyObject *args, PyObject *kwargs) { .new_end_point = {new_end_row, new_end_column}, }; ts_tree_edit(self->tree, &edit); - Py_XDECREF(self->source); - self->source = Py_None; - Py_INCREF(self->source); + Py_XDECREF(self->source_or_callback); + self->source_or_callback = Py_None; + Py_INCREF(self->source_or_callback); } Py_RETURN_NONE; } @@ -1045,7 +1181,7 @@ static PyType_Spec tree_type_spec = { .slots = tree_type_slots, }; -static PyObject *tree_new_internal(ModuleState *state, TSTree *tree, PyObject *source, +static PyObject *tree_new_internal(ModuleState *state, TSTree *tree, PyObject *source_or_callback, int keep_text) { Tree *self = (Tree *)state->tree_type->tp_alloc(state->tree_type, 0); if (self != NULL) { @@ -1053,11 +1189,11 @@ static PyObject *tree_new_internal(ModuleState *state, TSTree *tree, PyObject *s } if (keep_text) { - self->source = source; + self->source_or_callback = source_or_callback; } else { - self->source = Py_None; + self->source_or_callback = Py_None; } - Py_INCREF(self->source); + Py_INCREF(self->source_or_callback); return (PyObject *)self; } @@ -1516,9 +1652,6 @@ static PyObject *parser_parse(Parser *self, PyObject *args, PyObject *kwargs) { new_tree = ts_parser_parse(self->parser, old_tree, input); Py_XDECREF(payload.previous_return_value); - // don't allow tree_new_internal to keep the source text - source_or_callback = Py_None; - keep_text = 0; } else { PyErr_SetString(PyExc_TypeError, "First argument byte buffer type or callable"); return NULL; @@ -1852,7 +1985,7 @@ static bool satisfies_text_predicates(Query *query, TSQueryMatch match, Tree *tr ModuleState *state = PyType_GetModuleState(Py_TYPE(query)); PyObject *pattern_text_predicates = PyList_GetItem(query->text_predicates, match.pattern_index); // if there is no source, ignore the text predicates - if (tree->source == Py_None || tree->source == NULL) { + if (tree->source_or_callback == Py_None || tree->source_or_callback == NULL) { return true; }