Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added RGBD diffusion policy implementation as well as Draw Triangle task #643

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
124 changes: 114 additions & 10 deletions examples/baselines/diffusion_policy/diffusion_policy/make_env.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,73 @@
from collections import deque
from typing import Optional

import gymnasium as gym
import mani_skill.envs
import numpy as np
from gymnasium.spaces import Box
from gymnasium.wrappers.frame_stack import FrameStack as GymFrameStack
from gymnasium.wrappers.frame_stack import LazyFrames
from mani_skill.utils import gym_utils
from mani_skill.utils.wrappers import CPUGymWrapper, FrameStack, RecordEpisode
from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv
from mani_skill.utils.wrappers import RecordEpisode, FrameStack, CPUGymWrapper


def make_eval_envs(env_id, num_envs: int, sim_backend: str, env_kwargs: dict, other_kwargs: dict, video_dir: Optional[str] = None, wrappers: list[gym.Wrapper] = []):
class DictFrameStack(GymFrameStack):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this not just inherit the maniskill frame stack wrapper?

I purposely didn't use the original framestack wrapper from gymnasium since it was not properly GPU parallelized.

def __init__(
self,
env: gym.Env,
num_stack: int,
lz4_compress: bool = False,
):
"""Observation wrapper that stacks the observations in a rolling manner.

Args:
env (Env): The environment to apply the wrapper
num_stack (int): The number of frames to stack
lz4_compress (bool): Use lz4 to compress the frames internally
"""
gym.utils.RecordConstructorArgs.__init__(
self, num_stack=num_stack, lz4_compress=lz4_compress
)
gym.ObservationWrapper.__init__(self, env)

self.num_stack = num_stack
self.lz4_compress = lz4_compress

self.frames = deque(maxlen=num_stack)

new_observation_space = gym.spaces.Dict()
for k, v in self.observation_space.items():
low = np.repeat(v.low[np.newaxis, ...], num_stack, axis=0)
high = np.repeat(v.high[np.newaxis, ...], num_stack, axis=0)
new_observation_space[k] = Box(low=low, high=high, dtype=v.dtype)
self.observation_space = new_observation_space

def observation(self, observation):
"""Converts the wrappers current frames to lazy frames.

Args:
observation: Ignored

Returns:
:class:`LazyFrames` object for the wrapper's frame buffer, :attr:`self.frames`
"""
assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack)
return {
k: LazyFrames([x[k] for x in self.frames], self.lz4_compress)
for k in self.observation_space.keys()
}


def make_eval_envs(
env_id,
num_envs: int,
sim_backend: str,
env_kwargs: dict,
other_kwargs: dict,
video_dir: Optional[str] = None,
wrappers: list[gym.Wrapper] = [],
):
"""Create vectorized environment for evaluation and/or recording videos.
For CPU vectorized environments only the first parallel environment is used to record videos.
For GPU vectorized environments all parallel environments are used to record videos.
Expand All @@ -20,29 +81,72 @@ def make_eval_envs(env_id, num_envs: int, sim_backend: str, env_kwargs: dict, ot
wrappers: the list of wrappers to apply to the environment.
"""
if sim_backend == "cpu":
def cpu_make_env(env_id, seed, video_dir=None, env_kwargs = dict(), other_kwargs = dict()):

def cpu_make_env(
env_id, seed, video_dir=None, env_kwargs=dict(), other_kwargs=dict()
):
def thunk():
env = gym.make(env_id, reconfiguration_freq=1, **env_kwargs)
for wrapper in wrappers:
env = wrapper(env)
env = CPUGymWrapper(env, ignore_terminations=True, record_metrics=True)
if video_dir:
env = RecordEpisode(env, output_dir=video_dir, save_trajectory=False, info_on_video=True, source_type="diffusion_policy", source_desc="diffusion_policy evaluation rollout")
env = gym.wrappers.FrameStack(env, other_kwargs['obs_horizon'])
env = RecordEpisode(
env,
output_dir=video_dir,
save_trajectory=False,
info_on_video=True,
source_type="diffusion_policy",
source_desc="diffusion_policy evaluation rollout",
)
if env_kwargs["obs_mode"] == "state":
env = gym.wrappers.FrameStack(env, other_kwargs["obs_horizon"])
elif env_kwargs["obs_mode"] == "rgbd":
env = DictFrameStack(env, other_kwargs["obs_horizon"])
env.action_space.seed(seed)
env.observation_space.seed(seed)
return env

return thunk
vector_cls = gym.vector.SyncVectorEnv if num_envs == 1 else lambda x : gym.vector.AsyncVectorEnv(x, context="forkserver")
env = vector_cls([cpu_make_env(env_id, seed, video_dir if seed == 0 else None, env_kwargs, other_kwargs) for seed in range(num_envs)])

vector_cls = (
gym.vector.SyncVectorEnv
if num_envs == 1
else lambda x: gym.vector.AsyncVectorEnv(x, context="forkserver")
)
env = vector_cls(
[
cpu_make_env(
env_id,
seed,
video_dir if seed == 0 else None,
env_kwargs,
other_kwargs,
)
for seed in range(num_envs)
]
)
else:
env = gym.make(env_id, num_envs=num_envs, sim_backend=sim_backend, reconfiguration_freq=1, **env_kwargs)
env = gym.make(
env_id,
num_envs=num_envs,
sim_backend=sim_backend,
reconfiguration_freq=1,
**env_kwargs
)
max_episode_steps = gym_utils.find_max_episode_steps_value(env)
for wrapper in wrappers:
env = wrapper(env)
env = FrameStack(env, num_stack=other_kwargs['obs_horizon'])
env = FrameStack(env, num_stack=other_kwargs["obs_horizon"])
if video_dir:
env = RecordEpisode(env, output_dir=video_dir, save_trajectory=False, save_video=True, source_type="diffusion_policy", source_desc="diffusion_policy evaluation rollout", max_steps_per_video=max_episode_steps)
env = RecordEpisode(
env,
output_dir=video_dir,
save_trajectory=False,
save_video=True,
source_type="diffusion_policy",
source_desc="diffusion_policy evaluation rollout",
max_steps_per_video=max_episode_steps,
)
env = ManiSkillVectorEnv(env, ignore_terminations=True, record_metrics=True)
return env
65 changes: 65 additions & 0 deletions examples/baselines/diffusion_policy/diffusion_policy/plain_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch.nn as nn


def make_mlp(in_channels, mlp_channels, act_builder=nn.ReLU, last_act=True):
c_in = in_channels
module_list = []
for idx, c_out in enumerate(mlp_channels):
module_list.append(nn.Linear(c_in, c_out))
if last_act or idx < len(mlp_channels) - 1:
module_list.append(act_builder())
c_in = c_out
return nn.Sequential(*module_list)


class PlainConv(nn.Module):
def __init__(
self,
in_channels=3,
out_dim=256,
pool_feature_map=False,
last_act=True, # True for ConvBody, False for CNN
):
super().__init__()
# assume input image size is 64x64

self.out_dim = out_dim
self.cnn = nn.Sequential(
nn.Conv2d(in_channels, 16, 3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # [32, 32]
nn.Conv2d(16, 32, 3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # [16, 16]
nn.Conv2d(32, 64, 3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # [8, 8]
nn.Conv2d(64, 128, 3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2), # [4, 4]
nn.Conv2d(128, 128, 1, padding=0, bias=True),
nn.ReLU(inplace=True),
)

if pool_feature_map:
self.pool = nn.AdaptiveMaxPool2d((1, 1))
self.fc = make_mlp(128, [out_dim], last_act=last_act)
else:
self.pool = None
self.fc = make_mlp(128 * 4 * 4 * 4, [out_dim], last_act=last_act)

self.reset_parameters()

def reset_parameters(self):
for name, module in self.named_modules():
if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)):
if module.bias is not None:
nn.init.zeros_(module.bias)

def forward(self, image):
x = self.cnn(image)
if self.pool is not None:
x = self.pool(x)
x = x.flatten(1)
x = self.fc(x)
return x
Loading