Skip to content

Commit

Permalink
Fix type annotations of callable to Callable (openai#259)
Browse files Browse the repository at this point in the history
  • Loading branch information
ianyfan authored Jan 11, 2023
1 parent 4e6dc3e commit b34c25b
Show file tree
Hide file tree
Showing 10 changed files with 32 additions and 25 deletions.
2 changes: 1 addition & 1 deletion gymnasium/envs/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
)


def load(name: str) -> callable:
def load(name: str) -> Callable:
"""Loads an environment with name and returns an environment creation function.
Args:
Expand Down
2 changes: 1 addition & 1 deletion gymnasium/utils/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ class PlayPlot:
"""

def __init__(
self, callback: callable, horizon_timesteps: int, plot_names: List[str]
self, callback: Callable, horizon_timesteps: int, plot_names: List[str]
):
"""Constructor of :class:`PlayPlot`.
Expand Down
9 changes: 5 additions & 4 deletions gymnasium/vector/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Module for vector environments."""
from typing import Iterable, List, Optional, Union
from typing import Callable, Iterable, List, Optional, Union

import gymnasium as gym
from gymnasium.core import Env
from gymnasium.vector.async_vector_env import AsyncVectorEnv
from gymnasium.vector.sync_vector_env import SyncVectorEnv
from gymnasium.vector.vector_env import VectorEnv, VectorEnvWrapper
Expand All @@ -14,7 +15,7 @@ def make(
id: str,
num_envs: int = 1,
asynchronous: bool = True,
wrappers: Optional[Union[callable, List[callable]]] = None,
wrappers: Optional[Union[Callable[[Env], Env], List[Callable[[Env], Env]]]] = None,
disable_env_checker: Optional[bool] = None,
**kwargs,
) -> VectorEnv:
Expand Down Expand Up @@ -43,12 +44,12 @@ def make(
The vectorized environment.
"""

def create_env(env_num: int):
def create_env(env_num: int) -> Callable[[], Env]:
"""Creates an environment that can enable or disable the environment checker."""
# If the env_num > 0 then disable the environment checker otherwise use the parameter
_disable_env_checker = True if env_num > 0 else disable_env_checker

def _make_env():
def _make_env() -> Env:
env = gym.envs.registration.make(
id,
disable_env_checker=_disable_env_checker,
Expand Down
8 changes: 4 additions & 4 deletions gymnasium/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import time
from copy import deepcopy
from enum import Enum
from typing import List, Optional, Sequence, Tuple, Union
from typing import Callable, List, Optional, Sequence, Tuple, Union

import numpy as np

import gymnasium as gym
from gymnasium import logger
from gymnasium.core import ObsType
from gymnasium.core import Env, ObsType
from gymnasium.error import (
AlreadyPendingCallError,
ClosedEnvironmentError,
Expand Down Expand Up @@ -59,14 +59,14 @@ class AsyncVectorEnv(VectorEnv):

def __init__(
self,
env_fns: Sequence[callable],
env_fns: Sequence[Callable[[], Env]],
observation_space: Optional[gym.Space] = None,
action_space: Optional[gym.Space] = None,
shared_memory: bool = True,
copy: bool = True,
context: Optional[str] = None,
daemon: bool = True,
worker: Optional[callable] = None,
worker: Optional[Callable] = None,
):
"""Vectorized environment that runs multiple environments in parallel.
Expand Down
7 changes: 6 additions & 1 deletion gymnasium/vector/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""Miscellaneous utilities."""
from __future__ import annotations

import contextlib
import os
from collections.abc import Callable

from gymnasium.core import Env


__all__ = ["CloudpickleWrapper", "clear_mpi_env_vars"]
Expand All @@ -9,7 +14,7 @@
class CloudpickleWrapper:
"""Wrapper that uses cloudpickle to pickle and unpickle the result."""

def __init__(self, fn: callable):
def __init__(self, fn: Callable[[], Env]):
"""Cloudpickle wrapper for a function."""
self.fn = fn

Expand Down
4 changes: 2 additions & 2 deletions gymnasium/vector/utils/numpy_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Numpy utility functions: concatenate space samples and create empty array."""
from collections import OrderedDict
from functools import singledispatch
from typing import Iterable, Union
from typing import Callable, Iterable, Union

import numpy as np

Expand Down Expand Up @@ -84,7 +84,7 @@ def _concatenate_custom(space, items, out):

@singledispatch
def create_empty_array(
space: Space, n: int = 1, fn: callable = np.zeros
space: Space, n: int = 1, fn: Callable[..., np.ndarray] = np.zeros
) -> Union[tuple, dict, np.ndarray]:
"""Create an empty (possibly nested) numpy array.
Expand Down
4 changes: 2 additions & 2 deletions tests/spaces/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json # note: ujson fails this test due to float equality
import pickle
import tempfile
from typing import List, Union
from typing import Callable, List, Union

import numpy as np
import pytest
Expand Down Expand Up @@ -283,7 +283,7 @@ def test_space_sample_mask(space: Space, mask, n_trials: int = 100):
elif isinstance(space, MultiDiscrete):
# Due to the multi-axis capability of MultiDiscrete, these functions need to be recursive and that the expected / observed numpy are of non-regular shapes
def _generate_frequency(
_dim: Union[np.ndarray, int], _mask, func: callable
_dim: Union[np.ndarray, int], _mask, func: Callable
) -> List:
if isinstance(_dim, np.ndarray):
return [
Expand Down
7 changes: 4 additions & 3 deletions tests/testing_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import types
from collections.abc import Callable
from typing import Any

import gymnasium as gym
Expand Down Expand Up @@ -45,9 +46,9 @@ def __init__(
self,
action_space: spaces.Space = spaces.Box(0, 1, (1,)),
observation_space: spaces.Space = spaces.Box(0, 1, (1,)),
reset_func: callable = basic_reset_func,
step_func: callable = new_step_func,
render_func: callable = basic_render_func,
reset_func: Callable = basic_reset_func,
step_func: Callable = new_step_func,
render_func: Callable = basic_render_func,
metadata: dict[str, Any] = {"render_modes": []},
render_mode: str | None = None,
spec: EnvSpec = EnvSpec(
Expand Down
8 changes: 4 additions & 4 deletions tests/utils/test_env_checker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Tests that the `env_checker` runs as expects and all errors are possible."""
import re
import warnings
from typing import Tuple, Union
from typing import Callable, Tuple, Union

import numpy as np
import pytest
Expand Down Expand Up @@ -106,7 +106,7 @@ def _reset_default_seed(self: GenericTestEnv, seed="Error", options=None):
],
],
)
def test_check_reset_seed(test, func: callable, message: str):
def test_check_reset_seed(test, func: Callable, message: str):
"""Tests the check reset seed function works as expected."""
if test is UserWarning:
with pytest.warns(
Expand Down Expand Up @@ -175,7 +175,7 @@ def _return_info_not_dict(self, seed=None, options=None):
],
],
)
def test_check_reset_return_type(test, func: callable, message: str):
def test_check_reset_return_type(test, func: Callable, message: str):
"""Tests the check `env.reset()` function has a correct return type."""

with pytest.raises(test, match=f"^{re.escape(message)}$"):
Expand All @@ -194,7 +194,7 @@ def test_check_reset_return_type(test, func: callable, message: str):
],
],
)
def test_check_reset_return_info_deprecation(test, func: callable, message: str):
def test_check_reset_return_info_deprecation(test, func: Callable, message: str):
"""Tests that return_info has been correct deprecated as an argument to `env.reset()`."""

with pytest.warns(test, match=f"^\\x1b\\[33mWARN: {re.escape(message)}\\x1b\\[0m$"):
Expand Down
6 changes: 3 additions & 3 deletions tests/utils/test_passive_env_checker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
import warnings
from typing import Dict, Union
from typing import Callable, Dict, Union

import numpy as np
import pytest
Expand Down Expand Up @@ -297,7 +297,7 @@ def _reset_result(self, seed=None, options=None):
],
],
)
def test_passive_env_reset_checker(test, func: callable, message: str, kwargs: Dict):
def test_passive_env_reset_checker(test, func: Callable, message: str, kwargs: Dict):
"""Tests the passive env reset check"""
if test is UserWarning:
with pytest.warns(
Expand Down Expand Up @@ -376,7 +376,7 @@ def _modified_step(
],
)
def test_passive_env_step_checker(
test: Union[UserWarning, type], func: callable, message: str
test: Union[UserWarning, type], func: Callable, message: str
):
"""Tests the passive env step checker."""
if test is UserWarning:
Expand Down

0 comments on commit b34c25b

Please sign in to comment.