diff --git a/src/scml/oneshot/rl/common.py b/src/scml/oneshot/rl/common.py index c330db86..4dc85304 100644 --- a/src/scml/oneshot/rl/common.py +++ b/src/scml/oneshot/rl/common.py @@ -21,23 +21,40 @@ def isin(x: int | tuple[int, int], y: tuple[int, int] | int): class WorldFactory(ABC): """Generates worlds satisfying predefined conditions and tests for them""" - def __call__(self) -> tuple[World, Agent]: - """Generates a world with one agent to be controlled externally and returns both""" + def __call__( + self, + types: tuple[type[Agent], ...] = tuple(), + params: tuple[dict[str, Any], ...] | None = None, + ) -> tuple[World, tuple[Agent]]: + """ + Generates a world with one or more agents to be controlled externally and returns both + + Args: + agent_types: The types of a list of agents to be guaranteed to exist in the world + agent_params: The parameters to pass to the constructors of these agents. None means no parameters for any agents + + Returns: + The constructed world and a tuple of the agents created corresponding (in order) to the given agent types/params + """ ... @abstractmethod - def is_valid_world(self, world: World) -> bool: - """Checks that the given world could have been generated from this generator""" + def is_valid_world( + self, + world: World, + types: tuple[type[Agent], ...] = tuple(), + ) -> bool: + """Checks that the given world could have been generated from this factory""" ... @abstractmethod def is_valid_awi(self, awi: AgentWorldInterface) -> bool: - """Checks that the given AWI is connected to a world that could have been generated from this generator""" + """Checks that the given AWI is connected to a world that could have been generated from this factory""" ... @abstractmethod - def contains_factory(self, generator: "WorldFactory") -> bool: - """Checks that the any world generated from the given `generator` could have been generated from this generator""" + def contains_factory(self, factory: "WorldFactory") -> bool: + """Checks that the any world generated from the given `factory` could have been generated from this factory""" ... def __contains__( diff --git a/src/scml/oneshot/rl/env.py b/src/scml/oneshot/rl/env.py index 55db1625..66ffbd55 100644 --- a/src/scml/oneshot/rl/env.py +++ b/src/scml/oneshot/rl/env.py @@ -1,19 +1,17 @@ from typing import Any import gymnasium as gym -import numpy as np from gymnasium.envs.registration import register -from scml.common import intin, make_array from scml.oneshot.agent import OneShotAgent -from scml.oneshot.awi import OneShotAWI +from scml.oneshot.agents import OneShotDummyAgent from scml.oneshot.rl.action import ActionManager from scml.oneshot.rl.factory import ( FixedPartnerNumbersOneShotFactory, OneShotWorldFactory, ) from scml.oneshot.rl.observation import ObservationManager -from scml.oneshot.world import SCML2023OneShotWorld +from scml.oneshot.world import SCML2020OneShotWorld __all__ = ["OneShotEnv"] @@ -25,6 +23,8 @@ def __init__( observation_manager: ObservationManager, render_mode=None, factory: OneShotWorldFactory = FixedPartnerNumbersOneShotFactory(), + agent_type: type[OneShotAgent] = OneShotDummyAgent, + agent_params: dict[str, Any] | None = None, extra_checks: bool = True, ): assert action_manager.factory in factory, ( @@ -40,9 +40,9 @@ def __init__( ) self._extra_checks = extra_checks - agent_type = factory.agent_type - self._world: SCML2023OneShotWorld = None # type: ignore + self._world: SCML2020OneShotWorld = None # type: ignore self._agent_type = agent_type + self._agent_params = agent_params if agent_params is not None else dict() self._agent_id: str = "" self._agent: OneShotAgent = None # type: ignore self._obs_manager = observation_manager @@ -81,7 +81,11 @@ def reset( import random random.seed(seed) - self._world, self._agent = self._factory() + self._world, agents = self._factory( + types=(self._agent_type,), params=(self._agent_params,) + ) + assert len(agents) == 1 + self._agent = agents[0] if self._extra_checks: assert self._world in self._factory self._agent_id = self._agent.id diff --git a/src/scml/oneshot/rl/factory.py b/src/scml/oneshot/rl/factory.py index c76edfa9..aa14b6ad 100644 --- a/src/scml/oneshot/rl/factory.py +++ b/src/scml/oneshot/rl/factory.py @@ -11,6 +11,7 @@ from scml.oneshot.agents import OneShotDummyAgent from scml.oneshot.awi import OneShotAWI from scml.oneshot.world import ( + SCML2020OneShotWorld, SCML2021OneShotWorld, SCML2022OneShotWorld, SCML2023OneShotWorld, @@ -50,32 +51,40 @@ class OneShotWorldFactory(WorldFactory, ABC): """A factory that generates oneshot worlds with a single agent of type `agent_type` with predetermined structure and settings""" - agent_type: type[OneShotAgent] = OneShotDummyAgent - agent_params: dict[str, Any] | None = None - @abstractmethod - def make(self) -> SCML2023OneShotWorld: + def make( + self, + types: tuple[type[OneShotAgent], ...] = (OneShotDummyAgent,), + params: tuple[dict[str, Any], ...] | None = None, + ) -> SCML2020OneShotWorld: """Generates the oneshot world and assigns an agent of type `agent_type` to it""" ... - def __call__(self) -> tuple[SCML2023OneShotWorld, OneShotAgent]: + def __call__( + self, + types: tuple[type[OneShotAgent], ...] = (OneShotDummyAgent,), + params: tuple[dict[str, Any], ...] | None = None, + ) -> tuple[SCML2020OneShotWorld, tuple[OneShotAgent]]: """Generates the world and assigns an agent to it""" - agent_type = self.agent_type - world = self.make() # type: ignore - world: SCML2023OneShotWorld - expected_type = agent_type._type_name() - agents = [ - i - for i, type_ in enumerate(world.agent_types) - if type_.split(":")[-1] == expected_type - ] - assert ( - len(agents) == 1 - ), f"Found the following agent of type {agent_type}: {agents}" - for a in world.agents.values(): - if a.type_name.split(":")[-1] == expected_type: - return world, a # type: ignore - raise RuntimeError(f"Cannot find a world of type {expected_type}") + world = self.make(types, params) + agents = [] + if types: + expected_types = [type._type_name() for type in types] + expected_set = set(expected_types) + agents = [ + i + for i, type_ in enumerate(world.agent_types) + if type_.split(":")[-1] in expected_set + ] + assert len(agents) == len( + types + ), f"Found the following agent of type {types=}: {agents=}" + agents = [] + for expected_type in expected_types: + for a in world.agents.values(): + if a.type_name.split(":")[-1] == expected_type: + agents.append(a) + return world, tuple(agents) # @define(frozen=True) @@ -146,9 +155,14 @@ def __attrs_post_init__(self): ) assert self.level == -1 or self.level < self.n_processes - def make(self) -> SCML2023OneShotWorld: + def make( + self, + types: tuple[type[OneShotAgent], ...] = (OneShotDummyAgent,), + params: tuple[dict[str, Any], ...] | None = None, + ) -> SCML2020OneShotWorld: """Generates a world""" - agent_type = self.agent_type + if types and params is None: + params = tuple(dict() for _ in types) n_processes = intin(self.n_processes) n_lines = intin(self.n_lines) # find my level @@ -166,13 +180,16 @@ def make(self) -> SCML2023OneShotWorld: else: n_agents_per_process[my_level + 1] = self.n_consumers n_agents_per_process[my_level - 1] = self.n_suppliers + n_competitors = intin(self.n_competitors) + 1 - n_agents_per_process[my_level] = n_competitors + n_agents_per_process[my_level] = max( + len(types), n_agents_per_process[my_level], n_competitors + ) n_agents = sum(n_agents_per_process) agent_types = list(random.choices(self.non_competitors, k=n_agents)) agent_params = None - if self.agent_params: + if params: agent_params: list[dict[str, Any]] | None = [dict() for _ in agent_types] agent_processes = np.zeros(n_agents, dtype=int) nxt, indx = 0, -1 @@ -180,16 +197,19 @@ def make(self) -> SCML2023OneShotWorld: last = nxt + n_agents_per_process[level] agent_processes[nxt:last] = level if level == my_level: - indx = random.randint(nxt, last) - agent_types[indx] = agent_type # type: ignore - if self.agent_params: - agent_params[indx]["controller_params"] = self.agent_params # type: ignore + indices = random.sample(range(nxt, last), k=len(types)) + assert params is not None and agent_params is not None + for indx, agent_type, p in zip(indices, types, params): + agent_types[indx] = agent_type + if params: + agent_params[indx]["controller_params"] = p nxt += n_agents_per_process[level] assert indx >= 0 type_ = { 2023: SCML2023OneShotWorld, 2022: SCML2022OneShotWorld, 2021: SCML2021OneShotWorld, + 2020: SCML2020OneShotWorld, }[self.year] return type_( **type_.generate( @@ -203,34 +223,40 @@ def make(self) -> SCML2023OneShotWorld: one_offer_per_step=True, ) - def is_valid_world(self, world: SCML2023OneShotWorld) -> bool: + def is_valid_world( + self, + world: SCML2020OneShotWorld, + types: tuple[type[OneShotAgent], ...] = (OneShotDummyAgent,), + ) -> bool: """Checks that the given world could have been generated from this factory""" - agent_type = self.agent_type - expected_type = agent_type._type_name() - agents = [ - i - for i, type_ in enumerate(world.agent_types) - if type_.split(":")[-1] == expected_type - ] - assert ( - len(agents) == 1 - ), f"Found the following agent of type {agent_type}: {agents}" - agent: OneShotAgent = None # type: ignore - for a in world.agents.values(): - if a.type_name.split(":")[-1] == expected_type: - agent = a # type: ignore - break - else: - warnings.warn(f"cannot find any agent of type {expected_type}") - return False - if not isin(world.n_processes, self.n_processes): - warnings.warn( - f"Invalid n_processes: {world.n_processes=} != {self.n_processes=}" - ) - return False - if not isin(agent.awi.n_lines, self.n_lines): - warnings.warn(f"Invalid n_lines: {agent.awi.n_lines=} != {self.n_lines=}") - return False + for agent_type in types: + expected_type = agent_type._type_name() + agents = [ + i + for i, type_ in enumerate(world.agent_types) + if type_.split(":")[-1] == expected_type + ] + assert ( + len(agents) == 1 + ), f"Found the following agent of type {agent_type}: {agents}" + agent: OneShotAgent = None # type: ignore + for a in world.agents.values(): + if a.type_name.split(":")[-1] == expected_type: + agent = a # type: ignore + break + else: + warnings.warn(f"cannot find any agent of type {expected_type}") + return False + if not isin(world.n_processes, self.n_processes): + warnings.warn( + f"Invalid n_processes: {world.n_processes=} != {self.n_processes=}" + ) + return False + if not isin(agent.awi.n_lines, self.n_lines): + warnings.warn( + f"Invalid n_lines: {agent.awi.n_lines=} != {self.n_lines=}" + ) + return False # TODO: check non-competitor types return self.is_valid_awi(agent.awi) # type: ignore @@ -300,9 +326,7 @@ def contains_factory(self, factory: WorldFactory) -> bool: return False if not isin(factory.n_lines, self.n_lines): return False - if set(factory.non_competitors).difference( - list(self.non_competitors) + [self.agent_type] - ): + if set(factory.non_competitors).difference(list(self.non_competitors)): return False return True @@ -330,9 +354,14 @@ def __attrs_post_init__(self): ) assert self.level == -1 or self.level < self.n_processes[-1] - def make(self) -> SCML2023OneShotWorld: + def make( + self, + types: tuple[type[OneShotAgent], ...] = (OneShotDummyAgent,), + params: tuple[dict[str, Any], ...] | None = None, + ) -> SCML2020OneShotWorld: """Generates a world""" - agent_type = self.agent_type + if types and params is None: + params = tuple(dict() for _ in types) n_processes = intin(self.n_processes) n_lines = intin(self.n_lines) # find my level @@ -354,12 +383,14 @@ def make(self) -> SCML2023OneShotWorld: n_agents_per_process[my_level + 1] = n_consumers n_agents_per_process[my_level - 1] = n_suppliers n_competitors = intin(n_competitors) + 1 - n_agents_per_process[my_level] = n_competitors + n_agents_per_process[my_level] = max( + len(types), n_agents_per_process[my_level], n_competitors + ) n_agents = sum(n_agents_per_process) agent_types = list(random.choices(self.non_competitors, k=n_agents)) agent_params = None - if self.agent_params: + if params: agent_params: list[dict[str, Any]] | None = [dict() for _ in agent_types] agent_processes = np.zeros(n_agents, dtype=int) nxt, indx = 0, -1 @@ -367,10 +398,12 @@ def make(self) -> SCML2023OneShotWorld: last = nxt + n_agents_per_process[level] agent_processes[nxt:last] = level if level == my_level: - indx = random.randint(nxt, last) - agent_types[indx] = agent_type # type: ignore - if self.agent_params: - agent_params[indx]["controller_params"] = self.agent_params # type: ignore + indices = random.sample(range(nxt, last), k=len(types)) + assert params is not None and agent_params is not None + for indx, agent_type, p in zip(indices, types, params): + agent_types[indx] = agent_type + if params: + agent_params[indx]["controller_params"] = p nxt += n_agents_per_process[level] assert indx >= 0 type_ = { @@ -390,35 +423,41 @@ def make(self) -> SCML2023OneShotWorld: one_offer_per_step=True, ) - def is_valid_world(self, world: SCML2023OneShotWorld) -> bool: + def is_valid_world( + self, + world: SCML2020OneShotWorld, + types: tuple[type[OneShotAgent], ...] = (OneShotDummyAgent,), + ) -> bool: """Checks that the given world could have been generated from this factory""" - agent_type = self.agent_type - expected_type = agent_type._type_name() - agents = [ - i - for i, type_ in enumerate(world.agent_types) - if type_.split(":")[-1] == expected_type - ] - assert ( - len(agents) == 1 - ), f"Found the following agent of type {agent_type}: {agents}" - agent: OneShotAgent = None # type: ignore - for a in world.agents.values(): - if a.type_name.split(":")[-1] == expected_type: - agent = a # type: ignore - break - else: - warnings.warn(f"cannot find any agent of type {expected_type}") - return False - if not isin(world.n_processes, self.n_processes): - warnings.warn( - f"Invalid n_processes: {world.n_processes=} != {self.n_processes=}" - ) - return False - if not isin(agent.awi.n_lines, self.n_lines): - warnings.warn(f"Invalid n_lines: {agent.awi.n_lines=} != {self.n_lines=}") - return False - # TODO: check non-competitor types + for agent_type in types: + expected_type = agent_type._type_name() + agents = [ + i + for i, type_ in enumerate(world.agent_types) + if type_.split(":")[-1] == expected_type + ] + assert ( + len(agents) == 1 + ), f"Found the following agent of type {agent_type}: {agents}" + agent: OneShotAgent = None # type: ignore + for a in world.agents.values(): + if a.type_name.split(":")[-1] == expected_type: + agent = a # type: ignore + break + else: + warnings.warn(f"cannot find any agent of type {expected_type}") + return False + if not isin(world.n_processes, self.n_processes): + warnings.warn( + f"Invalid n_processes: {world.n_processes=} != {self.n_processes=}" + ) + return False + if not isin(agent.awi.n_lines, self.n_lines): + warnings.warn( + f"Invalid n_lines: {agent.awi.n_lines=} != {self.n_lines=}" + ) + return False + # TODO: check non-competitor types return self.is_valid_awi(agent.awi) # type: ignore def is_valid_awi(self, awi: OneShotAWI) -> bool: @@ -491,9 +530,7 @@ def contains_factory(self, factory: WorldFactory) -> bool: return False if not isin(factory.n_lines, self.n_lines): return False - if set(factory.non_competitors).difference( - list(self.non_competitors) + [self.agent_type] - ): + if set(factory.non_competitors).difference(list(self.non_competitors)): return False return True @@ -504,7 +541,11 @@ class ANACOneShotFactory(OneShotWorldFactory): year: int = 2023 - def make(self) -> SCML2023OneShotWorld: + def make( + self, + types: tuple[type[OneShotAgent], ...] = (OneShotDummyAgent,), + params: tuple[dict[str, Any], ...] | None = None, + ) -> SCML2020OneShotWorld: """Generates a world""" type_ = { 2023: SCML2023OneShotWorld, @@ -512,39 +553,45 @@ def make(self) -> SCML2023OneShotWorld: 2021: SCML2021OneShotWorld, }[self.year] d = type_.generate() - n = len(d["agent_types"]) - i = random.randint(0, n - 1) - d["agent_params"][i].update( - dict( - controller_type=self.agent_type, - controller_params=self.agent_params - if self.agent_params is not None - else dict(), - ) - ) + if types: + if params is None: + params = tuple(dict() for _ in types) + n = len(d["agent_types"]) + indices = random.sample(range(n), k=len(types)) + for i, agent_type, p in zip(indices, types, params): + d["agent_params"][i].update( + dict( + controller_type=agent_type, + controller_params=p, + ) + ) return type_(**d, one_offer_per_step=True) - def is_valid_world(self, world: SCML2023OneShotWorld) -> bool: + def is_valid_world( + self, + world: SCML2020OneShotWorld, + types: tuple[type[OneShotAgent], ...] = (OneShotDummyAgent,), + ) -> bool: """Checks that the given world could have been generated from this factory""" - agent_type = self.agent_type - expected_type = agent_type._type_name() - agents = [ - i - for i, type_ in enumerate(world.agent_types) - if type_.split(":")[-1] == expected_type - ] - assert ( - len(agents) == 1 - ), f"Found the following agent of type {agent_type}: {agents}" - agent: OneShotAgent = None # type: ignore - for a in world.agents.values(): - if a.type_name.split(":")[-1] == expected_type: - agent = a # type: ignore - break - else: - warnings.warn(f"cannot find any agent of type {expected_type}") - return False + for agent_type in types: + expected_type = agent_type._type_name() + agents = [ + i + for i, type_ in enumerate(world.agent_types) + if type_.split(":")[-1] == expected_type + ] + assert ( + len(agents) == 1 + ), f"Found the following agent of type {agent_type}: {agents}" + agent: OneShotAgent = None # type: ignore + for a in world.agents.values(): + if a.type_name.split(":")[-1] == expected_type: + agent = a # type: ignore + break + else: + warnings.warn(f"cannot find any agent of type {expected_type}") + return False return self.is_valid_awi(agent.awi) # type: ignore def is_valid_awi(self, awi: OneShotAWI) -> bool: diff --git a/tests/test_rl.py b/tests/test_rl.py index e91d26d0..308cbf2f 100644 --- a/tests/test_rl.py +++ b/tests/test_rl.py @@ -1,13 +1,20 @@ +import random + +import numpy as np +from negmas.gb.common import ResponseType +from negmas.sao.common import SAOResponse from pytest import mark from scml.common import intin +from scml.oneshot.common import QUANTITY from scml.oneshot.rl.action import ( + ActionManager, FixedPartnerNumbersActionManager, LimitedPartnerNumbersActionManager, UnconstrainedActionManager, ) from scml.oneshot.rl.agent import OneShotRLAgent -from scml.oneshot.rl.common import RLModel, model_wrapper +from scml.oneshot.rl.common import model_wrapper from scml.oneshot.rl.env import OneShotEnv from scml.oneshot.rl.factory import ( FixedPartnerNumbersOneShotFactory, @@ -100,13 +107,14 @@ def test_training(type_): def test_rl_agent_fallback(): - factory = FixedPartnerNumbersOneShotFactory(agent_type=OneShotRLAgent) + factory = FixedPartnerNumbersOneShotFactory() action, obs = ( FixedPartnerNumbersActionManager(factory), FixedPartnerNumbersObservationManager(factory), ) - world, agent = factory() - assert isinstance(agent._obj, OneShotRLAgent), agent.type_name # type: ignore + world, agents = factory(types=(OneShotRLAgent,)) + assert len(agents) == 1 + assert isinstance(agents[0]._obj, OneShotRLAgent), agent.type_name # type: ignore world.run() @@ -120,12 +128,53 @@ def test_rl_agent_with_a_trained_model(): factory = LimitedPartnerNumbersOneShotFactory() obs = LimitedPartnerNumbersObservationManager(factory) - factory = LimitedPartnerNumbersOneShotFactory( - agent_type=OneShotRLAgent, - agent_params=dict(models=[model_wrapper(model)], observation_managers=[obs]), + world, agent = factory( + types=(OneShotRLAgent,), + params=(dict(models=[model_wrapper(model)], observation_managers=[obs]),), ) - world, agent = factory() assert isinstance(agent._obj, OneShotRLAgent), agent.type_name # type: ignore world.step() assert agent._valid_index == 0 # type: ignore world.run() + + +# @mark.parametrize( +# "type_", +# [ +# LimitedPartnerNumbersActionManager, +# FixedPartnerNumbersActionManager, +# UnconstrainedActionManager, +# ], +# ) +# def test_action_manager(type_: type[ActionManager]): +# factory = FixedPartnerNumbersOneShotFactory() +# manager = type_(factory) +# space = manager.make_space() +# world, agents = factory() +# for _ in range(100): +# agent = agents[0] +# # action = space.sample() +# responses = dict() +# awi = agent.awi +# for aid, nmi in awi.state.running_sell_nmis.items(): +# mine_indx = [i for i, x in enumerate(nmi.agent_ids) if x == agent.id][0] +# partner_indx = [i for i, x in enumerate(nmi.agent_ids) if x != agent.id][0] +# partner = [x for i, x in enumerate(nmi.agent_ids) if x != agent.id][0] +# resp = random.choice( +# [ +# ResponseType.REJECT_OFFER, +# ResponseType.END_NEGOTIATION, +# ResponseType.ACCEPT_OFFER, +# ] +# ) +# responses[partner] = SAOResponse( +# resp, +# awi.current_output_outcome_space.random_outcome() +# if resp != ResponseType.END_NEGOTIATION +# else None, +# ) +# world.step(1, neg_actions={agent.id: responses}) +# action = manager.encode(awi, responses) +# decoded = manager.decode(awi, action) +# encoded = manager.encode(awi, decoded) +# assert np.all(np.isclose(action, encoded)), f"{action=}\n{decoded=}\n{encoded=}"