Skip to content

Commit

Permalink
docs: improve examples and add usage file
Browse files Browse the repository at this point in the history
  • Loading branch information
amaanq committed Mar 5, 2024
1 parent 4d2d35a commit 4396e05
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 20 deletions.
86 changes: 66 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,27 +119,27 @@ def foo():
)


def read_callable(byte_offset, point):
def read_callable_byte_offset(byte_offset, point):
return src[byte_offset : byte_offset + 1]


tree = parser.parse(read_callable)
tree = parser.parse(read_callable_byte_offset)
```

And to use the point:

```python
src_lines = ["def foo():\n", " if bar:\n", " baz()"]
src_lines = ["\n", "def foo():\n", " if bar:\n", " baz()\n"]


def read_callable(byte_offset, point):
def read_callable_point(byte_offset, point):
row, column = point
if row >= len(src_lines) or column >= len(src_lines[row]):
return None
return src_lines[row][column:].encode("utf8")


tree = parser.parse(read_callable)
tree = parser.parse(read_callable_point)
```

Inspect the resulting `Tree`:
Expand All @@ -148,7 +148,7 @@ Inspect the resulting `Tree`:
root_node = tree.root_node
assert root_node.type == 'module'
assert root_node.start_point == (1, 0)
assert root_node.end_point == (3, 13)
assert root_node.end_point == (4, 0)

function_node = root_node.children[0]
assert function_node.type == 'function_definition'
Expand All @@ -159,17 +159,34 @@ assert function_name_node.type == 'identifier'
assert function_name_node.start_point == (1, 4)
assert function_name_node.end_point == (1, 7)

assert root_node.sexp() == "(module "
"(function_definition "
"name: (identifier) "
"parameters: (parameters) "
"body: (block "
"(if_statement "
"condition: (identifier) "
"consequence: (block "
"(expression_statement (call "
"function: (identifier) "
"arguments: (argument_list))))))))"
function_body_node = function_node.child_by_field_name("body")

if_statement_node = function_body_node.child(0)
assert if_statement_node.type == "if_statement"

function_call_node = if_statement_node.child_by_field_name("consequence").child(0).child(0)
assert function_call_node.type == "call"

function_call_name_node = function_call_node.child_by_field_name("function")
assert function_call_name_node.type == "identifier"

function_call_args_node = function_call_node.child_by_field_name("arguments")
assert function_call_args_node.type == "argument_list"


assert root_node.sexp() == (
"(module "
"(function_definition "
"name: (identifier) "
"parameters: (parameters) "
"body: (block "
"(if_statement "
"condition: (identifier) "
"consequence: (block "
"(expression_statement (call "
"function: (identifier) "
"arguments: (argument_list))))))))"
)
```

### Walking syntax trees
Expand Down Expand Up @@ -209,7 +226,7 @@ When a source file is edited, you can edit the syntax tree to keep it in sync wi
the source:

```python
new_src = src[:5] + src[5:5 + 2].upper() + src[5 + 2:]
new_src = src[:5] + src[5 : 5 + 2].upper() + src[5 + 2 :]

tree.edit(
start_byte=5,
Expand Down Expand Up @@ -250,13 +267,19 @@ You can search for patterns in a syntax tree using a [tree query]:
query = PY_LANGUAGE.query(
"""
(function_definition
name: (identifier) @function.def)
name: (identifier) @function.def
body: (block) @function.block)
(call
function: (identifier) @function.call)
function: (identifier) @function.call
arguments: (argument_list) @function.args)
"""
)
```

#### Captures

```python
captures = query.captures(tree.root_node)
assert len(captures) == 2
assert captures[0][0] == function_name_node
Expand All @@ -269,9 +292,32 @@ query's range. Only one of the `..._byte` or `..._point` pairs need to be given
to restrict the range. If all are omitted, the entire range of the passed node
is used.

#### Matches

```python
matches = query.matches(tree.root_node)
assert len(matches) == 2

# first match
assert matches[0][1]["function.def"] == function_name_node
assert matches[0][1]["function.block"] == function_body_node

# second match
assert matches[1][1]["function.call"] == function_call_name_node
assert matches[1][1]["function.args"] == function_call_args_node
```

The `Query.matches()` method takes the same optional arguments as `Query.captures()`.
The difference between the two methods is that `Query.matches()` groups captures into matches,
which is much more useful when your captures within a query relate to each other. It maps the
capture's name to the node that was captured via a dictionary.

To try out and explore the code referenced in this README, check out [examples/usage.py].

[tree-sitter]: https://tree-sitter.github.io/tree-sitter/
[issue]: https://github.com/tree-sitter/py-tree-sitter/issues/new
[tree-sitter-python]: https://github.com/tree-sitter/tree-sitter-python
[tree query]: https://tree-sitter.github.io/tree-sitter/using-parsers#query-syntax
[ci]: https://github.com/tree-sitter/py-tree-sitter/actions/workflows/ci.yml
[examples/walk_tree.py]: https://github.com/tree-sitter/py-tree-sitter/blob/master/examples/walk_tree.py
[examples/usage.py]: https://github.com/tree-sitter/py-tree-sitter/blob/master/examples/usage.py
178 changes: 178 additions & 0 deletions examples/usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from tree_sitter import Language, Parser
import tree_sitter_python

PY_LANGUAGE = Language(tree_sitter_python.language(), "python")

parser = Parser()
parser.set_language(PY_LANGUAGE)

# parsing a string of code
tree = parser.parse(
bytes(
"""
def foo():
if bar:
baz()
""",
"utf8",
)
)

# parsing a callable by using the byte offset
src = bytes(
"""
def foo():
if bar:
baz()
""",
"utf8",
)


def read_callable_byte_offset(byte_offset, point):
return src[byte_offset : byte_offset + 1]


tree = parser.parse(read_callable_byte_offset)


# parsing a callable by using the point
src_lines = ["\n", "def foo():\n", " if bar:\n", " baz()\n"]


def read_callable_point(byte_offset, point):
row, column = point
if row >= len(src_lines) or column >= len(src_lines[row]):
return None
return src_lines[row][column:].encode("utf8")


tree = parser.parse(read_callable_point)

# inspecting nodes in the tree
root_node = tree.root_node
assert root_node.type == "module"
assert root_node.start_point == (1, 0)
assert root_node.end_point == (4, 0)

function_node = root_node.child(0)
assert function_node.type == "function_definition"
assert function_node.child_by_field_name("name").type == "identifier"

function_name_node = function_node.child(1)
assert function_name_node.type == "identifier"
assert function_name_node.start_point == (1, 4)
assert function_name_node.end_point == (1, 7)

function_body_node = function_node.child_by_field_name("body")

if_statement_node = function_body_node.child(0)
assert if_statement_node.type == "if_statement"

function_call_node = if_statement_node.child_by_field_name("consequence").child(0).child(0)
assert function_call_node.type == "call"

function_call_name_node = function_call_node.child_by_field_name("function")
assert function_call_name_node.type == "identifier"

function_call_args_node = function_call_node.child_by_field_name("arguments")
assert function_call_args_node.type == "argument_list"


# getting the sexp representation of the tree
assert root_node.sexp() == (
"(module "
"(function_definition "
"name: (identifier) "
"parameters: (parameters) "
"body: (block "
"(if_statement "
"condition: (identifier) "
"consequence: (block "
"(expression_statement (call "
"function: (identifier) "
"arguments: (argument_list))))))))"
)

# walking the tree
cursor = tree.walk()

assert cursor.node.type == "module"

assert cursor.goto_first_child()
assert cursor.node.type == "function_definition"

assert cursor.goto_first_child()
assert cursor.node.type == "def"

# Returns `False` because the `def` node has no children
assert not cursor.goto_first_child()

assert cursor.goto_next_sibling()
assert cursor.node.type == "identifier"

assert cursor.goto_next_sibling()
assert cursor.node.type == "parameters"

assert cursor.goto_parent()
assert cursor.node.type == "function_definition"

# editing the tree
new_src = src[:5] + src[5 : 5 + 2].upper() + src[5 + 2 :]

tree.edit(
start_byte=5,
old_end_byte=5,
new_end_byte=5 + 2,
start_point=(0, 5),
old_end_point=(0, 5),
new_end_point=(0, 5 + 2),
)

new_tree = parser.parse(new_src, tree)

# inspecting the changes
for changed_range in tree.changed_ranges(new_tree):
print("Changed range:")
print(f" Start point {changed_range.start_point}")
print(f" Start byte {changed_range.start_byte}")
print(f" End point {changed_range.end_point}")
print(f" End byte {changed_range.end_byte}")


# querying the tree
query = PY_LANGUAGE.query(
"""
(function_definition
name: (identifier) @function.def
body: (block) @function.block)
(call
function: (identifier) @function.call
arguments: (argument_list) @function.args)
"""
)

# ...with captures
captures = query.captures(tree.root_node)
assert len(captures) == 4
assert captures[0][0] == function_name_node
assert captures[0][1] == "function.def"
assert captures[1][0] == function_body_node
assert captures[1][1] == "function.block"
assert captures[2][0] == function_call_name_node
assert captures[2][1] == "function.call"
assert captures[3][0] == function_call_args_node
assert captures[3][1] == "function.args"

# ...with matches
matches = query.matches(tree.root_node)
assert len(matches) == 2

# first match
assert matches[0][1]["function.def"] == function_name_node
assert matches[0][1]["function.block"] == function_body_node

# second match
assert matches[1][1]["function.call"] == function_call_name_node
assert matches[1][1]["function.args"] == function_call_args_node

0 comments on commit 4396e05

Please sign in to comment.