From 8eae1ba511ea40b2462088eeb0e8a78582d0c179 Mon Sep 17 00:00:00 2001 From: Jelmer Date: Fri, 15 Jul 2022 17:15:43 +0200 Subject: [PATCH 1/3] Allow components to have __init__ functions --- zookeeper/core/component.py | 368 ++++++++++++++++++++----------- zookeeper/core/component_test.py | 221 +++++++++++++++---- zookeeper/core/field_test.py | 7 +- zookeeper/core/utils.py | 21 ++ 4 files changed, 443 insertions(+), 174 deletions(-) diff --git a/zookeeper/core/component.py b/zookeeper/core/component.py index dccfa79..acf6926 100644 --- a/zookeeper/core/component.py +++ b/zookeeper/core/component.py @@ -78,12 +78,100 @@ class C: import functools import inspect +import logging from typing import AbstractSet, Any, Dict, Iterator, List, Optional, Tuple, Type from zookeeper.core import utils from zookeeper.core.factory_registry import FACTORY_REGISTRY from zookeeper.core.field import ComponentField, Field -from zookeeper.core.utils import ConfigurationError +from zookeeper.core.utils import ( + ConfigurationError, + configuration_mode, + in_configuration_mode, +) + + +########################## +# A component `__init__` # +########################## +class ComponentInit: + """ + We wrap both the original __init__ and the __component_init__ in this descriptor. + If we would simply store e.g. `__original__init` on every component, we + would run into problems with resolution of calls to `super().__init__()`, since they + would trigger `__component_init__` on the parent class again, with the same value + for the `instance` argument. + + By using a descriptor, we can effectively create a different `__component_init__` + for every class, and thus make sure that `super().__init__()` refers to the correct + function. + """ + + def __init__(self, original_init): + self.original_init = original_init + + @staticmethod + def __component_init__(instance, **kwargs): + """Accepts keyword-arguments corresponding to fields defined on the component.""" + + # Use the `kwargs` to set field values. + for name, value in kwargs.items(): + if name not in instance.__component_fields__: + raise TypeError( + "Keyword arguments passed to component `__init__` must correspond to " + f"component fields. Received non-matching argument '{name}'." + ) + if utils.is_component_instance(value): + if value.__component_configured__: + raise ValueError( + "Sub-component instances passed to the `__init__` method of a " + "component must not already be configured. Received configured " + f"component argument '{name}={repr(value)}'." + ) + # Set the component parent correctly if the value being passed in + # does not already have a parent. + if value.__component_parent__ is None: + value.__component_parent__ = instance + + # This will contain default field values from the fields defined on this + # instance. + instance.__component_default_field_values__ = {} + + # Save a shallow-clone of the arguments. + instance.__component_instantiated_field_values__ = {**kwargs} + + # This will contain configured field values that apply to this instance, if + # any, which override everything else. + instance.__component_configured_field_values__ = {} + + # This will contain the names of every field of this instance and all + # component ancestors for which a value is defined. More names will be added + # during configuration. + instance.__component_fields_with_values_in_scope__ = set( + field.name + for field in instance.__component_fields__.values() + if field.has_default + ) | set(kwargs) + + def __get__(self, instance, type=None): + """`instance` is the instance on which `__init__` is called. If the instance + has not yet been configured, we defer to `__component_init__`. If it has been + configured, we defer to `self.original_init` instead to execute the user's code. + """ + if instance is None: + # This happens when we call `__init__` on the class, e.g. `A.__init__`. + # In that case, we want to just access the descriptor for testing purposes. + return self + + if not instance.__component_configured__: + f = functools.partial(self.__component_init__, instance) + f = functools.update_wrapper(f, self.__component_init__) + else: + f = functools.partial(self.original_init, instance) + f = functools.update_wrapper(f, self.original_init) + + return f + ################################### # Component class method wrappers # @@ -104,6 +192,39 @@ def _type_check_and_maybe_cache(instance, field: Field, result: Any) -> None: object.__setattr__(instance, field.name, result) +def _wrap_init(component_cls: Type) -> None: + # If `__init__` is a ComponentInit, we don't want to call that. + # This means we subclass another component, in which case we want to call the + # original init of the parent class. + if isinstance(component_cls.__init__, ComponentInit): + original_init = component_cls.__init__.original_init + # Otherwise, keep track of any user `__init__`so we can call it after configuration. + else: + original_init = component_cls.__init__ + if original_init is not object.__init__: + if not callable(original_init): + raise TypeError( + "The `__init__` attribute of a @component class must be a method." + ) + call_args = inspect.signature(original_init).parameters + if len(call_args) > 1 or len(call_args) == 1 and "self" not in call_args: + raise TypeError( + "The `__init__` method of a @component class must take no " + f"arguments except `self`, but `{component_cls.__name__}.__init__` " + f"accepts arguments {tuple(name for name in call_args)}." + ) + + if hasattr(component_cls, "__post_configure__"): + raise TypeError( + f"{component_cls.__name__} has a deprecated `__post_configure__` method" + " as well as a custom `__init__`! Only one can be defined at a time." + ) + + # Set __init__ to the ComponentInit descriptor, so that it will resolve correctly + # based on the configuration state of the component. + component_cls.__init__ = ComponentInit(original_init=original_init) + + def _wrap_getattribute(component_cls: Type) -> None: """The logic for this overriden `__getattribute__` is as follows: @@ -142,6 +263,16 @@ def _wrap_getattribute(component_cls: Type) -> None: @functools.wraps(fn) def base_wrapped_fn(instance, name): + if ( + not in_configuration_mode() + and not name.startswith("__") + and not instance.__component_configured__ + ): + raise ConfigurationError( + f"Component '{instance.__component_name__}' has not been configured yet!" + " Please call `configure` before accessing any attributes." + ) + # If this is not an access to a field, return via the wrapped function. if name not in fn(instance, "__component_fields__"): return fn(instance, name) @@ -303,6 +434,25 @@ def wrapped_fn(instance) -> List[str]: component_cls.__dir__ = wrapped_fn +def _wrap_call(component_cls: Type) -> None: + # We only wrap the call function if there was one to begin with. + if not hasattr(component_cls, "__call__"): + return + + fn = component_cls.__call__ # type: ignore + + @functools.wraps(fn) + def wrapped_fn(instance, *args, **kwargs) -> List[str]: + if not instance.__component_configured__: + raise ConfigurationError( + f"Component '{instance.__component_name__}' has not been configured yet!" + " Please call `configure` before calling this component." + ) + return fn(instance, *args, **kwargs) + + component_cls.__call__ = wrapped_fn + + ###################################### # Implement the `ItemsView` protocol # ###################################### @@ -441,54 +591,6 @@ def __component_str__(instance): return f"{instance.__class__.__name__}(\n{INDENT}{joined_str}\n)" -########################## -# A component `__init__` # -########################## - - -def __component_init__(instance, **kwargs): - """Accepts keyword-arguments corresponding to fields defined on the component.""" - - # Use the `kwargs` to set field values. - for name, value in kwargs.items(): - if name not in instance.__component_fields__: - raise TypeError( - "Keyword arguments passed to component `__init__` must correspond to " - f"component fields. Received non-matching argument '{name}'." - ) - if utils.is_component_instance(value): - if value.__component_configured__: - raise ValueError( - "Sub-component instances passed to the `__init__` method of a " - "component must not already be configured. Received configured " - f"component argument '{name}={repr(value)}'." - ) - # Set the component parent correctly if the value being passed in - # does not already have a parent. - if value.__component_parent__ is None: - value.__component_parent__ = instance - - # This will contain default field values from the fields defined on this - # instance. - instance.__component_default_field_values__ = {} - - # Save a shallow-clone of the arguments. - instance.__component_instantiated_field_values__ = {**kwargs} - - # This will contain configured field values that apply to this instance, if - # any, which override everything else. - instance.__component_configured_field_values__ = {} - - # This will contain the names of every field of this instance and all - # component ancestors for which a value is defined. More names will be added - # during configuration. - instance.__component_fields_with_values_in_scope__ = set( - field.name - for field in instance.__component_fields__.values() - if field.has_default - ) | set(kwargs) - - ##################################### # Recursive component configuration # ##################################### @@ -694,6 +796,9 @@ def configure_component_instance( if hasattr(instance.__class__, "__post_configure__"): instance.__post_configure__() + # This calls the original init of the instance rather than __component_init__. + instance.__init__() # TODO: Should we support any arguments? + return conf @@ -717,11 +822,7 @@ def component(cls: Type): "cannot be applied again." ) - if cls.__init__ not in (object.__init__, __component_init__): - # A component class could have `__component_init__` as its init method - # if it inherits from a component. - raise TypeError("Component classes must not define a custom `__init__` method.") - cls.__init__ = __component_init__ + _wrap_init(cls) if hasattr(cls, "__pre_configure__"): if not callable(cls.__pre_configure__): @@ -739,6 +840,10 @@ def component(cls: Type): ) if hasattr(cls, "__post_configure__"): + logging.warning( + f"{cls.__name__} has a deprecated `__post_configure__` method! " + "Rename it to __init__ instead, the functionality is the same." + ) if not callable(cls.__post_configure__): raise TypeError( "The `__post_configure__` attribute of a @component class must be a " @@ -785,6 +890,7 @@ def component(cls: Type): _wrap_setattr(cls) _wrap_delattr(cls) _wrap_dir(cls) + _wrap_call(cls) # Implement the `ItemsView` protocol if hasattr(cls, "__len__") and cls.__len__ != __component_len__: @@ -823,87 +929,93 @@ def configure( overwrite any values already set on the instance - either class defaults or those set in `__init__`. """ - # Only component instances can be configured. - if not utils.is_component_instance(instance): - raise TypeError( - "Only @component, @factory, and @task instances can be configured. " - f"Received: {instance}." - ) - - # Configuration can only happen once. - if instance.__component_configured__: - raise ValueError( - f"Component '{instance.__component_name__}' has already been configured." - ) - # Maintain a FIFO queue of component instances that need to be configured, - # along with the config dict, name that should be passed, and a set of field - # names that are in-scope for component field inheritence. - # This queue allows us to recursively configure component instances in - # the component tree in a top-down, breadth-first order. - fifo_component_queue = [(instance, conf, name, frozenset(conf.keys()))] - - while len(fifo_component_queue) > 0: - ( - current_instance, - current_conf, - current_name, - current_fields_in_scope, - ) = fifo_component_queue.pop(0) - - if current_instance.__component_configured__: - continue + # Enable configuration mode to signal that we're allowed to access unconfigured + # components + with configuration_mode(): + # Only component instances can be configured. + if not utils.is_component_instance(instance): + raise TypeError( + "Only @component, @factory, and @task instances can be configured. " + f"Received: {instance}." + ) - current_conf = configure_component_instance( - current_instance, - conf=current_conf, - name=current_name, - fields_in_scope=current_fields_in_scope, - interactive=interactive, - ) + # Configuration can only happen once. + if instance.__component_configured__: + raise ValueError( + f"Component '{instance.__component_name__}' has already been configured." + ) - # Collect the sub-component instances that need to be recursively - # configured, and add them to the queue. - for field in current_instance.__component_fields__.values(): - if not isinstance(field, ComponentField): + # Maintain a FIFO queue of component instances that need to be configured, + # along with the config dict, name that should be passed, and a set of field + # names that are in-scope for component field inheritence. + # This queue allows us to recursively configure component instances in + # the component tree in a top-down, breadth-first order. + fifo_component_queue = [(instance, conf, name, frozenset(conf.keys()))] + + while len(fifo_component_queue) > 0: + ( + current_instance, + current_conf, + current_name, + current_fields_in_scope, + ) = fifo_component_queue.pop(0) + + if current_instance.__component_configured__: continue - try: - sub_component_instance = base_getattr(current_instance, field.name) - except (AttributeError, ConfigurationError) as e: - if field.allow_missing: + current_conf = configure_component_instance( + current_instance, + conf=current_conf, + name=current_name, + fields_in_scope=current_fields_in_scope, + interactive=interactive, + ) + + # Collect the sub-component instances that need to be recursively + # configured, and add them to the queue. + for field in current_instance.__component_fields__.values(): + if not isinstance(field, ComponentField): continue - raise e from None - if ( - not utils.is_component_instance(sub_component_instance) - or sub_component_instance.__component_configured__ - ): - continue + try: + sub_component_instance = base_getattr(current_instance, field.name) + except (AttributeError, ConfigurationError) as e: + if field.allow_missing: + continue + raise e from None + + if ( + not utils.is_component_instance(sub_component_instance) + or sub_component_instance.__component_configured__ + ): + continue - # Generate the configuration dict that will be used with the nested - # sub-component. This consists of all keys scoped to `field.name`. - sub_component_conf = { - a[len(f"{field.name}.") :]: b - for a, b in current_conf.items() - if a.startswith(f"{field.name}.") - } - - # The name of the sub-component is full-stop-delimited. - sub_component_name = f"{current_instance.__component_name__}.{field.name}" - - # At this point the current instance has already been configured so - # we know that every one of its fields is in scope. - sub_component_fields_in_scope = current_fields_in_scope | frozenset( - current_instance.__component_fields__.keys() - ) + # Generate the configuration dict that will be used with the nested + # sub-component. This consists of all keys scoped to `field.name`. + sub_component_conf = { + a[len(f"{field.name}.") :]: b + for a, b in current_conf.items() + if a.startswith(f"{field.name}.") + } + + # The name of the sub-component is full-stop-delimited. + sub_component_name = ( + f"{current_instance.__component_name__}.{field.name}" + ) - # Add the sub-component to the end of the queue. - fifo_component_queue.append( - ( - sub_component_instance, - sub_component_conf, - sub_component_name, - sub_component_fields_in_scope, + # At this point the current instance has already been configured so + # we know that every one of its fields is in scope. + sub_component_fields_in_scope = current_fields_in_scope | frozenset( + current_instance.__component_fields__.keys() + ) + + # Add the sub-component to the end of the queue. + fifo_component_queue.append( + ( + sub_component_instance, + sub_component_conf, + sub_component_name, + sub_component_fields_in_scope, + ) ) - ) diff --git a/zookeeper/core/component_test.py b/zookeeper/core/component_test.py index 46187bf..15e8588 100644 --- a/zookeeper/core/component_test.py +++ b/zookeeper/core/component_test.py @@ -8,7 +8,7 @@ from zookeeper.core.component import base_hasattr, component, configure from zookeeper.core.factory import factory from zookeeper.core.field import ComponentField, Field -from zookeeper.core.utils import ConfigurationError +from zookeeper.core.utils import ConfigurationError, configuration_mode @pytest.fixture @@ -18,6 +18,9 @@ class A: a: int = Field() b: str = Field("foo") + def __init__(self): + self.c = 2 + return A @@ -45,31 +48,17 @@ def foo(self): pass -def test_init_decorate_error(): - """An error should be raised when attempting to decorate a class with an `__init__` - method.""" - with pytest.raises( - TypeError, - match="Component classes must not define a custom `__init__` method.", - ): - - @component - class A: - def __init__(self, a, b=5): - self.a = a - self.b = b - - -def test_no_init(ExampleComponentClass): - """If the decorated class does not have an `__init__` method, the decorated class - should define an `__init__` which accepts kwargs to set field values, and raises - appropriate arguments when other values are passed.""" +def test_component_init(ExampleComponentClass): + """The decorated class should define an `__init__` which accepts kwargs to set field + values, and raises appropriate arguments when other values are passed.""" x = ExampleComponentClass(a=2) + configure(x, {}) assert x.a == 2 assert x.b == "foo" x = ExampleComponentClass(a=0, b="bar") + configure(x, {}) assert x.a == 0 assert x.b == "bar" @@ -91,6 +80,97 @@ def test_no_init(ExampleComponentClass): ExampleComponentClass(some_other_field_name=0) +def test_user_init(ExampleComponentClass): + """The user-defined `__init__` method should be called exactly once, after + configuration.""" + + x = ExampleComponentClass() + descriptor = ExampleComponentClass.__init__ + with patch.object( + descriptor, "original_init", wraps=descriptor.original_init + ) as user_init: + user_init.assert_not_called() + print(user_init) + + configure(x, {"a": 0, "b": "bar"}) + + # The user-defined `__init__` method is called after configuration. + # No other argument other than `self` should have been passed! + user_init.assert_called_once_with(x) + + # These should now all exist. + assert x.a == 0 + assert x.b == "bar" + assert x.c == 2 + + +def test_user_init_subclass(): + """If we have a set of component subclasses, with a customn `__init__` that calls + `super().__init__()`, these should all be resolved correctly.""" + + class Base: + def __init__(self): + self.base = 4 + super().__init__() + + @component + class Top(Base): + def __init__(self): + self.top = 3 + super().__init__() + + @component + class Middle(Top): + def __init__(self): + self.middle = 2 + super().__init__() + + @component + class Bottom(Middle): + def __init__(self): + self.bottom = 1 + super().__init__() + + b = Bottom() + configure(b, {}) + + # Each of the __init__ functions should have been called without errors, so these + # values should all be set. + assert b.base == 4 + assert b.top == 3 + assert b.middle == 2 + assert b.bottom == 1 + + +def test_init_signature_error(): + """An error should be raised when attempting to decorate a class with an `__init__` + method that takes any arguments other than `self`.""" + with pytest.raises( + TypeError, + match="The `__init__` method of a @component class must take no arguments except `self`", + ): + + @component + class A: + def __init__(self, a, b=5): + self.a = a + self.b = b + + with pytest.raises( + TypeError, + match="The `__init__` method of a @component class must take no arguments except `self`", + ): + + @component + class B: + def __init__(var): + pass + + +def test_init(): + pass + + def test_configure_override_field_values(ExampleComponentClass): """Component fields should be overriden correctly.""" @@ -471,7 +551,6 @@ def test_type_check(ExampleComponentClass): """During configuration we should type-check all field values.""" instance = ExampleComponentClass() - configure(instance, {"a": 4.5}, name="x") # Attempting to access the field should now raise a type error. @@ -529,14 +608,21 @@ class A3: base: Tuple[float, float, float] = ComponentField(F3) # These should succeed. - A1().base - A2().base + a1 = A1() + configure(a1, {}) + a1.base + + a2 = A2() + configure(a2, {}) + a2.base # Do this here to drop any already captured output. capsys.readouterr() # This should succeed, but without a type-check (should print a warning) - A3().base + a3 = A3() + configure(a3, {}) + a3.base captured = capsys.readouterr() assert ( captured.err @@ -616,20 +702,33 @@ class B: def __post_configure__(self, x): pass + with pytest.raises( + TypeError, + match="C has a deprecated `__post_configure__` method as well as a custom `__init__`! Only one can be defined at a time.", + ): + + @component + class C: + def __init__(self): + super().__init__() + + def __post_configure__(self): + pass + # This definition should succeed. @component - class C: + class D: a: int = Field(0) b: float = Field(3.14) def __post_configure__(self): - self.c = self.a + self.b + self.d = self.a + self.b - c = C() + d = D() - configure(c, {"a": 1, "b": -3.14}) + configure(d, {"a": 1, "b": -3.14}) - assert c.c == 1 - 3.14 + assert d.d == 1 - 3.14 def test_component_configure_error_non_existant_key(): @@ -923,15 +1022,19 @@ class Parent: a: int = Field(5) instance = Parent(a=100) - assert instance.a == 100 - assert instance.child_1.a == 100 - assert instance.child_2.a == -1 + with configuration_mode(): + assert instance.a == 100 + assert instance.child_1.a == 100 + assert instance.child_2.a == -1 + # Setting an attribute is allowed even if the component is not yet configured. instance.a = 2020 - instance.child_2.a = -7 - assert instance.a == 2020 - assert instance.child_1.a == 2020 - assert instance.child_2.a == -7 + # But accessing them is not. + with configuration_mode(): + instance.child_2.a = -7 + assert instance.a == 2020 + assert instance.child_1.a == 2020 + assert instance.child_2.a == -7 configure(instance, {"a": 0, "child_2.a": 5}) assert instance.a == 0 @@ -970,17 +1073,18 @@ class A: with_value: int = Field(0) instance = A() - assert hasattr(instance, "with_value") - assert base_hasattr(instance, "with_value") - assert not base_hasattr(instance, "fake_attribute") + with configuration_mode(): + assert hasattr(instance, "with_value") + assert base_hasattr(instance, "with_value") + assert not base_hasattr(instance, "fake_attribute") - with pytest.raises(ConfigurationError): - hasattr(instance, "attribute") + with pytest.raises(ConfigurationError): + hasattr(instance, "attribute") - assert base_hasattr(instance, "attribute") + assert base_hasattr(instance, "attribute") - assert not hasattr(instance, "missing_attribute") - assert base_hasattr(instance, "missing_attribute") + assert not hasattr(instance, "missing_attribute") + assert base_hasattr(instance, "missing_attribute") def test_component_configure_component_passed_as_config(): @@ -999,3 +1103,32 @@ class Parent: assert instance.child is new_child_instance assert instance.child.__component_parent__ is instance assert instance.child.x == 7 # This value should be correctly inherited. + + +def test_no_unconfigured_access(): + @component + class TestComponent: + var: int = Field(5) + + def __call__(self, *args): + print("This is valid, but can only be used after configuration.") + self.called = True + + x = TestComponent() + + with pytest.raises( + ConfigurationError, + match="Please call `configure` before accessing any attributes.", + ): + print(x.var) + + with pytest.raises( + ConfigurationError, + match="Please call `configure` before calling this component", + ): + x() + + configure(x, {}) + assert x.var == 5 + x() + assert x.called diff --git a/zookeeper/core/field_test.py b/zookeeper/core/field_test.py index e42eecd..56aff6e 100644 --- a/zookeeper/core/field_test.py +++ b/zookeeper/core/field_test.py @@ -2,7 +2,7 @@ import pytest -from zookeeper.core.component import component +from zookeeper.core.component import component, configure from zookeeper.core.field import ComponentField, Field from zookeeper.core.partial_component import PartialComponent from zookeeper.core.utils import ConfigurationError @@ -89,7 +89,7 @@ def foo_value(self): @Field def bar(self) -> int: - return int(self.foo_value**self.foo_value) + return int(self.foo_value ** self.foo_value) instance = A() @@ -181,6 +181,8 @@ class A: assert A.foo.has_default default_value = A.foo.get_default(A()) assert isinstance(default_value, ConcreteComponent) + + configure(default_value, {}) assert default_value.a == 5 @@ -194,6 +196,7 @@ class A: assert A.foo.has_default default_value = A.foo.get_default(A()) assert isinstance(default_value, ConcreteComponent) + configure(default_value, {}) assert default_value.a == 5 diff --git a/zookeeper/core/utils.py b/zookeeper/core/utils.py index a9c85a5..6234fcb 100644 --- a/zookeeper/core/utils.py +++ b/zookeeper/core/utils.py @@ -1,6 +1,7 @@ import inspect import re from ast import literal_eval +from contextlib import contextmanager from typing import Any, Callable, Iterator, Sequence, Type, TypeVar import click @@ -15,6 +16,26 @@ def __repr__(self): missing = Missing() +# Will be set to True if and only if zookeeper is currently in the process of configuring a component. +_CONFIGURATION_MODE = False + + +def in_configuration_mode(): + return _CONFIGURATION_MODE + + +@contextmanager +def configuration_mode(): + """Context manager that toggles _CONFIGURATION_MODE.""" + global _CONFIGURATION_MODE + # It may already be True, if we're in a nested context. + prev_val = _CONFIGURATION_MODE + # But set it to True either way + _CONFIGURATION_MODE = True + yield + # And then set back to the original value. + _CONFIGURATION_MODE = prev_val + class ConfigurationError(Exception): pass From fb74cf08b7550f65c82067d54ce1301d1d707d69 Mon Sep 17 00:00:00 2001 From: Jelmer Date: Mon, 18 Jul 2022 11:23:41 +0200 Subject: [PATCH 2/3] Linting --- zookeeper/core/field_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zookeeper/core/field_test.py b/zookeeper/core/field_test.py index 56aff6e..d6c972c 100644 --- a/zookeeper/core/field_test.py +++ b/zookeeper/core/field_test.py @@ -89,7 +89,7 @@ def foo_value(self): @Field def bar(self) -> int: - return int(self.foo_value ** self.foo_value) + return int(self.foo_value**self.foo_value) instance = A() From 4218d2e1af87b99d3f5f1f0d786f51d7e0f5b945 Mon Sep 17 00:00:00 2001 From: Jelmer Date: Mon, 18 Jul 2022 12:22:03 +0200 Subject: [PATCH 3/3] Fix context manager --- zookeeper/core/component.py | 22 ++++++++++--------- zookeeper/core/component_test.py | 8 ++++++- zookeeper/core/utils.py | 36 +++++++++++++++++--------------- 3 files changed, 38 insertions(+), 28 deletions(-) diff --git a/zookeeper/core/component.py b/zookeeper/core/component.py index acf6926..bb689f0 100644 --- a/zookeeper/core/component.py +++ b/zookeeper/core/component.py @@ -95,12 +95,11 @@ class C: # A component `__init__` # ########################## class ComponentInit: - """ - We wrap both the original __init__ and the __component_init__ in this descriptor. - If we would simply store e.g. `__original__init` on every component, we - would run into problems with resolution of calls to `super().__init__()`, since they - would trigger `__component_init__` on the parent class again, with the same value - for the `instance` argument. + """We wrap both the original __init__ and the __component_init__ in this descriptor. + If we would simply store e.g. `__original__init` on every component, we would run + into problems with resolution of calls to `super().__init__()`, since they would + trigger `__component_init__` on the parent class again, with the same value for the + `instance` argument. By using a descriptor, we can effectively create a different `__component_init__` for every class, and thus make sure that `super().__init__()` refers to the correct @@ -112,7 +111,8 @@ def __init__(self, original_init): @staticmethod def __component_init__(instance, **kwargs): - """Accepts keyword-arguments corresponding to fields defined on the component.""" + """Accepts keyword-arguments corresponding to fields defined on the + component.""" # Use the `kwargs` to set field values. for name, value in kwargs.items(): @@ -154,9 +154,11 @@ def __component_init__(instance, **kwargs): ) | set(kwargs) def __get__(self, instance, type=None): - """`instance` is the instance on which `__init__` is called. If the instance - has not yet been configured, we defer to `__component_init__`. If it has been - configured, we defer to `self.original_init` instead to execute the user's code. + """`instance` is the instance on which `__init__` is called. + + If the instance has not yet been configured, we defer to `__component_init__`. + If it has been configured, we defer to `self.original_init` instead to execute + the user's code. """ if instance is None: # This happens when we call `__init__` on the class, e.g. `A.__init__`. diff --git a/zookeeper/core/component_test.py b/zookeeper/core/component_test.py index 15e8588..cdedf68 100644 --- a/zookeeper/core/component_test.py +++ b/zookeeper/core/component_test.py @@ -8,7 +8,11 @@ from zookeeper.core.component import base_hasattr, component, configure from zookeeper.core.factory import factory from zookeeper.core.field import ComponentField, Field -from zookeeper.core.utils import ConfigurationError, configuration_mode +from zookeeper.core.utils import ( + ConfigurationError, + configuration_mode, + in_configuration_mode, +) @pytest.fixture @@ -1106,6 +1110,8 @@ class Parent: def test_no_unconfigured_access(): + assert not in_configuration_mode() + @component class TestComponent: var: int = Field(5) diff --git a/zookeeper/core/utils.py b/zookeeper/core/utils.py index 6234fcb..431ac2e 100644 --- a/zookeeper/core/utils.py +++ b/zookeeper/core/utils.py @@ -1,7 +1,7 @@ import inspect import re +import threading from ast import literal_eval -from contextlib import contextmanager from typing import Any, Callable, Iterator, Sequence, Type, TypeVar import click @@ -16,25 +16,27 @@ def __repr__(self): missing = Missing() -# Will be set to True if and only if zookeeper is currently in the process of configuring a component. -_CONFIGURATION_MODE = False +# Will be set to True if and only if zookeeper is currently in the process of +# configuring a component. Local to this thread only. +thread_local = threading.local() +thread_local._CONFIGURATION_MODE = False def in_configuration_mode(): - return _CONFIGURATION_MODE - - -@contextmanager -def configuration_mode(): - """Context manager that toggles _CONFIGURATION_MODE.""" - global _CONFIGURATION_MODE - # It may already be True, if we're in a nested context. - prev_val = _CONFIGURATION_MODE - # But set it to True either way - _CONFIGURATION_MODE = True - yield - # And then set back to the original value. - _CONFIGURATION_MODE = prev_val + return thread_local._CONFIGURATION_MODE + + +class configuration_mode: + def __enter__(self): + # It may already be True, if we're in a nested context. + self.prev_val = thread_local._CONFIGURATION_MODE + # But set it to True either way + thread_local._CONFIGURATION_MODE = True + return self + + def __exit__(self, *args, **kwargs): + # And then set back to the original value. + thread_local._CONFIGURATION_MODE = self.prev_val class ConfigurationError(Exception):