Skip to content

Commit

Permalink
added resolve open functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
mastern2k3 committed Feb 23, 2022
1 parent 39f8a48 commit 9408553
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 12 deletions.
1 change: 1 addition & 0 deletions src/deadsimple/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .resolve import resolve, Depends
from .lazy import Lazy, LazyResolver
from .exceptions import GeneratorClosureException, InvalidGeneratorFactoryExcpetion
from .context import ResolutionContextManager, resolve_open
60 changes: 60 additions & 0 deletions src/deadsimple/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import TypeVar, Generic, Callable
from dataclasses import dataclass

from .resolve import (
_Context,
_resolve,
_close_open_generators,
get_context,
_create_context,
)


TContextValue = TypeVar("TContextValue")


@dataclass
class ResolutionContextManager(Generic[TContextValue]):

factory: Callable[..., TContextValue]
context: _Context

def __enter__(self) -> TContextValue:

try:
value = _resolve(self.factory, self.context)
except Exception as ex:
if len(self.context.open_generators) > 0:
_close_open_generators(self.context, ex)
raise ex

return value

def __exit__(self, exc_type, exc_value, exc_traceback):
if len(self.context.open_generators) > 0:
_close_open_generators(self.context, exc_value)


def resolve_open(
factory: Callable[..., TContextValue],
overrides: dict = None,
) -> ResolutionContextManager[TContextValue]:

context = _create_context(overrides)

resolved_cache = {}

context = _Context(
resolved_cache=resolved_cache,
open_generators=[],
)

resolved_cache[get_context] = context

if overrides is not None:
resolved_cache.update(overrides)

return ResolutionContextManager[TContextValue](
factory=factory,
context=context,
)
33 changes: 21 additions & 12 deletions src/deadsimple/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,17 @@ class _Context:


def get_context() -> _Context:
raise NotImplementedError("get_context is an abstract dependency")
raise NotImplementedError(
"get_context is an abstract dependency and is only avaliable during injection"
)


_resolver_cache: Dict[Callable, Callable[[_Context], Any]] = {}


def resolve(factory: Callable[..., TReturn], overrides: dict = None) -> TReturn:

if overrides is not None:
resolved_cache = overrides
else:
resolved_cache = {}

context = _Context(
resolved_cache=resolved_cache,
open_generators=[],
)

resolved_cache[get_context] = context
context = _create_context(overrides)

resolve_exception = None

Expand All @@ -58,6 +50,23 @@ def resolve(factory: Callable[..., TReturn], overrides: dict = None) -> TReturn:
return value


def _create_context(overrides: dict = None) -> _Context:

resolved_cache = {}

context = _Context(
resolved_cache=resolved_cache,
open_generators=[],
)

resolved_cache[get_context] = context

if overrides is not None:
resolved_cache.update(overrides)

return context


def _close_open_generators(context: _Context, resolve_exception: Optional[Exception]):

exceptions = None
Expand Down
95 changes: 95 additions & 0 deletions src/tests/resolve_open_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from dataclasses import dataclass
from unittest.mock import Mock, call

from pytest import raises

from deadsimple import Depends, resolve_open


@dataclass
class _TestDepA:
value: str


@dataclass
class _TestDepB:
dep_a: _TestDepA


def test_context_closes_only_after_exit():

mock = Mock()

def get_dep_a():
mock.enter_a()
yield _TestDepA(value="some value")
mock.exit_a()

with resolve_open(get_dep_a) as dep_a:
assert dep_a.value == "some value"
mock.inside_scope()

expected_calls = [
call.enter_a(),
call.inside_scope(),
call.exit_a(),
]

assert mock.mock_calls == expected_calls


def test_generators_closed_on_exception_in_context():

mock = Mock()

def get_dep_a():
mock.enter_a()
yield _TestDepA(value="some value")
mock.exit_a()

class _Exception(Exception):
pass

with raises(_Exception):
with resolve_open(get_dep_a) as dep_a:
mock.inside_scope()
assert dep_a.value == "some value"
raise _Exception()

expected_calls = [
call.enter_a(),
call.inside_scope(),
call.exit_a(),
]

assert mock.mock_calls == expected_calls


def test_generators_closed_on_exception_in_dependency():

mock = Mock()

class _Exception(Exception):
pass

def get_dep_a():
mock.enter_a()
yield _TestDepA(value="some value")
mock.exit_a()

def get_dep_b(dep_a=Depends(get_dep_a)):
assert dep_a.value == "some value"
mock.enter_b()
raise _Exception()

with raises(_Exception):
with resolve_open(get_dep_b):
mock.inside_scope()

expected_calls = [
call.enter_a(),
call.enter_b(),
call.exit_a(),
]

assert mock.mock_calls == expected_calls

0 comments on commit 9408553

Please sign in to comment.