Skip to content

Commit

Permalink
bugfix in a context test
Browse files Browse the repository at this point in the history
When generating two RepeatingContext objects, they are not guaranteed to
included each other except if the same configs were used for both.
  • Loading branch information
yasserfarouk committed Mar 2, 2024
1 parent 26d1056 commit 15df0b8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 23 deletions.
46 changes: 24 additions & 22 deletions tests/rl/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,27 @@
WeakSupplierContext,
)


@pytest.mark.parametrize(
"context_type",
(
GeneralContext,
RepeatingContext,
ConsumerContext,
SupplierContext,
StrongConsumerContext,
StrongSupplierContext,
WeakConsumerContext,
WeakSupplierContext,
BalancedConsumerContext,
BalancedSupplierContext,
ANACContext,
ANACOneShotContext,
FixedPartnerNumbersContext,
FixedPartnerNumbersOneShotContext,
LimitedPartnerNumbersContext,
LimitedPartnerNumbersOneShotContext,
),
context_types = (
GeneralContext,
RepeatingContext,
ConsumerContext,
SupplierContext,
StrongConsumerContext,
StrongSupplierContext,
WeakConsumerContext,
WeakSupplierContext,
BalancedConsumerContext,
BalancedSupplierContext,
ANACContext,
ANACOneShotContext,
FixedPartnerNumbersContext,
FixedPartnerNumbersOneShotContext,
LimitedPartnerNumbersContext,
LimitedPartnerNumbersOneShotContext,
)


@pytest.mark.parametrize("context_type", context_types)
def test_context_can_generate_and_run(context_type):
context = context_type()
config = context.make_config()
Expand Down Expand Up @@ -91,7 +90,10 @@ def test_context_can_generate_and_run(context_type):
b._obj.awi # type: ignore
), f"world {i} has incorrect AWI for agent {b.id}"

c2 = context_type()
if issubclass(context_type, RepeatingContext):
c2 = context_type(configs=context.configs)
else:
c2 = context_type()
assert context.contains_context(
c2, raise_on_failure=True
), "Identical contexts do not match"
Expand Down
2 changes: 1 addition & 1 deletion tests/rl/test_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def test_rl_agent_with_a_trained_model(type_, continuous):
model.learn(total_timesteps=NTRAINING)

if issubclass(type_, RepeatingContext):
context = RepeatingContext(configs=env._context.configs) # type: ignore
context = type_(configs=env._context.configs) # type: ignore
else:
context = type_()
obs = FlexibleObservationManager(context, continuous=continuous)
Expand Down

0 comments on commit 15df0b8

Please sign in to comment.