From a022d9e4474201df1a5dd2a3f456f99676436ebb Mon Sep 17 00:00:00 2001 From: Josh Handley Date: Sat, 7 Oct 2023 12:16:14 -0400 Subject: [PATCH] Make predicate w. missing capture return true If a predicate in a query references a capture that was not matched because it was optional in the query, return true rather than throw an error. In other words, if the capture was not found, ignore the predicate entirely. This matches the behavior of tree-sitter CLI and the playground. It is not a common case but we have run into the issue matching Kotlin functions where the return type of a function is optional in the grammar requiring us to do additional filtering in the app code after the query. --- tests/test_tree_sitter.py | 49 ++++++++++++++++++++++++++ tree_sitter/binding.c | 72 ++++++++++++++++++++------------------- 2 files changed, 86 insertions(+), 35 deletions(-) diff --git a/tests/test_tree_sitter.py b/tests/test_tree_sitter.py index d98fe3f..51b9356 100644 --- a/tests/test_tree_sitter.py +++ b/tests/test_tree_sitter.py @@ -1432,6 +1432,55 @@ def test_text_predicates(self): self.assertEqual(2, len(captures_notext)) self.assertSetEqual({"function-name"}, set([c[1] for c in captures_notext])) + def test_text_predicate_on_optional_capture(self): + parser = Parser() + parser.set_language(JAVASCRIPT) + source = b"fun1(1)" + tree = parser.parse(source) + root_node = tree.root_node + + # optional capture that is missing in source used in #eq? @capture string + query1 = JAVASCRIPT.query( + """ + ((call_expression + function: (identifier) @function-name + arguments: (arguments (string)? @optional-string-arg) + (#eq? @optional-string-arg "1"))) + """ + ) + captures1 = query1.captures(root_node) + self.assertEqual(1, len(captures1)) + self.assertEqual(b"fun1", captures1[0][0].text) + self.assertEqual("function-name", captures1[0][1]) + + # optional capture that is missing in source used in #eq? @capture @capture + query2 = JAVASCRIPT.query( + """ + ((call_expression + function: (identifier) @function-name + arguments: (arguments (string)? @optional-string-arg) + (#eq? @optional-string-arg @function-name))) + """ + ) + captures2 = query2.captures(root_node) + self.assertEqual(1, len(captures2)) + self.assertEqual(b"fun1", captures2[0][0].text) + self.assertEqual("function-name", captures2[0][1]) + + # optional capture that is missing in source used in #match? @capture string + query3 = JAVASCRIPT.query( + """ + ((call_expression + function: (identifier) @function-name + arguments: (arguments (string)? @optional-string-arg) + (#match? @optional-string-arg "\\d+"))) + """ + ) + captures3 = query3.captures(root_node) + self.assertEqual(1, len(captures3)) + self.assertEqual(b"fun1", captures3[0][0].text) + self.assertEqual("function-name", captures3[0][1]) + def test_text_predicates_errors(self): parser = Parser() parser.set_language(JAVASCRIPT) diff --git a/tree_sitter/binding.c b/tree_sitter/binding.c index c8011be..94f8ff6 100644 --- a/tree_sitter/binding.c +++ b/tree_sitter/binding.c @@ -1853,7 +1853,6 @@ static Node *node_for_capture_index(ModuleState *state, uint32_t index, TSQueryM return capture_node; } } - PyErr_SetString(PyExc_ValueError, "An error occurred, capture was not found with given index"); return NULL; } @@ -1879,19 +1878,20 @@ static bool satisfies_text_predicates(Query *query, TSQueryMatch match, Tree *tr node1 = node_for_capture_index(state, capture1_value_id, match, tree); node2 = node_for_capture_index(state, capture2_value_id, match, tree); if (node1 == NULL || node2 == NULL) { - goto error; - } - node1_text = node_get_text(node1, NULL); - node2_text = node_get_text(node2, NULL); - if (node1_text == NULL || node2_text == NULL) { - goto error; + is_satisfied = true; + } else { + node1_text = node_get_text(node1, NULL); + node2_text = node_get_text(node2, NULL); + if (node1_text == NULL || node2_text == NULL) { + goto error; + } + is_satisfied = PyObject_RichCompareBool(node1_text, node2_text, Py_EQ) == + ((CaptureEqCapture *)text_predicate)->is_positive; + Py_XDECREF(node1); + Py_XDECREF(node2); + Py_XDECREF(node1_text); + Py_XDECREF(node2_text); } - Py_XDECREF(node1); - Py_XDECREF(node2); - is_satisfied = PyObject_RichCompareBool(node1_text, node2_text, Py_EQ) == - ((CaptureEqCapture *)text_predicate)->is_positive; - Py_XDECREF(node1_text); - Py_XDECREF(node2_text); if (!is_satisfied) { return false; } @@ -1899,16 +1899,17 @@ static bool satisfies_text_predicates(Query *query, TSQueryMatch match, Tree *tr uint32_t capture_value_id = ((CaptureEqString *)text_predicate)->capture_value_id; node1 = node_for_capture_index(state, capture_value_id, match, tree); if (node1 == NULL) { - goto error; - } - node1_text = node_get_text(node1, NULL); - if (node1_text == NULL) { - goto error; + is_satisfied = true; + } else { + node1_text = node_get_text(node1, NULL); + if (node1_text == NULL) { + goto error; + } + PyObject *string_value = ((CaptureEqString *)text_predicate)->string_value; + is_satisfied = PyObject_RichCompareBool(node1_text, string_value, Py_EQ) == + ((CaptureEqString *)text_predicate)->is_positive; } Py_XDECREF(node1); - PyObject *string_value = ((CaptureEqString *)text_predicate)->string_value; - is_satisfied = PyObject_RichCompareBool(node1_text, string_value, Py_EQ) == - ((CaptureEqString *)text_predicate)->is_positive; Py_XDECREF(node1_text); if (!is_satisfied) { return false; @@ -1917,22 +1918,23 @@ static bool satisfies_text_predicates(Query *query, TSQueryMatch match, Tree *tr uint32_t capture_value_id = ((CaptureMatchString *)text_predicate)->capture_value_id; node1 = node_for_capture_index(state, capture_value_id, match, tree); if (node1 == NULL) { - goto error; - } - node1_text = node_get_text(node1, NULL); - if (node1_text == NULL) { - goto error; + is_satisfied = true; + } else { + node1_text = node_get_text(node1, NULL); + if (node1_text == NULL) { + goto error; + } + PyObject *search_result = + PyObject_CallMethod(((CaptureMatchString *)text_predicate)->regex, "search", + "s", PyBytes_AsString(node1_text)); + Py_XDECREF(node1_text); + is_satisfied = (search_result != NULL && search_result != Py_None) == + ((CaptureMatchString *)text_predicate)->is_positive; + if (search_result != NULL) { + Py_DECREF(search_result); + } } Py_XDECREF(node1); - PyObject *search_result = - PyObject_CallMethod(((CaptureMatchString *)text_predicate)->regex, "search", "s", - PyBytes_AsString(node1_text)); - Py_XDECREF(node1_text); - is_satisfied = (search_result != NULL && search_result != Py_None) == - ((CaptureMatchString *)text_predicate)->is_positive; - if (search_result != NULL) { - Py_DECREF(search_result); - } if (!is_satisfied) { return false; }