From 80e729ca80d10ed0871c42d9a156633f7eaffec3 Mon Sep 17 00:00:00 2001 From: Yasser Mohammad Date: Fri, 12 Apr 2024 14:47:25 +0900 Subject: [PATCH] bugfix: Correcting negotiation stats in CLI --- src/scml/cli.py | 62 ++++++++++++++++++++++----- tests/oneshot/test_scml2024oneshot.py | 23 ++++++++++ 2 files changed, 74 insertions(+), 11 deletions(-) diff --git a/src/scml/cli.py b/src/scml/cli.py index b720f906..683fff21 100644 --- a/src/scml/cli.py +++ b/src/scml/cli.py @@ -1376,8 +1376,16 @@ def print_and_log(s): n_erred = int(round(n_signed * world.contract_err_fraction)) n_breached = int(round(n_signed * world.breach_fraction)) n_executed = int(round(n_signed * world.contract_execution_fraction)) - exogenous = [_ for _ in world.saved_contracts if not _["issues"]] - negotiated = [_ for _ in world.saved_contracts if _["issues"]] + exogenous = [ + _ + for _ in world.saved_contracts + if any(is_system_agent(a) for a in _["partners"]) + ] + negotiated = [ + _ + for _ in world.saved_contracts + if all(not is_system_agent(a) for a in _["partners"]) + ] n_exogenous = len(exogenous) n_negotiated = len(negotiated) n_exogenous_signed = len([_ for _ in exogenous if _["signed_at"] >= 0]) @@ -1483,7 +1491,7 @@ def print_and_log(s): ) @click.option( "--show-contracts/--no-contracts", - default=True, + default=False, help="Show or do not show all signed contracts", ) @click.option( @@ -1739,8 +1747,16 @@ def print_and_log(s): n_erred = int(round(n_signed * world.contract_err_fraction)) n_breached = int(round(n_signed * world.breach_fraction)) n_executed = int(round(n_signed * world.contract_execution_fraction)) - exogenous = [_ for _ in world.saved_contracts if not _["issues"]] - negotiated = [_ for _ in world.saved_contracts if _["issues"]] + exogenous = [ + _ + for _ in world.saved_contracts + if any(is_system_agent(a) for a in _["partners"]) + ] + negotiated = [ + _ + for _ in world.saved_contracts + if all(not is_system_agent(a) for a in _["partners"]) + ] n_exogenous = len(exogenous) n_negotiated = len(negotiated) n_exogenous_signed = len([_ for _ in exogenous if _["signed_at"] >= 0]) @@ -2069,8 +2085,16 @@ def print_and_log(s): n_erred = int(round(n_signed * world.contract_err_fraction)) n_breached = int(round(n_signed * world.breach_fraction)) n_executed = int(round(n_signed * world.contract_execution_fraction)) - exogenous = [_ for _ in world.saved_contracts if not _["issues"]] - negotiated = [_ for _ in world.saved_contracts if _["issues"]] + exogenous = [ + _ + for _ in world.saved_contracts + if any(is_system_agent(a) for a in _["partners"]) + ] + negotiated = [ + _ + for _ in world.saved_contracts + if all(not is_system_agent(a) for a in _["partners"]) + ] n_exogenous = len(exogenous) n_negotiated = len(negotiated) n_exogenous_signed = len([_ for _ in exogenous if _["signed_at"] >= 0]) @@ -2399,8 +2423,16 @@ def print_and_log(s): n_erred = int(round(n_signed * world.contract_err_fraction)) n_breached = int(round(n_signed * world.breach_fraction)) n_executed = int(round(n_signed * world.contract_execution_fraction)) - exogenous = [_ for _ in world.saved_contracts if not _["issues"]] - negotiated = [_ for _ in world.saved_contracts if _["issues"]] + exogenous = [ + _ + for _ in world.saved_contracts + if any(is_system_agent(a) for a in _["partners"]) + ] + negotiated = [ + _ + for _ in world.saved_contracts + if all(not is_system_agent(a) for a in _["partners"]) + ] n_exogenous = len(exogenous) n_negotiated = len(negotiated) n_exogenous_signed = len([_ for _ in exogenous if _["signed_at"] >= 0]) @@ -2729,8 +2761,16 @@ def print_and_log(s): n_erred = int(round(n_signed * world.contract_err_fraction)) n_breached = int(round(n_signed * world.breach_fraction)) n_executed = int(round(n_signed * world.contract_execution_fraction)) - exogenous = [_ for _ in world.saved_contracts if not _["issues"]] - negotiated = [_ for _ in world.saved_contracts if _["issues"]] + exogenous = [ + _ + for _ in world.saved_contracts + if any(is_system_agent(a) for a in _["partners"]) + ] + negotiated = [ + _ + for _ in world.saved_contracts + if all(not is_system_agent(a) for a in _["partners"]) + ] n_exogenous = len(exogenous) n_negotiated = len(negotiated) n_exogenous_signed = len([_ for _ in exogenous if _["signed_at"] >= 0]) diff --git a/tests/oneshot/test_scml2024oneshot.py b/tests/oneshot/test_scml2024oneshot.py index b4b83166..315c3555 100644 --- a/tests/oneshot/test_scml2024oneshot.py +++ b/tests/oneshot/test_scml2024oneshot.py @@ -2,6 +2,7 @@ import pytest from rich import print +from collections import Counter from scml.oneshot import PLACEHOLDER_AGENT_PREFIX from scml.oneshot.agents.greedy import ( @@ -18,6 +19,7 @@ from scml.oneshot.context import ANACOneShotContext from scml.oneshot.rl.agent import OneShotRLAgent from scml.oneshot.world import SCML2024OneShotWorld +from scml.utils import DefaultAgentsOneShot2024 from ..switches import DefaultOneShotWorld @@ -236,3 +238,24 @@ def test_combining_stats(): makefig=True, ylegend=1.0, ) + + +def test_run_defaults_gets_contracts(): + world = SCML2024OneShotWorld( + **SCML2024OneShotWorld.generate(DefaultAgentsOneShot2024, n_steps=50) + ) + world.run() + exogenous = [ + _ + for _ in world.saved_contracts + if any(is_system_agent(a) for a in _["partners"]) + ] + negotiated = [ + _ + for _ in world.saved_contracts + if all(not is_system_agent(a) for a in _["partners"]) + ] + assert len(exogenous) > 0 + assert ( + len(negotiated) > 0 + ), f"{Counter([tuple(_['partners']) for _ in world.saved_contracts])}"