Skip to content

Commit

Permalink
feat: make match captures a dictionary (#165)
Browse files Browse the repository at this point in the history
Return the captures of a match as a dict where the key is the capture
name and the value is a Node or a list of Nodes. For regular captures,
the value is a Node. For captures that can contain multiple nodes
because they use a * or + quantifier, the value is a list of Nodes.
  • Loading branch information
jhandley authored Feb 26, 2024
1 parent 9c86022 commit c1d1126
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 15 deletions.
47 changes: 42 additions & 5 deletions tests/test_tree_sitter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from os import path
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
from unittest import TestCase

from tree_sitter import Language, LookaheadIterator, Node, Parser, Query, Range, Tree
Expand Down Expand Up @@ -1260,22 +1260,29 @@ def test_errors(self):

def collect_matches(
self,
matches: List[Tuple[int, List[Tuple[Node, str]]]],
matches: List[Tuple[int, Dict[str, Union[Node, List[Node]]]]],
) -> List[Tuple[int, List[Tuple[str, str]]]]:
return [(m[0], self.format_captures(m[1])) for m in matches]

def format_captures(
self,
captures: List[Tuple[Node, str]],
captures: Dict[str, Union[Node, List[Node]]],
) -> List[Tuple[str, str]]:
return [(capture[1], capture[0].text.decode("utf-8")) for capture in captures]
return [(name, self.format_capture(capture)) for name, capture in captures.items()]

def format_capture(self, capture: Union[Node, List[Node]]) -> str:
return (
"[" + ", ".join(["'" + n.text.decode("utf-8") + "'" for n in capture]) + "]"
if isinstance(capture, List)
else capture.text.decode("utf-8")
)

def assert_query_matches(
self,
language: Language,
query: Query,
source: bytes,
expected: List[Tuple[int, List[Tuple[str, str]]]],
expected: List[Tuple[int, Dict[str, str]]]
):
parser = Parser()
parser.set_language(language)
Expand Down Expand Up @@ -1367,6 +1374,36 @@ def test_matches_with_nesting_and_no_fields(self):
],
)

def test_matches_with_list_capture(self):
query = JAVASCRIPT.query(
"""(function_declaration name: (identifier) @fn-name
body: (statement_block (_)* @fn-statements)
)"""
)
self.assert_query_matches(
JAVASCRIPT,
query,
b"""function one() {
x = 1;
y = 2;
z = 3;
}
function two() {
x = 1;
}
""",
[
(
0,
[
("fn-name", "one"),
("fn-statements", "['x = 1;', 'y = 2;', 'z = 3;']"),
],
),
(0, [("fn-name", "two"), ("fn-statements", "['x = 1;']")]),
],
)

def test_captures(self):
parser = Parser()
parser.set_language(PYTHON)
Expand Down
4 changes: 2 additions & 2 deletions tree_sitter/_binding.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Callable, Iterable, List, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import tree_sitter

Expand Down Expand Up @@ -349,7 +349,7 @@ class Query:
end_point: Optional[Tuple[int, int]] = None,
start_byte: Optional[int] = None,
end_byte: Optional[int] = None,
) -> List[Tuple[int, List[Tuple[Node, str]]]]:
) -> List[Tuple[int, Dict[str, Union[Node, List[Node]]]]]:
"""Get a list of all of the matches within the given node."""

def captures(
Expand Down
27 changes: 19 additions & 8 deletions tree_sitter/binding.c
Original file line number Diff line number Diff line change
Expand Up @@ -1989,6 +1989,14 @@ static bool satisfies_text_predicates(Query *query, TSQueryMatch match, Tree *tr
return false;
}

static bool is_list_capture(TSQuery *query, TSQueryMatch *match, unsigned int capture_index) {
TSQuantifier quantifier = ts_query_capture_quantifier_for_id(
query,
match->pattern_index,
match->captures[capture_index].index);
return quantifier == TSQuantifierZeroOrMore || quantifier == TSQuantifierOneOrMore;
}

static PyObject *query_matches(Query *self, PyObject *args, PyObject *kwargs) {
ModuleState *state = PyType_GetModuleState(Py_TYPE(self));
char *keywords[] = {
Expand Down Expand Up @@ -2028,7 +2036,7 @@ static PyObject *query_matches(Query *self, PyObject *args, PyObject *kwargs) {
if (match == NULL) {
goto error;
}
PyObject *captures_for_match = PyList_New(0);
PyObject *captures_for_match = PyDict_New();
if (captures_for_match == NULL) {
goto error;
}
Expand All @@ -2045,15 +2053,18 @@ static PyObject *query_matches(Query *self, PyObject *args, PyObject *kwargs) {
PyList_GetItem(self->capture_names, capture->capture.index);
PyObject *capture_node =
node_new_internal(state, capture->capture.node, node->tree);
PyObject *item = PyTuple_Pack(2, capture_node, capture_name);
if (item == NULL) {
Py_XDECREF(captures_for_match);
Py_XDECREF(capture_node);
goto error;

if (is_list_capture(self->query, &_match, i)) {
PyObject *defult_new_capture_list = PyList_New(0);
PyObject *capture_list = PyDict_SetDefault(captures_for_match, capture_name, defult_new_capture_list);
Py_INCREF(capture_list);
Py_DECREF(defult_new_capture_list);
PyList_Append(capture_list, capture_node);
Py_DECREF(capture_list);
} else {
PyDict_SetItem(captures_for_match, capture_name, capture_node);
}
Py_XDECREF(capture_node);
PyList_Append(captures_for_match, item);
Py_XDECREF(item);
}
Py_XDECREF(capture);
}
Expand Down

0 comments on commit c1d1126

Please sign in to comment.