Skip to content

Commit

Permalink
Allowing factories to create worlds with different agent types
Browse files Browse the repository at this point in the history
- Previously, a world factory was limited in two ways:
  1. OneShotWorldFactory had agent_type and agent_params as members.
     This meant that we need to pass the type when creating the factory
     but the parameters of the OneShotRLAgent should include the
     factory. They necessitated creating the factory twice when training
     and testing on a different world even from the same factory.
     Moreover, this parameter meant that the factory is always assigned
     to the same test agent type which is not necessary.
  2. There can be a single world controlled externally (e.g. from an RL
     policy). This is limiting us to use RL but we cannot use MARL. We
     resolved that by allowing __call__ (and make()) to take a tuple of
     types/params to use for multiple external agents. For RL
     applications, this should be a single-valued tuple.
  Both of these issues are resolved now by removing
  agent_type/agent_params from all WorldFactories and passing the
  types/params to the __call__() method.
  • Loading branch information
yasserfarouk committed Jul 12, 2023
1 parent 1ee6f42 commit 281ee9e
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 150 deletions.
31 changes: 24 additions & 7 deletions src/scml/oneshot/rl/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,40 @@ def isin(x: int | tuple[int, int], y: tuple[int, int] | int):
class WorldFactory(ABC):
"""Generates worlds satisfying predefined conditions and tests for them"""

def __call__(self) -> tuple[World, Agent]:
"""Generates a world with one agent to be controlled externally and returns both"""
def __call__(
self,
types: tuple[type[Agent], ...] = tuple(),
params: tuple[dict[str, Any], ...] | None = None,
) -> tuple[World, tuple[Agent]]:
"""
Generates a world with one or more agents to be controlled externally and returns both
Args:
agent_types: The types of a list of agents to be guaranteed to exist in the world
agent_params: The parameters to pass to the constructors of these agents. None means no parameters for any agents
Returns:
The constructed world and a tuple of the agents created corresponding (in order) to the given agent types/params
"""
...

@abstractmethod
def is_valid_world(self, world: World) -> bool:
"""Checks that the given world could have been generated from this generator"""
def is_valid_world(
self,
world: World,
types: tuple[type[Agent], ...] = tuple(),
) -> bool:
"""Checks that the given world could have been generated from this factory"""
...

@abstractmethod
def is_valid_awi(self, awi: AgentWorldInterface) -> bool:
"""Checks that the given AWI is connected to a world that could have been generated from this generator"""
"""Checks that the given AWI is connected to a world that could have been generated from this factory"""
...

@abstractmethod
def contains_factory(self, generator: "WorldFactory") -> bool:
"""Checks that the any world generated from the given `generator` could have been generated from this generator"""
def contains_factory(self, factory: "WorldFactory") -> bool:
"""Checks that the any world generated from the given `factory` could have been generated from this factory"""
...

def __contains__(
Expand Down
18 changes: 11 additions & 7 deletions src/scml/oneshot/rl/env.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
from typing import Any

import gymnasium as gym
import numpy as np
from gymnasium.envs.registration import register

from scml.common import intin, make_array
from scml.oneshot.agent import OneShotAgent
from scml.oneshot.awi import OneShotAWI
from scml.oneshot.agents import OneShotDummyAgent
from scml.oneshot.rl.action import ActionManager
from scml.oneshot.rl.factory import (
FixedPartnerNumbersOneShotFactory,
OneShotWorldFactory,
)
from scml.oneshot.rl.observation import ObservationManager
from scml.oneshot.world import SCML2023OneShotWorld
from scml.oneshot.world import SCML2020OneShotWorld

__all__ = ["OneShotEnv"]

Expand All @@ -25,6 +23,8 @@ def __init__(
observation_manager: ObservationManager,
render_mode=None,
factory: OneShotWorldFactory = FixedPartnerNumbersOneShotFactory(),
agent_type: type[OneShotAgent] = OneShotDummyAgent,
agent_params: dict[str, Any] | None = None,
extra_checks: bool = True,
):
assert action_manager.factory in factory, (
Expand All @@ -40,9 +40,9 @@ def __init__(
)
self._extra_checks = extra_checks

agent_type = factory.agent_type
self._world: SCML2023OneShotWorld = None # type: ignore
self._world: SCML2020OneShotWorld = None # type: ignore
self._agent_type = agent_type
self._agent_params = agent_params if agent_params is not None else dict()
self._agent_id: str = ""
self._agent: OneShotAgent = None # type: ignore
self._obs_manager = observation_manager
Expand Down Expand Up @@ -81,7 +81,11 @@ def reset(
import random

random.seed(seed)
self._world, self._agent = self._factory()
self._world, agents = self._factory(
types=(self._agent_type,), params=(self._agent_params,)
)
assert len(agents) == 1
self._agent = agents[0]
if self._extra_checks:
assert self._world in self._factory
self._agent_id = self._agent.id
Expand Down
Loading

0 comments on commit 281ee9e

Please sign in to comment.