Skip to content

Commit

Permalink
fix: prevent dupe calls, alternative (#546)
Browse files Browse the repository at this point in the history
* use variant generation to simplify discover callbacks

* add test checking that order is not important

* remove obsolete file

* fix tests

* fix type_registered

---------

Co-authored-by: Talley Lambert <[email protected]>
  • Loading branch information
Czaki and tlambert03 authored Oct 4, 2023
1 parent edf00f7 commit 21ee452
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 60 deletions.
137 changes: 78 additions & 59 deletions src/magicgui/type_map/_type_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import datetime
import inspect
import itertools
import os
import pathlib
import sys
Expand Down Expand Up @@ -366,6 +367,65 @@ def _validate_return_callback(func: Callable) -> None:
_T = TypeVar("_T", bound=type)


def _register_type_callback(
resolved_type: _T,
return_callback: ReturnCallback | None = None,
) -> list[type]:
modified_callbacks = []
if return_callback is None:
return []
_validate_return_callback(return_callback)
# if the type is a Union, add the callback to all of the types in the union
# (except NoneType)
if get_origin(resolved_type) is Union:
for type_per in _generate_union_variants(resolved_type):
if return_callback not in _RETURN_CALLBACKS[type_per]:
_RETURN_CALLBACKS[type_per].append(return_callback)
modified_callbacks.append(type_per)

for t in get_args(resolved_type):
if not _is_none_type(t) and return_callback not in _RETURN_CALLBACKS[t]:
_RETURN_CALLBACKS[t].append(return_callback)
modified_callbacks.append(t)
elif return_callback not in _RETURN_CALLBACKS[resolved_type]:
_RETURN_CALLBACKS[resolved_type].append(return_callback)
modified_callbacks.append(resolved_type)
return modified_callbacks


def _register_widget(
resolved_type: _T,
widget_type: WidgetRef | None = None,
**options: Any,
) -> WidgetTuple | None:
_options = cast(dict, options)

previous_widget = _TYPE_DEFS.get(resolved_type)

if "choices" in _options:
_TYPE_DEFS[resolved_type] = (widgets.ComboBox, _options)
if widget_type is not None:
warnings.warn(
"Providing `choices` overrides `widget_type`. Categorical widget "
f"will be used for type {resolved_type}",
stacklevel=2,
)
elif widget_type is not None:
if not isinstance(widget_type, (str, WidgetProtocol)) and not (
inspect.isclass(widget_type) and issubclass(widget_type, widgets.Widget)
):
raise TypeError(
'"widget_type" must be either a string, WidgetProtocol, or '
"Widget subclass"
)
_TYPE_DEFS[resolved_type] = (widget_type, _options)
elif "bind" in _options:
# if we're binding a value to this parameter, it doesn't matter what type
# of ValueWidget is used... it usually won't be shown
_TYPE_DEFS[resolved_type] = (widgets.EmptyWidget, _options)
return previous_widget


@overload
def register_type(
type_: _T,
Expand Down Expand Up @@ -435,43 +495,11 @@ def register_type(
"must be provided."
)

def _deco(type_: _T) -> _T:
resolved_type = resolve_single_type(type_)
if return_callback is not None:
_validate_return_callback(return_callback)
# if the type is a Union, add the callback to all of the types in the union
# (except NoneType)
if get_origin(resolved_type) is Union:
for t in get_args(resolved_type):
if not _is_none_type(t):
_RETURN_CALLBACKS[t].append(return_callback)
else:
_RETURN_CALLBACKS[resolved_type].append(return_callback)

_options = cast(dict, options)

if "choices" in _options:
_TYPE_DEFS[resolved_type] = (widgets.ComboBox, _options)
if widget_type is not None:
warnings.warn(
"Providing `choices` overrides `widget_type`. Categorical widget "
f"will be used for type {resolved_type}",
stacklevel=2,
)
elif widget_type is not None:
if not isinstance(widget_type, (str, WidgetProtocol)) and not (
inspect.isclass(widget_type) and issubclass(widget_type, widgets.Widget)
):
raise TypeError(
'"widget_type" must be either a string, WidgetProtocol, or '
"Widget subclass"
)
_TYPE_DEFS[resolved_type] = (widget_type, _options)
elif "bind" in _options:
# if we're binding a value to this parameter, it doesn't matter what type
# of ValueWidget is used... it usually won't be shown
_TYPE_DEFS[resolved_type] = (widgets.EmptyWidget, _options)
return type_
def _deco(type__: _T) -> _T:
resolved_type = resolve_single_type(type__)
_register_type_callback(resolved_type, return_callback)
_register_widget(resolved_type, widget_type, **options)
return type__

return _deco if type_ is None else _deco(type_)

Expand Down Expand Up @@ -507,23 +535,19 @@ def type_registered(
"""
resolved_type = resolve_single_type(type_)

# check if return_callback is already registered
rc_was_present = return_callback in _RETURN_CALLBACKS.get(resolved_type, [])
# store any previous widget_type and options for this type
prev_type_def: WidgetTuple | None = _TYPE_DEFS.get(resolved_type, None)
resolved_type = register_type(
resolved_type,
widget_type=widget_type,
return_callback=return_callback,
**options,
)

revert_list = _register_type_callback(resolved_type, return_callback)
prev_type_def = _register_widget(resolved_type, widget_type, **options)

new_type_def: WidgetTuple | None = _TYPE_DEFS.get(resolved_type, None)
try:
yield
finally:
# restore things to before the context
if return_callback is not None and not rc_was_present:
_RETURN_CALLBACKS[resolved_type].remove(return_callback)
if return_callback is not None: # this if is only for mypy
for return_callback_type in revert_list:
_RETURN_CALLBACKS[return_callback_type].remove(return_callback)

if _TYPE_DEFS.get(resolved_type, None) is not new_type_def:
warnings.warn("Type definition changed during context", stacklevel=2)
Expand All @@ -537,9 +561,6 @@ def type_registered(
def type2callback(type_: type) -> list[ReturnCallback]:
"""Return any callbacks that have been registered for ``type_``.
Note that if the return type is X, then the callbacks registered for Optional[X]
will be returned also be returned.
Parameters
----------
type_ : type
Expand All @@ -555,7 +576,7 @@ def type2callback(type_: type) -> list[ReturnCallback]:

# look for direct hits ...
# if it's an Optional, we need to look for the type inside the Optional
_, type_ = _is_optional(resolve_single_type(type_))
type_ = resolve_single_type(type_)
if type_ in _RETURN_CALLBACKS:
return _RETURN_CALLBACKS[type_]

Expand All @@ -566,10 +587,8 @@ def type2callback(type_: type) -> list[ReturnCallback]:
return []


def _is_optional(type_: Any) -> tuple[bool, type]:
# TODO: this function is too similar to _type_optional above... need to combine
if get_origin(type_) is Union:
args = get_args(type_)
if len(args) == 2 and any(_is_none_type(i) for i in args):
return True, next(i for i in args if not _is_none_type(i))
return False, type_
def _generate_union_variants(type_: Any) -> Iterator[type]:
type_args = get_args(type_)
for i in range(2, len(type_args) + 1):
for per in itertools.combinations(type_args, i):
yield cast(type, Union[per])
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,12 @@ def always_qapp(qapp):
for w in qapp.topLevelWidgets():
w.close()
w.deleteLater()


@pytest.fixture(autouse=True, scope="function")
def _clean_return_callbacks():
from magicgui.type_map._type_map import _RETURN_CALLBACKS

yield

_RETURN_CALLBACKS.clear()
36 changes: 35 additions & 1 deletion tests/test_magicgui.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import inspect
from enum import Enum
from typing import NewType, Optional
from typing import NewType, Optional, Union
from unittest.mock import Mock

import pytest
Expand Down Expand Up @@ -901,3 +901,37 @@ def func_optional(a: bool) -> ReturnType:
mock.reset_mock()
func_optional(a=False)
mock.assert_called_once_with(func_optional, None, ReturnType)


@pytest.mark.parametrize("optional", [True, False])
def test_no_duplication_call(optional):
mock = Mock()
mock2 = Mock()

NewInt = NewType("NewInt", int)
register_type(Optional[NewInt], return_callback=mock)
register_type(NewInt, return_callback=mock)
register_type(NewInt, return_callback=mock2)
ReturnType = Optional[NewInt] if optional else NewInt

@magicgui
def func() -> ReturnType:
return NewInt(1)

func()

mock.assert_called_once()
assert mock2.call_count == (not optional)


def test_no_order():
mock = Mock()

register_type(Union[int, None], return_callback=mock)

@magicgui
def func() -> Union[None, int]:
return 1

func()
mock.assert_called_once()
35 changes: 35 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,41 @@ def test_type_registered_warns():
assert isinstance(widgets.create_widget(annotation=Path), widgets.FileEdit)


def test_type_registered_optional_callbacks():
assert not _RETURN_CALLBACKS[int]
assert not _RETURN_CALLBACKS[Optional[int]]

@magicgui
def func1(a: int) -> int:
return a

@magicgui
def func2(a: int) -> Optional[int]:
return a

mock1 = Mock()
mock2 = Mock()
mock3 = Mock()

register_type(int, return_callback=mock2)

with type_registered(Optional[int], return_callback=mock1):
func1(1)
mock1.assert_called_once_with(func1, 1, int)
mock1.reset_mock()
func2(2)
mock1.assert_called_once_with(func2, 2, Optional[int])
mock1.reset_mock()
mock2.assert_called_once_with(func1, 1, int)
assert _RETURN_CALLBACKS[int] == [mock2, mock1]
assert _RETURN_CALLBACKS[Optional[int]] == [mock1]
register_type(Optional[int], return_callback=mock3)
assert _RETURN_CALLBACKS[Optional[int]] == [mock1, mock3]

assert _RETURN_CALLBACKS[Optional[int]] == [mock3]
assert _RETURN_CALLBACKS[int] == [mock2, mock3]


def test_pick_widget_literal():
from typing import Literal

Expand Down

0 comments on commit 21ee452

Please sign in to comment.