Skip to content

Commit

Permalink
Make predicate w. missing capture return true
Browse files Browse the repository at this point in the history
If a predicate in a query references a capture that was not matched
because it was optional in the query, return true rather than throw an
error. In other words, if the capture was not found, ignore the
predicate entirely.

This matches the behavior of tree-sitter CLI and the playground. It
 is not a common case but we have run into the issue matching Kotlin functions
where the return type of a function is optional in the grammar requiring
us to do additional filtering in the app code after the query.
  • Loading branch information
jhandley authored and amaanq committed Nov 12, 2023
1 parent 882ece6 commit a022d9e
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 35 deletions.
49 changes: 49 additions & 0 deletions tests/test_tree_sitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1432,6 +1432,55 @@ def test_text_predicates(self):
self.assertEqual(2, len(captures_notext))
self.assertSetEqual({"function-name"}, set([c[1] for c in captures_notext]))

def test_text_predicate_on_optional_capture(self):
parser = Parser()
parser.set_language(JAVASCRIPT)
source = b"fun1(1)"
tree = parser.parse(source)
root_node = tree.root_node

# optional capture that is missing in source used in #eq? @capture string
query1 = JAVASCRIPT.query(
"""
((call_expression
function: (identifier) @function-name
arguments: (arguments (string)? @optional-string-arg)
(#eq? @optional-string-arg "1")))
"""
)
captures1 = query1.captures(root_node)
self.assertEqual(1, len(captures1))
self.assertEqual(b"fun1", captures1[0][0].text)
self.assertEqual("function-name", captures1[0][1])

# optional capture that is missing in source used in #eq? @capture @capture
query2 = JAVASCRIPT.query(
"""
((call_expression
function: (identifier) @function-name
arguments: (arguments (string)? @optional-string-arg)
(#eq? @optional-string-arg @function-name)))
"""
)
captures2 = query2.captures(root_node)
self.assertEqual(1, len(captures2))
self.assertEqual(b"fun1", captures2[0][0].text)
self.assertEqual("function-name", captures2[0][1])

# optional capture that is missing in source used in #match? @capture string
query3 = JAVASCRIPT.query(
"""
((call_expression
function: (identifier) @function-name
arguments: (arguments (string)? @optional-string-arg)
(#match? @optional-string-arg "\\d+")))
"""
)
captures3 = query3.captures(root_node)
self.assertEqual(1, len(captures3))
self.assertEqual(b"fun1", captures3[0][0].text)
self.assertEqual("function-name", captures3[0][1])

def test_text_predicates_errors(self):
parser = Parser()
parser.set_language(JAVASCRIPT)
Expand Down
72 changes: 37 additions & 35 deletions tree_sitter/binding.c
Original file line number Diff line number Diff line change
Expand Up @@ -1853,7 +1853,6 @@ static Node *node_for_capture_index(ModuleState *state, uint32_t index, TSQueryM
return capture_node;
}
}
PyErr_SetString(PyExc_ValueError, "An error occurred, capture was not found with given index");
return NULL;
}

Expand All @@ -1879,36 +1878,38 @@ static bool satisfies_text_predicates(Query *query, TSQueryMatch match, Tree *tr
node1 = node_for_capture_index(state, capture1_value_id, match, tree);
node2 = node_for_capture_index(state, capture2_value_id, match, tree);
if (node1 == NULL || node2 == NULL) {
goto error;
}
node1_text = node_get_text(node1, NULL);
node2_text = node_get_text(node2, NULL);
if (node1_text == NULL || node2_text == NULL) {
goto error;
is_satisfied = true;
} else {
node1_text = node_get_text(node1, NULL);
node2_text = node_get_text(node2, NULL);
if (node1_text == NULL || node2_text == NULL) {
goto error;
}
is_satisfied = PyObject_RichCompareBool(node1_text, node2_text, Py_EQ) ==
((CaptureEqCapture *)text_predicate)->is_positive;
Py_XDECREF(node1);
Py_XDECREF(node2);
Py_XDECREF(node1_text);
Py_XDECREF(node2_text);
}
Py_XDECREF(node1);
Py_XDECREF(node2);
is_satisfied = PyObject_RichCompareBool(node1_text, node2_text, Py_EQ) ==
((CaptureEqCapture *)text_predicate)->is_positive;
Py_XDECREF(node1_text);
Py_XDECREF(node2_text);
if (!is_satisfied) {
return false;
}
} else if (capture_eq_string_is_instance(text_predicate)) {
uint32_t capture_value_id = ((CaptureEqString *)text_predicate)->capture_value_id;
node1 = node_for_capture_index(state, capture_value_id, match, tree);
if (node1 == NULL) {
goto error;
}
node1_text = node_get_text(node1, NULL);
if (node1_text == NULL) {
goto error;
is_satisfied = true;
} else {
node1_text = node_get_text(node1, NULL);
if (node1_text == NULL) {
goto error;
}
PyObject *string_value = ((CaptureEqString *)text_predicate)->string_value;
is_satisfied = PyObject_RichCompareBool(node1_text, string_value, Py_EQ) ==
((CaptureEqString *)text_predicate)->is_positive;
}
Py_XDECREF(node1);
PyObject *string_value = ((CaptureEqString *)text_predicate)->string_value;
is_satisfied = PyObject_RichCompareBool(node1_text, string_value, Py_EQ) ==
((CaptureEqString *)text_predicate)->is_positive;
Py_XDECREF(node1_text);
if (!is_satisfied) {
return false;
Expand All @@ -1917,22 +1918,23 @@ static bool satisfies_text_predicates(Query *query, TSQueryMatch match, Tree *tr
uint32_t capture_value_id = ((CaptureMatchString *)text_predicate)->capture_value_id;
node1 = node_for_capture_index(state, capture_value_id, match, tree);
if (node1 == NULL) {
goto error;
}
node1_text = node_get_text(node1, NULL);
if (node1_text == NULL) {
goto error;
is_satisfied = true;
} else {
node1_text = node_get_text(node1, NULL);
if (node1_text == NULL) {
goto error;
}
PyObject *search_result =
PyObject_CallMethod(((CaptureMatchString *)text_predicate)->regex, "search",
"s", PyBytes_AsString(node1_text));
Py_XDECREF(node1_text);
is_satisfied = (search_result != NULL && search_result != Py_None) ==
((CaptureMatchString *)text_predicate)->is_positive;
if (search_result != NULL) {
Py_DECREF(search_result);
}
}
Py_XDECREF(node1);
PyObject *search_result =
PyObject_CallMethod(((CaptureMatchString *)text_predicate)->regex, "search", "s",
PyBytes_AsString(node1_text));
Py_XDECREF(node1_text);
is_satisfied = (search_result != NULL && search_result != Py_None) ==
((CaptureMatchString *)text_predicate)->is_positive;
if (search_result != NULL) {
Py_DECREF(search_result);
}
if (!is_satisfied) {
return false;
}
Expand Down

0 comments on commit a022d9e

Please sign in to comment.