diff --git a/README.md b/README.md index e0df789..6e95096 100644 --- a/README.md +++ b/README.md @@ -119,27 +119,27 @@ def foo(): ) -def read_callable(byte_offset, point): +def read_callable_byte_offset(byte_offset, point): return src[byte_offset : byte_offset + 1] -tree = parser.parse(read_callable) +tree = parser.parse(read_callable_byte_offset) ``` And to use the point: ```python -src_lines = ["def foo():\n", " if bar:\n", " baz()"] +src_lines = ["\n", "def foo():\n", " if bar:\n", " baz()\n"] -def read_callable(byte_offset, point): +def read_callable_point(byte_offset, point): row, column = point if row >= len(src_lines) or column >= len(src_lines[row]): return None return src_lines[row][column:].encode("utf8") -tree = parser.parse(read_callable) +tree = parser.parse(read_callable_point) ``` Inspect the resulting `Tree`: @@ -148,7 +148,7 @@ Inspect the resulting `Tree`: root_node = tree.root_node assert root_node.type == 'module' assert root_node.start_point == (1, 0) -assert root_node.end_point == (3, 13) +assert root_node.end_point == (4, 0) function_node = root_node.children[0] assert function_node.type == 'function_definition' @@ -159,17 +159,34 @@ assert function_name_node.type == 'identifier' assert function_name_node.start_point == (1, 4) assert function_name_node.end_point == (1, 7) -assert root_node.sexp() == "(module " - "(function_definition " - "name: (identifier) " - "parameters: (parameters) " - "body: (block " - "(if_statement " - "condition: (identifier) " - "consequence: (block " - "(expression_statement (call " - "function: (identifier) " - "arguments: (argument_list))))))))" +function_body_node = function_node.child_by_field_name("body") + +if_statement_node = function_body_node.child(0) +assert if_statement_node.type == "if_statement" + +function_call_node = if_statement_node.child_by_field_name("consequence").child(0).child(0) +assert function_call_node.type == "call" + +function_call_name_node = function_call_node.child_by_field_name("function") +assert function_call_name_node.type == "identifier" + +function_call_args_node = function_call_node.child_by_field_name("arguments") +assert function_call_args_node.type == "argument_list" + + +assert root_node.sexp() == ( + "(module " + "(function_definition " + "name: (identifier) " + "parameters: (parameters) " + "body: (block " + "(if_statement " + "condition: (identifier) " + "consequence: (block " + "(expression_statement (call " + "function: (identifier) " + "arguments: (argument_list))))))))" +) ``` ### Walking syntax trees @@ -209,7 +226,7 @@ When a source file is edited, you can edit the syntax tree to keep it in sync wi the source: ```python -new_src = src[:5] + src[5:5 + 2].upper() + src[5 + 2:] +new_src = src[:5] + src[5 : 5 + 2].upper() + src[5 + 2 :] tree.edit( start_byte=5, @@ -250,13 +267,19 @@ You can search for patterns in a syntax tree using a [tree query]: query = PY_LANGUAGE.query( """ (function_definition - name: (identifier) @function.def) + name: (identifier) @function.def + body: (block) @function.block) (call - function: (identifier) @function.call) + function: (identifier) @function.call + arguments: (argument_list) @function.args) """ ) +``` + +#### Captures +```python captures = query.captures(tree.root_node) assert len(captures) == 2 assert captures[0][0] == function_name_node @@ -269,9 +292,32 @@ query's range. Only one of the `..._byte` or `..._point` pairs need to be given to restrict the range. If all are omitted, the entire range of the passed node is used. +#### Matches + +```python +matches = query.matches(tree.root_node) +assert len(matches) == 2 + +# first match +assert matches[0][1]["function.def"] == function_name_node +assert matches[0][1]["function.block"] == function_body_node + +# second match +assert matches[1][1]["function.call"] == function_call_name_node +assert matches[1][1]["function.args"] == function_call_args_node +``` + +The `Query.matches()` method takes the same optional arguments as `Query.captures()`. +The difference between the two methods is that `Query.matches()` groups captures into matches, +which is much more useful when your captures within a query relate to each other. It maps the +capture's name to the node that was captured via a dictionary. + +To try out and explore the code referenced in this README, check out [examples/usage.py]. + [tree-sitter]: https://tree-sitter.github.io/tree-sitter/ [issue]: https://github.com/tree-sitter/py-tree-sitter/issues/new [tree-sitter-python]: https://github.com/tree-sitter/tree-sitter-python [tree query]: https://tree-sitter.github.io/tree-sitter/using-parsers#query-syntax [ci]: https://github.com/tree-sitter/py-tree-sitter/actions/workflows/ci.yml [examples/walk_tree.py]: https://github.com/tree-sitter/py-tree-sitter/blob/master/examples/walk_tree.py +[examples/usage.py]: https://github.com/tree-sitter/py-tree-sitter/blob/master/examples/usage.py diff --git a/examples/usage.py b/examples/usage.py new file mode 100644 index 0000000..55156b2 --- /dev/null +++ b/examples/usage.py @@ -0,0 +1,178 @@ +from tree_sitter import Language, Parser +import tree_sitter_python + +PY_LANGUAGE = Language(tree_sitter_python.language(), "python") + +parser = Parser() +parser.set_language(PY_LANGUAGE) + +# parsing a string of code +tree = parser.parse( + bytes( + """ +def foo(): + if bar: + baz() +""", + "utf8", + ) +) + +# parsing a callable by using the byte offset +src = bytes( + """ +def foo(): + if bar: + baz() +""", + "utf8", +) + + +def read_callable_byte_offset(byte_offset, point): + return src[byte_offset : byte_offset + 1] + + +tree = parser.parse(read_callable_byte_offset) + + +# parsing a callable by using the point +src_lines = ["\n", "def foo():\n", " if bar:\n", " baz()\n"] + + +def read_callable_point(byte_offset, point): + row, column = point + if row >= len(src_lines) or column >= len(src_lines[row]): + return None + return src_lines[row][column:].encode("utf8") + + +tree = parser.parse(read_callable_point) + +# inspecting nodes in the tree +root_node = tree.root_node +assert root_node.type == "module" +assert root_node.start_point == (1, 0) +assert root_node.end_point == (4, 0) + +function_node = root_node.child(0) +assert function_node.type == "function_definition" +assert function_node.child_by_field_name("name").type == "identifier" + +function_name_node = function_node.child(1) +assert function_name_node.type == "identifier" +assert function_name_node.start_point == (1, 4) +assert function_name_node.end_point == (1, 7) + +function_body_node = function_node.child_by_field_name("body") + +if_statement_node = function_body_node.child(0) +assert if_statement_node.type == "if_statement" + +function_call_node = if_statement_node.child_by_field_name("consequence").child(0).child(0) +assert function_call_node.type == "call" + +function_call_name_node = function_call_node.child_by_field_name("function") +assert function_call_name_node.type == "identifier" + +function_call_args_node = function_call_node.child_by_field_name("arguments") +assert function_call_args_node.type == "argument_list" + + +# getting the sexp representation of the tree +assert root_node.sexp() == ( + "(module " + "(function_definition " + "name: (identifier) " + "parameters: (parameters) " + "body: (block " + "(if_statement " + "condition: (identifier) " + "consequence: (block " + "(expression_statement (call " + "function: (identifier) " + "arguments: (argument_list))))))))" +) + +# walking the tree +cursor = tree.walk() + +assert cursor.node.type == "module" + +assert cursor.goto_first_child() +assert cursor.node.type == "function_definition" + +assert cursor.goto_first_child() +assert cursor.node.type == "def" + +# Returns `False` because the `def` node has no children +assert not cursor.goto_first_child() + +assert cursor.goto_next_sibling() +assert cursor.node.type == "identifier" + +assert cursor.goto_next_sibling() +assert cursor.node.type == "parameters" + +assert cursor.goto_parent() +assert cursor.node.type == "function_definition" + +# editing the tree +new_src = src[:5] + src[5 : 5 + 2].upper() + src[5 + 2 :] + +tree.edit( + start_byte=5, + old_end_byte=5, + new_end_byte=5 + 2, + start_point=(0, 5), + old_end_point=(0, 5), + new_end_point=(0, 5 + 2), +) + +new_tree = parser.parse(new_src, tree) + +# inspecting the changes +for changed_range in tree.changed_ranges(new_tree): + print("Changed range:") + print(f" Start point {changed_range.start_point}") + print(f" Start byte {changed_range.start_byte}") + print(f" End point {changed_range.end_point}") + print(f" End byte {changed_range.end_byte}") + + +# querying the tree +query = PY_LANGUAGE.query( + """ +(function_definition + name: (identifier) @function.def + body: (block) @function.block) + +(call + function: (identifier) @function.call + arguments: (argument_list) @function.args) +""" +) + +# ...with captures +captures = query.captures(tree.root_node) +assert len(captures) == 4 +assert captures[0][0] == function_name_node +assert captures[0][1] == "function.def" +assert captures[1][0] == function_body_node +assert captures[1][1] == "function.block" +assert captures[2][0] == function_call_name_node +assert captures[2][1] == "function.call" +assert captures[3][0] == function_call_args_node +assert captures[3][1] == "function.args" + +# ...with matches +matches = query.matches(tree.root_node) +assert len(matches) == 2 + +# first match +assert matches[0][1]["function.def"] == function_name_node +assert matches[0][1]["function.block"] == function_body_node + +# second match +assert matches[1][1]["function.call"] == function_call_name_node +assert matches[1][1]["function.args"] == function_call_args_node