-
Notifications
You must be signed in to change notification settings - Fork 0
/
environment.py
58 lines (45 loc) · 1.32 KB
/
environment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gym
import cv2
class Environment:
def __init__(self):
pass
def numActions(self):
# Returns number of actions
raise NotImplementedError
def restart(self):
# Restarts environment
raise NotImplementedError
def act(self, action):
# Performs action and returns reward
raise NotImplementedError
def getScreen(self):
# Gets current game screen
raise NotImplementedError
def isTerminal(self):
# Returns if game is done
raise NotImplementedError
class GymEnvironment(Environment):
def __init__(self, env_id, args):
self.gym = gym.make(env_id)
self.obs = None
self.terminal = None
self.display = args.display_screen
self.dims = (args.screen_width, args.screen_height)
def numActions(self):
assert isinstance(self.gym.action_space, gym.spaces.Discrete)
return self.gym.action_space.n
def restart(self):
self.obs = self.gym.reset()
self.terminal = False
def act(self, action):
self.obs, reward, self.terminal, _ = self.gym.step(action)
if self.display:
self.gym.render()
return reward
def getScreen(self):
assert self.obs is not None
screen = cv2.cvtColor(self.obs, cv2.COLOR_BGR2GRAY)
return cv2.resize(screen, self.dims)
def isTerminal(self):
assert self.terminal is not None
return self.terminal