-
Notifications
You must be signed in to change notification settings - Fork 2
/
VecMonitorMulti.py
37 lines (30 loc) · 1.18 KB
/
VecMonitorMulti.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from stable_baselines3.common.vec_env.base_vec_env import VecEnvStepReturn, VecEnvWrapper
import numpy as np
from copy import deepcopy
class VecMonitor(VecEnvWrapper):
def __init__(self, env, task_count):
super().__init__(env)
self.task_count = task_count
self.agent_count = self.venv.get_attr("num_players")[0]
def reset(self):
obs = self.venv.reset()
self.episode_rewards = np.zeros((self.venv.num_envs,self.agent_count,self.task_count))
return obs
def step_async(self, act):
self.venv.step_async(act)
def step_wait(self):
obs, rew, done, info = self.venv.step_wait()
self.episode_rewards += rew
index = 0
result_info = []
for d in done:
new_info = dict()
if d:
new_info["real_rewards"] = deepcopy(np.array(info[index]["real_rewards"]))
new_info["episode_rewards"] = deepcopy(self.episode_rewards[index])
self.episode_rewards[index,:] = 0
result_info.append(new_info)
else:
result_info.append([])
index += 1
return obs, rew, done, result_info