From 54d494856c6220dba9749bfb40e4da481efec939 Mon Sep 17 00:00:00 2001 From: Andrey Rakhmatullin Date: Thu, 27 Jul 2023 20:35:41 +0400 Subject: [PATCH] Add get_generic_param(). --- tests/test_pages.py | 27 +++++++++++++++- tests/test_utils.py | 78 +++++++++++++++++++++++++++++++++++++++++++-- web_poet/_typing.py | 24 -------------- web_poet/pages.py | 33 +++++++++++++++---- web_poet/rules.py | 3 +- web_poet/utils.py | 25 ++++++++++++++- 6 files changed, 153 insertions(+), 37 deletions(-) delete mode 100644 web_poet/_typing.py diff --git a/tests/test_pages.py b/tests/test_pages.py index 9e865835..315569b9 100644 --- a/tests/test_pages.py +++ b/tests/test_pages.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Generic, List, Optional, TypeVar import attrs import pytest @@ -232,6 +232,31 @@ class SubclassStrict(BasePage, Returns[Item]): await page2.to_item() +def test_returns_inheritance() -> None: + @attrs.define + class MyItem: + name: str + + class BasePage(ItemPage[MyItem]): + @field + def name(self): + return "hello" + + MetadataT = TypeVar("MetadataT") + + class HasMetadata(Generic[MetadataT]): + pass + + class DummyMetadata: + pass + + class Page(BasePage, HasMetadata[DummyMetadata]): + pass + + page = Page() + assert page.item_cls is MyItem + + @pytest.mark.asyncio async def test_extractor(book_list_html_response) -> None: @attrs.define diff --git a/tests/test_utils.py b/tests/test_utils.py index c6f52733..9094fdbf 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,12 +2,17 @@ import inspect import random import warnings -from typing import Any +from typing import Any, Generic, TypeVar from unittest import mock import pytest -from web_poet.utils import _create_deprecated_class, cached_method, ensure_awaitable +from web_poet.utils import ( + _create_deprecated_class, + cached_method, + ensure_awaitable, + get_generic_param, +) class SomeBaseClass: @@ -466,3 +471,72 @@ async def n_called(self): foo.n_called(), ) assert results == [1, 1, 1, 1, 1] + + +ItemT = TypeVar("ItemT") + + +class Item: + pass + + +class Item2: + pass + + +class MyGeneric(Generic[ItemT]): + pass + + +class MyGeneric2(Generic[ItemT]): + pass + + +class Base(MyGeneric[ItemT]): + pass + + +class BaseSpecialized(MyGeneric[Item]): + pass + + +class BaseAny(MyGeneric): + pass + + +class Derived(Base): + pass + + +class Specialized(BaseSpecialized): + pass + + +class SpecializedAdditionalClass(BaseSpecialized, Item2): + pass + + +class SpecializedTwice(BaseSpecialized, Base[Item2]): + pass + + +class SpecializedTwoGenerics(MyGeneric2[Item2], BaseSpecialized): + pass + + +@pytest.mark.parametrize( + ["cls", "param"], + [ + (MyGeneric, None), + (Base, None), + (BaseAny, None), + (Derived, None), + (BaseSpecialized, Item), + (Specialized, Item), + (SpecializedAdditionalClass, Item), + (SpecializedTwice, Item2), + (SpecializedTwoGenerics, Item), + ], +) +def test_get_generic_param(cls, param) -> None: + assert get_generic_param(cls, expected=MyGeneric) == param diff --git a/web_poet/_typing.py b/web_poet/_typing.py deleted file mode 100644 index d7a2c984..00000000 --- a/web_poet/_typing.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Utilities for typing""" -import typing - - -def is_generic_alias(obj) -> bool: - for attr_name in ["GenericAlias", "_GenericAlias"]: - if hasattr(typing, attr_name): - if isinstance(obj, getattr(typing, attr_name)): - return True - return False - - -def get_generic_parameter(cls): - for base in getattr(cls, "__orig_bases__", []): - if is_generic_alias(base): - args = typing.get_args(base) - return args[0] - - -def get_item_cls(cls, default=None): - param = get_generic_parameter(cls) - if param is None or isinstance(param, typing.TypeVar): # class is not parametrized - return default - return param diff --git a/web_poet/pages.py b/web_poet/pages.py index e53a585b..e3cd3dca 100644 --- a/web_poet/pages.py +++ b/web_poet/pages.py @@ -1,17 +1,21 @@ import abc import inspect -import typing from contextlib import suppress from functools import wraps +from typing import Any, Generic, Optional, TypeVar, overload import attr import parsel -from web_poet._typing import get_item_cls from web_poet.fields import FieldsMixin, item_from_fields from web_poet.mixins import ResponseShortcutsMixin, SelectorShortcutsMixin from web_poet.page_inputs import HttpResponse -from web_poet.utils import CallableT, _create_deprecated_class, cached_method +from web_poet.utils import ( + CallableT, + _create_deprecated_class, + cached_method, + get_generic_param, +) class Injectable(abc.ABC, FieldsMixin): @@ -35,25 +39,40 @@ class Injectable(abc.ABC, FieldsMixin): Injectable.register(type(None)) -def is_injectable(cls: typing.Any) -> bool: +def is_injectable(cls: Any) -> bool: """Return True if ``cls`` is a class which inherits from :class:`~.Injectable`.""" return isinstance(cls, type) and issubclass(cls, Injectable) -ItemT = typing.TypeVar("ItemT") +ItemT = TypeVar("ItemT") -class Returns(typing.Generic[ItemT]): +class Returns(Generic[ItemT]): """Inherit from this generic mixin to change the item class used by :class:`~.ItemPage`""" @property - def item_cls(self) -> typing.Type[ItemT]: + def item_cls(self) -> type: """Item class""" return get_item_cls(self.__class__, default=dict) +@overload +def get_item_cls(cls: type, default: type) -> type: + ... + + +@overload +def get_item_cls(cls: type, default: None) -> Optional[type]: + ... + + +def get_item_cls(cls: type, default: Optional[type] = None) -> Optional[type]: + param = get_generic_param(cls, Returns) + return param or default + + _NOT_SET = object() diff --git a/web_poet/rules.py b/web_poet/rules.py index 53f3efe0..3337bac2 100644 --- a/web_poet/rules.py +++ b/web_poet/rules.py @@ -22,9 +22,8 @@ import attrs from url_matcher import Patterns, URLMatcher -from web_poet._typing import get_item_cls from web_poet.page_inputs.url import _Url -from web_poet.pages import ItemPage +from web_poet.pages import ItemPage, get_item_cls from web_poet.utils import _create_deprecated_class, as_list, str_to_pattern Strings = Union[str, Iterable[str]] diff --git a/web_poet/utils.py b/web_poet/utils.py index 4eccc997..49d7b8c0 100644 --- a/web_poet/utils.py +++ b/web_poet/utils.py @@ -1,9 +1,10 @@ import inspect import weakref +from collections import deque from collections.abc import Iterable from functools import lru_cache, partial, wraps from types import MethodType -from typing import Any, Callable, List, Optional, TypeVar, Union +from typing import Any, Callable, List, Optional, Tuple, TypeVar, Union, get_args from warnings import warn import packaging.version @@ -273,3 +274,25 @@ def str_to_pattern(url_pattern: Union[str, Patterns]) -> Patterns: if isinstance(url_pattern, Patterns): return url_pattern return Patterns([url_pattern]) + + +def get_generic_param( + cls: type, expected: Union[type, Tuple[type, ...]] +) -> Optional[type]: + """Search the base classes recursively breadth-first for a generic class and return its param. + + Returns the param of the first found class that is a subclass of ``expected``. + """ + visited = set() + queue = deque([cls]) + while queue: + node = queue.popleft() + visited.add(node) + for base in getattr(node, "__orig_bases__", []): + origin = getattr(base, "__origin__", None) + if origin and issubclass(origin, expected): + result = get_args(base)[0] + if not isinstance(result, TypeVar): + return result + queue.append(base) + return None