diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6b0ab48c4..28abfed8b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,7 @@ repos: hooks: - id: flake8 args: - - '--per-file-ignores=*/__init__.py:F401 gymnasium/envs/registration.py:E704 docs/tutorials/*.py:E402' + - '--per-file-ignores=*/__init__.py:F401 gymnasium/envs/registration.py:E704 docs/tutorials/*.py:E402 gymnasium/experimental/wrappers/__init__.py:E402' - --ignore=E203,W503,E741 - --max-complexity=30 - --max-line-length=456 diff --git a/docs/api/experimental.md b/docs/api/experimental.md new file mode 100644 index 000000000..670d2baab --- /dev/null +++ b/docs/api/experimental.md @@ -0,0 +1,214 @@ +--- +title: Experimental +--- + +# Experimental + +```{toctree} +:hidden: +experimental/functional +experimental/wrappers +experimental/vector +experimental/vector_wrappers +``` + +## Functional Environments + +The gymnasium ``Env`` provides high flexibility for the implementation of individual environments however this can complicate parallelism of environments. Therefore, we propose the :class:`gymnasium.experimental.FuncEnv` where each part of environment has its own function related to it. + +## Wrappers + +Gymnasium already contains a large collection of wrappers, but we believe that the wrappers can be improved to + + * Support arbitrarily complex observation / action spaces. As RL has advanced, action and observation spaces are becoming more complex and the current wrappers were not implemented with these spaces in mind. + * Support for numpy, jax and pytorch data. With hardware accelerated environments, i.e. Brax, written in Jax and similar pytorch based programs, numpy is not the only game in town anymore. Therefore, these upgrades will use Jumpy for calling numpy, jax and torch depending on the data. + * More wrappers. Projects like Supersuit aimed to bring more wrappers for RL however wrappers can be moved into Gymnasium. + * Versioning. Like environments, the implementation details of wrapper can cause changes agent performance. Therefore, we propose adding version numbers with all wrappers. + + * In v28, we aim to rewrite the VectorEnv to not inherit from Env, as a result new vectorised versions of the wrappers will be provided. + +### Lambda Observation Wrappers +```{eval-rst} +.. py:currentmodule:: gymnasium + +.. list-table:: + :header-rows: 1 + + * - Old name + - New name + - Vector version + - Tree structure + * - :class:`wrappers.TransformObservation` + - :class:`experimental.wrappers.LambdaObservationV0` + - VectorLambdaObservation + - No + * - :class:`wrappers.FilterObservation` + - FilterObservation + - VectorFilterObservation (*) + - Yes + * - :class:`wrappers.FlattenObservation` + - FlattenObservation + - VectorFlattenObservation (*) + - No + * - :class:`wrappers.GrayScaleObservation` + - GrayscaleObservation + - VectorGrayscaleObservation (*) + - Yes + * - :class:`wrappers.PixelObservationWrapper` + - PixelObservation + - VectorPixelObservation (*) + - No + * - :class:`wrappers.ResizeObservation` + - ResizeObservation + - VectorResizeObservation (*) + - Yes + * - Not Implemented + - ReshapeObservation + - VectorReshapeObservation (*) + - Yes + * - Not Implemented + - RescaleObservation + - VectorRescaleObservation (*) + - Yes + * - Not Implemented + - DtypeObservation + - VectorDtypeObservation (*) + - Yes + * - :class:`NormalizeObservation` + - NormalizeObservation + - VectorNormalizeObservation + - No + * - :class:`TimeAwareObservation` + - TimeAwareObservation + - VectorTimeAwareObservation + - No + * - :class:`FrameStack` + - FrameStackObservation + - VectorFrameStackObservation + - No + * - Not Implemented + - DelayObservation + - VectorDelayObservation + - No + * - :class:`AtariPreprocessing` + - AtariPreprocessing + - Not Implemented + - No +``` + +### Lambda Action Wrappers +```{eval-rst} +.. py:currentmodule:: gymnasium + +.. list-table:: + :header-rows: 1 + + * - Old name + - New name + - Vector version + - Tree structure + * - Not Implemented + - :class:`experimental.wrappers.LambdaActionV0` + - VectorLambdaAction + - No + * - :class:`wrappers.ClipAction` + - ClipAction + - VectorClipAction (*) + - Yes + * - :class:`wrappers.RescaleAction` + - RescaleAction + - VectorRescaleAction (*) + - Yes + * - Not Implemented + - NanAction + - VectorNanAction (*) + - Yes + * - Not Implemented + - StickyAction + - VectorStickyAction + - No +``` + +### Lambda Reward Wrappers +```{eval-rst} +.. py:currentmodule:: gymnasium + +.. list-table:: + :header-rows: 1 + + * - Old name + - New name + - Vector version + * - :class:`wrappers.TransformReward` + - :class:`experimental.wrappers.LambdaRewardV0` + - VectorLambdaReward + * - Not Implemented + - :class:`experimental.wrappers.ClipRewardV0` + - VectorClipReward (*) + * - Not Implemented + - RescaleReward + - VectorRescaleReward (*) + * - :class:`wrappers.NormalizeReward` + - NormalizeReward + - VectorNormalizeReward +``` + +### Common Wrappers +```{eval-rst} +.. py:currentmodule:: gymnasium + +.. list-table:: + :header-rows: 1 + + * - Old name + - New name + - Vector version + * - :class:`wrappers.AutoResetWrapper` + - AutoReset + - VectorAutoReset + * - :class:`wrappers.PassiveEnvChecker` + - PassiveEnvChecker + - VectorPassiveEnvChecker + * - :class:`wrappers.OrderEnforcing` + - OrderEnforcing + - VectorOrderEnforcing (*) + * - :class:`wrappers.EnvCompatibility` + - Moved to `shimmy `_ + - Not Implemented + * - :class:`RecordEpisodeStatistics` + - RecordEpisodeStatistics + - VectorRecordEpisodeStatistics + * - :class:`RenderCollection` + - RenderCollection + - VectorRenderCollection + * - :class:`HumanRendering` + - HumanRendering + - Not Implemented + * - Not Implemented + - JaxToNumpy + - VectorJaxToNumpy + * - Not Implemented + - JaxToTorch + - VectorJaxToTorch +``` + +### Vector Only Wrappers +```{eval-rst} +.. py:currentmodule:: gymnasium + +.. list-table:: + :header-rows: 1 + + * - Old name + - New name + * - :class:`wrappers.VectorListInfo` + - VectorListInfo +``` + +## Vector Environment + +These changes will be made in v0.28 + +## Wrappers for Vector Environments + +These changes will be made in v0.28 diff --git a/docs/api/experimental/functional.md b/docs/api/experimental/functional.md new file mode 100644 index 000000000..24eaa9b65 --- /dev/null +++ b/docs/api/experimental/functional.md @@ -0,0 +1,36 @@ +--- +title: Functional +--- + +# Functional Environment + +## gymnasium.experimental.FuncEnv + +```{eval-rst} +.. autoclass:: gymnasium.experimental.FuncEnv + +.. autofunction:: gymnasium.experimental.FuncEnv.initial +.. autofunction:: gymnasium.experimental.FuncEnv.transition + +.. autofunction:: gymnasium.experimental.FuncEnv.observation +.. autofunction:: gymnasium.experimental.FuncEnv.initial + +.. autofunction:: gymnasium.experimental.FuncEnv.observation +.. autofunction:: gymnasium.experimental.FuncEnv.reward +.. autofunction:: gymnasium.experimental.FuncEnv.terminal + +.. autofunction:: gymnasium.experimental.FuncEnv.state_info +.. autofunction:: gymnasium.experimental.FuncEnv.step_info + +.. autofunction:: gymnasium.experimental.FuncEnv.transform + +.. autofunction:: gymnasium.experimental.FuncEnv.render_image +.. autofunction:: gymnasium.experimental.FuncEnv.render_init +.. autofunction:: gymnasium.experimental.FuncEnv.render_close +``` + +## gymnasium.experimental.func2env.FunctionalJaxCompatibilityEnv + +```{eval-rst} +... autoclass:: gymasnium.experimental.func2env.FunctionalJaxCompatibilityEnv +``` \ No newline at end of file diff --git a/docs/api/experimental/vector.md b/docs/api/experimental/vector.md new file mode 100644 index 000000000..393b7ec21 --- /dev/null +++ b/docs/api/experimental/vector.md @@ -0,0 +1,15 @@ +--- +title: Vector +--- + +# Vectorizing Environment + +## gymnasium.experimental.VectorEnv + +## gymnasium.experimental.vector.AsyncVectorEnv + +## gymnasium.experimental.vector.SyncVectorEnv + +## Custom Vector environments + +## EnvPool diff --git a/docs/api/experimental/vector_wrappers.md b/docs/api/experimental/vector_wrappers.md new file mode 100644 index 000000000..71e64f4f1 --- /dev/null +++ b/docs/api/experimental/vector_wrappers.md @@ -0,0 +1,15 @@ +--- +title: Vector Wrappers +--- + +# Vector Environment Wrappers + +## Vector Lambda Observation Wrappers + +## Vector Lambda Action Wrappers + +## Vector Lambda Reward Wrappers + +## Vector Common Wrappers + +## Vector Only Wrappers diff --git a/docs/api/experimental/wrappers.md b/docs/api/experimental/wrappers.md new file mode 100644 index 000000000..6ae00794e --- /dev/null +++ b/docs/api/experimental/wrappers.md @@ -0,0 +1,26 @@ +# Wrappers + +## Lambda Observation Wrappers + +```{eval-rst} +.. autoclass:: gymnasium.experimental.wrappers.LambdaObservationV0 +``` + +## Lambda Action Wrappers + +```{eval-rst} +.. autoclass:: gymnasium.experimental.wrappers.LambdaActionV0 +``` + +## Lambda Reward Wrappers + +```{eval-rst} +.. autoclass:: gymnasium.experimental.wrappers.LambdaRewardV0 +.. autoclass:: gymnasium.experimental.wrappers.ClipRewardV0 +``` + +## Common Wrappers + +```{eval-rst} + +``` \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index 3ba19bcf3..e8dafabe3 100644 --- a/docs/index.md +++ b/docs/index.md @@ -48,6 +48,7 @@ api/spaces api/wrappers api/vector api/utils +api/experimental ``` ```{toctree} diff --git a/gymnasium/__init__.py b/gymnasium/__init__.py index f870da348..218588f1e 100644 --- a/gymnasium/__init__.py +++ b/gymnasium/__init__.py @@ -10,10 +10,8 @@ ) from gymnasium.spaces.space import Space from gymnasium.envs.registration import make, spec, register, registry, pprint_registry -from gymnasium import envs, spaces, utils, vector, wrappers, error, logger +from gymnasium import envs, spaces, utils, vector, wrappers, error, logger, experimental -import os -import sys __all__ = [ # core classes @@ -37,6 +35,7 @@ "wrappers", "error", "logger", + "experimental", ] __version__ = "0.26.3" @@ -45,6 +44,9 @@ # pygame # DSP is far more benign (and should probably be the default in SDL anyways) +import os +import sys + if sys.platform.startswith("linux"): os.environ["SDL_AUDIODRIVER"] = "dsp" diff --git a/gymnasium/dev_wrappers/__init__.py b/gymnasium/dev_wrappers/__init__.py deleted file mode 100644 index 645710bef..000000000 --- a/gymnasium/dev_wrappers/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -"""Root __init__ of the gym dev_wrappers.""" -from typing import TypeVar - -ArgType = TypeVar("ArgType") diff --git a/gymnasium/envs/phys2d/__init__.py b/gymnasium/envs/phys2d/__init__.py index c6e51c28c..f00c65601 100644 --- a/gymnasium/envs/phys2d/__init__.py +++ b/gymnasium/envs/phys2d/__init__.py @@ -1,2 +1,2 @@ -from gymnasium.envs.phys2d.cartpole import CartPoleF -from gymnasium.envs.phys2d.pendulum import PendulumF +from gymnasium.envs.phys2d.cartpole import CartPoleFunctional +from gymnasium.envs.phys2d.pendulum import PendulumFunctional diff --git a/gymnasium/envs/phys2d/cartpole.py b/gymnasium/envs/phys2d/cartpole.py index 3d4c46168..e2b8a52e0 100644 --- a/gymnasium/envs/phys2d/cartpole.py +++ b/gymnasium/envs/phys2d/cartpole.py @@ -10,41 +10,42 @@ from jax.random import PRNGKey import gymnasium as gym -from gymnasium.envs.phys2d.conversion import JaxEnv from gymnasium.error import DependencyNotInstalled -from gymnasium.functional import ActType, FuncEnv, StateType +from gymnasium.experimental.func_jax_env import FunctionalJaxEnv +from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.utils import EzPickle RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821 -class CartPoleF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]): +class CartPoleFunctional( + FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType] +): """Cartpole but in jax and functional. Example usage: - ``` - import jax - import jax.numpy as jnp - key = jax.random.PRNGKey(0) + >>> import jax + >>> import jax.numpy as jnp - env = CartPole({"x_init": 0.5}) - state = env.initial(key) - print(state) - print(env.step(state, 0)) + >>> key = jax.random.PRNGKey(0) - env.transform(jax.jit) + >>> env = CartPole({"x_init": 0.5}) + >>> state = env.initial(key) + >>> print(state) + >>> print(env.step(state, 0)) - state = env.initial(key) - print(state) - print(env.step(state, 0)) + >>> env.transform(jax.jit) - vkey = jax.random.split(key, 10) - env.transform(jax.vmap) - vstate = env.initial(vkey) - print(vstate) - print(env.step(vstate, jnp.array([0 for _ in range(10)]))) - ``` + >>> state = env.initial(key) + >>> print(state) + >>> print(env.step(state, 0)) + + >>> vkey = jax.random.split(key, 10) + >>> env.transform(jax.vmap) + >>> vstate = env.initial(vkey) + >>> print(vstate) + >>> print(env.step(vstate, jnp.array([0 for _ in range(10)]))) """ gravity = 9.8 @@ -232,13 +233,13 @@ def render_close(self, render_state: RenderStateType) -> None: pygame.quit() -class CartPoleJaxEnv(JaxEnv, EzPickle): +class CartPoleJaxEnv(FunctionalJaxEnv, EzPickle): metadata = {"render_modes": ["rgb_array"], "render_fps": 50} def __init__(self, render_mode: Optional[str] = None, **kwargs): EzPickle.__init__(self, render_mode=render_mode, **kwargs) - env = CartPoleF(**kwargs) + env = CartPoleFunctional(**kwargs) env.transform(jax.jit) action_space = env.action_space observation_space = env.observation_space diff --git a/gymnasium/envs/phys2d/pendulum.py b/gymnasium/envs/phys2d/pendulum.py index 687f87626..3ba3b5751 100644 --- a/gymnasium/envs/phys2d/pendulum.py +++ b/gymnasium/envs/phys2d/pendulum.py @@ -10,15 +10,17 @@ from jax.random import PRNGKey import gymnasium as gym -from gymnasium.envs.phys2d.conversion import JaxEnv from gymnasium.error import DependencyNotInstalled -from gymnasium.functional import ActType, FuncEnv, StateType +from gymnasium.experimental.func_jax_env import FunctionalJaxEnv +from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.utils import EzPickle RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821 -class PendulumF(FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType]): +class PendulumFunctional( + FuncEnv[jnp.ndarray, jnp.ndarray, int, float, bool, RenderStateType] +): """Pendulum but in jax and functional.""" max_speed = 8 @@ -180,13 +182,13 @@ def render_close(self, render_state: RenderStateType) -> None: pygame.quit() -class PendulumJaxEnv(JaxEnv, EzPickle): +class PendulumJaxEnv(FunctionalJaxEnv, EzPickle): metadata = {"render_modes": ["rgb_array"], "render_fps": 30} def __init__(self, render_mode: Optional[str] = None, **kwargs): EzPickle.__init__(self, render_mode=render_mode, **kwargs) - env = PendulumF(**kwargs) + env = PendulumFunctional(**kwargs) env.transform(jax.jit) action_space = env.action_space observation_space = env.observation_space diff --git a/gymnasium/experimental/__init__.py b/gymnasium/experimental/__init__.py new file mode 100644 index 000000000..70ba0a3cd --- /dev/null +++ b/gymnasium/experimental/__init__.py @@ -0,0 +1,12 @@ +"""Root __init__ of the gym dev_wrappers.""" + + +from gymnasium.experimental.functional import FuncEnv + +__all__ = [ + # Functional + "FuncEnv", + "functional", + # Wrapper + "wrappers", +] diff --git a/gymnasium/envs/phys2d/conversion.py b/gymnasium/experimental/func_jax_env.py similarity index 77% rename from gymnasium/envs/phys2d/conversion.py rename to gymnasium/experimental/func_jax_env.py index 430b8d884..90bf0aa02 100644 --- a/gymnasium/envs/phys2d/conversion.py +++ b/gymnasium/experimental/func_jax_env.py @@ -1,4 +1,7 @@ -from typing import Any, Dict, Optional, Tuple +"""Functional to Environment compatibility.""" +from __future__ import annotations + +from typing import Any import jax.numpy as jnp import jax.random as jrng @@ -7,14 +10,12 @@ import gymnasium as gym from gymnasium import Space from gymnasium.envs.registration import EnvSpec -from gymnasium.functional import ActType, FuncEnv, StateType +from gymnasium.experimental.functional import ActType, FuncEnv, StateType from gymnasium.utils import seeding -class JaxEnv(gym.Env): - """ - A conversion layer for numpy-based environments. - """ +class FunctionalJaxEnv(gym.Env): + """A conversion layer for jax-based environments.""" state: StateType rng: jrng.PRNGKey @@ -24,20 +25,24 @@ def __init__( func_env: FuncEnv, observation_space: Space, action_space: Space, - metadata: Optional[Dict[str, Any]] = None, - render_mode: Optional[str] = None, - reward_range: Tuple[float, float] = (-float("inf"), float("inf")), - spec: Optional[EnvSpec] = None, + metadata: dict[str, Any] | None = None, + render_mode: str | None = None, + reward_range: tuple[float, float] = (-float("inf"), float("inf")), + spec: EnvSpec | None = None, ): """Initialize the environment from a FuncEnv.""" if metadata is None: - metadata = {} + metadata = {"render_mode": []} + self.func_env = func_env + self.observation_space = observation_space self.action_space = action_space + self.metadata = metadata self.render_mode = render_mode self.reward_range = reward_range + self.spec = spec self._is_box_action_space = isinstance(self.action_space, gym.spaces.Box) @@ -52,7 +57,8 @@ def __init__( self.rng = jrng.PRNGKey(seed) - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + def reset(self, *, seed: int | None = None, options: dict | None = None): + """Resets the environment using the seed.""" super().reset(seed=seed) if seed is not None: self.rng = jrng.PRNGKey(seed) @@ -68,6 +74,7 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): return obs, info def step(self, action: ActType): + """Steps through the environment using the action.""" if self._is_box_action_space: assert isinstance(self.action_space, gym.spaces.Box) # For typing action = np.clip(action, self.action_space.low, self.action_space.high) @@ -90,6 +97,7 @@ def step(self, action: ActType): return observation, float(reward), bool(terminated), False, info def render(self): + """Returns the render state if `render_mode` is "rgb_array".""" if self.render_mode == "rgb_array": self.render_state, image = self.func_env.render_image( self.state, self.render_state @@ -99,15 +107,16 @@ def render(self): raise NotImplementedError def close(self): + """Closes the environments and render state if set.""" if self.render_state is not None: self.func_env.render_close(self.render_state) self.render_state = None def _convert_jax_to_numpy(element: Any): - """ - Convert a jax observation/action to a numpy array, or a numpy-based container. - Currently required because all tests assume that stuff is in numpy arrays, hopefully will be removed soon. + """Convert a jax observation/action to a numpy array, or a numpy-based container. + + Requires as all tests assume that data is in numpy arrays, to be removed soon. """ if isinstance(element, jnp.ndarray): return np.asarray(element) diff --git a/gymnasium/functional.py b/gymnasium/experimental/functional.py similarity index 94% rename from gymnasium/functional.py rename to gymnasium/experimental/functional.py index 86e900e07..b67c8b16e 100644 --- a/gymnasium/functional.py +++ b/gymnasium/experimental/functional.py @@ -1,6 +1,7 @@ """Base class and definitions for an alternative, functional backend for gym envs, particularly suitable for hardware accelerated and otherwise transformed environments.""" +from __future__ import annotations -from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeVar +from typing import Any, Callable, Generic, TypeVar import numpy as np @@ -35,7 +36,7 @@ class FuncEnv( we intend to flesh it out and officially expose it to end users. """ - def __init__(self, options: Optional[Dict[str, Any]] = None): + def __init__(self, options: dict[str, Any] | None = None): """Initialize the environment constants.""" self.__dict__.update(options or {}) @@ -43,14 +44,14 @@ def initial(self, rng: Any) -> StateType: """Initial state.""" raise NotImplementedError - def observation(self, state: StateType) -> ObsType: - """Observation.""" - raise NotImplementedError - def transition(self, state: StateType, action: ActType, rng: Any) -> StateType: """Transition.""" raise NotImplementedError + def observation(self, state: StateType) -> ObsType: + """Observation.""" + raise NotImplementedError + def reward( self, state: StateType, action: ActType, next_state: StateType ) -> RewardType: @@ -83,7 +84,7 @@ def transform(self, func: Callable[[Callable], Callable]): def render_image( self, state: StateType, render_state: RenderStateType - ) -> Tuple[RenderStateType, np.ndarray]: + ) -> tuple[RenderStateType, np.ndarray]: """Show the state.""" raise NotImplementedError diff --git a/gymnasium/experimental/wrappers/__init__.py b/gymnasium/experimental/wrappers/__init__.py new file mode 100644 index 000000000..556272fea --- /dev/null +++ b/gymnasium/experimental/wrappers/__init__.py @@ -0,0 +1,21 @@ +"""Experimental Wrappers.""" +# isort: skip_file + +from typing import TypeVar + +ArgType = TypeVar("ArgType") + +from gymnasium.experimental.wrappers.lambda_action import LambdaActionV0 +from gymnasium.experimental.wrappers.lambda_observations import LambdaObservationV0 +from gymnasium.experimental.wrappers.lambda_reward import ClipRewardV0, LambdaRewardV0 + +__all__ = [ + "ArgType", + # Lambda Action + "LambdaActionV0", + # Lambda Observation + "LambdaObservationV0", + # Lambda Reward + "LambdaRewardV0", + "ClipRewardV0", +] diff --git a/gymnasium/dev_wrappers/lambda_action.py b/gymnasium/experimental/wrappers/lambda_action.py similarity index 93% rename from gymnasium/dev_wrappers/lambda_action.py rename to gymnasium/experimental/wrappers/lambda_action.py index cb7ea1e10..b858b5970 100644 --- a/gymnasium/dev_wrappers/lambda_action.py +++ b/gymnasium/experimental/wrappers/lambda_action.py @@ -4,7 +4,7 @@ import gymnasium as gym from gymnasium.core import ActType -from gymnasium.dev_wrappers import ArgType +from gymnasium.experimental.wrappers import ArgType class LambdaActionV0(gym.ActionWrapper): diff --git a/gymnasium/dev_wrappers/lambda_observations.py b/gymnasium/experimental/wrappers/lambda_observations.py similarity index 87% rename from gymnasium/dev_wrappers/lambda_observations.py rename to gymnasium/experimental/wrappers/lambda_observations.py index 52e5d5564..92f4d2fb3 100644 --- a/gymnasium/dev_wrappers/lambda_observations.py +++ b/gymnasium/experimental/wrappers/lambda_observations.py @@ -4,10 +4,10 @@ import gymnasium as gym from gymnasium.core import ObsType -from gymnasium.dev_wrappers import ArgType +from gymnasium.experimental.wrappers import ArgType -class LambdaObservationsV0(gym.ObservationWrapper): +class LambdaObservationV0(gym.ObservationWrapper): """Lambda observation wrapper where a function is provided that is applied to the observation.""" def __init__( diff --git a/gymnasium/dev_wrappers/lambda_reward.py b/gymnasium/experimental/wrappers/lambda_reward.py similarity index 90% rename from gymnasium/dev_wrappers/lambda_reward.py rename to gymnasium/experimental/wrappers/lambda_reward.py index 27f7e3b02..2a0393812 100644 --- a/gymnasium/dev_wrappers/lambda_reward.py +++ b/gymnasium/experimental/wrappers/lambda_reward.py @@ -5,8 +5,8 @@ import numpy as np import gymnasium as gym -from gymnasium.dev_wrappers import ArgType from gymnasium.error import InvalidBound +from gymnasium.experimental.wrappers import ArgType class LambdaRewardV0(gym.RewardWrapper): @@ -14,7 +14,7 @@ class LambdaRewardV0(gym.RewardWrapper): Example: >>> import gymnasium as gym - >>> from gymnasium.wrappers import LambdaRewardV0 + >>> from gymnasium.experimental.wrappers import LambdaRewardV0 >>> env = gym.make("CartPole-v1") >>> env = LambdaRewardV0(env, lambda r: 2 * r + 1) >>> _ = env.reset() @@ -47,14 +47,14 @@ def reward(self, reward: Union[float, int, np.ndarray]) -> Any: return self.func(reward) -class ClipRewardsV0(LambdaRewardV0): +class ClipRewardV0(LambdaRewardV0): """A wrapper that clips the rewards for an environment between an upper and lower bound. Example with an upper and lower bound: >>> import gymnasium as gym - >>> from gymnasium.wrappers import ClipRewardsV0 + >>> from gymnasium.experimental.wrappers import ClipRewardV0 >>> env = gym.make("CartPole-v1") - >>> env = ClipRewardsV0(env, 0, 0.5) + >>> env = ClipRewardV0(env, 0, 0.5) >>> env.reset() >>> _, rew, _, _, _ = env.step(1) >>> rew diff --git a/gymnasium/wrappers/__init__.py b/gymnasium/wrappers/__init__.py index 17925999a..152dc4a22 100644 --- a/gymnasium/wrappers/__init__.py +++ b/gymnasium/wrappers/__init__.py @@ -1,8 +1,4 @@ """Module of wrapper classes.""" -from gymnasium import error -from gymnasium.dev_wrappers.lambda_action import LambdaActionV0 -from gymnasium.dev_wrappers.lambda_observations import LambdaObservationsV0 -from gymnasium.dev_wrappers.lambda_reward import ClipRewardsV0, LambdaRewardV0 from gymnasium.wrappers.atari_preprocessing import AtariPreprocessing from gymnasium.wrappers.autoreset import AutoResetWrapper from gymnasium.wrappers.clip_action import ClipAction diff --git a/tests/dev_wrappers/test_lambda_rewards/__init__.py b/tests/dev_wrappers/test_lambda_rewards/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/envs/functional/__init__.py b/tests/envs/functional/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/dev_wrappers/__init__.py b/tests/experimental/__init__.py similarity index 100% rename from tests/dev_wrappers/__init__.py rename to tests/experimental/__init__.py diff --git a/tests/dev_wrappers/test_lambda_actions/__init__.py b/tests/experimental/functional/__init__.py similarity index 100% rename from tests/dev_wrappers/test_lambda_actions/__init__.py rename to tests/experimental/functional/__init__.py diff --git a/tests/envs/functional/test_core.py b/tests/experimental/functional/test_core.py similarity index 93% rename from tests/envs/functional/test_core.py rename to tests/experimental/functional/test_core.py index 6c94fd8dc..d1282bd4e 100644 --- a/tests/envs/functional/test_core.py +++ b/tests/experimental/functional/test_core.py @@ -2,10 +2,10 @@ import numpy as np -from gymnasium.functional import FuncEnv +from gymnasium.experimental import FuncEnv -class TestEnv(FuncEnv): +class GenericTestFuncEnv(FuncEnv): def __init__(self, options: Optional[Dict[str, Any]] = None): super().__init__(options) @@ -26,7 +26,7 @@ def terminal(self, state: np.ndarray) -> bool: def test_api(): - env = TestEnv() + env = GenericTestFuncEnv() state = env.initial(None) obs = env.observation(state) assert state.shape == (2,) diff --git a/tests/envs/functional/test_jax.py b/tests/experimental/functional/test_jax.py similarity index 88% rename from tests/envs/functional/test_jax.py rename to tests/experimental/functional/test_jax.py index 41aca4ec5..284b1a748 100644 --- a/tests/envs/functional/test_jax.py +++ b/tests/experimental/functional/test_jax.py @@ -4,11 +4,11 @@ import numpy as np import pytest -from gymnasium.envs.phys2d.cartpole import CartPoleF # noqa: E402 -from gymnasium.envs.phys2d.pendulum import PendulumF # noqa: E402 +from gymnasium.envs.phys2d.cartpole import CartPoleFunctional +from gymnasium.envs.phys2d.pendulum import PendulumFunctional -@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF]) +@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional]) def test_normal(env_class): env = env_class() rng = jrng.PRNGKey(0) @@ -40,7 +40,7 @@ def test_normal(env_class): state = next_state -@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF]) +@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional]) def test_jit(env_class): env = env_class() rng = jrng.PRNGKey(0) @@ -73,7 +73,7 @@ def test_jit(env_class): state = next_state -@pytest.mark.parametrize("env_class", [CartPoleF, PendulumF]) +@pytest.mark.parametrize("env_class", [CartPoleFunctional, PendulumFunctional]) def test_vmap(env_class): env = env_class() num_envs = 10 diff --git a/tests/dev_wrappers/test_lambda_observations/__init__.py b/tests/experimental/wrappers/__init__.py similarity index 100% rename from tests/dev_wrappers/test_lambda_observations/__init__.py rename to tests/experimental/wrappers/__init__.py diff --git a/tests/dev_wrappers/test_lambda_actions/test_lambda_action.py b/tests/experimental/wrappers/test_lambda_action.py similarity index 96% rename from tests/dev_wrappers/test_lambda_actions/test_lambda_action.py rename to tests/experimental/wrappers/test_lambda_action.py index 7dafa9713..2e3b71fcb 100644 --- a/tests/dev_wrappers/test_lambda_actions/test_lambda_action.py +++ b/tests/experimental/wrappers/test_lambda_action.py @@ -4,8 +4,8 @@ import gymnasium as gym from gymnasium.error import InvalidAction +from gymnasium.experimental.wrappers import LambdaActionV0 from gymnasium.spaces import Box -from gymnasium.wrappers import LambdaActionV0 from tests.testing_env import GenericTestEnv NUM_ENVS = 3 diff --git a/tests/dev_wrappers/test_lambda_observations/test_lambda_observation.py b/tests/experimental/wrappers/test_lambda_observation.py similarity index 91% rename from tests/dev_wrappers/test_lambda_observations/test_lambda_observation.py rename to tests/experimental/wrappers/test_lambda_observation.py index 02eba36d6..17e0d866d 100644 --- a/tests/dev_wrappers/test_lambda_observations/test_lambda_observation.py +++ b/tests/experimental/wrappers/test_lambda_observation.py @@ -3,8 +3,8 @@ import numpy as np import gymnasium as gym +from gymnasium.experimental.wrappers import LambdaObservationV0 from gymnasium.spaces import Box -from gymnasium.wrappers import LambdaObservationsV0 NUM_ENVS = 3 BOX_SPACE = Box(-5, 5, (1,), dtype=np.float64) @@ -25,7 +25,7 @@ def test_lambda_observation_v0(): observation_shift = 1 env.reset(seed=SEED) - wrapped_env = LambdaObservationsV0( + wrapped_env = LambdaObservationV0( env, lambda observation: observation + observation_shift ) wrapped_obs, _, _, _, _ = wrapped_env.step(DISCRETE_ACTION) @@ -48,7 +48,7 @@ def test_lambda_observation_v0_within_vector(): observation_shift = 1 env.reset(seed=SEED) - wrapped_env = LambdaObservationsV0( + wrapped_env = LambdaObservationV0( env, lambda observation: observation + observation_shift ) wrapped_obs, _, _, _, _ = wrapped_env.step( diff --git a/tests/dev_wrappers/test_lambda_rewards/test_lambda_rewards.py b/tests/experimental/wrappers/test_lambda_rewards.py similarity index 92% rename from tests/dev_wrappers/test_lambda_rewards/test_lambda_rewards.py rename to tests/experimental/wrappers/test_lambda_rewards.py index b395122f2..f5e843525 100644 --- a/tests/dev_wrappers/test_lambda_rewards/test_lambda_rewards.py +++ b/tests/experimental/wrappers/test_lambda_rewards.py @@ -5,7 +5,7 @@ import gymnasium as gym from gymnasium.error import InvalidBound -from gymnasium.wrappers import ClipRewardsV0, LambdaRewardV0 +from gymnasium.experimental.wrappers import ClipRewardV0, LambdaRewardV0 ENV_ID = "CartPole-v1" DISCRETE_ACTION = 0 @@ -65,7 +65,7 @@ def test_clip_reward(lower_bound, upper_bound, expected_reward): accordingly to the input args. """ env = gym.make(ENV_ID) - env = ClipRewardsV0(env, lower_bound, upper_bound) + env = ClipRewardV0(env, lower_bound, upper_bound) env.reset(seed=SEED) _, rew, _, _, _ = env.step(DISCRETE_ACTION) @@ -84,7 +84,7 @@ def test_clip_reward_within_vector(lower_bound, upper_bound, expected_reward): actions = [DISCRETE_ACTION for _ in range(NUM_ENVS)] env = gym.vector.make(ENV_ID, num_envs=NUM_ENVS) - env = ClipRewardsV0(env, lower_bound, upper_bound) + env = ClipRewardV0(env, lower_bound, upper_bound) env.reset(seed=SEED) _, rew, _, _, _ = env.step(actions) @@ -106,4 +106,4 @@ def test_clip_reward_incorrect_params(lower_bound, upper_bound): env = gym.make(ENV_ID) with pytest.raises(InvalidBound): - env = ClipRewardsV0(env, lower_bound, upper_bound) + env = ClipRewardV0(env, lower_bound, upper_bound)