From cf062abe04c9cad1957c650fa214e014bdf5ab82 Mon Sep 17 00:00:00 2001 From: Henry Schreiner Date: Fri, 25 Aug 2023 15:09:40 -0400 Subject: [PATCH] fix: support ND quick constructs being mistakenly passed in (#528) Adding some mypy codes and it helped work out this issue. Signed-off-by: Henry Schreiner --- pyproject.toml | 3 ++- src/hist/basehist.py | 36 +++++++++++++++++++++++------------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e8c2f643..d14fe6d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -132,7 +132,8 @@ files = "src" python_version = "3.8" strict = true show_error_codes = true -enable_error_code = ["ignore-without-code", "truthy-bool"] +enable_error_code = ["ignore-without-code", "truthy-bool", "redundant-expr"] +warn_unreachable = true [[tool.mypy.overrides]] module = [ diff --git a/src/hist/basehist.py b/src/hist/basehist.py index de19f9e8..9834996a 100644 --- a/src/hist/basehist.py +++ b/src/hist/basehist.py @@ -4,7 +4,7 @@ import operator import typing import warnings -from typing import Any, Callable, Iterator, Mapping, Sequence, Tuple, Union +from typing import Any, Callable, Generator, Iterator, Mapping, Sequence, Tuple, Union import boost_histogram as bh import histoprint @@ -39,6 +39,7 @@ def __lt__(self, __other: Any) -> bool: ] IndexingWithMapping = Union[InnerIndexing, Mapping[Union[int, str], InnerIndexing]] IndexingExpr = Union[IndexingWithMapping, Tuple[IndexingWithMapping, ...]] +AxisTypes = Union[AxisProtocol, Tuple[int, float, float]] # Workaround for bug in mplhep @@ -53,12 +54,22 @@ def _proc_kw_for_lw(kwargs: Mapping[str, Any]) -> dict[str, Any]: } +def process_mistaken_quick_construct( + axes: Sequence[AxisTypes | hist.quick_construct.ConstructProxy], +) -> Generator[AxisTypes, None, None]: + for ax in axes: + if isinstance(ax, hist.quick_construct.ConstructProxy): + yield from ax.axes + else: + yield ax + + class BaseHist(bh.Histogram, metaclass=MetaConstructor, family=hist): __slots__ = () def __init__( self, - *args: AxisProtocol | Storage | str | tuple[int, float, float], + *in_args: AxisTypes | Storage | str, storage: Storage | str | None = None, metadata: Any = None, data: np.typing.NDArray[Any] | None = None, @@ -73,17 +84,16 @@ def __init__( self.name = name self.label = label - if args and storage is None and isinstance(args[-1], (Storage, str)): - storage = args[-1] - args = args[:-1] + args: tuple[AxisTypes, ...] + + if in_args and storage is None and isinstance(in_args[-1], (Storage, str)): + storage = in_args[-1] + args = in_args[:-1] # type: ignore[assignment] + else: + args = in_args # type: ignore[assignment] # Support raw Quick Construct being accidentally passed in - args = [ # type: ignore[assignment] - a.axes[0] # type: ignore[union-attr, union-attr, union-attr, union-attr] - if isinstance(a, hist.quick_construct.ConstructProxy) and len(a.axes) == 1 - else a - for a in args - ] + args = tuple(ax for ax in process_mistaken_quick_construct(args)) if isinstance(storage, str): storage_str = storage.title() @@ -93,7 +103,7 @@ def __init__( storage_str = "WeightedMean" storage = getattr(bh.storage, storage_str)() elif isinstance(storage, type): - msg = f"Please use '{storage.__name__}()' instead of '{storage.__name__}'" + msg = f"Please use '{storage.__name__}()' instead of '{storage.__name__}'" # type: ignore[unreachable] warnings.warn(msg, stacklevel=2) storage = storage() @@ -222,7 +232,7 @@ def fill( """ data_dict = { - self._name_to_index(k) if isinstance(k, str) else k: v + self._name_to_index(k) if isinstance(k, str) else k: v # type: ignore[redundant-expr] for k, v in kwargs.items() }