Skip to content

Commit

Permalink
Use builtin generics (#4458)
Browse files Browse the repository at this point in the history

uvx ruff check --output-format concise src --target-version py39 --select UP006 --fix --unsafe-fixes
uvx ruff check --output-format concise src --target-version py39 --select F401 --fix
plus some manual fixups
  • Loading branch information
hauntsaninja authored Sep 16, 2024
1 parent 2a45cec commit 8fb2add
Show file tree
Hide file tree
Showing 27 changed files with 298 additions and 377 deletions.
72 changes: 34 additions & 38 deletions src/black/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,13 @@
from typing import (
Any,
Collection,
Dict,
Generator,
Iterator,
List,
MutableMapping,
Optional,
Pattern,
Sequence,
Set,
Sized,
Tuple,
Union,
)

Expand Down Expand Up @@ -176,7 +172,7 @@ def read_pyproject_toml(
"line-ranges", "Cannot use line-ranges in the pyproject.toml file."
)

default_map: Dict[str, Any] = {}
default_map: dict[str, Any] = {}
if ctx.default_map:
default_map.update(ctx.default_map)
default_map.update(config)
Expand All @@ -186,9 +182,9 @@ def read_pyproject_toml(


def spellcheck_pyproject_toml_keys(
ctx: click.Context, config_keys: List[str], config_file_path: str
ctx: click.Context, config_keys: list[str], config_file_path: str
) -> None:
invalid_keys: List[str] = []
invalid_keys: list[str] = []
available_config_options = {param.name for param in ctx.command.params}
for key in config_keys:
if key not in available_config_options:
Expand All @@ -202,8 +198,8 @@ def spellcheck_pyproject_toml_keys(


def target_version_option_callback(
c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
) -> List[TargetVersion]:
c: click.Context, p: Union[click.Option, click.Parameter], v: tuple[str, ...]
) -> list[TargetVersion]:
"""Compute the target versions from a --target-version flag.
This is its own function because mypy couldn't infer the type correctly
Expand All @@ -213,8 +209,8 @@ def target_version_option_callback(


def enable_unstable_feature_callback(
c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...]
) -> List[Preview]:
c: click.Context, p: Union[click.Option, click.Parameter], v: tuple[str, ...]
) -> list[Preview]:
"""Compute the features from an --enable-unstable-feature flag."""
return [Preview[val] for val in v]

Expand Down Expand Up @@ -519,7 +515,7 @@ def main( # noqa: C901
ctx: click.Context,
code: Optional[str],
line_length: int,
target_version: List[TargetVersion],
target_version: list[TargetVersion],
check: bool,
diff: bool,
line_ranges: Sequence[str],
Expand All @@ -533,7 +529,7 @@ def main( # noqa: C901
skip_magic_trailing_comma: bool,
preview: bool,
unstable: bool,
enable_unstable_feature: List[Preview],
enable_unstable_feature: list[Preview],
quiet: bool,
verbose: bool,
required_version: Optional[str],
Expand All @@ -543,7 +539,7 @@ def main( # noqa: C901
force_exclude: Optional[Pattern[str]],
stdin_filename: Optional[str],
workers: Optional[int],
src: Tuple[str, ...],
src: tuple[str, ...],
config: Optional[str],
) -> None:
"""The uncompromising code formatter."""
Expand Down Expand Up @@ -643,7 +639,7 @@ def main( # noqa: C901
enabled_features=set(enable_unstable_feature),
)

lines: List[Tuple[int, int]] = []
lines: list[tuple[int, int]] = []
if line_ranges:
if ipynb:
err("Cannot use --line-ranges with ipynb files.")
Expand Down Expand Up @@ -733,7 +729,7 @@ def main( # noqa: C901
def get_sources(
*,
root: Path,
src: Tuple[str, ...],
src: tuple[str, ...],
quiet: bool,
verbose: bool,
include: Pattern[str],
Expand All @@ -742,14 +738,14 @@ def get_sources(
force_exclude: Optional[Pattern[str]],
report: "Report",
stdin_filename: Optional[str],
) -> Set[Path]:
) -> set[Path]:
"""Compute the set of files to be formatted."""
sources: Set[Path] = set()
sources: set[Path] = set()

assert root.is_absolute(), f"INTERNAL ERROR: `root` must be absolute but is {root}"
using_default_exclude = exclude is None
exclude = re_compile_maybe_verbose(DEFAULT_EXCLUDES) if exclude is None else exclude
gitignore: Optional[Dict[Path, PathSpec]] = None
gitignore: Optional[dict[Path, PathSpec]] = None
root_gitignore = get_gitignore(root)

for s in src:
Expand Down Expand Up @@ -841,7 +837,7 @@ def reformat_code(
mode: Mode,
report: Report,
*,
lines: Collection[Tuple[int, int]] = (),
lines: Collection[tuple[int, int]] = (),
) -> None:
"""
Reformat and print out `content` without spawning child processes.
Expand Down Expand Up @@ -874,7 +870,7 @@ def reformat_one(
mode: Mode,
report: "Report",
*,
lines: Collection[Tuple[int, int]] = (),
lines: Collection[tuple[int, int]] = (),
) -> None:
"""Reformat a single file under `src` without spawning child processes.
Expand Down Expand Up @@ -930,7 +926,7 @@ def format_file_in_place(
write_back: WriteBack = WriteBack.NO,
lock: Any = None, # multiprocessing.Manager().Lock() is some crazy proxy
*,
lines: Collection[Tuple[int, int]] = (),
lines: Collection[tuple[int, int]] = (),
) -> bool:
"""Format file under `src` path. Return True if changed.
Expand Down Expand Up @@ -997,7 +993,7 @@ def format_stdin_to_stdout(
content: Optional[str] = None,
write_back: WriteBack = WriteBack.NO,
mode: Mode,
lines: Collection[Tuple[int, int]] = (),
lines: Collection[tuple[int, int]] = (),
) -> bool:
"""Format file on stdin. Return True if changed.
Expand Down Expand Up @@ -1048,7 +1044,7 @@ def check_stability_and_equivalence(
dst_contents: str,
*,
mode: Mode,
lines: Collection[Tuple[int, int]] = (),
lines: Collection[tuple[int, int]] = (),
) -> None:
"""Perform stability and equivalence checks.
Expand All @@ -1065,7 +1061,7 @@ def format_file_contents(
*,
fast: bool,
mode: Mode,
lines: Collection[Tuple[int, int]] = (),
lines: Collection[tuple[int, int]] = (),
) -> FileContent:
"""Reformat contents of a file and return new contents.
Expand Down Expand Up @@ -1196,7 +1192,7 @@ def format_ipynb_string(src_contents: str, *, fast: bool, mode: Mode) -> FileCon


def format_str(
src_contents: str, *, mode: Mode, lines: Collection[Tuple[int, int]] = ()
src_contents: str, *, mode: Mode, lines: Collection[tuple[int, int]] = ()
) -> str:
"""Reformat a string and return new contents.
Expand Down Expand Up @@ -1243,10 +1239,10 @@ def f(


def _format_str_once(
src_contents: str, *, mode: Mode, lines: Collection[Tuple[int, int]] = ()
src_contents: str, *, mode: Mode, lines: Collection[tuple[int, int]] = ()
) -> str:
src_node = lib2to3_parse(src_contents.lstrip(), mode.target_versions)
dst_blocks: List[LinesBlock] = []
dst_blocks: list[LinesBlock] = []
if mode.target_versions:
versions = mode.target_versions
else:
Expand Down Expand Up @@ -1296,7 +1292,7 @@ def _format_str_once(
return "".join(dst_contents)


def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:
def decode_bytes(src: bytes) -> tuple[FileContent, Encoding, NewLine]:
"""Return a tuple of (decoded_contents, encoding, newline).
`newline` is either CRLF or LF but `decoded_contents` is decoded with
Expand All @@ -1314,8 +1310,8 @@ def decode_bytes(src: bytes) -> Tuple[FileContent, Encoding, NewLine]:


def get_features_used( # noqa: C901
node: Node, *, future_imports: Optional[Set[str]] = None
) -> Set[Feature]:
node: Node, *, future_imports: Optional[set[str]] = None
) -> set[Feature]:
"""Return a set of (relatively) new Python features used in this file.
Currently looking for:
Expand All @@ -1333,7 +1329,7 @@ def get_features_used( # noqa: C901
- except* clause;
- variadic generics;
"""
features: Set[Feature] = set()
features: set[Feature] = set()
if future_imports:
features |= {
FUTURE_FLAG_TO_FEATURE[future_import]
Expand Down Expand Up @@ -1471,20 +1467,20 @@ def _contains_asexpr(node: Union[Node, Leaf]) -> bool:


def detect_target_versions(
node: Node, *, future_imports: Optional[Set[str]] = None
) -> Set[TargetVersion]:
node: Node, *, future_imports: Optional[set[str]] = None
) -> set[TargetVersion]:
"""Detect the version to target based on the nodes used."""
features = get_features_used(node, future_imports=future_imports)
return {
version for version in TargetVersion if features <= VERSION_TO_FEATURES[version]
}


def get_future_imports(node: Node) -> Set[str]:
def get_future_imports(node: Node) -> set[str]:
"""Return a set of __future__ imports in the file."""
imports: Set[str] = set()
imports: set[str] = set()

def get_imports_from_children(children: List[LN]) -> Generator[str, None, None]:
def get_imports_from_children(children: list[LN]) -> Generator[str, None, None]:
for child in children:
if isinstance(child, Leaf):
if child.type == token.NAME:
Expand Down Expand Up @@ -1571,7 +1567,7 @@ def assert_equivalent(src: str, dst: str) -> None:


def assert_stable(
src: str, dst: str, mode: Mode, *, lines: Collection[Tuple[int, int]] = ()
src: str, dst: str, mode: Mode, *, lines: Collection[tuple[int, int]] = ()
) -> None:
"""Raise AssertionError if `dst` reformats differently the second time."""
if lines:
Expand Down
4 changes: 2 additions & 2 deletions src/black/_width_table.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Generated by make_width_table.py
# wcwidth 0.2.6
# Unicode 15.0.0
from typing import Final, List, Tuple
from typing import Final

WIDTH_TABLE: Final[List[Tuple[int, int, int]]] = [
WIDTH_TABLE: Final[list[tuple[int, int, int]]] = [
(0, 0, 0),
(1, 31, -1),
(127, 159, -1),
Expand Down
14 changes: 7 additions & 7 deletions src/black/brackets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Builds on top of nodes.py to track brackets."""

from dataclasses import dataclass, field
from typing import Dict, Final, Iterable, List, Optional, Sequence, Set, Tuple, Union
from typing import Final, Iterable, Optional, Sequence, Union

from black.nodes import (
BRACKET,
Expand Down Expand Up @@ -60,12 +60,12 @@ class BracketTracker:
"""Keeps track of brackets on a line."""

depth: int = 0
bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
delimiters: Dict[LeafID, Priority] = field(default_factory=dict)
bracket_match: dict[tuple[Depth, NodeType], Leaf] = field(default_factory=dict)
delimiters: dict[LeafID, Priority] = field(default_factory=dict)
previous: Optional[Leaf] = None
_for_loop_depths: List[int] = field(default_factory=list)
_lambda_argument_depths: List[int] = field(default_factory=list)
invisible: List[Leaf] = field(default_factory=list)
_for_loop_depths: list[int] = field(default_factory=list)
_lambda_argument_depths: list[int] = field(default_factory=list)
invisible: list[Leaf] = field(default_factory=list)

def mark(self, leaf: Leaf) -> None:
"""Mark `leaf` with bracket-related metadata. Keep track of delimiters.
Expand Down Expand Up @@ -353,7 +353,7 @@ def max_delimiter_priority_in_atom(node: LN) -> Priority:
return 0


def get_leaves_inside_matching_brackets(leaves: Sequence[Leaf]) -> Set[LeafID]:
def get_leaves_inside_matching_brackets(leaves: Sequence[Leaf]) -> set[LeafID]:
"""Return leaves that are inside matching brackets.
The input `leaves` can have non-matching brackets at the head or tail parts.
Expand Down
14 changes: 7 additions & 7 deletions src/black/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, Iterable, NamedTuple, Set, Tuple
from typing import Iterable, NamedTuple

from platformdirs import user_cache_dir

Expand Down Expand Up @@ -55,7 +55,7 @@ def get_cache_file(mode: Mode) -> Path:
class Cache:
mode: Mode
cache_file: Path
file_data: Dict[str, FileData] = field(default_factory=dict)
file_data: dict[str, FileData] = field(default_factory=dict)

@classmethod
def read(cls, mode: Mode) -> Self:
Expand All @@ -76,7 +76,7 @@ def read(cls, mode: Mode) -> Self:

with cache_file.open("rb") as fobj:
try:
data: Dict[str, Tuple[float, int, str]] = pickle.load(fobj)
data: dict[str, tuple[float, int, str]] = pickle.load(fobj)
file_data = {k: FileData(*v) for k, v in data.items()}
except (pickle.UnpicklingError, ValueError, IndexError):
return cls(mode, cache_file)
Expand Down Expand Up @@ -114,14 +114,14 @@ def is_changed(self, source: Path) -> bool:
return True
return False

def filtered_cached(self, sources: Iterable[Path]) -> Tuple[Set[Path], Set[Path]]:
def filtered_cached(self, sources: Iterable[Path]) -> tuple[set[Path], set[Path]]:
"""Split an iterable of paths in `sources` into two sets.
The first contains paths of files that modified on disk or are not in the
cache. The other contains paths to non-modified files.
"""
changed: Set[Path] = set()
done: Set[Path] = set()
changed: set[Path] = set()
done: set[Path] = set()
for src in sources:
if self.is_changed(src):
changed.add(src)
Expand All @@ -140,7 +140,7 @@ def write(self, sources: Iterable[Path]) -> None:
dir=str(self.cache_file.parent), delete=False
) as f:
# We store raw tuples in the cache because it's faster.
data: Dict[str, Tuple[float, int, str]] = {
data: dict[str, tuple[float, int, str]] = {
k: (*v,) for k, v in self.file_data.items()
}
pickle.dump(data, f, protocol=4)
Expand Down
Loading

0 comments on commit 8fb2add

Please sign in to comment.