From af74857dc89299e88cd7fdeee2a3bde48b096a67 Mon Sep 17 00:00:00 2001 From: Matej Aleksandrov Date: Mon, 14 Oct 2024 07:26:35 -0700 Subject: [PATCH] Rebase Pyink to Black v24.10.0. PiperOrigin-RevId: 685699118 --- .github/workflows/test.yml | 2 +- patches/pyink.patch | 299 +++++++++--------- pyproject.toml | 24 +- src/pyink/__init__.py | 133 ++++---- src/pyink/_width_table.py | 4 +- src/pyink/brackets.py | 14 +- src/pyink/cache.py | 17 +- src/pyink/comments.py | 14 +- src/pyink/concurrency.py | 6 +- src/pyink/debug.py | 4 +- src/pyink/files.py | 31 +- src/pyink/handle_ipynb_magics.py | 65 +++- src/pyink/ink.py | 24 +- src/pyink/ink_adjusted_lines.py | 8 +- src/pyink/linegen.py | 124 ++++---- src/pyink/lines.py | 45 +-- src/pyink/mode.py | 22 +- src/pyink/nodes.py | 32 +- src/pyink/output.py | 4 +- src/pyink/parsing.py | 26 +- src/pyink/ranges.py | 30 +- src/pyink/resources/pyink.schema.json | 3 +- src/pyink/schema.py | 7 +- src/pyink/strings.py | 8 +- src/pyink/trans.py | 76 +++-- .../funcdef_return_type_trailing_comma.py | 1 + tests/data/cases/function_trailing_comma.py | 130 ++++++++ ...ew_pep646_typed_star_arg_type_var_tuple.py | 8 + tests/test_black.py | 67 +++- tests/test_format.py | 4 +- 30 files changed, 713 insertions(+), 519 deletions(-) create mode 100644 tests/data/cases/preview_pep646_typed_star_arg_type_var_tuple.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 91f1682394f..5c9047649e7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,7 +19,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] os: [ubuntu-latest, macOS-latest, windows-latest] steps: diff --git a/patches/pyink.patch b/patches/pyink.patch index 2445a78ef06..f3fd4ab923c 100644 --- a/patches/pyink.patch +++ b/patches/pyink.patch @@ -25,7 +25,7 @@ from typing import ( Any, Collection, -@@ -28,12 +28,13 @@ from typing import ( +@@ -24,12 +24,13 @@ from typing import ( Union, ) @@ -40,7 +40,7 @@ from pyink._pyink_version import version as __version__ from pyink.cache import Cache from pyink.comments import normalize_fmt_off -@@ -66,9 +67,15 @@ from pyink.handle_ipynb_magics import ( +@@ -62,9 +63,15 @@ from pyink.handle_ipynb_magics import ( ) from pyink.linegen import LN, LineGenerator, transform_line from pyink.lines import EmptyLineTracker, LinesBlock @@ -58,7 +58,7 @@ from pyink.nodes import STARS, is_number_token, is_simple_decorator_expression, syms from pyink.output import color_diff, diff, dump_to_file, err, ipynb_diff, out from pyink.parsing import ( # noqa F401 -@@ -84,9 +91,8 @@ from pyink.ranges import ( +@@ -80,9 +87,8 @@ from pyink.ranges import ( parse_line_ranges, sanitized_lines, ) @@ -69,7 +69,7 @@ COMPILED = Path(__file__).suffix in (".pyd", ".so") -@@ -264,25 +270,26 @@ def validate_regex( +@@ -260,25 +266,26 @@ def validate_regex( multiple=True, help=( "Python versions that should be supported by Black's output. You should" @@ -103,7 +103,7 @@ ), ) @click.option( -@@ -317,17 +324,17 @@ def validate_regex( +@@ -313,17 +320,17 @@ def validate_regex( "--preview", is_flag=True, help=( @@ -126,7 +126,7 @@ ), ) @click.option( -@@ -342,20 +349,68 @@ def validate_regex( +@@ -338,20 +345,68 @@ def validate_regex( ), ) @click.option( @@ -200,7 +200,7 @@ ), ) @click.option( -@@ -368,11 +423,11 @@ def validate_regex( +@@ -364,11 +419,11 @@ def validate_regex( multiple=True, metavar="START-END", help=( @@ -217,7 +217,7 @@ ), default=(), ) -@@ -380,9 +435,9 @@ def validate_regex( +@@ -376,9 +431,9 @@ def validate_regex( "--fast/--safe", is_flag=True, help=( @@ -230,7 +230,7 @@ ), ) @click.option( -@@ -392,8 +447,8 @@ def validate_regex( +@@ -388,8 +443,8 @@ def validate_regex( "Require a specific version of Black to be running. This is useful for" " ensuring that all contributors to your project are using the same" " version, because different versions of Black may format code a little" @@ -241,7 +241,7 @@ ), ) @click.option( -@@ -401,11 +456,12 @@ def validate_regex( +@@ -397,11 +452,12 @@ def validate_regex( type=str, callback=validate_regex, help=( @@ -259,7 +259,7 @@ ), show_default=False, ) -@@ -414,8 +470,8 @@ def validate_regex( +@@ -410,8 +466,8 @@ def validate_regex( type=str, callback=validate_regex, help=( @@ -270,7 +270,7 @@ ), ) @click.option( -@@ -423,10 +479,10 @@ def validate_regex( +@@ -419,10 +475,10 @@ def validate_regex( type=str, callback=validate_regex, help=( @@ -285,7 +285,7 @@ ), ) @click.option( -@@ -434,9 +490,9 @@ def validate_regex( +@@ -430,9 +486,9 @@ def validate_regex( type=str, is_eager=True, help=( @@ -298,7 +298,7 @@ ), ) @click.option( -@@ -446,10 +502,10 @@ def validate_regex( +@@ -442,10 +498,10 @@ def validate_regex( callback=validate_regex, help=( "A regular expression that matches files and directories that should be" @@ -313,7 +313,7 @@ ), show_default=True, ) -@@ -459,10 +515,10 @@ def validate_regex( +@@ -455,10 +511,10 @@ def validate_regex( type=click.IntRange(min=1), default=None, help=( @@ -328,7 +328,7 @@ ), ) @click.option( -@@ -470,8 +526,8 @@ def validate_regex( +@@ -466,8 +522,8 @@ def validate_regex( "--quiet", is_flag=True, help=( @@ -339,7 +339,7 @@ ), ) @click.option( -@@ -487,15 +543,20 @@ def validate_regex( +@@ -483,15 +539,20 @@ def validate_regex( @click.version_option( version=__version__, message=( @@ -363,19 +363,19 @@ ), is_eager=True, metavar="SRC ...", -@@ -534,6 +595,11 @@ def main( # noqa: C901 +@@ -530,6 +591,11 @@ def main( # noqa: C901 preview: bool, unstable: bool, - enable_unstable_feature: List[Preview], + enable_unstable_feature: list[Preview], + pyink: bool, + pyink_indentation: str, + pyink_ipynb_indentation: str, -+ pyink_annotation_pragmas: List[str], ++ pyink_annotation_pragmas: list[str], + pyink_use_majority_quotes: bool, quiet: bool, verbose: bool, required_version: Optional[str], -@@ -631,7 +697,15 @@ def main( # noqa: C901 +@@ -636,7 +702,15 @@ def main( # noqa: C901 preview=preview, unstable=unstable, python_cell_magics=set(python_cell_magics), @@ -391,22 +391,8 @@ + ), ) - lines: List[Tuple[int, int]] = [] -@@ -1098,9 +1172,10 @@ def validate_cell(src: str, mode: Mode) - """ - if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS): - raise NothingChanged -- if ( -- src[:2] == "%%" -- and src.split()[0][2:] not in PYTHON_CELL_MAGICS | mode.python_cell_magics -+ line = ink.get_code_start(src) -+ if line.startswith("%%") and ( -+ line.split(maxsplit=1)[0][2:] -+ not in PYTHON_CELL_MAGICS | mode.python_cell_magics - ): - raise NothingChanged - -@@ -1153,6 +1228,17 @@ def validate_metadata(nb: MutableMapping + lines: list[tuple[int, int]] = [] +@@ -1132,6 +1206,17 @@ def validate_metadata(nb: MutableMapping if language is not None and language != "python": raise NothingChanged from None @@ -424,7 +410,7 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileContent: """Format Jupyter notebook. -@@ -1164,7 +1250,6 @@ def format_ipynb_string(src_contents: st +@@ -1143,7 +1228,6 @@ def format_ipynb_string(src_contents: st raise NothingChanged trailing_newline = src_contents[-1] == "\n" @@ -432,7 +418,7 @@ nb = json.loads(src_contents) validate_metadata(nb) for cell in nb["cells"]: -@@ -1176,14 +1261,15 @@ def format_ipynb_string(src_contents: st +@@ -1155,14 +1239,15 @@ def format_ipynb_string(src_contents: st pass else: cell["source"] = dst.splitlines(keepends=True) @@ -456,7 +442,7 @@ def format_str( -@@ -1244,6 +1330,8 @@ def _format_str_once( +@@ -1223,6 +1308,8 @@ def _format_str_once( future_imports = get_future_imports(src_node) versions = detect_target_versions(src_node, future_imports=future_imports) @@ -469,14 +455,14 @@ +++ b/_width_table.py @@ -3,7 +3,7 @@ # Unicode 15.0.0 - from typing import Final, List, Tuple + from typing import Final --WIDTH_TABLE: Final[List[Tuple[int, int, int]]] = [ -+WIDTH_TABLE: Final[Tuple[Tuple[int, int, int], ...]] = ( +-WIDTH_TABLE: Final[list[tuple[int, int, int]]] = [ ++WIDTH_TABLE: Final[tuple[tuple[int, int, int], ...]] = ( (0, 0, 0), (1, 31, -1), (127, 159, -1), -@@ -475,4 +475,4 @@ WIDTH_TABLE: Final[List[Tuple[int, int, +@@ -475,4 +475,4 @@ WIDTH_TABLE: Final[list[tuple[int, int, (131072, 196605, 2), (196608, 262141, 2), (917760, 917999, 0), @@ -486,7 +472,7 @@ +++ b/comments.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from functools import lru_cache - from typing import Collection, Final, Iterator, List, Optional, Tuple, Union + from typing import Collection, Final, Iterator, Optional, Union +from pyink import ink_comments from pyink.mode import Mode, Preview @@ -496,8 +482,8 @@ return False --def contains_pragma_comment(comment_list: List[Leaf]) -> bool: -+def contains_pragma_comment(comment_list: List[Leaf], mode: Mode) -> bool: +-def contains_pragma_comment(comment_list: list[Leaf]) -> bool: ++def contains_pragma_comment(comment_list: list[Leaf], mode: Mode) -> bool: """ Returns: True iff one of the comments in @comment_list is a pragma used by one @@ -512,7 +498,7 @@ return False --- a/files.py +++ b/files.py -@@ -231,7 +231,7 @@ def strip_specifier_set(specifier_set: S +@@ -228,7 +228,7 @@ def strip_specifier_set(specifier_set: S def find_user_pyproject_toml() -> Path: r"""Return the path to the top-level user configuration for pyink. @@ -523,7 +509,7 @@ May raise: --- a/handle_ipynb_magics.py +++ b/handle_ipynb_magics.py -@@ -148,12 +148,14 @@ def mask_cell(src: str) -> Tuple[str, Li +@@ -178,12 +178,14 @@ def mask_cell(src: str) -> tuple[str, li from IPython.core.inputtransformer2 import TransformerManager transformer_manager = TransformerManager() @@ -539,7 +525,7 @@ # Multi-line magic, not supported. raise NothingChanged replacements += magic_replacements -@@ -239,7 +241,7 @@ def replace_magics(src: str) -> Tuple[st +@@ -269,7 +271,7 @@ def replace_magics(src: str) -> tuple[st magic_finder = MagicFinder() magic_finder.visit(ast.parse(src)) new_srcs = [] @@ -548,7 +534,7 @@ if i in magic_finder.magics: offsets_and_magics = magic_finder.magics[i] if len(offsets_and_magics) != 1: # pragma: nocover -@@ -273,6 +275,8 @@ def unmask_cell(src: str, replacements: +@@ -303,6 +305,8 @@ def unmask_cell(src: str, replacements: """ for replacement in replacements: src = src.replace(replacement.mask, replacement.src) @@ -561,7 +547,7 @@ +++ b/linegen.py @@ -9,6 +9,12 @@ from enum import Enum, auto from functools import partial, wraps - from typing import Collection, Iterator, List, Optional, Set, Union, cast + from typing import Collection, Iterator, Optional, Union, cast +if sys.version_info < (3, 8): + from typing_extensions import Final, Literal @@ -691,7 +677,7 @@ + yield from self.line(_DEDENT) def visit_stmt( - self, node: Node, keywords: Set[str], parens: Set[str] + self, node: Node, keywords: set[str], parens: set[str] @@ -245,6 +272,7 @@ class LineGenerator(Visitor[Line]): maybe_make_parens_invisible_in_atom( child, @@ -855,7 +841,7 @@ + ll, sn, preferred_quote=preferred_quote, line_str=line_str + ) - transformers: List[Transformer] + transformers: list[Transformer] if ( - not line.contains_uncollapsable_type_comments() + not line.contains_uncollapsable_pragma_comments() @@ -870,16 +856,16 @@ Note: this function should not have side effects. It's relied upon by _maybe_split_omitting_optional_parens to get an opinion whether to prefer splitting on the right side of an assignment statement. -@@ -1098,7 +1161,7 @@ def bracket_split_build_line( +@@ -1139,7 +1202,7 @@ def bracket_split_build_line( result = Line(mode=original.mode, depth=original.depth) if component is _BracketSplitComponent.body: result.inside_brackets = True - result.depth += 1 + result.depth = result.depth + (Indentation.CONTINUATION,) - if leaves: - no_commas = ( - # Ensure a trailing comma for imports and standalone function arguments -@@ -1392,15 +1455,17 @@ def normalize_invisible_parens( # noqa: + if _ensure_trailing_comma(leaves, original, opening_bracket): + for i in range(len(leaves) - 1, -1, -1): + if leaves[i].type == STANDALONE_COMMENT: +@@ -1408,15 +1471,17 @@ def normalize_invisible_parens( # noqa: if maybe_make_parens_invisible_in_atom( child, parent=node, @@ -898,7 +884,7 @@ ): wrap_in_parentheses(node, child, visible=False) elif is_one_tuple(child): -@@ -1452,7 +1517,7 @@ def _normalize_import_from(parent: Node, +@@ -1468,7 +1533,7 @@ def _normalize_import_from(parent: Node, parent.append_child(Leaf(token.RPAR, "")) @@ -907,7 +893,7 @@ if node.children[0].type == token.AWAIT and len(node.children) > 1: if ( node.children[1].type == syms.atom -@@ -1461,6 +1526,7 @@ def remove_await_parens(node: Node) -> N +@@ -1477,6 +1542,7 @@ def remove_await_parens(node: Node) -> N if maybe_make_parens_invisible_in_atom( node.children[1], parent=node, @@ -915,7 +901,7 @@ remove_brackets_around_comma=True, ): wrap_in_parentheses(node, node.children[1], visible=False) -@@ -1529,7 +1595,7 @@ def _maybe_wrap_cms_in_parens( +@@ -1545,7 +1611,7 @@ def _maybe_wrap_cms_in_parens( node.insert_child(1, new_child) @@ -924,7 +910,7 @@ """Recursively hide optional parens in `with` statements.""" # Removing all unnecessary parentheses in with statements in one pass is a tad # complex as different variations of bracketed statements result in pretty -@@ -1551,21 +1617,23 @@ def remove_with_parens(node: Node, paren +@@ -1567,21 +1633,23 @@ def remove_with_parens(node: Node, paren if maybe_make_parens_invisible_in_atom( node, parent=parent, @@ -950,7 +936,7 @@ remove_brackets_around_comma=True, ): wrap_in_parentheses(node, node.children[0], visible=False) -@@ -1574,6 +1642,7 @@ def remove_with_parens(node: Node, paren +@@ -1590,6 +1658,7 @@ def remove_with_parens(node: Node, paren def maybe_make_parens_invisible_in_atom( node: LN, parent: LN, @@ -958,7 +944,7 @@ remove_brackets_around_comma: bool = False, ) -> bool: """If it's safe, make the parens in the atom `node` invisible, recursively. -@@ -1623,7 +1692,7 @@ def maybe_make_parens_invisible_in_atom( +@@ -1639,7 +1708,7 @@ def maybe_make_parens_invisible_in_atom( if ( # If the prefix of `middle` includes a type comment with # ignore annotation, then we do not remove the parentheses @@ -967,7 +953,7 @@ ): first.value = "" if first.prefix.strip(): -@@ -1633,6 +1702,7 @@ def maybe_make_parens_invisible_in_atom( +@@ -1649,6 +1718,7 @@ def maybe_make_parens_invisible_in_atom( maybe_make_parens_invisible_in_atom( middle, parent=parent, @@ -975,7 +961,7 @@ remove_brackets_around_comma=remove_brackets_around_comma, ) -@@ -1691,7 +1761,7 @@ def generate_trailers_to_omit(line: Line +@@ -1707,7 +1777,7 @@ def generate_trailers_to_omit(line: Line if not line.magic_trailing_comma: yield omit @@ -983,8 +969,8 @@ + length = line.indentation_spaces() opening_bracket: Optional[Leaf] = None closing_bracket: Optional[Leaf] = None - inner_brackets: Set[LeafID] = set() -@@ -1776,7 +1846,7 @@ def run_transformer( + inner_brackets: set[LeafID] = set() +@@ -1792,7 +1862,7 @@ def run_transformer( or not line.bracket_tracker.invisible or any(bracket.value for bracket in line.bracket_tracker.invisible) or line.contains_multiline_strings() @@ -1000,7 +986,7 @@ import itertools import math from dataclasses import dataclass, field -@@ -28,7 +29,7 @@ from pyink.nodes import ( +@@ -17,7 +18,7 @@ from pyink.nodes import ( is_multiline_string, is_one_sequence_between, is_type_comment, @@ -1009,7 +995,7 @@ is_with_or_async_with_stmt, make_simple_prefix, replace_child, -@@ -46,12 +47,24 @@ LeafID = int +@@ -35,12 +36,24 @@ LeafID = int LN = Union[Leaf, Node] @@ -1031,11 +1017,11 @@ mode: Mode = field(repr=False) - depth: int = 0 -+ depth: Tuple[Indentation, ...] = field(default_factory=tuple) - leaves: List[Leaf] = field(default_factory=list) ++ depth: tuple[Indentation, ...] = field(default_factory=tuple) + leaves: list[Leaf] = field(default_factory=list) # keys ordered like `leaves` - comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict) -@@ -60,6 +73,9 @@ class Line: + comments: dict[LeafID, list[Leaf]] = field(default_factory=dict) +@@ -49,6 +62,9 @@ class Line: should_split_rhs: bool = False magic_trailing_comma: Optional[Leaf] = None @@ -1045,7 +1031,7 @@ def append( self, leaf: Leaf, preformatted: bool = False, track_bracket: bool = False ) -> None: -@@ -108,7 +124,7 @@ class Line: +@@ -97,7 +113,7 @@ class Line: or when a standalone comment is not the first leaf on the line. """ if ( @@ -1054,7 +1040,7 @@ or self.bracket_tracker.any_open_for_or_lambda() ): if self.is_comment: -@@ -273,7 +289,7 @@ class Line: +@@ -262,7 +278,7 @@ class Line: return True return False @@ -1063,7 +1049,7 @@ ignored_ids = set() try: last_leaf = self.leaves[-1] -@@ -298,11 +314,9 @@ class Line: +@@ -287,11 +303,9 @@ class Line: comment_seen = False for leaf_id, comments in self.comments.items(): for comment in comments: @@ -1078,7 +1064,7 @@ return True comment_seen = True -@@ -337,7 +351,7 @@ class Line: +@@ -326,7 +340,7 @@ class Line: # line. for node in self.leaves[-2:]: for comment in self.comments.get(id(node), []): @@ -1087,7 +1073,7 @@ return True return False -@@ -492,7 +506,7 @@ class Line: +@@ -481,7 +495,7 @@ class Line: if not self: return "\n" @@ -1096,7 +1082,7 @@ leaves = iter(self.leaves) first = next(leaves) res = f"{first.prefix}{indent}{first.value}" -@@ -564,7 +578,7 @@ class EmptyLineTracker: +@@ -553,7 +567,7 @@ class EmptyLineTracker: lines (two on module-level). """ form_feed = ( @@ -1105,16 +1091,16 @@ and bool(current_line.leaves) and "\f\n" in current_line.leaves[0].prefix ) -@@ -609,7 +623,7 @@ class EmptyLineTracker: +@@ -598,7 +612,7 @@ class EmptyLineTracker: - def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]: # noqa: C901 + def _maybe_empty_lines(self, current_line: Line) -> tuple[int, int]: # noqa: C901 max_allowed = 1 - if current_line.depth == 0: + if not current_line.depth: max_allowed = 1 if self.mode.is_pyi else 2 if current_line.leaves: -@@ -626,7 +640,7 @@ class EmptyLineTracker: +@@ -615,7 +629,7 @@ class EmptyLineTracker: # Mutate self.previous_defs, remainder of this function should be pure previous_def = None @@ -1123,7 +1109,7 @@ previous_def = self.previous_defs.pop() if current_line.is_def or current_line.is_class: self.previous_defs.append(current_line) -@@ -682,10 +696,25 @@ class EmptyLineTracker: +@@ -671,10 +685,25 @@ class EmptyLineTracker: ) if ( @@ -1151,7 +1137,7 @@ ): return (before or 1), 0 -@@ -702,8 +731,9 @@ class EmptyLineTracker: +@@ -691,8 +720,9 @@ class EmptyLineTracker: return 0, 1 return 0, 0 @@ -1163,7 +1149,7 @@ ): if self.mode.is_pyi: return 0, 0 -@@ -712,7 +742,7 @@ class EmptyLineTracker: +@@ -701,7 +731,7 @@ class EmptyLineTracker: comment_to_add_newlines: Optional[LinesBlock] = None if ( self.previous_line.is_comment @@ -1172,7 +1158,7 @@ and before == 0 ): slc = self.semantic_leading_comment -@@ -729,9 +759,9 @@ class EmptyLineTracker: +@@ -718,9 +748,9 @@ class EmptyLineTracker: if self.mode.is_pyi: if current_line.is_class or self.previous_line.is_class: @@ -1184,7 +1170,7 @@ newlines = 1 elif current_line.is_stub_class and self.previous_line.is_stub_class: # No blank line between classes with an empty body -@@ -760,7 +790,11 @@ class EmptyLineTracker: +@@ -749,7 +779,11 @@ class EmptyLineTracker: newlines = 1 if current_line.depth else 2 # If a user has left no space after a dummy implementation, don't insert # new lines. This is useful for instance for @overload or Protocols. @@ -1197,7 +1183,7 @@ newlines = 0 if comment_to_add_newlines is not None: previous_block = comment_to_add_newlines.previous_block -@@ -1031,7 +1065,7 @@ def can_omit_invisible_parens( +@@ -1020,7 +1054,7 @@ def can_omit_invisible_parens( def _can_omit_opening_paren(line: Line, *, first: Leaf, line_length: int) -> bool: """See `can_omit_invisible_parens`.""" remainder = False @@ -1206,7 +1192,7 @@ _index = -1 for _index, leaf, leaf_length in line.enumerate_with_length(): if leaf.type in CLOSING_BRACKETS and leaf.opening_bracket is first: -@@ -1055,7 +1089,7 @@ def _can_omit_opening_paren(line: Line, +@@ -1044,7 +1078,7 @@ def _can_omit_opening_paren(line: Line, def _can_omit_closing_paren(line: Line, *, last: Leaf, line_length: int) -> bool: """See `can_omit_invisible_parens`.""" @@ -1221,12 +1207,12 @@ from enum import Enum, auto from hashlib import sha256 from operator import attrgetter --from typing import Dict, Final, Set -+from typing import Dict, Final, Literal, Set, Tuple +-from typing import Final ++from typing import Final, Literal from pyink.const import DEFAULT_LINE_LENGTH -@@ -224,7 +224,31 @@ class Deprecated(UserWarning): +@@ -229,7 +229,31 @@ class Deprecated(UserWarning): """Visible deprecation warning.""" @@ -1259,8 +1245,8 @@ @dataclass -@@ -232,12 +256,20 @@ class Mode: - target_versions: Set[TargetVersion] = field(default_factory=set) +@@ -237,12 +261,20 @@ class Mode: + target_versions: set[TargetVersion] = field(default_factory=set) line_length: int = DEFAULT_LINE_LENGTH string_normalization: bool = True + # No effect if string_normalization is False @@ -1271,16 +1257,16 @@ is_ipynb: bool = False skip_source_first_line: bool = False magic_trailing_comma: bool = True - python_cell_magics: Set[str] = field(default_factory=set) + python_cell_magics: set[str] = field(default_factory=set) preview: bool = False + is_pyink: bool = False + pyink_indentation: Literal[2, 4] = 4 + pyink_ipynb_indentation: Literal[1, 2] = 1 -+ pyink_annotation_pragmas: Tuple[str, ...] = DEFAULT_ANNOTATION_PRAGMAS ++ pyink_annotation_pragmas: tuple[str, ...] = DEFAULT_ANNOTATION_PRAGMAS unstable: bool = False - enabled_features: Set[Preview] = field(default_factory=set) + enabled_features: set[Preview] = field(default_factory=set) -@@ -249,6 +281,9 @@ class Mode: +@@ -254,6 +286,9 @@ class Mode: except those in UNSTABLE_FEATURES are enabled. Any features in `self.enabled_features` are also enabled. """ @@ -1290,7 +1276,7 @@ if self.unstable: return True if feature in self.enabled_features: -@@ -280,11 +315,26 @@ class Mode: +@@ -285,12 +320,27 @@ class Mode: version_str, str(self.line_length), str(int(self.string_normalization)), @@ -1301,6 +1287,7 @@ str(int(self.skip_source_first_line)), str(int(self.magic_trailing_comma)), str(int(self.preview)), + str(int(self.unstable)), + str(int(self.is_pyink)), + str(self.pyink_indentation), + str(self.pyink_ipynb_indentation), @@ -1319,7 +1306,7 @@ + return Quote.DOUBLE --- a/nodes.py +++ b/nodes.py -@@ -23,6 +23,7 @@ else: +@@ -12,6 +12,7 @@ else: from mypy_extensions import mypyc_attr @@ -1327,7 +1314,7 @@ from pyink.cache import CACHE_DIR from pyink.mode import Mode, Preview from pyink.strings import get_string_prefix, has_triple_quotes -@@ -798,9 +799,13 @@ def is_function_or_class(node: Node) -> +@@ -793,9 +794,13 @@ def is_function_or_class(node: Node) -> return node.type in {syms.funcdef, syms.classdef, syms.async_funcdef} @@ -1343,7 +1330,7 @@ return False # If there is a comment, we want to keep it. -@@ -919,11 +924,13 @@ def is_type_comment(leaf: Leaf) -> bool: +@@ -914,11 +919,13 @@ def is_type_comment(leaf: Leaf) -> bool: return t in {token.COMMENT, STANDALONE_COMMENT} and v.startswith("# type:") @@ -1374,7 +1361,7 @@ +# Yes, we use the _Black_ style to format _Pyink_ code. +pyink = false line-length = 88 - target-version = ['py38'] + target-version = ['py39'] include = '\.pyi?$' -extend-exclude = ''' -/( @@ -1405,7 +1392,7 @@ +name = "pyink" +description = "Pyink is a python formatter, forked from Black with slightly different behavior." license = { text = "MIT" } - requires-python = ">=3.8" + requires-python = ">=3.9" -authors = [ - { name = "Łukasz Langa", email = "lukasz@langa.pl" }, -] @@ -1423,7 +1410,7 @@ classifiers = [ "Development Status :: 5 - Production/Stable", "Environment :: Console", -@@ -71,53 +42,38 @@ dependencies = [ +@@ -71,50 +42,38 @@ dependencies = [ "platformdirs>=2", "tomli>=1.1.0; python_version < '3.11'", "typing_extensions>=4.0.1; python_version < '3.11'", @@ -1435,10 +1422,7 @@ [project.optional-dependencies] colorama = ["colorama>=0.4.3"] uvloop = ["uvloop>=0.15.2"] --d = [ -- "aiohttp>=3.7.4; sys_platform != 'win32' or implementation_name != 'pypy'", -- "aiohttp>=3.7.4, !=3.9.0; sys_platform == 'win32' and implementation_name == 'pypy'", --] +-d = ["aiohttp>=3.10"] jupyter = [ "ipython>=7.8.0", "tokenize-rt>=3.2.0", @@ -1485,7 +1469,7 @@ [tool.hatch.build.targets.wheel] only-include = ["src"] sources = ["src"] -@@ -128,7 +84,6 @@ macos-max-compat = true +@@ -125,7 +84,6 @@ macos-max-compat = true # Option below requires `tests/optional.py` addopts = "--strict-config --strict-markers" optional-tests = [ @@ -1493,10 +1477,10 @@ "no_jupyter: run when `jupyter` extra NOT installed", ] markers = [ -@@ -152,36 +107,3 @@ filterwarnings = [ - # https://github.com/aio-libs/aiohttp/pull/7302 - "ignore:datetime.*utcfromtimestamp\\(\\) is deprecated and scheduled for removal:DeprecationWarning", +@@ -133,36 +91,3 @@ markers = [ ] + xfail_strict = true + filterwarnings = ["error"] -[tool.coverage.report] -omit = [ - "src/blib2to3/*", @@ -1512,7 +1496,7 @@ -# Specify the target platform details in config, so your developers are -# free to run mypy on Windows, Linux, or macOS and get consistent -# results. --python_version = "3.8" +-python_version = "3.9" -mypy_path = "src" -strict = true -# Unreachable blocks have been an issue when compiling mypyc, let's try to avoid 'em in the first place. @@ -1535,7 +1519,7 @@ @@ -1,7 +1,7 @@ { "$schema": "http://json-schema.org/draft-07/schema#", -- "$id": "https://github.com/psf/black/blob/main/black/resources/black.schema.json", +- "$id": "https://github.com/psf/black/blob/main/src/black/resources/black.schema.json", - "$comment": "tool.black table in pyproject.toml", + "$id": "https://github.com/google/pyink/blob/pyink/src/pyink/resources/pyink.schema.json", + "$comment": "tool.pyink table in pyproject.toml", @@ -1554,7 +1538,7 @@ --- a/strings.py +++ b/strings.py @@ -8,6 +8,7 @@ from functools import lru_cache - from typing import Final, List, Match, Pattern, Tuple + from typing import Final, Match, Pattern from pyink._width_table import WIDTH_TABLE +from pyink.mode import Quote @@ -1605,7 +1589,36 @@ +pyink = false --- a/tests/test_black.py +++ b/tests/test_black.py -@@ -2809,6 +2809,82 @@ class TestFileCollection: +@@ -44,7 +44,7 @@ from pyink import Feature, TargetVersion + from pyink import re_compile_maybe_verbose as compile_pattern + from pyink.cache import FileData, get_cache_dir, get_cache_file + from pyink.debug import DebugVisitor +-from pyink.mode import Mode, Preview ++from pyink.mode import Mode, Preview, Quote, QuoteStyle + from pyink.output import color_diff, diff + from pyink.parsing import ASTSafetyError + from pyink.report import Report +@@ -2365,6 +2365,19 @@ class TestCaching: + {Preview.docstring_check_for_newline}, + {Preview.hex_codes_in_unicode_sequences}, + ] ++ elif field.type is Quote: ++ values = list(Quote) ++ elif field.type is QuoteStyle: ++ values = list(QuoteStyle) ++ elif field.name == "pyink_indentation": ++ values = [2, 4] ++ elif field.name == "pyink_ipynb_indentation": ++ values = [1, 2] ++ elif field.name == "pyink_annotation_pragmas": ++ values = [ ++ ("type: ignore",), ++ ("noqa", "pylint:", "pytype: ignore", "@param"), ++ ] + elif field.type is bool: + values = [True, False] + elif field.type is int: +@@ -2845,6 +2858,82 @@ class TestFileCollection: stdin_filename=stdin_filename, ) @@ -1732,26 +1745,10 @@ "mode, expected_output, expectation", [ pytest.param( -@@ -208,6 +227,29 @@ def test_cell_magic_with_custom_python_m - assert result == expected_output +@@ -224,6 +243,13 @@ def test_cell_magic_with_custom_python_m + format_cell(src, fast=True, mode=JUPYTER_MODE) -+@pytest.mark.parametrize( -+ "src", -+ ( -+ " %%custom_magic \nx=2", -+ "\n\n%%custom_magic\nx=2", -+ "# comment\n%%custom_magic\nx=2", -+ "\n \n # comment with %%time\n\t\n %%custom_magic # comment \nx=2", -+ ), -+) -+def test_cell_magic_with_custom_python_magic_after_spaces_and_comments_noop( -+ src: str, -+) -> None: -+ with pytest.raises(NothingChanged): -+ format_cell(src, fast=True, mode=JUPYTER_MODE) -+ -+ +def test_cell_magic_with_forced_single_quoted_strings() -> None: + src = "%time" + mode = replace(JUPYTER_MODE, quote_style=QuoteStyle.SINGLE) @@ -1762,7 +1759,7 @@ def test_cell_magic_nested() -> None: src = "%%time\n%%time\n2+2" result = format_cell(src, fast=True, mode=JUPYTER_MODE) -@@ -381,6 +423,45 @@ def test_entire_notebook_no_trailing_new +@@ -397,6 +423,45 @@ def test_entire_notebook_no_trailing_new assert result == expected @@ -1808,7 +1805,7 @@ def test_entire_notebook_without_changes() -> None: content = read_jupyter_notebook("jupyter", "notebook_without_changes") with pytest.raises(NothingChanged): -@@ -432,6 +513,30 @@ def test_ipynb_diff_with_no_change() -> +@@ -448,6 +513,30 @@ def test_ipynb_diff_with_no_change() -> assert expected in result.output @@ -1891,7 +1888,7 @@ + pyink --check {toxinidir}/src {toxinidir}/tests {toxinidir}/docs {toxinidir}/scripts --- a/trans.py +++ b/trans.py -@@ -28,8 +28,8 @@ from typing import ( +@@ -24,8 +24,8 @@ from typing import ( from mypy_extensions import trait from pyink.comments import contains_pragma_comment @@ -1902,7 +1899,7 @@ from pyink.nodes import ( CLOSING_BRACKETS, OPENING_BRACKETS, -@@ -279,9 +279,18 @@ class StringTransformer(ABC): +@@ -275,9 +275,18 @@ class StringTransformer(ABC): # Ideally this would be a dataclass, but unfortunately mypyc breaks when used with # `abc.ABC`. @@ -1922,7 +1919,7 @@ @abstractmethod def do_match(self, line: Line) -> TMatchResult: -@@ -759,7 +768,9 @@ class StringMerger(StringTransformer, Cu +@@ -755,7 +764,9 @@ class StringMerger(StringTransformer, Cu S_leaf = Leaf(token.STRING, S) if self.normalize_strings: @@ -1933,7 +1930,7 @@ # Fill the 'custom_splits' list with the appropriate CustomSplit objects. temp_string = S_leaf.value[len(prefix) + 1 : -1] -@@ -860,7 +871,7 @@ class StringMerger(StringTransformer, Cu +@@ -856,7 +867,7 @@ class StringMerger(StringTransformer, Cu if id(leaf) in line.comments: num_of_inline_string_comments += 1 @@ -1942,7 +1939,7 @@ return TErr("Cannot merge strings which have pragma comments.") if num_of_strings < 2: -@@ -1000,7 +1011,13 @@ class StringParenStripper(StringTransfor +@@ -996,7 +1007,13 @@ class StringParenStripper(StringTransfor idx += 1 if string_indices: @@ -1957,7 +1954,7 @@ return TErr("This line has no strings wrapped in parens.") def do_transform( -@@ -1162,7 +1179,7 @@ class BaseStringSplitter(StringTransform +@@ -1158,7 +1175,7 @@ class BaseStringSplitter(StringTransform ) if id(line.leaves[string_idx]) in line.comments and contains_pragma_comment( @@ -1966,7 +1963,7 @@ ): return TErr( "Line appears to end with an inline pragma comment. Splitting the line" -@@ -1204,7 +1221,7 @@ class BaseStringSplitter(StringTransform +@@ -1200,7 +1217,7 @@ class BaseStringSplitter(StringTransform # NN: The leaf that is after N. # WMA4 the whitespace at the beginning of the line. @@ -1975,7 +1972,7 @@ if is_valid_index(string_idx - 1): p_idx = string_idx - 1 -@@ -1558,7 +1575,7 @@ class StringSplitter(BaseStringSplitter, +@@ -1554,7 +1571,7 @@ class StringSplitter(BaseStringSplitter, characters expand to two columns). """ result = self.line_length @@ -1984,7 +1981,7 @@ result -= 1 if ends_with_comma else 0 result -= string_op_leaves_length return result -@@ -1569,11 +1586,11 @@ class StringSplitter(BaseStringSplitter, +@@ -1565,11 +1582,11 @@ class StringSplitter(BaseStringSplitter, # The last index of a string of length N is N-1. max_break_width -= 1 # Leading whitespace is not present in the string value (e.g. Leaf.value). @@ -1998,7 +1995,7 @@ ) return -@@ -1870,7 +1887,9 @@ class StringSplitter(BaseStringSplitter, +@@ -1866,7 +1883,9 @@ class StringSplitter(BaseStringSplitter, def _maybe_normalize_string_quotes(self, leaf: Leaf) -> None: if self.normalize_strings: @@ -2009,7 +2006,7 @@ def _normalize_f_string(self, string: str, prefix: str) -> str: """ -@@ -1993,7 +2012,8 @@ class StringParenWrapper(BaseStringSplit +@@ -1989,7 +2008,8 @@ class StringParenWrapper(BaseStringSplit char == " " or char in SPLIT_SAFE_CHARS for char in string_value ): # And will still violate the line length limit when split... @@ -2019,7 +2016,7 @@ if str_width(string_value) > max_string_width: # And has no associated custom splits... if not self.has_custom_splits(string_value): -@@ -2239,7 +2259,7 @@ class StringParenWrapper(BaseStringSplit +@@ -2235,7 +2255,7 @@ class StringParenWrapper(BaseStringSplit string_value = LL[string_idx].value string_line = Line( mode=line.mode, diff --git a/pyproject.toml b/pyproject.toml index 5dcad2052b5..1f37e49776b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ # Yes, we use the _Black_ style to format _Pyink_ code. pyink = false line-length = 88 -target-version = ['py38'] +target-version = ['py39'] include = '\.pyi?$' extend-exclude = 'tests/data' unstable = true @@ -15,7 +15,7 @@ build-backend = "hatchling.build" name = "pyink" description = "Pyink is a python formatter, forked from Black with slightly different behavior." license = { text = "MIT" } -requires-python = ">=3.8" +requires-python = ">=3.9" readme = "README.md" authors = [{name = "The Pyink Maintainers", email = "pyink-maintainers@google.com"}] classifiers = [ @@ -26,11 +26,11 @@ classifiers = [ "Operating System :: OS Independent", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Software Development :: Quality Assurance", ] @@ -90,20 +90,4 @@ markers = [ "incompatible_with_mypyc: run when testing mypyc compiled black" ] xfail_strict = true -filterwarnings = [ - "error", - # this is mitigated by a try/catch in https://github.com/psf/black/pull/2974/ - # this ignore can be removed when support for aiohttp 3.7 is dropped. - '''ignore:Decorator `@unittest_run_loop` is no longer needed in aiohttp 3\.8\+:DeprecationWarning''', - # this is mitigated by a try/catch in https://github.com/psf/black/pull/3198/ - # this ignore can be removed when support for aiohttp 3.x is dropped. - '''ignore:Middleware decorator is deprecated since 4\.0 and its behaviour is default, you can simply remove this decorator:DeprecationWarning''', - # aiohttp is using deprecated cgi modules - Safe to remove when fixed: - # https://github.com/aio-libs/aiohttp/issues/6905 - '''ignore:'cgi' is deprecated and slated for removal in Python 3.13:DeprecationWarning''', - # Work around https://github.com/pytest-dev/pytest/issues/10977 for Python 3.12 - '''ignore:(Attribute s|Attribute n|ast.Str|ast.Bytes|ast.NameConstant|ast.Num) is deprecated and will be removed in Python 3.14:DeprecationWarning''', - # Will be fixed with aiohttp 3.9.0 - # https://github.com/aio-libs/aiohttp/pull/7302 - "ignore:datetime.*utcfromtimestamp\\(\\) is deprecated and scheduled for removal:DeprecationWarning", -] +filterwarnings = ["error"] diff --git a/src/pyink/__init__.py b/src/pyink/__init__.py index 3fa6c6338c8..97433748faa 100644 --- a/src/pyink/__init__.py +++ b/src/pyink/__init__.py @@ -14,17 +14,13 @@ from typing import ( Any, Collection, - Dict, Generator, Iterator, - List, MutableMapping, Optional, Pattern, Sequence, - Set, Sized, - Tuple, Union, ) @@ -58,12 +54,12 @@ ) from pyink.handle_ipynb_magics import ( PYTHON_CELL_MAGICS, - TRANSFORMED_MAGICS, jupyter_dependencies_are_installed, mask_cell, put_trailing_semicolon_back, remove_trailing_semicolon, unmask_cell, + validate_cell, ) from pyink.linegen import LN, LineGenerator, transform_line from pyink.lines import EmptyLineTracker, LinesBlock @@ -182,7 +178,7 @@ def read_pyproject_toml( "line-ranges", "Cannot use line-ranges in the pyproject.toml file." ) - default_map: Dict[str, Any] = {} + default_map: dict[str, Any] = {} if ctx.default_map: default_map.update(ctx.default_map) default_map.update(config) @@ -192,9 +188,9 @@ def read_pyproject_toml( def spellcheck_pyproject_toml_keys( - ctx: click.Context, config_keys: List[str], config_file_path: str + ctx: click.Context, config_keys: list[str], config_file_path: str ) -> None: - invalid_keys: List[str] = [] + invalid_keys: list[str] = [] available_config_options = {param.name for param in ctx.command.params} for key in config_keys: if key not in available_config_options: @@ -208,8 +204,8 @@ def spellcheck_pyproject_toml_keys( def target_version_option_callback( - c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...] -) -> List[TargetVersion]: + c: click.Context, p: Union[click.Option, click.Parameter], v: tuple[str, ...] +) -> list[TargetVersion]: """Compute the target versions from a --target-version flag. This is its own function because mypy couldn't infer the type correctly @@ -219,8 +215,8 @@ def target_version_option_callback( def enable_unstable_feature_callback( - c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...] -) -> List[Preview]: + c: click.Context, p: Union[click.Option, click.Parameter], v: tuple[str, ...] +) -> list[Preview]: """Compute the features from an --enable-unstable-feature flag.""" return [Preview[val] for val in v] @@ -580,7 +576,7 @@ def main( # noqa: C901 ctx: click.Context, code: Optional[str], line_length: int, - target_version: List[TargetVersion], + target_version: list[TargetVersion], check: bool, diff: bool, line_ranges: Sequence[str], @@ -594,11 +590,11 @@ def main( # noqa: C901 skip_magic_trailing_comma: bool, preview: bool, unstable: bool, - enable_unstable_feature: List[Preview], + enable_unstable_feature: list[Preview], pyink: bool, pyink_indentation: str, pyink_ipynb_indentation: str, - pyink_annotation_pragmas: List[str], + pyink_annotation_pragmas: list[str], pyink_use_majority_quotes: bool, quiet: bool, verbose: bool, @@ -609,12 +605,21 @@ def main( # noqa: C901 force_exclude: Optional[Pattern[str]], stdin_filename: Optional[str], workers: Optional[int], - src: Tuple[str, ...], + src: tuple[str, ...], config: Optional[str], ) -> None: """The uncompromising code formatter.""" ctx.ensure_object(dict) + assert sys.version_info >= (3, 9), "Black requires Python 3.9+" + if sys.version_info[:3] == (3, 12, 5): + out( + "Python 3.12.5 has a memory safety issue that can cause Black's " + "AST safety checks to fail. " + "Please upgrade to Python 3.12.6 or downgrade to Python 3.12.4" + ) + ctx.exit(1) + if src and code is not None: out( main.get_usage(ctx) @@ -708,7 +713,7 @@ def main( # noqa: C901 ), ) - lines: List[Tuple[int, int]] = [] + lines: list[tuple[int, int]] = [] if line_ranges: if ipynb: err("Cannot use --line-ranges with ipynb files.") @@ -798,7 +803,7 @@ def main( # noqa: C901 def get_sources( *, root: Path, - src: Tuple[str, ...], + src: tuple[str, ...], quiet: bool, verbose: bool, include: Pattern[str], @@ -807,14 +812,14 @@ def get_sources( force_exclude: Optional[Pattern[str]], report: "Report", stdin_filename: Optional[str], -) -> Set[Path]: +) -> set[Path]: """Compute the set of files to be formatted.""" - sources: Set[Path] = set() + sources: set[Path] = set() assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}" using_default_exclude = exclude is None exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude - gitignore: Optional[Dict[Path, PathSpec]] = None + gitignore: Optional[dict[Path, PathSpec]] = None root_gitignore = get_gitignore(root) for s in src: @@ -906,7 +911,7 @@ def reformat_code( mode: Mode, report: Report, *, - lines: Collection[Tuple[int, int]] = (), + lines: Collection[tuple[int, int]] = (), ) -> None: """ Reformat and print out `content` without spawning child processes. @@ -939,7 +944,7 @@ def reformat_one( mode: Mode, report: "Report", *, - lines: Collection[Tuple[int, int]] = (), + lines: Collection[tuple[int, int]] = (), ) -> None: """Reformat a single file under `src` without spawning child processes. @@ -995,7 +1000,7 @@ def format_file_in_place( write_back: WriteBack = WriteBack.NO, lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy *, - lines: Collection[Tuple[int, int]] = (), + lines: Collection[tuple[int, int]] = (), ) -> bool: """Format file under `src` path. Return True if changed. @@ -1062,7 +1067,7 @@ def format_stdin_to_stdout( content: Optional[str] = None, write_back: WriteBack = WriteBack.NO, mode: Mode, - lines: Collection[Tuple[int, int]] = (), + lines: Collection[tuple[int, int]] = (), ) -> bool: """Format file on stdin. Return True if changed. @@ -1113,7 +1118,7 @@ def check_stability_and_equivalence( dst_contents: str, *, mode: Mode, - lines: Collection[Tuple[int, int]] = (), + lines: Collection[tuple[int, int]] = (), ) -> None: """Perform stability and equivalence checks. @@ -1130,7 +1135,7 @@ def format_file_contents( *, fast: bool, mode: Mode, - lines: Collection[Tuple[int, int]] = (), + lines: Collection[tuple[int, int]] = (), ) -> FileContent: """Reformat contents of a file and return new contents. @@ -1153,33 +1158,6 @@ def format_file_contents( return dst_contents -def validate_cell(src: str, mode: Mode) -> None: - """Check that cell does not already contain TransformerManager transformations, - or non-Python cell magics, which might cause tokenizer_rt to break because of - indentations. - - If a cell contains ``!ls``, then it'll be transformed to - ``get_ipython().system('ls')``. However, if the cell originally contained - ``get_ipython().system('ls')``, then it would get transformed in the same way: - - >>> TransformerManager().transform_cell("get_ipython().system('ls')") - "get_ipython().system('ls')\n" - >>> TransformerManager().transform_cell("!ls") - "get_ipython().system('ls')\n" - - Due to the impossibility of safely roundtripping in such situations, cells - containing transformed magics will be ignored. - """ - if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS): - raise NothingChanged - line = ink.get_code_start(src) - if line.startswith("%%") and ( - line.split(maxsplit=1)[0][2:] - not in PYTHON_CELL_MAGICS | mode.python_cell_magics - ): - raise NothingChanged - - def format_cell(src: str, *, fast: bool, mode: Mode) -> str: """Format code in given cell of Jupyter notebook. @@ -1273,7 +1251,7 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon def format_str( - src_contents: str, *, mode: Mode, lines: Collection[Tuple[int, int]] = () + src_contents: str, *, mode: Mode, lines: Collection[tuple[int, int]] = () ) -> str: """Reformat a string and return new contents. @@ -1320,10 +1298,10 @@ def f( def _format_str_once( - src_contents: str, *, mode: Mode, lines: Collection[Tuple[int, int]] = () + src_contents: str, *, mode: Mode, lines: Collection[tuple[int, int]] = () ) -> str: src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions) - dst_blocks: List[LinesBlock] = [] + dst_blocks: list[LinesBlock] = [] if mode.target_versions: versions = mode.target_versions else: @@ -1375,7 +1353,7 @@ def _format_str_once( return "".join(dst_contents) -def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]: +def decode_bytes(src: bytes) -> tuple[FileContent, Encoding, NewLine]: """Return a tuple of (decoded_contents, encoding, newline). `newline` is either CRLF or LF but `decoded_contents` is decoded with @@ -1393,8 +1371,8 @@ def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]: def get_features_used( # noqa: C901 - node: Node, *, future_imports: Optional[Set[str]] = None -) -> Set[Feature]: + node: Node, *, future_imports: Optional[set[str]] = None +) -> set[Feature]: """Return a set of (relatively) new Python features used in this file. Currently looking for: @@ -1412,7 +1390,7 @@ def get_features_used( # noqa: C901 - except* clause; - variadic generics; """ - features: Set[Feature] = set() + features: set[Feature] = set() if future_imports: features |= { FUTURE_FLAG_TO_FEATURE[future_import] @@ -1550,8 +1528,8 @@ def _contains_asexpr(node: Union[Node, Leaf]) -> bool: def detect_target_versions( - node: Node, *, future_imports: Optional[Set[str]] = None -) -> Set[TargetVersion]: + node: Node, *, future_imports: Optional[set[str]] = None +) -> set[TargetVersion]: """Detect the version to target based on the nodes used.""" features = get_features_used(node, future_imports=future_imports) return { @@ -1559,11 +1537,11 @@ def detect_target_versions( } -def get_future_imports(node: Node) -> Set[str]: +def get_future_imports(node: Node) -> set[str]: """Return a set of __future__ imports in the file.""" - imports: Set[str] = set() + imports: set[str] = set() - def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]: + def get_imports_from_children(children: list[LN]) -> Generator[str, None, None]: for child in children: if isinstance(child, Leaf): if child.type == token.NAME: @@ -1609,6 +1587,13 @@ def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]: return imports +def _black_info() -> str: + return ( + f"Black {__version__} on " + f"Python ({platform.python_implementation()}) {platform.python_version()}" + ) + + def assert_equivalent(src: str, dst: str) -> None: """Raise AssertionError if `src` and `dst` aren't equivalent.""" try: @@ -1626,7 +1611,7 @@ def assert_equivalent(src: str, dst: str) -> None: except Exception as exc: log = dump_to_file("".join(traceback.format_tb(exc.__traceback__)), dst) raise ASTSafetyError( - f"INTERNAL ERROR: Black produced invalid code: {exc}. " + f"INTERNAL ERROR: {_black_info()} produced invalid code: {exc}. " "Please report a bug on https://github.com/psf/black/issues. " f"This invalid output might be helpful: {log}" ) from None @@ -1636,14 +1621,14 @@ def assert_equivalent(src: str, dst: str) -> None: if src_ast_str != dst_ast_str: log = dump_to_file(diff(src_ast_str, dst_ast_str, "src", "dst")) raise ASTSafetyError( - "INTERNAL ERROR: Black produced code that is not equivalent to the" - " source. Please report a bug on " - f"https://github.com/psf/black/issues. This diff might be helpful: {log}" + f"INTERNAL ERROR: {_black_info()} produced code that is not equivalent to" + " the source. Please report a bug on https://github.com/psf/black/issues." + f" This diff might be helpful: {log}" ) from None def assert_stable( - src: str, dst: str, mode: Mode, *, lines: Collection[Tuple[int, int]] = () + src: str, dst: str, mode: Mode, *, lines: Collection[tuple[int, int]] = () ) -> None: """Raise AssertionError if `dst` reformats differently the second time.""" if lines: @@ -1664,9 +1649,9 @@ def assert_stable( diff(dst, newdst, "first pass", "second pass"), ) raise AssertionError( - "INTERNAL ERROR: Black produced different code on the second pass of the" - " formatter. Please report a bug on https://github.com/psf/black/issues." - f" This diff might be helpful: {log}" + f"INTERNAL ERROR: {_black_info()} produced different code on the second" + " pass of the formatter. Please report a bug on" + f" https://github.com/psf/black/issues. This diff might be helpful: {log}" ) from None diff --git a/src/pyink/_width_table.py b/src/pyink/_width_table.py index b19b5e3129b..12fd15d6e6c 100644 --- a/src/pyink/_width_table.py +++ b/src/pyink/_width_table.py @@ -1,9 +1,9 @@ # Generated by make_width_table.py # wcwidth 0.2.6 # Unicode 15.0.0 -from typing import Final, List, Tuple +from typing import Final -WIDTH_TABLE: Final[Tuple[Tuple[int, int, int], ...]] = ( +WIDTH_TABLE: Final[tuple[tuple[int, int, int], ...]] = ( (0, 0, 0), (1, 31, -1), (127, 159, -1), diff --git a/src/pyink/brackets.py b/src/pyink/brackets.py index 771a51c3eb7..188a318feda 100644 --- a/src/pyink/brackets.py +++ b/src/pyink/brackets.py @@ -1,7 +1,7 @@ """Builds on top of nodes.py to track brackets.""" from dataclasses import dataclass, field -from typing import Dict, Final, Iterable, List, Optional, Sequence, Set, Tuple, Union +from typing import Final, Iterable, Optional, Sequence, Union from pyink.nodes import ( BRACKET, @@ -60,12 +60,12 @@ class BracketTracker: """Keeps track of brackets on a line.""" depth: int = 0 - bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict) - delimiters: Dict[LeafID, Priority] = field(default_factory=dict) + bracket_match: dict[tuple[Depth, NodeType], Leaf] = field(default_factory=dict) + delimiters: dict[LeafID, Priority] = field(default_factory=dict) previous: Optional[Leaf] = None - _for_loop_depths: List[int] = field(default_factory=list) - _lambda_argument_depths: List[int] = field(default_factory=list) - invisible: List[Leaf] = field(default_factory=list) + _for_loop_depths: list[int] = field(default_factory=list) + _lambda_argument_depths: list[int] = field(default_factory=list) + invisible: list[Leaf] = field(default_factory=list) def mark(self, leaf: Leaf) -> None: """Mark `leaf` with bracket-related metadata. Keep track of delimiters. @@ -353,7 +353,7 @@ def max_delimiter_priority_in_atom(node: LN) -> Priority: return 0 -def get_leaves_inside_matching_brackets(leaves: Sequence[Leaf]) -> Set[LeafID]: +def get_leaves_inside_matching_brackets(leaves: Sequence[Leaf]) -> set[LeafID]: """Return leaves that are inside matching brackets. The input `leaves` can have non-matching brackets at the head or tail parts. diff --git a/src/pyink/cache.py b/src/pyink/cache.py index 66054249b83..f834dacda71 100644 --- a/src/pyink/cache.py +++ b/src/pyink/cache.py @@ -7,7 +7,7 @@ import tempfile from dataclasses import dataclass, field from pathlib import Path -from typing import Dict, Iterable, NamedTuple, Set, Tuple +from typing import Iterable, NamedTuple from platformdirs import user_cache_dir @@ -55,7 +55,7 @@ def get_cache_file(mode: Mode) -> Path: class Cache: mode: Mode cache_file: Path - file_data: Dict[str, FileData] = field(default_factory=dict) + file_data: dict[str, FileData] = field(default_factory=dict) @classmethod def read(cls, mode: Mode) -> Self: @@ -76,7 +76,7 @@ def read(cls, mode: Mode) -> Self: with cache_file.open("rb") as fobj: try: - data: Dict[str, Tuple[float, int, str]] = pickle.load(fobj) + data: dict[str, tuple[float, int, str]] = pickle.load(fobj) file_data = {k: FileData(*v) for k, v in data.items()} except (pickle.UnpicklingError, ValueError, IndexError): return cls(mode, cache_file) @@ -114,14 +114,14 @@ def is_changed(self, source: Path) -> bool: return True return False - def filtered_cached(self, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]: + def filtered_cached(self, sources: Iterable[Path]) -> tuple[set[Path], set[Path]]: """Split an iterable of paths in `sources` into two sets. The first contains paths of files that modified on disk or are not in the cache. The other contains paths to non-modified files. """ - changed: Set[Path] = set() - done: Set[Path] = set() + changed: set[Path] = set() + done: set[Path] = set() for src in sources: if self.is_changed(src): changed.add(src) @@ -139,9 +139,8 @@ def write(self, sources: Iterable[Path]) -> None: with tempfile.NamedTemporaryFile( dir=str(self.cache_file.parent), delete=False ) as f: - # We store raw tuples in the cache because pickling NamedTuples - # doesn't work with mypyc on Python 3.8, and because it's faster. - data: Dict[str, Tuple[float, int, str]] = { + # We store raw tuples in the cache because it's faster. + data: dict[str, tuple[float, int, str]] = { k: (*v,) for k, v in self.file_data.items() } pickle.dump(data, f, protocol=4) diff --git a/src/pyink/comments.py b/src/pyink/comments.py index 1202562f8d7..2bbed0dec3d 100644 --- a/src/pyink/comments.py +++ b/src/pyink/comments.py @@ -1,7 +1,7 @@ import re from dataclasses import dataclass from functools import lru_cache -from typing import Collection, Final, Iterator, List, Optional, Tuple, Union +from typing import Collection, Final, Iterator, Optional, Union from pyink import ink_comments from pyink.mode import Mode, Preview @@ -78,9 +78,9 @@ def generate_comments(leaf: LN) -> Iterator[Leaf]: @lru_cache(maxsize=4096) -def list_comments(prefix: str, *, is_endmarker: bool) -> List[ProtoComment]: +def list_comments(prefix: str, *, is_endmarker: bool) -> list[ProtoComment]: """Return a list of :class:`ProtoComment` objects parsed from the given `prefix`.""" - result: List[ProtoComment] = [] + result: list[ProtoComment] = [] if not prefix or "#" not in prefix: return result @@ -167,7 +167,7 @@ def make_comment(content: str) -> str: def normalize_fmt_off( - node: Node, mode: Mode, lines: Collection[Tuple[int, int]] + node: Node, mode: Mode, lines: Collection[tuple[int, int]] ) -> None: """Convert content between `# fmt: off`/`# fmt: on` into standalone comments.""" try_again = True @@ -176,7 +176,7 @@ def normalize_fmt_off( def convert_one_fmt_off_pair( - node: Node, mode: Mode, lines: Collection[Tuple[int, int]] + node: Node, mode: Mode, lines: Collection[tuple[int, int]] ) -> bool: """Convert content of a single `# fmt: off`/`# fmt: on` into a standalone comment. @@ -337,7 +337,7 @@ def _generate_ignored_nodes_from_fmt_skip( # statements. The ignored nodes should be previous siblings of the # parent suite node. leaf.prefix = "" - ignored_nodes: List[LN] = [] + ignored_nodes: list[LN] = [] parent_sibling = parent.prev_sibling while parent_sibling is not None and parent_sibling.type != syms.suite: ignored_nodes.insert(0, parent_sibling) @@ -377,7 +377,7 @@ def children_contains_fmt_on(container: LN) -> bool: return False -def contains_pragma_comment(comment_list: List[Leaf], mode: Mode) -> bool: +def contains_pragma_comment(comment_list: list[Leaf], mode: Mode) -> bool: """ Returns: True iff one of the comments in @comment_list is a pragma used by one diff --git a/src/pyink/concurrency.py b/src/pyink/concurrency.py index b4603729570..37fd4e54d9d 100644 --- a/src/pyink/concurrency.py +++ b/src/pyink/concurrency.py @@ -13,7 +13,7 @@ from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor from multiprocessing import Manager from pathlib import Path -from typing import Any, Iterable, Optional, Set +from typing import Any, Iterable, Optional from mypy_extensions import mypyc_attr @@ -69,7 +69,7 @@ def shutdown(loop: asyncio.AbstractEventLoop) -> None: # not ideal, but this shouldn't cause any issues ... hopefully. ~ichard26 @mypyc_attr(patchable=True) def reformat_many( - sources: Set[Path], + sources: set[Path], fast: bool, write_back: WriteBack, mode: Mode, @@ -119,7 +119,7 @@ def reformat_many( async def schedule_formatting( - sources: Set[Path], + sources: set[Path], fast: bool, write_back: WriteBack, mode: Mode, diff --git a/src/pyink/debug.py b/src/pyink/debug.py index 16c1f0e39a5..0a757cbbdc8 100644 --- a/src/pyink/debug.py +++ b/src/pyink/debug.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Any, Iterator, List, TypeVar, Union +from typing import Any, Iterator, TypeVar, Union from pyink.nodes import Visitor from pyink.output import out @@ -14,7 +14,7 @@ @dataclass class DebugVisitor(Visitor[T]): tree_depth: int = 0 - list_output: List[str] = field(default_factory=list) + list_output: list[str] = field(default_factory=list) print_output: bool = True def out(self, message: str, *args: Any, **kwargs: Any) -> None: diff --git a/src/pyink/files.py b/src/pyink/files.py index df7c2792c4a..00bf06850c2 100644 --- a/src/pyink/files.py +++ b/src/pyink/files.py @@ -6,14 +6,11 @@ from typing import ( TYPE_CHECKING, Any, - Dict, Iterable, Iterator, - List, Optional, Pattern, Sequence, - Tuple, Union, ) @@ -43,7 +40,7 @@ @lru_cache -def _load_toml(path: Union[Path, str]) -> Dict[str, Any]: +def _load_toml(path: Union[Path, str]) -> dict[str, Any]: with open(path, "rb") as f: return tomllib.load(f) @@ -56,7 +53,7 @@ def _cached_resolve(path: Path) -> Path: @lru_cache def find_project_root( srcs: Sequence[str], stdin_filename: Optional[str] = None -) -> Tuple[Path, str]: +) -> tuple[Path, str]: """Return a directory containing .git, .hg, or pyproject.toml. pyproject.toml files are only considered if they contain a [tool.pyink] @@ -106,7 +103,7 @@ def find_project_root( def find_pyproject_toml( - path_search_start: Tuple[str, ...], stdin_filename: Optional[str] = None + path_search_start: tuple[str, ...], stdin_filename: Optional[str] = None ) -> Optional[str]: """Find the absolute filepath to a pyproject.toml if it exists""" path_project_root, _ = find_project_root(path_search_start, stdin_filename) @@ -128,13 +125,13 @@ def find_pyproject_toml( @mypyc_attr(patchable=True) -def parse_pyproject_toml(path_config: str) -> Dict[str, Any]: +def parse_pyproject_toml(path_config: str) -> dict[str, Any]: """Parse a pyproject toml file, pulling out relevant parts for Black. If parsing fails, will raise a tomllib.TOMLDecodeError. """ pyproject_toml = _load_toml(path_config) - config: Dict[str, Any] = pyproject_toml.get("tool", {}).get("pyink", {}) + config: dict[str, Any] = pyproject_toml.get("tool", {}).get("pyink", {}) config = {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} if "target_version" not in config: @@ -146,8 +143,8 @@ def parse_pyproject_toml(path_config: str) -> Dict[str, Any]: def infer_target_version( - pyproject_toml: Dict[str, Any], -) -> Optional[List[TargetVersion]]: + pyproject_toml: dict[str, Any], +) -> Optional[list[TargetVersion]]: """Infer Black's target version from the project metadata in pyproject.toml. Supports the PyPA standard format (PEP 621): @@ -170,7 +167,7 @@ def infer_target_version( return None -def parse_req_python_version(requires_python: str) -> Optional[List[TargetVersion]]: +def parse_req_python_version(requires_python: str) -> Optional[list[TargetVersion]]: """Parse a version string (i.e. ``"3.7"``) to a list of TargetVersion. If parsing fails, will raise a packaging.version.InvalidVersion error. @@ -185,7 +182,7 @@ def parse_req_python_version(requires_python: str) -> Optional[List[TargetVersio return None -def parse_req_python_specifier(requires_python: str) -> Optional[List[TargetVersion]]: +def parse_req_python_specifier(requires_python: str) -> Optional[list[TargetVersion]]: """Parse a specifier string (i.e. ``">=3.7,<3.10"``) to a list of TargetVersion. If parsing fails, will raise a packaging.specifiers.InvalidSpecifier error. @@ -196,7 +193,7 @@ def parse_req_python_specifier(requires_python: str) -> Optional[List[TargetVers return None target_version_map = {f"3.{v.value}": v for v in TargetVersion} - compatible_versions: List[str] = list(specifier_set.filter(target_version_map)) + compatible_versions: list[str] = list(specifier_set.filter(target_version_map)) if compatible_versions: return [target_version_map[v] for v in compatible_versions] return None @@ -251,7 +248,7 @@ def find_user_pyproject_toml() -> Path: def get_gitignore(root: Path) -> PathSpec: """Return a PathSpec matching gitignore content if present.""" gitignore = root / ".gitignore" - lines: List[str] = [] + lines: list[str] = [] if gitignore.is_file(): with gitignore.open(encoding="utf-8") as gf: lines = gf.readlines() @@ -272,8 +269,6 @@ def resolves_outside_root_or_cannot_stat( root directory. Also returns True if we failed to resolve the path. """ try: - if sys.version_info < (3, 8, 6): - path = path.absolute() # https://bugs.python.org/issue33660 resolved_path = _cached_resolve(path) except OSError as e: if report: @@ -304,7 +299,7 @@ def best_effort_relative_path(path: Path, root: Path) -> Path: def _path_is_ignored( root_relative_path: str, root: Path, - gitignore_dict: Dict[Path, PathSpec], + gitignore_dict: dict[Path, PathSpec], ) -> bool: path = root / root_relative_path # Note that this logic is sensitive to the ordering of gitignore_dict. Callers must @@ -337,7 +332,7 @@ def gen_python_files( extend_exclude: Optional[Pattern[str]], force_exclude: Optional[Pattern[str]], report: Report, - gitignore_dict: Optional[Dict[Path, PathSpec]], + gitignore_dict: Optional[dict[Path, PathSpec]], *, verbose: bool, quiet: bool, diff --git a/src/pyink/handle_ipynb_magics.py b/src/pyink/handle_ipynb_magics.py index 64156c313f3..e7cc7e56557 100644 --- a/src/pyink/handle_ipynb_magics.py +++ b/src/pyink/handle_ipynb_magics.py @@ -3,17 +3,19 @@ import ast import collections import dataclasses +import re import secrets import sys from functools import lru_cache from importlib.util import find_spec -from typing import Dict, List, Optional, Tuple +from typing import Optional if sys.version_info >= (3, 10): from typing import TypeGuard else: from typing_extensions import TypeGuard +from pyink.mode import Mode from pyink.output import out from pyink.report import NothingChanged @@ -64,7 +66,35 @@ def jupyter_dependencies_are_installed(*, warn: bool) -> bool: return installed -def remove_trailing_semicolon(src: str) -> Tuple[str, bool]: +def validate_cell(src: str, mode: Mode) -> None: + """Check that cell does not already contain TransformerManager transformations, + or non-Python cell magics, which might cause tokenizer_rt to break because of + indentations. + + If a cell contains ``!ls``, then it'll be transformed to + ``get_ipython().system('ls')``. However, if the cell originally contained + ``get_ipython().system('ls')``, then it would get transformed in the same way: + + >>> TransformerManager().transform_cell("get_ipython().system('ls')") + "get_ipython().system('ls')\n" + >>> TransformerManager().transform_cell("!ls") + "get_ipython().system('ls')\n" + + Due to the impossibility of safely roundtripping in such situations, cells + containing transformed magics will be ignored. + """ + if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS): + raise NothingChanged + + line = _get_code_start(src) + if line.startswith("%%") and ( + line.split(maxsplit=1)[0][2:] + not in PYTHON_CELL_MAGICS | mode.python_cell_magics + ): + raise NothingChanged + + +def remove_trailing_semicolon(src: str) -> tuple[str, bool]: """Remove trailing semicolon from Jupyter notebook cell. For example, @@ -120,7 +150,7 @@ def put_trailing_semicolon_back(src: str, has_trailing_semicolon: bool) -> str: return str(tokens_to_src(tokens)) -def mask_cell(src: str) -> Tuple[str, List[Replacement]]: +def mask_cell(src: str) -> tuple[str, list[Replacement]]: """Mask IPython magics so content becomes parseable Python code. For example, @@ -135,7 +165,7 @@ def mask_cell(src: str) -> Tuple[str, List[Replacement]]: The replacements are returned, along with the transformed code. """ - replacements: List[Replacement] = [] + replacements: list[Replacement] = [] try: ast.parse(src) except SyntaxError: @@ -188,7 +218,7 @@ def get_token(src: str, magic: str) -> str: return f'"{token}"' -def replace_cell_magics(src: str) -> Tuple[str, List[Replacement]]: +def replace_cell_magics(src: str) -> tuple[str, list[Replacement]]: """Replace cell magic with token. Note that 'src' will already have been processed by IPython's @@ -205,7 +235,7 @@ def replace_cell_magics(src: str) -> Tuple[str, List[Replacement]]: The replacement, along with the transformed code, is returned. """ - replacements: List[Replacement] = [] + replacements: list[Replacement] = [] tree = ast.parse(src) @@ -219,7 +249,7 @@ def replace_cell_magics(src: str) -> Tuple[str, List[Replacement]]: return f"{mask}\n{cell_magic_finder.cell_magic.body}", replacements -def replace_magics(src: str) -> Tuple[str, List[Replacement]]: +def replace_magics(src: str) -> tuple[str, list[Replacement]]: """Replace magics within body of cell. Note that 'src' will already have been processed by IPython's @@ -260,7 +290,7 @@ def replace_magics(src: str) -> Tuple[str, List[Replacement]]: return "\n".join(new_srcs), replacements -def unmask_cell(src: str, replacements: List[Replacement]) -> str: +def unmask_cell(src: str, replacements: list[Replacement]) -> str: """Remove replacements from cell. For example @@ -280,6 +310,21 @@ def unmask_cell(src: str, replacements: List[Replacement]) -> str: return src +def _get_code_start(src: str) -> str: + """Provides the first line where the code starts. + + Iterates over lines of code until it finds the first line that doesn't + contain only empty spaces and comments. It removes any empty spaces at the + start of the line and returns it. If such line doesn't exist, it returns an + empty string. + """ + for match in re.finditer(".+", src): + line = match.group(0).lstrip() + if line and not line.startswith("#"): + return line + return "" + + def _is_ipython_magic(node: ast.expr) -> TypeGuard[ast.Attribute]: """Check if attribute is IPython magic. @@ -295,7 +340,7 @@ def _is_ipython_magic(node: ast.expr) -> TypeGuard[ast.Attribute]: ) -def _get_str_args(args: List[ast.expr]) -> List[str]: +def _get_str_args(args: list[ast.expr]) -> list[str]: str_args = [] for arg in args: assert isinstance(arg, ast.Constant) and isinstance(arg.value, str) @@ -379,7 +424,7 @@ class MagicFinder(ast.NodeVisitor): """ def __init__(self) -> None: - self.magics: Dict[int, List[OffsetAndMagic]] = collections.defaultdict(list) + self.magics: dict[int, list[OffsetAndMagic]] = collections.defaultdict(list) def visit_Assign(self, node: ast.Assign) -> None: """Look for system assign magics. diff --git a/src/pyink/ink.py b/src/pyink/ink.py index 2cb219972e6..b57a3094018 100644 --- a/src/pyink/ink.py +++ b/src/pyink/ink.py @@ -3,17 +3,9 @@ This is a separate module for easier patch management. """ +from collections.abc import Collection, Iterator, Sequence import re -from typing import ( - Collection, - Iterator, - List, - Optional, - Sequence, - Set, - Tuple, - Union, -) +from typing import Optional, Union from blib2to3.pgen2.token import ASYNC, FSTRING_START, NEWLINE, STRING from blib2to3.pytree import type_repr @@ -83,7 +75,7 @@ def get_code_start(src: str) -> str: return "" -def convert_unchanged_lines(src_node: Node, lines: Collection[Tuple[int, int]]): +def convert_unchanged_lines(src_node: Node, lines: Collection[tuple[int, int]]): """Converts unchanged lines to STANDALONE_COMMENT. The idea is similar to how Black implements `# fmt: on/off` where it also @@ -107,7 +99,7 @@ def convert_unchanged_lines(src_node: Node, lines: Collection[Tuple[int, int]]): more formatting to pass (1). However, it's hard to get it correct when incorrect indentations are used. So we defer this to future optimizations. """ - lines_set: Set[int] = set() + lines_set: set[int] = set() for start, end in lines: lines_set.update(range(start, end + 1)) visitor = _TopLevelStatementsVisitor(lines_set) @@ -133,7 +125,7 @@ class _TopLevelStatementsVisitor(Visitor[None]): classes/functions/statements. """ - def __init__(self, lines_set: Set[int]): + def __init__(self, lines_set: set[int]): self._lines_set = lines_set def visit_simple_stmt(self, node: Node) -> Iterator[None]: @@ -179,7 +171,7 @@ def visit_suite(self, node: Node) -> Iterator[None]: _convert_node_to_standalone_comment(semantic_parent) -def _convert_unchanged_line_by_line(node: Node, lines_set: Set[int]): +def _convert_unchanged_line_by_line(node: Node, lines_set: set[int]): """Converts unchanged to STANDALONE_COMMENT line by line.""" for leaf in node.leaves(): if leaf.type != NEWLINE: @@ -191,7 +183,7 @@ def _convert_unchanged_line_by_line(node: Node, lines_set: Set[int]): # match_stmt: "match" subject_expr ':' NEWLINE INDENT case_block+ DEDENT # Here we need to check `subject_expr`. The `case_block+` will be # checked by their own NEWLINEs. - nodes_to_ignore: List[LN] = [] + nodes_to_ignore: list[LN] = [] prev_sibling = leaf.prev_sibling while prev_sibling: nodes_to_ignore.insert(0, prev_sibling) @@ -339,7 +331,7 @@ def _leaf_line_end(leaf: Leaf) -> int: return leaf.lineno + str(leaf).count("\n") -def _get_line_range(node_or_nodes: Union[LN, List[LN]]) -> Set[int]: +def _get_line_range(node_or_nodes: Union[LN, list[LN]]) -> set[int]: """Returns the line range of this node or list of nodes.""" if isinstance(node_or_nodes, list): nodes = node_or_nodes diff --git a/src/pyink/ink_adjusted_lines.py b/src/pyink/ink_adjusted_lines.py index baaa46f46dd..295dbcfb9ce 100644 --- a/src/pyink/ink_adjusted_lines.py +++ b/src/pyink/ink_adjusted_lines.py @@ -4,21 +4,21 @@ module will be folded to pyink.ink in the future. """ +from collections.abc import Collection, Sequence import dataclasses import difflib -from typing import Collection, List, Sequence, Tuple -def is_valid_line_range(lines: Tuple[int, int]) -> bool: +def is_valid_line_range(lines: tuple[int, int]) -> bool: """Returns whether the line range is valid.""" return not lines or lines[0] <= lines[1] def adjusted_lines( - lines: Collection[Tuple[int, int]], + lines: Collection[tuple[int, int]], original_source: str, modified_source: str, -) -> List[Tuple[int, int]]: +) -> list[tuple[int, int]]: """Returns the adjusted line ranges based on edits from the original code. This computes the new line ranges by diffing original_source and diff --git a/src/pyink/linegen.py b/src/pyink/linegen.py index 0722e323528..2ab9587c509 100644 --- a/src/pyink/linegen.py +++ b/src/pyink/linegen.py @@ -7,7 +7,7 @@ from dataclasses import replace from enum import Enum, auto from functools import partial, wraps -from typing import Collection, Iterator, List, Optional, Set, Union, cast +from typing import Collection, Iterator, Optional, Union, cast if sys.version_info < (3, 8): from typing_extensions import Final, Literal @@ -224,7 +224,7 @@ def visit_DEDENT(self, node: Leaf) -> Iterator[Line]: yield from self.line(_DEDENT) def visit_stmt( - self, node: Node, keywords: Set[str], parens: Set[str] + self, node: Node, keywords: set[str], parens: set[str] ) -> Iterator[Line]: """Visit a statement. @@ -613,7 +613,7 @@ def __post_init__(self) -> None: self.current_line = Line(mode=self.mode) v = self.visit_stmt - Ø: Set[str] = set() + Ø: set[str] = set() self.visit_assert_stmt = partial(v, keywords={"assert"}, parens={"assert", ","}) self.visit_if_stmt = partial( v, keywords={"if", "else", "elif"}, parens={"if", "elif"} @@ -690,7 +690,7 @@ def transform_line( ll, sn, preferred_quote=preferred_quote, line_str=line_str ) - transformers: List[Transformer] + transformers: list[Transformer] if ( not line.contains_uncollapsable_pragma_comments() and not line.should_split_rhs @@ -790,7 +790,7 @@ def should_split_funcdef_with_rhs(line: Line, mode: Mode) -> bool: """If a funcdef has a magic trailing comma in the return type, then we should first split the line with rhs to respect the comma. """ - return_type_leaves: List[Leaf] = [] + return_type_leaves: list[Leaf] = [] in_return_type = False for leaf in line.leaves: @@ -832,9 +832,9 @@ def left_hand_split( Prefer RHS otherwise. This is why this function is not symmetrical with :func:`right_hand_split` which also handles optional parentheses. """ - tail_leaves: List[Leaf] = [] - body_leaves: List[Leaf] = [] - head_leaves: List[Leaf] = [] + tail_leaves: list[Leaf] = [] + body_leaves: list[Leaf] = [] + head_leaves: list[Leaf] = [] current_leaves = head_leaves matching_bracket: Optional[Leaf] = None for leaf in line.leaves: @@ -899,9 +899,9 @@ def _first_right_hand_split( _maybe_split_omitting_optional_parens to get an opinion whether to prefer splitting on the right side of an assignment statement. """ - tail_leaves: List[Leaf] = [] - body_leaves: List[Leaf] = [] - head_leaves: List[Leaf] = [] + tail_leaves: list[Leaf] = [] + body_leaves: list[Leaf] = [] + head_leaves: list[Leaf] = [] current_leaves = tail_leaves opening_bracket: Optional[Leaf] = None closing_bracket: Optional[Leaf] = None @@ -932,8 +932,8 @@ def _first_right_hand_split( and tail_leaves[0].opening_bracket is head_leaves[-1] ): inner_body_leaves = list(body_leaves) - hugged_opening_leaves: List[Leaf] = [] - hugged_closing_leaves: List[Leaf] = [] + hugged_opening_leaves: list[Leaf] = [] + hugged_closing_leaves: list[Leaf] = [] is_unpacking = body_leaves[0].type in [token.STAR, token.DOUBLESTAR] unpacking_offset: int = 1 if is_unpacking else 0 while ( @@ -1142,8 +1142,49 @@ def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None ) +def _ensure_trailing_comma( + leaves: list[Leaf], original: Line, opening_bracket: Leaf +) -> bool: + if not leaves: + return False + # Ensure a trailing comma for imports + if original.is_import: + return True + # ...and standalone function arguments + if not original.is_def: + return False + if opening_bracket.value != "(": + return False + # Don't add commas if we already have any commas + if any( + leaf.type == token.COMMA + and ( + Preview.typed_params_trailing_comma not in original.mode + or not is_part_of_annotation(leaf) + ) + for leaf in leaves + ): + return False + + # Find a leaf with a parent (comments don't have parents) + leaf_with_parent = next((leaf for leaf in leaves if leaf.parent), None) + if leaf_with_parent is None: + return True + # Don't add commas inside parenthesized return annotations + if get_annotation_type(leaf_with_parent) == "return": + return False + # Don't add commas inside PEP 604 unions + if ( + leaf_with_parent.parent + and leaf_with_parent.parent.next_sibling + and leaf_with_parent.parent.next_sibling.type == token.VBAR + ): + return False + return True + + def bracket_split_build_line( - leaves: List[Leaf], + leaves: list[Leaf], original: Line, opening_bracket: Leaf, *, @@ -1162,42 +1203,17 @@ def bracket_split_build_line( if component is _BracketSplitComponent.body: result.inside_brackets = True result.depth = result.depth + (Indentation.CONTINUATION,) - if leaves: - no_commas = ( - # Ensure a trailing comma for imports and standalone function arguments - original.is_def - # Don't add one after any comments or within type annotations - and opening_bracket.value == "(" - # Don't add one if there's already one there - and not any( - leaf.type == token.COMMA - and ( - Preview.typed_params_trailing_comma not in original.mode - or not is_part_of_annotation(leaf) - ) - for leaf in leaves - ) - # Don't add one inside parenthesized return annotations - and get_annotation_type(leaves[0]) != "return" - # Don't add one inside PEP 604 unions - and not ( - leaves[0].parent - and leaves[0].parent.next_sibling - and leaves[0].parent.next_sibling.type == token.VBAR - ) - ) - - if original.is_import or no_commas: - for i in range(len(leaves) - 1, -1, -1): - if leaves[i].type == STANDALONE_COMMENT: - continue + if _ensure_trailing_comma(leaves, original, opening_bracket): + for i in range(len(leaves) - 1, -1, -1): + if leaves[i].type == STANDALONE_COMMENT: + continue - if leaves[i].type != token.COMMA: - new_comma = Leaf(token.COMMA, ",") - leaves.insert(i + 1, new_comma) - break + if leaves[i].type != token.COMMA: + new_comma = Leaf(token.COMMA, ",") + leaves.insert(i + 1, new_comma) + break - leaves_to_track: Set[LeafID] = set() + leaves_to_track: set[LeafID] = set() if component is _BracketSplitComponent.head: leaves_to_track = get_leaves_inside_matching_brackets(leaves) # Populate the line @@ -1389,7 +1405,7 @@ def append_to_line(leaf: Leaf) -> Iterator[Line]: def normalize_invisible_parens( # noqa: C901 - node: Node, parens_after: Set[str], *, mode: Mode, features: Collection[Feature] + node: Node, parens_after: set[str], *, mode: Mode, features: Collection[Feature] ) -> None: """Make existing optional parentheses invisible or create new ones. @@ -1746,7 +1762,7 @@ def should_split_line(line: Line, opening_bracket: Leaf) -> bool: ) -def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[LeafID]]: +def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[set[LeafID]]: """Generate sets of closing bracket IDs that should be omitted in a RHS. Brackets can be omitted if the entire trailer up to and including @@ -1757,14 +1773,14 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf the one that needs to explode are omitted. """ - omit: Set[LeafID] = set() + omit: set[LeafID] = set() if not line.magic_trailing_comma: yield omit length = line.indentation_spaces() opening_bracket: Optional[Leaf] = None closing_bracket: Optional[Leaf] = None - inner_brackets: Set[LeafID] = set() + inner_brackets: set[LeafID] = set() for index, leaf, leaf_length in line.enumerate_with_length(is_reversed=True): length += leaf_length if length > line_length: @@ -1829,10 +1845,10 @@ def run_transformer( features: Collection[Feature], *, line_str: str = "", -) -> List[Line]: +) -> list[Line]: if not line_str: line_str = line_to_string(line) - result: List[Line] = [] + result: list[Line] = [] for transformed_line in transform(line, features, mode): if str(transformed_line).strip("\n") == line_str: raise CannotTransform("Line transformer returned an unchanged result") diff --git a/src/pyink/lines.py b/src/pyink/lines.py index 225ba37445b..590132fd793 100644 --- a/src/pyink/lines.py +++ b/src/pyink/lines.py @@ -2,18 +2,7 @@ import itertools import math from dataclasses import dataclass, field -from typing import ( - Callable, - Dict, - Iterator, - List, - Optional, - Sequence, - Tuple, - TypeVar, - Union, - cast, -) +from typing import Callable, Iterator, Optional, Sequence, TypeVar, Union, cast from pyink.brackets import COMMA_PRIORITY, DOT_PRIORITY, BracketTracker from pyink.mode import Mode, Preview @@ -64,10 +53,10 @@ class Line: """Holds leaves and comments. Can be printed with `str(line)`.""" mode: Mode = field(repr=False) - depth: Tuple[Indentation, ...] = field(default_factory=tuple) - leaves: List[Leaf] = field(default_factory=list) + depth: tuple[Indentation, ...] = field(default_factory=tuple) + leaves: list[Leaf] = field(default_factory=list) # keys ordered like `leaves` - comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict) + comments: dict[LeafID, list[Leaf]] = field(default_factory=dict) bracket_tracker: BracketTracker = field(default_factory=BracketTracker) inside_brackets: bool = False should_split_rhs: bool = False @@ -440,7 +429,7 @@ def append_comment(self, comment: Leaf) -> bool: self.comments.setdefault(id(last_leaf), []).append(comment) return True - def comments_after(self, leaf: Leaf) -> List[Leaf]: + def comments_after(self, leaf: Leaf) -> list[Leaf]: """Generate comments that should appear directly after `leaf`.""" return self.comments.get(id(leaf), []) @@ -473,13 +462,13 @@ def is_complex_subscript(self, leaf: Leaf) -> bool: def enumerate_with_length( self, is_reversed: bool = False - ) -> Iterator[Tuple[Index, Leaf, int]]: + ) -> Iterator[tuple[Index, Leaf, int]]: """Return an enumeration of leaves with their length. Stops prematurely on multiline strings and standalone comments. """ op = cast( - Callable[[Sequence[Leaf]], Iterator[Tuple[Index, Leaf]]], + Callable[[Sequence[Leaf]], Iterator[tuple[Index, Leaf]]], enumerate_reversed if is_reversed else enumerate, ) for index, leaf in op(self.leaves): @@ -545,11 +534,11 @@ class LinesBlock: previous_block: Optional["LinesBlock"] original_line: Line before: int = 0 - content_lines: List[str] = field(default_factory=list) + content_lines: list[str] = field(default_factory=list) after: int = 0 form_feed: bool = False - def all_lines(self) -> List[str]: + def all_lines(self) -> list[str]: empty_line = str(Line(mode=self.mode)) prefix = make_simple_prefix(self.before, self.form_feed, empty_line) return [prefix] + self.content_lines + [empty_line * self.after] @@ -568,7 +557,7 @@ class EmptyLineTracker: mode: Mode previous_line: Optional[Line] = None previous_block: Optional[LinesBlock] = None - previous_defs: List[Line] = field(default_factory=list) + previous_defs: list[Line] = field(default_factory=list) semantic_leading_comment: Optional[LinesBlock] = None def maybe_empty_lines(self, current_line: Line) -> LinesBlock: @@ -621,7 +610,7 @@ def maybe_empty_lines(self, current_line: Line) -> LinesBlock: self.previous_block = block return block - def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]: # noqa: C901 + def _maybe_empty_lines(self, current_line: Line) -> tuple[int, int]: # noqa: C901 max_allowed = 1 if not current_line.depth: max_allowed = 1 if self.mode.is_pyi else 2 @@ -722,7 +711,7 @@ def _maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]: # noqa: C9 def _maybe_empty_lines_for_class_or_def( # noqa: C901 self, current_line: Line, before: int, user_had_newline: bool - ) -> Tuple[int, int]: + ) -> tuple[int, int]: assert self.previous_line is not None if self.previous_line.is_decorator: @@ -806,7 +795,7 @@ def _maybe_empty_lines_for_class_or_def( # noqa: C901 return newlines, 0 -def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]: +def enumerate_reversed(sequence: Sequence[T]) -> Iterator[tuple[Index, T]]: """Like `reversed(enumerate(sequence))` if that were possible.""" index = len(sequence) - 1 for element in reversed(sequence): @@ -815,7 +804,7 @@ def enumerate_reversed(sequence: Sequence[T]) -> Iterator[Tuple[Index, T]]: def append_leaves( - new_line: Line, old_line: Line, leaves: List[Leaf], preformatted: bool = False + new_line: Line, old_line: Line, leaves: list[Leaf], preformatted: bool = False ) -> None: """ Append leaves (taken from @old_line) to @new_line, making sure to fix the @@ -872,10 +861,10 @@ def is_line_short_enough( # noqa: C901 # Depth (which is based on the existing bracket_depth concept) # is needed to determine nesting level of the MLS. # Includes special case for trailing commas. - commas: List[int] = [] # tracks number of commas per depth level + commas: list[int] = [] # tracks number of commas per depth level multiline_string: Optional[Leaf] = None # store the leaves that contain parts of the MLS - multiline_string_contexts: List[LN] = [] + multiline_string_contexts: list[LN] = [] max_level_to_update: Union[int, float] = math.inf # track the depth of the MLS for i, leaf in enumerate(line.leaves): @@ -899,7 +888,7 @@ def is_line_short_enough( # noqa: C901 if leaf.bracket_depth <= max_level_to_update and leaf.type == token.COMMA: # Inside brackets, ignore trailing comma # directly after MLS/MLS-containing expression - ignore_ctxs: List[Optional[LN]] = [None] + ignore_ctxs: list[Optional[LN]] = [None] ignore_ctxs += multiline_string_contexts if (line.inside_brackets or leaf.bracket_depth > 0) and ( i != len(line.leaves) - 1 or leaf.prev_sibling not in ignore_ctxs diff --git a/src/pyink/mode.py b/src/pyink/mode.py index d00ebd14b6e..19d3b49750c 100644 --- a/src/pyink/mode.py +++ b/src/pyink/mode.py @@ -8,7 +8,7 @@ from enum import Enum, auto from hashlib import sha256 from operator import attrgetter -from typing import Dict, Final, Literal, Set, Tuple +from typing import Final, Literal from pyink.const import DEFAULT_LINE_LENGTH @@ -26,6 +26,10 @@ class TargetVersion(Enum): PY312 = 12 PY313 = 13 + def pretty(self) -> str: + assert self.name[:2] == "PY" + return f"Python {self.name[2]}.{self.name[3:]}" + class Feature(Enum): F_STRINGS = 2 @@ -60,7 +64,7 @@ class Feature(Enum): } -VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = { +VERSION_TO_FEATURES: dict[TargetVersion, set[Feature]] = { TargetVersion.PY33: {Feature.ASYNC_IDENTIFIERS}, TargetVersion.PY34: {Feature.ASYNC_IDENTIFIERS}, TargetVersion.PY35: {Feature.TRAILING_COMMA_IN_CALL, Feature.ASYNC_IDENTIFIERS}, @@ -185,7 +189,7 @@ class Feature(Enum): } -def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> bool: +def supports_feature(target_versions: set[TargetVersion], feature: Feature) -> bool: return all(feature in VERSION_TO_FEATURES[version] for version in target_versions) @@ -206,9 +210,10 @@ class Preview(Enum): docstring_check_for_newline = auto() remove_redundant_guard_parens = auto() parens_for_long_if_clauses_in_case_block = auto() + pep646_typed_star_arg_type_var_tuple = auto() -UNSTABLE_FEATURES: Set[Preview] = { +UNSTABLE_FEATURES: set[Preview] = { # Many issues, see summary in https://github.com/psf/black/issues/4042 Preview.string_processing, # See issues #3452 and #4158 @@ -253,7 +258,7 @@ class QuoteStyle(Enum): @dataclass class Mode: - target_versions: Set[TargetVersion] = field(default_factory=set) + target_versions: set[TargetVersion] = field(default_factory=set) line_length: int = DEFAULT_LINE_LENGTH string_normalization: bool = True # No effect if string_normalization is False @@ -264,14 +269,14 @@ class Mode: is_ipynb: bool = False skip_source_first_line: bool = False magic_trailing_comma: bool = True - python_cell_magics: Set[str] = field(default_factory=set) + python_cell_magics: set[str] = field(default_factory=set) preview: bool = False is_pyink: bool = False pyink_indentation: Literal[2, 4] = 4 pyink_ipynb_indentation: Literal[1, 2] = 1 - pyink_annotation_pragmas: Tuple[str, ...] = DEFAULT_ANNOTATION_PRAGMAS + pyink_annotation_pragmas: tuple[str, ...] = DEFAULT_ANNOTATION_PRAGMAS unstable: bool = False - enabled_features: Set[Preview] = field(default_factory=set) + enabled_features: set[Preview] = field(default_factory=set) def __contains__(self, feature: Preview) -> bool: """ @@ -322,6 +327,7 @@ def get_cache_key(self) -> str: str(int(self.skip_source_first_line)), str(int(self.magic_trailing_comma)), str(int(self.preview)), + str(int(self.unstable)), str(int(self.is_pyink)), str(self.pyink_indentation), str(self.pyink_ipynb_indentation), diff --git a/src/pyink/nodes.py b/src/pyink/nodes.py index 14b89dbf25f..1703f402954 100644 --- a/src/pyink/nodes.py +++ b/src/pyink/nodes.py @@ -3,18 +3,7 @@ """ import sys -from typing import ( - Final, - Generic, - Iterator, - List, - Literal, - Optional, - Set, - Tuple, - TypeVar, - Union, -) +from typing import Final, Generic, Iterator, Literal, Optional, TypeVar, Union if sys.version_info >= (3, 10): from typing import TypeGuard @@ -255,9 +244,15 @@ def whitespace(leaf: Leaf, *, complex_subscript: bool, mode: Mode) -> str: # no elif ( prevp.type == token.STAR and parent_type(prevp) == syms.star_expr - and parent_type(prevp.parent) == syms.subscriptlist + and ( + parent_type(prevp.parent) == syms.subscriptlist + or ( + Preview.pep646_typed_star_arg_type_var_tuple in mode + and parent_type(prevp.parent) == syms.tname_star + ) + ) ): - # No space between typevar tuples. + # No space between typevar tuples or unpacking them. return NO elif prevp.type in VARARGS_SPECIALS: @@ -457,7 +452,7 @@ def preceding_leaf(node: Optional[LN]) -> Optional[Leaf]: return None -def prev_siblings_are(node: Optional[LN], tokens: List[Optional[NodeType]]) -> bool: +def prev_siblings_are(node: Optional[LN], tokens: list[Optional[NodeType]]) -> bool: """Return if the `node` and its previous siblings match types against the provided list of tokens; the provided `node`has its type matched against the last element in the list. `None` can be used as the first element to declare that the start of the @@ -629,8 +624,8 @@ def is_tuple_containing_walrus(node: LN) -> bool: def is_one_sequence_between( opening: Leaf, closing: Leaf, - leaves: List[Leaf], - brackets: Tuple[int, int] = (token.LPAR, token.RPAR), + leaves: list[Leaf], + brackets: tuple[int, int] = (token.LPAR, token.RPAR), ) -> bool: """Return True if content between `opening` and `closing` is a one-sequence.""" if (opening.type, closing.type) != brackets: @@ -740,7 +735,7 @@ def is_yield(node: LN) -> bool: return False -def is_vararg(leaf: Leaf, within: Set[NodeType]) -> bool: +def is_vararg(leaf: Leaf, within: set[NodeType]) -> bool: """Return True if `leaf` is a star or double star in a vararg or kwarg. If `within` includes VARARGS_PARENTS, this applies to function signatures. @@ -1013,6 +1008,7 @@ def get_annotation_type(leaf: Leaf) -> Literal["return", "param", None]: def is_part_of_annotation(leaf: Leaf) -> bool: """Returns whether this leaf is part of a type annotation.""" + assert leaf.parent is not None return get_annotation_type(leaf) is not None diff --git a/src/pyink/output.py b/src/pyink/output.py index 7c7dd0fe14e..0dbd74e5e22 100644 --- a/src/pyink/output.py +++ b/src/pyink/output.py @@ -6,7 +6,7 @@ import json import re import tempfile -from typing import Any, List, Optional +from typing import Any, Optional from click import echo, style from mypy_extensions import mypyc_attr @@ -59,7 +59,7 @@ def ipynb_diff(a: str, b: str, a_name: str, b_name: str) -> str: _line_pattern = re.compile(r"(.*?(?:\r\n|\n|\r|$))") -def _splitlines_no_ff(source: str) -> List[str]: +def _splitlines_no_ff(source: str) -> list[str]: """Split a string into lines ignoring form feed and other chars. This mimics how the Python parser splits source code. diff --git a/src/pyink/parsing.py b/src/pyink/parsing.py index d6fb1f7f87d..e5f839e6e2b 100644 --- a/src/pyink/parsing.py +++ b/src/pyink/parsing.py @@ -5,7 +5,7 @@ import ast import sys import warnings -from typing import Iterable, Iterator, List, Set, Tuple +from typing import Collection, Iterator from pyink.mode import VERSION_TO_FEATURES, Feature, TargetVersion, supports_feature from pyink.nodes import syms @@ -21,7 +21,7 @@ class InvalidInput(ValueError): """Raised when input source code fails all parse attempts.""" -def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]: +def get_grammars(target_versions: set[TargetVersion]) -> list[Grammar]: if not target_versions: # No target_version specified, so try all grammars. return [ @@ -52,12 +52,20 @@ def get_grammars(target_versions: Set[TargetVersion]) -> List[Grammar]: return grammars -def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) -> Node: +def lib2to3_parse( + src_txt: str, target_versions: Collection[TargetVersion] = () +) -> Node: """Given a string with source, return the lib2to3 Node.""" if not src_txt.endswith("\n"): src_txt += "\n" grammars = get_grammars(set(target_versions)) + if target_versions: + max_tv = max(target_versions, key=lambda tv: tv.value) + tv_str = f" for target version {max_tv.pretty()}" + else: + tv_str = "" + errors = {} for grammar in grammars: drv = driver.Driver(grammar) @@ -73,14 +81,14 @@ def lib2to3_parse(src_txt: str, target_versions: Iterable[TargetVersion] = ()) - except IndexError: faulty_line = "" errors[grammar.version] = InvalidInput( - f"Cannot parse: {lineno}:{column}: {faulty_line}" + f"Cannot parse{tv_str}: {lineno}:{column}: {faulty_line}" ) except TokenError as te: # In edge cases these are raised; and typically don't have a "faulty_line". lineno, column = te.args[1] errors[grammar.version] = InvalidInput( - f"Cannot parse: {lineno}:{column}: {te.args[0]}" + f"Cannot parse{tv_str}: {lineno}:{column}: {te.args[0]}" ) else: @@ -115,7 +123,7 @@ class ASTSafetyError(Exception): def _parse_single_version( - src: str, version: Tuple[int, int], *, type_comments: bool + src: str, version: tuple[int, int], *, type_comments: bool ) -> ast.AST: filename = "" with warnings.catch_warnings(): @@ -151,7 +159,7 @@ def parse_ast(src: str) -> ast.AST: def _normalize(lineend: str, value: str) -> str: # To normalize, we strip any leading and trailing space from # each line... - stripped: List[str] = [i.strip() for i in value.splitlines()] + stripped: list[str] = [i.strip() for i in value.splitlines()] normalized = lineend.join(stripped) # ...and remove any blank lines at the beginning and end of # the whole string @@ -164,14 +172,14 @@ def stringify_ast(node: ast.AST) -> Iterator[str]: def _stringify_ast_with_new_parent( - node: ast.AST, parent_stack: List[ast.AST], new_parent: ast.AST + node: ast.AST, parent_stack: list[ast.AST], new_parent: ast.AST ) -> Iterator[str]: parent_stack.append(new_parent) yield from _stringify_ast(node, parent_stack) parent_stack.pop() -def _stringify_ast(node: ast.AST, parent_stack: List[ast.AST]) -> Iterator[str]: +def _stringify_ast(node: ast.AST, parent_stack: list[ast.AST]) -> Iterator[str]: if ( isinstance(node, ast.Constant) and isinstance(node.value, str) diff --git a/src/pyink/ranges.py b/src/pyink/ranges.py index 85a1486b5a8..12968e99f06 100644 --- a/src/pyink/ranges.py +++ b/src/pyink/ranges.py @@ -2,7 +2,7 @@ import difflib from dataclasses import dataclass -from typing import Collection, Iterator, List, Sequence, Set, Tuple, Union +from typing import Collection, Iterator, Sequence, Union from pyink.nodes import ( LN, @@ -18,8 +18,8 @@ from blib2to3.pgen2.token import ASYNC, NEWLINE -def parse_line_ranges(line_ranges: Sequence[str]) -> List[Tuple[int, int]]: - lines: List[Tuple[int, int]] = [] +def parse_line_ranges(line_ranges: Sequence[str]) -> list[tuple[int, int]]: + lines: list[tuple[int, int]] = [] for lines_str in line_ranges: parts = lines_str.split("-") if len(parts) != 2: @@ -40,14 +40,14 @@ def parse_line_ranges(line_ranges: Sequence[str]) -> List[Tuple[int, int]]: return lines -def is_valid_line_range(lines: Tuple[int, int]) -> bool: +def is_valid_line_range(lines: tuple[int, int]) -> bool: """Returns whether the line range is valid.""" return not lines or lines[0] <= lines[1] def sanitized_lines( - lines: Collection[Tuple[int, int]], src_contents: str -) -> Collection[Tuple[int, int]]: + lines: Collection[tuple[int, int]], src_contents: str +) -> Collection[tuple[int, int]]: """Returns the valid line ranges for the given source. This removes ranges that are entirely outside the valid lines. @@ -74,10 +74,10 @@ def sanitized_lines( def adjusted_lines( - lines: Collection[Tuple[int, int]], + lines: Collection[tuple[int, int]], original_source: str, modified_source: str, -) -> List[Tuple[int, int]]: +) -> list[tuple[int, int]]: """Returns the adjusted line ranges based on edits from the original code. This computes the new line ranges by diffing original_source and @@ -153,7 +153,7 @@ def adjusted_lines( return new_lines -def convert_unchanged_lines(src_node: Node, lines: Collection[Tuple[int, int]]) -> None: +def convert_unchanged_lines(src_node: Node, lines: Collection[tuple[int, int]]) -> None: """Converts unchanged lines to STANDALONE_COMMENT. The idea is similar to how `# fmt: on/off` is implemented. It also converts the @@ -177,7 +177,7 @@ def convert_unchanged_lines(src_node: Node, lines: Collection[Tuple[int, int]]) more formatting to pass (1). However, it's hard to get it correct when incorrect indentations are used. So we defer this to future optimizations. """ - lines_set: Set[int] = set() + lines_set: set[int] = set() for start, end in lines: lines_set.update(range(start, end + 1)) visitor = _TopLevelStatementsVisitor(lines_set) @@ -205,7 +205,7 @@ class _TopLevelStatementsVisitor(Visitor[None]): classes/functions/statements. """ - def __init__(self, lines_set: Set[int]): + def __init__(self, lines_set: set[int]): self._lines_set = lines_set def visit_simple_stmt(self, node: Node) -> Iterator[None]: @@ -249,7 +249,7 @@ def visit_suite(self, node: Node) -> Iterator[None]: _convert_node_to_standalone_comment(semantic_parent) -def _convert_unchanged_line_by_line(node: Node, lines_set: Set[int]) -> None: +def _convert_unchanged_line_by_line(node: Node, lines_set: set[int]) -> None: """Converts unchanged to STANDALONE_COMMENT line by line.""" for leaf in node.leaves(): if leaf.type != NEWLINE: @@ -261,7 +261,7 @@ def _convert_unchanged_line_by_line(node: Node, lines_set: Set[int]) -> None: # match_stmt: "match" subject_expr ':' NEWLINE INDENT case_block+ DEDENT # Here we need to check `subject_expr`. The `case_block+` will be # checked by their own NEWLINEs. - nodes_to_ignore: List[LN] = [] + nodes_to_ignore: list[LN] = [] prev_sibling = leaf.prev_sibling while prev_sibling: nodes_to_ignore.insert(0, prev_sibling) @@ -382,7 +382,7 @@ def _leaf_line_end(leaf: Leaf) -> int: return leaf.lineno + str(leaf).count("\n") -def _get_line_range(node_or_nodes: Union[LN, List[LN]]) -> Set[int]: +def _get_line_range(node_or_nodes: Union[LN, list[LN]]) -> set[int]: """Returns the line range of this node or list of nodes.""" if isinstance(node_or_nodes, list): nodes = node_or_nodes @@ -463,7 +463,7 @@ def _calculate_lines_mappings( modified_source.splitlines(keepends=True), ) matching_blocks = matcher.get_matching_blocks() - lines_mappings: List[_LinesMapping] = [] + lines_mappings: list[_LinesMapping] = [] # matching_blocks is a sequence of "same block of code ranges", see # https://docs.python.org/3/library/difflib.html#difflib.SequenceMatcher.get_matching_blocks # Each block corresponds to a _LinesMapping with is_changed_block=False, diff --git a/src/pyink/resources/pyink.schema.json b/src/pyink/resources/pyink.schema.json index f50a2fc08ce..0b604baaa62 100644 --- a/src/pyink/resources/pyink.schema.json +++ b/src/pyink/resources/pyink.schema.json @@ -90,7 +90,8 @@ "is_simple_lookup_for_doublestar_expression", "docstring_check_for_newline", "remove_redundant_guard_parens", - "parens_for_long_if_clauses_in_case_block" + "parens_for_long_if_clauses_in_case_block", + "pep646_typed_star_arg_type_var_tuple" ] }, "description": "Enable specific features included in the `--unstable` style. Requires `--preview`. No compatibility guarantees are provided on the behavior or existence of any unstable features." diff --git a/src/pyink/schema.py b/src/pyink/schema.py index 97201ee9a78..8cf5ef9a71d 100644 --- a/src/pyink/schema.py +++ b/src/pyink/schema.py @@ -1,6 +1,5 @@ import importlib.resources import json -import sys from typing import Any @@ -11,10 +10,6 @@ def get_schema(tool_name: str = "pyink") -> Any: pkg = "pyink.resources" fname = "pyink.schema.json" - if sys.version_info < (3, 9): - with importlib.resources.open_text(pkg, fname, encoding="utf-8") as f: - return json.load(f) - - schema = importlib.resources.files(pkg).joinpath(fname) # type: ignore[unreachable] + schema = importlib.resources.files(pkg).joinpath(fname) with schema.open(encoding="utf-8") as f: return json.load(f) diff --git a/src/pyink/strings.py b/src/pyink/strings.py index cfa878709e5..1a2f49d49ea 100644 --- a/src/pyink/strings.py +++ b/src/pyink/strings.py @@ -5,7 +5,7 @@ import re import sys from functools import lru_cache -from typing import Final, List, Match, Pattern, Tuple +from typing import Final, Match, Pattern from pyink._width_table import WIDTH_TABLE from pyink.mode import Quote @@ -44,7 +44,7 @@ def has_triple_quotes(string: str) -> bool: return raw_string[:3] in {'"""', "'''"} -def lines_with_leading_tabs_expanded(s: str) -> List[str]: +def lines_with_leading_tabs_expanded(s: str) -> list[str]: """ Splits string into lines and expands only leading tabs (following the normal Python rules) @@ -245,9 +245,9 @@ def normalize_string_quotes(s: str, *, preferred_quote: Quote) -> str: def normalize_fstring_quotes( quote: str, - middles: List[Leaf], + middles: list[Leaf], is_raw_fstring: bool, -) -> Tuple[List[Leaf], str]: +) -> tuple[list[Leaf], str]: """Prefer double quotes but only if it doesn't cause more escaping. Adds or removes backslashes as appropriate. diff --git a/src/pyink/trans.py b/src/pyink/trans.py index a063b9ddf75..e6209fb24ef 100644 --- a/src/pyink/trans.py +++ b/src/pyink/trans.py @@ -11,16 +11,12 @@ Callable, ClassVar, Collection, - Dict, Final, Iterable, Iterator, - List, Literal, Optional, Sequence, - Set, - Tuple, TypeVar, Union, ) @@ -68,7 +64,7 @@ class CannotTransform(Exception): ParserState = int StringID = int TResult = Result[T, CannotTransform] # (T)ransform Result -TMatchResult = TResult[List[Index]] +TMatchResult = TResult[list[Index]] SPLIT_SAFE_CHARS = frozenset(["\u3001", "\u3002", "\uff0c"]) # East Asian stops @@ -179,7 +175,7 @@ def original_is_simple_lookup_func( return True -def handle_is_simple_look_up_prev(line: Line, index: int, disallowed: Set[int]) -> bool: +def handle_is_simple_look_up_prev(line: Line, index: int, disallowed: set[int]) -> bool: """ Handling the determination of is_simple_lookup for the lines prior to the doublestar token. This is required because of the need to isolate the chained expression @@ -202,7 +198,7 @@ def handle_is_simple_look_up_prev(line: Line, index: int, disallowed: Set[int]) def handle_is_simple_lookup_forward( - line: Line, index: int, disallowed: Set[int] + line: Line, index: int, disallowed: set[int] ) -> bool: """ Handling decision is_simple_lookup for the lines behind the doublestar token. @@ -227,7 +223,7 @@ def handle_is_simple_lookup_forward( return True -def is_expression_chained(chained_leaves: List[Leaf]) -> bool: +def is_expression_chained(chained_leaves: list[Leaf]) -> bool: """ Function to determine if the variable is a chained call. (e.g., foo.lookup, foo().lookup, (foo.lookup())) will be recognized as chained call) @@ -307,7 +303,7 @@ def do_match(self, line: Line) -> TMatchResult: @abstractmethod def do_transform( - self, line: Line, string_indices: List[int] + self, line: Line, string_indices: list[int] ) -> Iterator[TResult[Line]]: """ Yields: @@ -397,8 +393,8 @@ class CustomSplitMapMixin: the resultant substrings go over the configured max line length. """ - _Key: ClassVar = Tuple[StringID, str] - _CUSTOM_SPLIT_MAP: ClassVar[Dict[_Key, Tuple[CustomSplit, ...]]] = defaultdict( + _Key: ClassVar = tuple[StringID, str] + _CUSTOM_SPLIT_MAP: ClassVar[dict[_Key, tuple[CustomSplit, ...]]] = defaultdict( tuple ) @@ -422,7 +418,7 @@ def add_custom_splits( key = self._get_key(string) self._CUSTOM_SPLIT_MAP[key] = tuple(custom_splits) - def pop_custom_splits(self, string: str) -> List[CustomSplit]: + def pop_custom_splits(self, string: str) -> list[CustomSplit]: """Custom Split Map Getter Method Returns: @@ -497,7 +493,7 @@ def do_match(self, line: Line) -> TMatchResult: break i += 1 - if not is_part_of_annotation(leaf) and not contains_comment: + if not contains_comment and not is_part_of_annotation(leaf): string_indices.append(idx) # Advance to the next non-STRING leaf. @@ -521,7 +517,7 @@ def do_match(self, line: Line) -> TMatchResult: return TErr("This line has no strings that need merging.") def do_transform( - self, line: Line, string_indices: List[int] + self, line: Line, string_indices: list[int] ) -> Iterator[TResult[Line]]: new_line = line @@ -552,7 +548,7 @@ def do_transform( @staticmethod def _remove_backslash_line_continuation_chars( - line: Line, string_indices: List[int] + line: Line, string_indices: list[int] ) -> TResult[Line]: """ Merge strings that were split across multiple lines using @@ -593,7 +589,7 @@ def _remove_backslash_line_continuation_chars( return Ok(new_line) def _merge_string_group( - self, line: Line, string_indices: List[int] + self, line: Line, string_indices: list[int] ) -> TResult[Line]: """ Merges string groups (i.e. set of adjacent strings). @@ -612,7 +608,7 @@ def _merge_string_group( is_valid_index = is_valid_index_factory(LL) # A dict of {string_idx: tuple[num_of_strings, string_leaf]}. - merged_string_idx_dict: Dict[int, Tuple[int, Leaf]] = {} + merged_string_idx_dict: dict[int, tuple[int, Leaf]] = {} for string_idx in string_indices: vresult = self._validate_msg(line, string_idx) if isinstance(vresult, Err): @@ -648,8 +644,8 @@ def _merge_string_group( return Ok(new_line) def _merge_one_string_group( - self, LL: List[Leaf], string_idx: int, is_valid_index: Callable[[int], bool] - ) -> Tuple[int, Leaf]: + self, LL: list[Leaf], string_idx: int, is_valid_index: Callable[[int], bool] + ) -> tuple[int, Leaf]: """ Merges one string group where the first string in the group is `LL[string_idx]`. @@ -1021,11 +1017,11 @@ def do_match(self, line: Line) -> TMatchResult: return TErr("This line has no strings wrapped in parens.") def do_transform( - self, line: Line, string_indices: List[int] + self, line: Line, string_indices: list[int] ) -> Iterator[TResult[Line]]: LL = line.leaves - string_and_rpar_indices: List[int] = [] + string_and_rpar_indices: list[int] = [] for string_idx in string_indices: string_parser = StringParser() rpar_idx = string_parser.parse(LL, string_idx) @@ -1048,7 +1044,7 @@ def do_transform( ) def _transform_to_new_line( - self, line: Line, string_and_rpar_indices: List[int] + self, line: Line, string_and_rpar_indices: list[int] ) -> Line: LL = line.leaves @@ -1301,7 +1297,7 @@ def _get_max_string_length(self, line: Line, string_idx: int) -> int: return max_string_length @staticmethod - def _prefer_paren_wrap_match(LL: List[Leaf]) -> Optional[int]: + def _prefer_paren_wrap_match(LL: list[Leaf]) -> Optional[int]: """ Returns: string_idx such that @LL[string_idx] is equal to our target (i.e. @@ -1346,14 +1342,14 @@ def _prefer_paren_wrap_match(LL: List[Leaf]) -> Optional[int]: return None -def iter_fexpr_spans(s: str) -> Iterator[Tuple[int, int]]: +def iter_fexpr_spans(s: str) -> Iterator[tuple[int, int]]: """ Yields spans corresponding to expressions in a given f-string. Spans are half-open ranges (left inclusive, right exclusive). Assumes the input string is a valid f-string, but will not crash if the input string is invalid. """ - stack: List[int] = [] # our curly paren stack + stack: list[int] = [] # our curly paren stack i = 0 while i < len(s): if s[i] == "{": @@ -1516,7 +1512,7 @@ def do_splitter_match(self, line: Line) -> TMatchResult: return Ok([string_idx]) def do_transform( - self, line: Line, string_indices: List[int] + self, line: Line, string_indices: list[int] ) -> Iterator[TResult[Line]]: LL = line.leaves assert len(string_indices) == 1, ( @@ -1618,7 +1614,7 @@ def more_splits_should_be_made() -> bool: else: return str_width(rest_value) > max_last_string_column() - string_line_results: List[Ok[Line]] = [] + string_line_results: list[Ok[Line]] = [] while more_splits_should_be_made(): if use_custom_breakpoints: # Custom User Split (manual) @@ -1747,7 +1743,7 @@ def more_splits_should_be_made() -> bool: last_line.comments = line.comments.copy() yield Ok(last_line) - def _iter_nameescape_slices(self, string: str) -> Iterator[Tuple[Index, Index]]: + def _iter_nameescape_slices(self, string: str) -> Iterator[tuple[Index, Index]]: """ Yields: All ranges of @string which, if @string were to be split there, @@ -1778,7 +1774,7 @@ def _iter_nameescape_slices(self, string: str) -> Iterator[Tuple[Index, Index]]: raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!") yield begin, end - def _iter_fexpr_slices(self, string: str) -> Iterator[Tuple[Index, Index]]: + def _iter_fexpr_slices(self, string: str) -> Iterator[tuple[Index, Index]]: """ Yields: All ranges of @string which, if @string were to be split there, @@ -1789,8 +1785,8 @@ def _iter_fexpr_slices(self, string: str) -> Iterator[Tuple[Index, Index]]: return yield from iter_fexpr_spans(string) - def _get_illegal_split_indices(self, string: str) -> Set[Index]: - illegal_indices: Set[Index] = set() + def _get_illegal_split_indices(self, string: str) -> set[Index]: + illegal_indices: set[Index] = set() iterators = [ self._iter_fexpr_slices(string), self._iter_nameescape_slices(string), @@ -1918,7 +1914,7 @@ def _normalize_f_string(self, string: str, prefix: str) -> str: else: return string - def _get_string_operator_leaves(self, leaves: Iterable[Leaf]) -> List[Leaf]: + def _get_string_operator_leaves(self, leaves: Iterable[Leaf]) -> list[Leaf]: LL = list(leaves) string_op_leaves = [] @@ -2028,7 +2024,7 @@ def do_splitter_match(self, line: Line) -> TMatchResult: return TErr("This line does not contain any non-atomic strings.") @staticmethod - def _return_match(LL: List[Leaf]) -> Optional[int]: + def _return_match(LL: list[Leaf]) -> Optional[int]: """ Returns: string_idx such that @LL[string_idx] is equal to our target (i.e. @@ -2053,7 +2049,7 @@ def _return_match(LL: List[Leaf]) -> Optional[int]: return None @staticmethod - def _else_match(LL: List[Leaf]) -> Optional[int]: + def _else_match(LL: list[Leaf]) -> Optional[int]: """ Returns: string_idx such that @LL[string_idx] is equal to our target (i.e. @@ -2080,7 +2076,7 @@ def _else_match(LL: List[Leaf]) -> Optional[int]: return None @staticmethod - def _assert_match(LL: List[Leaf]) -> Optional[int]: + def _assert_match(LL: list[Leaf]) -> Optional[int]: """ Returns: string_idx such that @LL[string_idx] is equal to our target (i.e. @@ -2115,7 +2111,7 @@ def _assert_match(LL: List[Leaf]) -> Optional[int]: return None @staticmethod - def _assign_match(LL: List[Leaf]) -> Optional[int]: + def _assign_match(LL: list[Leaf]) -> Optional[int]: """ Returns: string_idx such that @LL[string_idx] is equal to our target (i.e. @@ -2162,7 +2158,7 @@ def _assign_match(LL: List[Leaf]) -> Optional[int]: return None @staticmethod - def _dict_or_lambda_match(LL: List[Leaf]) -> Optional[int]: + def _dict_or_lambda_match(LL: list[Leaf]) -> Optional[int]: """ Returns: string_idx such that @LL[string_idx] is equal to our target (i.e. @@ -2201,7 +2197,7 @@ def _dict_or_lambda_match(LL: List[Leaf]) -> Optional[int]: return None def do_transform( - self, line: Line, string_indices: List[int] + self, line: Line, string_indices: list[int] ) -> Iterator[TResult[Line]]: LL = line.leaves assert len(string_indices) == 1, ( @@ -2367,7 +2363,7 @@ class StringParser: DONE: Final = 8 # Lookup Table for Next State - _goto: Final[Dict[Tuple[ParserState, NodeType], ParserState]] = { + _goto: Final[dict[tuple[ParserState, NodeType], ParserState]] = { # A string trailer may start with '.' OR '%'. (START, token.DOT): DOT, (START, token.PERCENT): PERCENT, @@ -2396,7 +2392,7 @@ def __init__(self) -> None: self._state = self.START self._unmatched_lpars = 0 - def parse(self, leaves: List[Leaf], string_idx: int) -> int: + def parse(self, leaves: list[Leaf], string_idx: int) -> int: """ Pre-conditions: * @leaves[@string_idx].type == token.STRING diff --git a/tests/data/cases/funcdef_return_type_trailing_comma.py b/tests/data/cases/funcdef_return_type_trailing_comma.py index 9b9b9c673de..14fd763d9d1 100644 --- a/tests/data/cases/funcdef_return_type_trailing_comma.py +++ b/tests/data/cases/funcdef_return_type_trailing_comma.py @@ -142,6 +142,7 @@ def SimplePyFn( Buffer[UInt8, 2], Buffer[UInt8, 2], ]: ... + # output # normal, short, function definition def foo(a, b) -> tuple[int, float]: ... diff --git a/tests/data/cases/function_trailing_comma.py b/tests/data/cases/function_trailing_comma.py index 92f46e27516..63cf3999c2e 100644 --- a/tests/data/cases/function_trailing_comma.py +++ b/tests/data/cases/function_trailing_comma.py @@ -60,6 +60,64 @@ def func() -> ((also_super_long_type_annotation_that_may_cause_an_AST_related_cr argument1, (one, two,), argument4, argument5, argument6 ) +def foo() -> ( + # comment inside parenthesised return type + int +): + ... + +def foo() -> ( + # comment inside parenthesised return type + # more + int + # another +): + ... + +def foo() -> ( + # comment inside parenthesised new union return type + int | str | bytes +): + ... + +def foo() -> ( + # comment inside plain tuple +): + pass + +def foo(arg: (# comment with non-return annotation + int + # comment with non-return annotation +)): + pass + +def foo(arg: (# comment with non-return annotation + int | range | memoryview + # comment with non-return annotation +)): + pass + +def foo(arg: (# only before + int +)): + pass + +def foo(arg: ( + int + # only after +)): + pass + +variable: ( # annotation + because + # why not +) + +variable: ( + because + # why not +) + # output def f( @@ -176,3 +234,75 @@ def func() -> ( argument5, argument6, ) + + +def foo() -> ( + # comment inside parenthesised return type + int +): ... + + +def foo() -> ( + # comment inside parenthesised return type + # more + int + # another +): ... + + +def foo() -> ( + # comment inside parenthesised new union return type + int + | str + | bytes +): ... + + +def foo() -> ( + # comment inside plain tuple +): + pass + + +def foo( + arg: ( # comment with non-return annotation + int + # comment with non-return annotation + ), +): + pass + + +def foo( + arg: ( # comment with non-return annotation + int + | range + | memoryview + # comment with non-return annotation + ), +): + pass + + +def foo(arg: int): # only before + pass + + +def foo( + arg: ( + int + # only after + ), +): + pass + + +variable: ( # annotation + because + # why not +) + +variable: ( + because + # why not +) diff --git a/tests/data/cases/preview_pep646_typed_star_arg_type_var_tuple.py b/tests/data/cases/preview_pep646_typed_star_arg_type_var_tuple.py new file mode 100644 index 00000000000..fb79e9983b1 --- /dev/null +++ b/tests/data/cases/preview_pep646_typed_star_arg_type_var_tuple.py @@ -0,0 +1,8 @@ +# flags: --minimum-version=3.11 --preview + + +def fn(*args: *tuple[*A, B]) -> None: + pass + + +fn.__annotations__ diff --git a/tests/test_black.py b/tests/test_black.py index 7a00fbd31fe..3695ec4ab83 100644 --- a/tests/test_black.py +++ b/tests/test_black.py @@ -12,7 +12,7 @@ import types from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager, redirect_stderr -from dataclasses import replace +from dataclasses import fields, replace from io import BytesIO from pathlib import Path, WindowsPath from platform import system @@ -44,7 +44,7 @@ from pyink import re_compile_maybe_verbose as compile_pattern from pyink.cache import FileData, get_cache_dir, get_cache_file from pyink.debug import DebugVisitor -from pyink.mode import Mode, Preview +from pyink.mode import Mode, Preview, Quote, QuoteStyle from pyink.output import color_diff, diff from pyink.parsing import ASTSafetyError from pyink.report import Report @@ -906,6 +906,9 @@ def test_get_features_used(self) -> None: self.check_features_used("a[*b]", {Feature.VARIADIC_GENERICS}) self.check_features_used("a[x, *y(), z] = t", {Feature.VARIADIC_GENERICS}) self.check_features_used("def fn(*args: *T): pass", {Feature.VARIADIC_GENERICS}) + self.check_features_used( + "def fn(*args: *tuple[*T]): pass", {Feature.VARIADIC_GENERICS} + ) self.check_features_used("with a: pass", set()) self.check_features_used("with a, b: pass", set()) @@ -2154,8 +2157,9 @@ def test_cache_single_file_already_cached(self) -> None: @event_loop() def test_cache_multiple_files(self) -> None: mode = DEFAULT_MODE - with cache_dir() as workspace, patch( - "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor + with ( + cache_dir() as workspace, + patch("concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor), ): one = (workspace / "one.py").resolve() one.write_text("print('hello')", encoding="utf-8") @@ -2177,9 +2181,10 @@ def test_no_cache_when_writeback_diff(self, color: bool) -> None: with cache_dir() as workspace: src = (workspace / "test.py").resolve() src.write_text("print('hello')", encoding="utf-8") - with patch.object(pyink.Cache, "read") as read_cache, patch.object( - pyink.Cache, "write" - ) as write_cache: + with ( + patch.object(pyink.Cache, "read") as read_cache, + patch.object(pyink.Cache, "write") as write_cache, + ): cmd = [str(src), "--diff"] if color: cmd.append("--color") @@ -2308,8 +2313,9 @@ def test_write_cache_creates_directory_if_needed(self) -> None: @event_loop() def test_failed_formatting_does_not_get_cached(self) -> None: mode = DEFAULT_MODE - with cache_dir() as workspace, patch( - "concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor + with ( + cache_dir() as workspace, + patch("concurrent.futures.ProcessPoolExecutor", new=ThreadPoolExecutor), ): failing = (workspace / "failing.py").resolve() failing.write_text("not actually python", encoding="utf-8") @@ -2341,6 +2347,49 @@ def test_read_cache_line_lengths(self) -> None: two = pyink.Cache.read(short_mode) assert two.is_changed(path) + def test_cache_key(self) -> None: + # Test that all members of the mode enum affect the cache key. + for field in fields(Mode): + values: List[Any] + if field.name == "target_versions": + values = [ + {TargetVersion.PY312}, + {TargetVersion.PY313}, + ] + elif field.name == "python_cell_magics": + values = [{"magic1"}, {"magic2"}] + elif field.name == "enabled_features": + # If you are looking to remove one of these features, just + # replace it with any other feature. + values = [ + {Preview.docstring_check_for_newline}, + {Preview.hex_codes_in_unicode_sequences}, + ] + elif field.type is Quote: + values = list(Quote) + elif field.type is QuoteStyle: + values = list(QuoteStyle) + elif field.name == "pyink_indentation": + values = [2, 4] + elif field.name == "pyink_ipynb_indentation": + values = [1, 2] + elif field.name == "pyink_annotation_pragmas": + values = [ + ("type: ignore",), + ("noqa", "pylint:", "pytype: ignore", "@param"), + ] + elif field.type is bool: + values = [True, False] + elif field.type is int: + values = [1, 2] + else: + raise AssertionError( + f"Unhandled field type: {field.type} for field {field.name}" + ) + modes = [replace(DEFAULT_MODE, **{field.name: value}) for value in values] + keys = [mode.get_cache_key() for mode in modes] + assert len(set(keys)) == len(modes) + def assert_collected_sources( src: Sequence[Union[str, Path]], diff --git a/tests/test_format.py b/tests/test_format.py index 576a2473934..0fd0190fc98 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -87,4 +87,6 @@ def test_patma_invalid() -> None: with pytest.raises(pyink.parsing.InvalidInput) as exc_info: assert_format(source, expected, mode, minimum_version=(3, 10)) - exc_info.match("Cannot parse: 10:11") + exc_info.match( + "Cannot parse for target version Python 3.10: 10:11: case a := b:" + )