Skip to content

Commit

Permalink
Merge pull request #117 from NREL/ray_v2_3
Browse files Browse the repository at this point in the history
Update to use ray v2.3 and gymnasium v0.26.3
  • Loading branch information
jlaw9 authored Mar 30, 2023
2 parents a5dcd5c + a2608ef commit f817ee3
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
8 changes: 4 additions & 4 deletions examples/benchmarks/RL_workshop.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,7 @@
"outputs": [],
"source": [
"# Here, obs is a list of MoleculeState observations, with the first entry corresponding to the parent node\n",
"obs = env.reset()"
"obs, info = env.reset()"
]
},
{
Expand Down Expand Up @@ -728,7 +728,7 @@
],
"source": [
"done = False\n",
"obs = env.reset()\n",
"obs, info = env.reset()\n",
"np.random.seed(0)\n",
"\n",
"while not done:\n",
Expand Down Expand Up @@ -770,7 +770,7 @@
"outputs": [],
"source": [
"model = policy_model()\n",
"obs = env.reset()"
"obs, info = env.reset()"
]
},
{
Expand Down Expand Up @@ -1045,7 +1045,7 @@
"\n",
"env_preprocessor = get_preprocessor(env.observation_space)(env.observation_space)\n",
"policy = trainer.get_policy()\n",
"obs = env.reset()"
"obs, info = env.reset()"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion rlmolecule/molecule_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Type, Union

import gym
import gymnasium as gym
import nfp
import numpy as np
import ray
Expand Down
2 changes: 1 addition & 1 deletion tests/test_molecule_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def molecule_env(qed_root: MoleculeState):

def test_policy_model(molecule_env, single_layer_model):

observation, reward, terminal, info = molecule_env.step(0)
observation, reward, terminal, truncated, info = molecule_env.step(0)

preprocessor = get_preprocessor(molecule_env.observation_space)
obs = preprocessor(molecule_env.observation_space).transform(observation)
Expand Down
12 changes: 6 additions & 6 deletions tests/test_molecule_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ def test_prune_terminal(builder):
assert repr(env.state.children[-1]) == "C (t)"

# select the terminal state
obs, reward, terminal, info = env.step(len(env.state.children) - 1)
obs, reward, terminal, truncated, info = env.step(len(env.state.children) - 1)
assert terminal
assert np.isclose(reward, 0.3597849378839701)

obs = env.reset()
obs, reward, terminal, info = env.step(len(env.state.children) - 1)
obs, info = env.reset()
obs, reward, terminal, truncated, info = env.step(len(env.state.children) - 1)
assert not terminal
assert np.isclose(reward, 0)

Expand All @@ -75,12 +75,12 @@ def test_prune_terminal_ray(ray_init):
assert repr(env.state.children[-1]) == "C (t)"

# select the terminal state
obs, reward, terminal, info = env.step(len(env.state.children) - 1)
obs, reward, terminal, truncated, info = env.step(len(env.state.children) - 1)
assert terminal
assert np.isclose(reward, 0.3597849378839701)

obs = env.reset()
obs, reward, terminal, info = env.step(len(env.state.children) - 1)
obs, info = env.reset()
obs, reward, terminal, truncated, info = env.step(len(env.state.children) - 1)
assert not terminal
assert np.isclose(reward, 0)

Expand Down

0 comments on commit f817ee3

Please sign in to comment.