Skip to content

Commit

Permalink
bugfix: Correcting negotiation stats in CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
yasserfarouk committed Apr 12, 2024
1 parent 980fb80 commit 80e729c
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 11 deletions.
62 changes: 51 additions & 11 deletions src/scml/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1376,8 +1376,16 @@ def print_and_log(s):
n_erred = int(round(n_signed * world.contract_err_fraction))
n_breached = int(round(n_signed * world.breach_fraction))
n_executed = int(round(n_signed * world.contract_execution_fraction))
exogenous = [_ for _ in world.saved_contracts if not _["issues"]]
negotiated = [_ for _ in world.saved_contracts if _["issues"]]
exogenous = [
_
for _ in world.saved_contracts
if any(is_system_agent(a) for a in _["partners"])
]
negotiated = [
_
for _ in world.saved_contracts
if all(not is_system_agent(a) for a in _["partners"])
]
n_exogenous = len(exogenous)
n_negotiated = len(negotiated)
n_exogenous_signed = len([_ for _ in exogenous if _["signed_at"] >= 0])
Expand Down Expand Up @@ -1483,7 +1491,7 @@ def print_and_log(s):
)
@click.option(
"--show-contracts/--no-contracts",
default=True,
default=False,
help="Show or do not show all signed contracts",
)
@click.option(
Expand Down Expand Up @@ -1739,8 +1747,16 @@ def print_and_log(s):
n_erred = int(round(n_signed * world.contract_err_fraction))
n_breached = int(round(n_signed * world.breach_fraction))
n_executed = int(round(n_signed * world.contract_execution_fraction))
exogenous = [_ for _ in world.saved_contracts if not _["issues"]]
negotiated = [_ for _ in world.saved_contracts if _["issues"]]
exogenous = [
_
for _ in world.saved_contracts
if any(is_system_agent(a) for a in _["partners"])
]
negotiated = [
_
for _ in world.saved_contracts
if all(not is_system_agent(a) for a in _["partners"])
]
n_exogenous = len(exogenous)
n_negotiated = len(negotiated)
n_exogenous_signed = len([_ for _ in exogenous if _["signed_at"] >= 0])
Expand Down Expand Up @@ -2069,8 +2085,16 @@ def print_and_log(s):
n_erred = int(round(n_signed * world.contract_err_fraction))
n_breached = int(round(n_signed * world.breach_fraction))
n_executed = int(round(n_signed * world.contract_execution_fraction))
exogenous = [_ for _ in world.saved_contracts if not _["issues"]]
negotiated = [_ for _ in world.saved_contracts if _["issues"]]
exogenous = [
_
for _ in world.saved_contracts
if any(is_system_agent(a) for a in _["partners"])
]
negotiated = [
_
for _ in world.saved_contracts
if all(not is_system_agent(a) for a in _["partners"])
]
n_exogenous = len(exogenous)
n_negotiated = len(negotiated)
n_exogenous_signed = len([_ for _ in exogenous if _["signed_at"] >= 0])
Expand Down Expand Up @@ -2399,8 +2423,16 @@ def print_and_log(s):
n_erred = int(round(n_signed * world.contract_err_fraction))
n_breached = int(round(n_signed * world.breach_fraction))
n_executed = int(round(n_signed * world.contract_execution_fraction))
exogenous = [_ for _ in world.saved_contracts if not _["issues"]]
negotiated = [_ for _ in world.saved_contracts if _["issues"]]
exogenous = [
_
for _ in world.saved_contracts
if any(is_system_agent(a) for a in _["partners"])
]
negotiated = [
_
for _ in world.saved_contracts
if all(not is_system_agent(a) for a in _["partners"])
]
n_exogenous = len(exogenous)
n_negotiated = len(negotiated)
n_exogenous_signed = len([_ for _ in exogenous if _["signed_at"] >= 0])
Expand Down Expand Up @@ -2729,8 +2761,16 @@ def print_and_log(s):
n_erred = int(round(n_signed * world.contract_err_fraction))
n_breached = int(round(n_signed * world.breach_fraction))
n_executed = int(round(n_signed * world.contract_execution_fraction))
exogenous = [_ for _ in world.saved_contracts if not _["issues"]]
negotiated = [_ for _ in world.saved_contracts if _["issues"]]
exogenous = [
_
for _ in world.saved_contracts
if any(is_system_agent(a) for a in _["partners"])
]
negotiated = [
_
for _ in world.saved_contracts
if all(not is_system_agent(a) for a in _["partners"])
]
n_exogenous = len(exogenous)
n_negotiated = len(negotiated)
n_exogenous_signed = len([_ for _ in exogenous if _["signed_at"] >= 0])
Expand Down
23 changes: 23 additions & 0 deletions tests/oneshot/test_scml2024oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
from rich import print
from collections import Counter

from scml.oneshot import PLACEHOLDER_AGENT_PREFIX
from scml.oneshot.agents.greedy import (
Expand All @@ -18,6 +19,7 @@
from scml.oneshot.context import ANACOneShotContext
from scml.oneshot.rl.agent import OneShotRLAgent
from scml.oneshot.world import SCML2024OneShotWorld
from scml.utils import DefaultAgentsOneShot2024

from ..switches import DefaultOneShotWorld

Expand Down Expand Up @@ -236,3 +238,24 @@ def test_combining_stats():
makefig=True,
ylegend=1.0,
)


def test_run_defaults_gets_contracts():
world = SCML2024OneShotWorld(
**SCML2024OneShotWorld.generate(DefaultAgentsOneShot2024, n_steps=50)
)
world.run()
exogenous = [
_
for _ in world.saved_contracts
if any(is_system_agent(a) for a in _["partners"])
]
negotiated = [
_
for _ in world.saved_contracts
if all(not is_system_agent(a) for a in _["partners"])
]
assert len(exogenous) > 0
assert (
len(negotiated) > 0
), f"{Counter([tuple(_['partners']) for _ in world.saved_contracts])}"

0 comments on commit 80e729c

Please sign in to comment.