diff --git a/src/deadsimple/__init__.py b/src/deadsimple/__init__.py index bb34414..681c8e9 100644 --- a/src/deadsimple/__init__.py +++ b/src/deadsimple/__init__.py @@ -1,3 +1,4 @@ from .resolve import resolve, Depends from .lazy import Lazy, LazyResolver from .exceptions import GeneratorClosureException, InvalidGeneratorFactoryExcpetion +from .context import ResolutionContextManager, resolve_open diff --git a/src/deadsimple/context.py b/src/deadsimple/context.py new file mode 100644 index 0000000..99001d8 --- /dev/null +++ b/src/deadsimple/context.py @@ -0,0 +1,60 @@ +from typing import TypeVar, Generic, Callable +from dataclasses import dataclass + +from .resolve import ( + _Context, + _resolve, + _close_open_generators, + get_context, + _create_context, +) + + +TContextValue = TypeVar("TContextValue") + + +@dataclass +class ResolutionContextManager(Generic[TContextValue]): + + factory: Callable[..., TContextValue] + context: _Context + + def __enter__(self) -> TContextValue: + + try: + value = _resolve(self.factory, self.context) + except Exception as ex: + if len(self.context.open_generators) > 0: + _close_open_generators(self.context, ex) + raise ex + + return value + + def __exit__(self, exc_type, exc_value, exc_traceback): + if len(self.context.open_generators) > 0: + _close_open_generators(self.context, exc_value) + + +def resolve_open( + factory: Callable[..., TContextValue], + overrides: dict = None, +) -> ResolutionContextManager[TContextValue]: + + context = _create_context(overrides) + + resolved_cache = {} + + context = _Context( + resolved_cache=resolved_cache, + open_generators=[], + ) + + resolved_cache[get_context] = context + + if overrides is not None: + resolved_cache.update(overrides) + + return ResolutionContextManager[TContextValue]( + factory=factory, + context=context, + ) diff --git a/src/deadsimple/resolve.py b/src/deadsimple/resolve.py index 25daec2..66a0cbd 100644 --- a/src/deadsimple/resolve.py +++ b/src/deadsimple/resolve.py @@ -24,7 +24,9 @@ class _Context: def get_context() -> _Context: - raise NotImplementedError("get_context is an abstract dependency") + raise NotImplementedError( + "get_context is an abstract dependency and is only avaliable during injection" + ) _resolver_cache: Dict[Callable, Callable[[_Context], Any]] = {} @@ -32,17 +34,7 @@ def get_context() -> _Context: def resolve(factory: Callable[..., TReturn], overrides: dict = None) -> TReturn: - if overrides is not None: - resolved_cache = overrides - else: - resolved_cache = {} - - context = _Context( - resolved_cache=resolved_cache, - open_generators=[], - ) - - resolved_cache[get_context] = context + context = _create_context(overrides) resolve_exception = None @@ -58,6 +50,23 @@ def resolve(factory: Callable[..., TReturn], overrides: dict = None) -> TReturn: return value +def _create_context(overrides: dict = None) -> _Context: + + resolved_cache = {} + + context = _Context( + resolved_cache=resolved_cache, + open_generators=[], + ) + + resolved_cache[get_context] = context + + if overrides is not None: + resolved_cache.update(overrides) + + return context + + def _close_open_generators(context: _Context, resolve_exception: Optional[Exception]): exceptions = None diff --git a/src/tests/resolve_open_test.py b/src/tests/resolve_open_test.py new file mode 100644 index 0000000..ae6d3c0 --- /dev/null +++ b/src/tests/resolve_open_test.py @@ -0,0 +1,95 @@ +from dataclasses import dataclass +from unittest.mock import Mock, call + +from pytest import raises + +from deadsimple import Depends, resolve_open + + +@dataclass +class _TestDepA: + value: str + + +@dataclass +class _TestDepB: + dep_a: _TestDepA + + +def test_context_closes_only_after_exit(): + + mock = Mock() + + def get_dep_a(): + mock.enter_a() + yield _TestDepA(value="some value") + mock.exit_a() + + with resolve_open(get_dep_a) as dep_a: + assert dep_a.value == "some value" + mock.inside_scope() + + expected_calls = [ + call.enter_a(), + call.inside_scope(), + call.exit_a(), + ] + + assert mock.mock_calls == expected_calls + + +def test_generators_closed_on_exception_in_context(): + + mock = Mock() + + def get_dep_a(): + mock.enter_a() + yield _TestDepA(value="some value") + mock.exit_a() + + class _Exception(Exception): + pass + + with raises(_Exception): + with resolve_open(get_dep_a) as dep_a: + mock.inside_scope() + assert dep_a.value == "some value" + raise _Exception() + + expected_calls = [ + call.enter_a(), + call.inside_scope(), + call.exit_a(), + ] + + assert mock.mock_calls == expected_calls + + +def test_generators_closed_on_exception_in_dependency(): + + mock = Mock() + + class _Exception(Exception): + pass + + def get_dep_a(): + mock.enter_a() + yield _TestDepA(value="some value") + mock.exit_a() + + def get_dep_b(dep_a=Depends(get_dep_a)): + assert dep_a.value == "some value" + mock.enter_b() + raise _Exception() + + with raises(_Exception): + with resolve_open(get_dep_b): + mock.inside_scope() + + expected_calls = [ + call.enter_a(), + call.enter_b(), + call.exit_a(), + ] + + assert mock.mock_calls == expected_calls