From ebd543c0ac9b8a5f17636d0a42c425e5f693860e Mon Sep 17 00:00:00 2001 From: Jelle Zijlstra Date: Mon, 11 Dec 2023 21:37:15 -0800 Subject: [PATCH] Fix feature detection for parenthesized context managers (#4104) --- CHANGES.md | 1 + src/black/__init__.py | 18 ++- tests/data/cases/pep_572_remove_parens.py | 2 +- tests/test_black.py | 130 ++++++++++++---------- 4 files changed, 93 insertions(+), 58 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index dcf6613b70c..e3b5b7392b9 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -10,6 +10,7 @@ - Fix bug where `# fmt: off` automatically dedents when used with the `--line-ranges` option, even when it is not within the specified line range. (#4084) +- Fix feature detection for parenthesized context managers (#4104) ### Preview style diff --git a/src/black/__init__.py b/src/black/__init__.py index 5073fa748d5..735ba713b8f 100644 --- a/src/black/__init__.py +++ b/src/black/__init__.py @@ -1351,7 +1351,7 @@ def get_features_used( # noqa: C901 if ( len(atom_children) == 3 and atom_children[0].type == token.LPAR - and atom_children[1].type == syms.testlist_gexp + and _contains_asexpr(atom_children[1]) and atom_children[2].type == token.RPAR ): features.add(Feature.PARENTHESIZED_CONTEXT_MANAGERS) @@ -1384,6 +1384,22 @@ def get_features_used( # noqa: C901 return features +def _contains_asexpr(node: Union[Node, Leaf]) -> bool: + """Return True if `node` contains an as-pattern.""" + if node.type == syms.asexpr_test: + return True + elif node.type == syms.atom: + if ( + len(node.children) == 3 + and node.children[0].type == token.LPAR + and node.children[2].type == token.RPAR + ): + return _contains_asexpr(node.children[1]) + elif node.type == syms.testlist_gexp: + return any(_contains_asexpr(child) for child in node.children) + return False + + def detect_target_versions( node: Node, *, future_imports: Optional[Set[str]] = None ) -> Set[TargetVersion]: diff --git a/tests/data/cases/pep_572_remove_parens.py b/tests/data/cases/pep_572_remove_parens.py index 88774d81649..24f1ac29168 100644 --- a/tests/data/cases/pep_572_remove_parens.py +++ b/tests/data/cases/pep_572_remove_parens.py @@ -1,4 +1,4 @@ -# flags: --minimum-version=3.8 --no-preview-line-length-1 +# flags: --minimum-version=3.8 if (foo := 0): pass diff --git a/tests/test_black.py b/tests/test_black.py index 23815da9042..0af5fd2a1f4 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -25,6 +25,7 @@ List, Optional, Sequence, + Set, Type, TypeVar, Union, @@ -874,71 +875,88 @@ def test_get_features_used_decorator(self) -> None: ) def test_get_features_used(self) -> None: - node = black.lib2to3_parse("def f(*, arg): ...\n") - self.assertEqual(black.get_features_used(node), set()) - node = black.lib2to3_parse("def f(*, arg,): ...\n") - self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF}) - node = black.lib2to3_parse("f(*arg,)\n") - self.assertEqual( - black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL} + self.check_features_used("def f(*, arg): ...\n", set()) + self.check_features_used( + "def f(*, arg,): ...\n", {Feature.TRAILING_COMMA_IN_DEF} ) - node = black.lib2to3_parse("def f(*, arg): f'string'\n") - self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS}) - node = black.lib2to3_parse("123_456\n") - self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES}) - node = black.lib2to3_parse("123456\n") - self.assertEqual(black.get_features_used(node), set()) + self.check_features_used("f(*arg,)\n", {Feature.TRAILING_COMMA_IN_CALL}) + self.check_features_used("def f(*, arg): f'string'\n", {Feature.F_STRINGS}) + self.check_features_used("123_456\n", {Feature.NUMERIC_UNDERSCORES}) + self.check_features_used("123456\n", set()) + source, expected = read_data("cases", "function") - node = black.lib2to3_parse(source) expected_features = { Feature.TRAILING_COMMA_IN_CALL, Feature.TRAILING_COMMA_IN_DEF, Feature.F_STRINGS, } - self.assertEqual(black.get_features_used(node), expected_features) - node = black.lib2to3_parse(expected) - self.assertEqual(black.get_features_used(node), expected_features) + self.check_features_used(source, expected_features) + self.check_features_used(expected, expected_features) + source, expected = read_data("cases", "expression") - node = black.lib2to3_parse(source) - self.assertEqual(black.get_features_used(node), set()) - node = black.lib2to3_parse(expected) - self.assertEqual(black.get_features_used(node), set()) - node = black.lib2to3_parse("lambda a, /, b: ...") - self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS}) - node = black.lib2to3_parse("def fn(a, /, b): ...") - self.assertEqual(black.get_features_used(node), {Feature.POS_ONLY_ARGUMENTS}) - node = black.lib2to3_parse("def fn(): yield a, b") - self.assertEqual(black.get_features_used(node), set()) - node = black.lib2to3_parse("def fn(): return a, b") - self.assertEqual(black.get_features_used(node), set()) - node = black.lib2to3_parse("def fn(): yield *b, c") - self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW}) - node = black.lib2to3_parse("def fn(): return a, *b, c") - self.assertEqual(black.get_features_used(node), {Feature.UNPACKING_ON_FLOW}) - node = black.lib2to3_parse("x = a, *b, c") - self.assertEqual(black.get_features_used(node), set()) - node = black.lib2to3_parse("x: Any = regular") - self.assertEqual(black.get_features_used(node), set()) - node = black.lib2to3_parse("x: Any = (regular, regular)") - self.assertEqual(black.get_features_used(node), set()) - node = black.lib2to3_parse("x: Any = Complex(Type(1))[something]") - self.assertEqual(black.get_features_used(node), set()) - node = black.lib2to3_parse("x: Tuple[int, ...] = a, b, c") - self.assertEqual( - black.get_features_used(node), {Feature.ANN_ASSIGN_EXTENDED_RHS} + self.check_features_used(source, set()) + self.check_features_used(expected, set()) + + self.check_features_used("lambda a, /, b: ...\n", {Feature.POS_ONLY_ARGUMENTS}) + self.check_features_used("def fn(a, /, b): ...", {Feature.POS_ONLY_ARGUMENTS}) + + self.check_features_used("def fn(): yield a, b", set()) + self.check_features_used("def fn(): return a, b", set()) + self.check_features_used("def fn(): yield *b, c", {Feature.UNPACKING_ON_FLOW}) + self.check_features_used( + "def fn(): return a, *b, c", {Feature.UNPACKING_ON_FLOW} ) - node = black.lib2to3_parse("try: pass\nexcept Something: pass") - self.assertEqual(black.get_features_used(node), set()) - node = black.lib2to3_parse("try: pass\nexcept (*Something,): pass") - self.assertEqual(black.get_features_used(node), set()) - node = black.lib2to3_parse("try: pass\nexcept *Group: pass") - self.assertEqual(black.get_features_used(node), {Feature.EXCEPT_STAR}) - node = black.lib2to3_parse("a[*b]") - self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS}) - node = black.lib2to3_parse("a[x, *y(), z] = t") - self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS}) - node = black.lib2to3_parse("def fn(*args: *T): pass") - self.assertEqual(black.get_features_used(node), {Feature.VARIADIC_GENERICS}) + self.check_features_used("x = a, *b, c", set()) + + self.check_features_used("x: Any = regular", set()) + self.check_features_used("x: Any = (regular, regular)", set()) + self.check_features_used("x: Any = Complex(Type(1))[something]", set()) + self.check_features_used( + "x: Tuple[int, ...] = a, b, c", {Feature.ANN_ASSIGN_EXTENDED_RHS} + ) + + self.check_features_used("try: pass\nexcept Something: pass", set()) + self.check_features_used("try: pass\nexcept (*Something,): pass", set()) + self.check_features_used( + "try: pass\nexcept *Group: pass", {Feature.EXCEPT_STAR} + ) + + self.check_features_used("a[*b]", {Feature.VARIADIC_GENERICS}) + self.check_features_used("a[x, *y(), z] = t", {Feature.VARIADIC_GENERICS}) + self.check_features_used("def fn(*args: *T): pass", {Feature.VARIADIC_GENERICS}) + + self.check_features_used("with a: pass", set()) + self.check_features_used("with a, b: pass", set()) + self.check_features_used("with a as b: pass", set()) + self.check_features_used("with a as b, c as d: pass", set()) + self.check_features_used("with (a): pass", set()) + self.check_features_used("with (a, b): pass", set()) + self.check_features_used("with (a, b) as (c, d): pass", set()) + self.check_features_used( + "with (a as b): pass", {Feature.PARENTHESIZED_CONTEXT_MANAGERS} + ) + self.check_features_used( + "with ((a as b)): pass", {Feature.PARENTHESIZED_CONTEXT_MANAGERS} + ) + self.check_features_used( + "with (a, b as c): pass", {Feature.PARENTHESIZED_CONTEXT_MANAGERS} + ) + self.check_features_used( + "with (a, (b as c)): pass", {Feature.PARENTHESIZED_CONTEXT_MANAGERS} + ) + self.check_features_used( + "with ((a, ((b as c)))): pass", {Feature.PARENTHESIZED_CONTEXT_MANAGERS} + ) + + def check_features_used(self, source: str, expected: Set[Feature]) -> None: + node = black.lib2to3_parse(source) + actual = black.get_features_used(node) + msg = f"Expected {expected} but got {actual} for {source!r}" + try: + self.assertEqual(actual, expected, msg=msg) + except AssertionError: + DebugVisitor.show(node) + raise def test_get_features_used_for_future_flags(self) -> None: for src, features in [