Skip to content

Commit

Permalink
fix: support ND quick constructs being mistakenly passed in
Browse files Browse the repository at this point in the history
Signed-off-by: Henry Schreiner <[email protected]>
  • Loading branch information
henryiii committed Aug 25, 2023
1 parent fcd40e7 commit de7afb0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ files = "src"
python_version = "3.8"
strict = true
show_error_codes = true
enable_error_code = ["ignore-without-code", "truthy-bool"]
enable_error_code = ["ignore-without-code", "truthy-bool", "redundant-expr"]
warn_unreachable = true

[[tool.mypy.overrides]]
module = [
Expand Down
36 changes: 23 additions & 13 deletions src/hist/basehist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import operator
import typing
import warnings
from typing import Any, Callable, Iterator, Mapping, Sequence, Tuple, Union
from typing import Any, Callable, Generator, Iterator, Mapping, Sequence, Tuple, Union

import boost_histogram as bh
import histoprint
Expand Down Expand Up @@ -39,6 +39,7 @@ def __lt__(self, __other: Any) -> bool:
]
IndexingWithMapping = Union[InnerIndexing, Mapping[Union[int, str], InnerIndexing]]
IndexingExpr = Union[IndexingWithMapping, Tuple[IndexingWithMapping, ...]]
AxisTypes = Union[AxisProtocol, Tuple[int, float, float]]


# Workaround for bug in mplhep
Expand All @@ -53,12 +54,22 @@ def _proc_kw_for_lw(kwargs: Mapping[str, Any]) -> dict[str, Any]:
}


def process_mistaken_quick_construct(
axes: Sequence[AxisTypes | hist.quick_construct.ConstructProxy],
) -> Generator[AxisTypes, None, None]:
for ax in axes:
if isinstance(ax, hist.quick_construct.ConstructProxy):
yield from ax.axes
else:
yield ax


class BaseHist(bh.Histogram, metaclass=MetaConstructor, family=hist):
__slots__ = ()

def __init__(
self,
*args: AxisProtocol | Storage | str | tuple[int, float, float],
*in_args: AxisTypes | Storage | str,
storage: Storage | str | None = None,
metadata: Any = None,
data: np.typing.NDArray[Any] | None = None,
Expand All @@ -73,17 +84,16 @@ def __init__(
self.name = name
self.label = label

if args and storage is None and isinstance(args[-1], (Storage, str)):
storage = args[-1]
args = args[:-1]
args: tuple[AxisTypes, ...]

if in_args and storage is None and isinstance(in_args[-1], (Storage, str)):
storage = in_args[-1]
args = in_args[:-1] # type: ignore[assignment]
else:
args = in_args # type: ignore[assignment]

# Support raw Quick Construct being accidentally passed in
args = [ # type: ignore[assignment]
a.axes[0] # type: ignore[union-attr, union-attr, union-attr, union-attr]
if isinstance(a, hist.quick_construct.ConstructProxy) and len(a.axes) == 1
else a
for a in args
]
args = tuple(ax for ax in process_mistaken_quick_construct(args))

if isinstance(storage, str):
storage_str = storage.title()
Expand All @@ -93,7 +103,7 @@ def __init__(
storage_str = "WeightedMean"
storage = getattr(bh.storage, storage_str)()
elif isinstance(storage, type):
msg = f"Please use '{storage.__name__}()' instead of '{storage.__name__}'"
msg = f"Please use '{storage.__name__}()' instead of '{storage.__name__}'" # type: ignore[unreachable]
warnings.warn(msg, stacklevel=2)
storage = storage()

Expand Down Expand Up @@ -222,7 +232,7 @@ def fill(
"""

data_dict = {
self._name_to_index(k) if isinstance(k, str) else k: v
self._name_to_index(k) if isinstance(k, str) else k: v # type: ignore[redundant-expr]
for k, v in kwargs.items()
}

Expand Down

0 comments on commit de7afb0

Please sign in to comment.