diff --git a/docs/source/tasks/drawing/index.md b/docs/source/tasks/drawing/index.md index daf03fe02..ad304f349 100644 --- a/docs/source/tasks/drawing/index.md +++ b/docs/source/tasks/drawing/index.md @@ -26,3 +26,26 @@ None + +## DrawTriangle-v1 + +:::{dropdown} Task Card +:icon: note +:color: primary + +**Task Description:** +Instantiates a table with a white canvas on it and a goal triangle with an outline. A robot with a stick is to draw the triangle with a red line. + +**Supported Robots: PandaStick** + +**Randomizations:** +- the goal triangle's position on the xy-plane is randomized +- the goal triangle's z-rotation is randomized in range [0, 2 $\pi$] + +**Success Conditions:** +- the drawn points by the robot are within a euclidean distance of 0.05m with points on the goal triangle +::: + + \ No newline at end of file diff --git a/examples/baselines/diffusion_policy/diffusion_policy/make_env.py b/examples/baselines/diffusion_policy/diffusion_policy/make_env.py index 13b6d0add..f8543185a 100644 --- a/examples/baselines/diffusion_policy/diffusion_policy/make_env.py +++ b/examples/baselines/diffusion_policy/diffusion_policy/make_env.py @@ -1,12 +1,69 @@ +from collections import deque from typing import Optional + import gymnasium as gym import mani_skill.envs +import numpy as np +from gymnasium.spaces import Box +from gymnasium.wrappers.frame_stack import LazyFrames from mani_skill.utils import gym_utils +from mani_skill.utils.wrappers import CPUGymWrapper, FrameStack, RecordEpisode from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv -from mani_skill.utils.wrappers import RecordEpisode, FrameStack, CPUGymWrapper +# from mani_skill.utils.wrappers.frame_stack import LazyFrames + +class DictFrameStack(FrameStack): + def __init__( + self, + env: gym.Env, + num_stack: int, + lz4_compress: bool = False, + ): + """Observation wrapper that stacks the observations in a rolling manner. + + Args: + env (Env): The environment to apply the wrapper + num_stack (int): The number of frames to stack + lz4_compress (bool): Use lz4 to compress the frames internally + """ + # gym.utils.RecordConstructorArgs.__init__( + # self, num_stack=num_stack, lz4_compress=lz4_compress + # ) + # gym.ObservationWrapper.__init__(self, env) + super().__init__(env, num_stack, lz4_compress) + + new_observation_space = gym.spaces.Dict() + for k, v in self.observation_space.items(): + low = np.repeat(v.low[np.newaxis, ...], num_stack, axis=0) + high = np.repeat(v.high[np.newaxis, ...], num_stack, axis=0) + new_observation_space[k] = Box(low=low, high=high, dtype=v.dtype) + self.observation_space = new_observation_space + + def observation(self, observation): + """Converts the wrappers current frames to lazy frames. -def make_eval_envs(env_id, num_envs: int, sim_backend: str, env_kwargs: dict, other_kwargs: dict, video_dir: Optional[str] = None, wrappers: list[gym.Wrapper] = []): + Args: + observation: Ignored + + Returns: + :class:`LazyFrames` object for the wrapper's frame buffer, :attr:`self.frames` + """ + assert len(self.frames) == self.num_stack, (len(self.frames), self.num_stack) + return { + k: LazyFrames([x[k] for x in self.frames], self.lz4_compress) + for k in self.observation_space.keys() + } + + +def make_eval_envs( + env_id, + num_envs: int, + sim_backend: str, + env_kwargs: dict, + other_kwargs: dict, + video_dir: Optional[str] = None, + wrappers: list[gym.Wrapper] = [], +): """Create vectorized environment for evaluation and/or recording videos. For CPU vectorized environments only the first parallel environment is used to record videos. For GPU vectorized environments all parallel environments are used to record videos. @@ -20,29 +77,72 @@ def make_eval_envs(env_id, num_envs: int, sim_backend: str, env_kwargs: dict, ot wrappers: the list of wrappers to apply to the environment. """ if sim_backend == "cpu": - def cpu_make_env(env_id, seed, video_dir=None, env_kwargs = dict(), other_kwargs = dict()): + + def cpu_make_env( + env_id, seed, video_dir=None, env_kwargs=dict(), other_kwargs=dict() + ): def thunk(): env = gym.make(env_id, reconfiguration_freq=1, **env_kwargs) for wrapper in wrappers: env = wrapper(env) env = CPUGymWrapper(env, ignore_terminations=True, record_metrics=True) if video_dir: - env = RecordEpisode(env, output_dir=video_dir, save_trajectory=False, info_on_video=True, source_type="diffusion_policy", source_desc="diffusion_policy evaluation rollout") - env = gym.wrappers.FrameStack(env, other_kwargs['obs_horizon']) + env = RecordEpisode( + env, + output_dir=video_dir, + save_trajectory=False, + info_on_video=True, + source_type="diffusion_policy", + source_desc="diffusion_policy evaluation rollout", + ) + if env_kwargs["obs_mode"] == "state": + env = gym.wrappers.FrameStack(env, other_kwargs["obs_horizon"]) + elif env_kwargs["obs_mode"] == "rgbd": + env = DictFrameStack(env, other_kwargs["obs_horizon"]) env.action_space.seed(seed) env.observation_space.seed(seed) return env return thunk - vector_cls = gym.vector.SyncVectorEnv if num_envs == 1 else lambda x : gym.vector.AsyncVectorEnv(x, context="forkserver") - env = vector_cls([cpu_make_env(env_id, seed, video_dir if seed == 0 else None, env_kwargs, other_kwargs) for seed in range(num_envs)]) + + vector_cls = ( + gym.vector.SyncVectorEnv + if num_envs == 1 + else lambda x: gym.vector.AsyncVectorEnv(x, context="forkserver") + ) + env = vector_cls( + [ + cpu_make_env( + env_id, + seed, + video_dir if seed == 0 else None, + env_kwargs, + other_kwargs, + ) + for seed in range(num_envs) + ] + ) else: - env = gym.make(env_id, num_envs=num_envs, sim_backend=sim_backend, reconfiguration_freq=1, **env_kwargs) + env = gym.make( + env_id, + num_envs=num_envs, + sim_backend=sim_backend, + reconfiguration_freq=1, + **env_kwargs + ) max_episode_steps = gym_utils.find_max_episode_steps_value(env) for wrapper in wrappers: env = wrapper(env) - env = FrameStack(env, num_stack=other_kwargs['obs_horizon']) + env = FrameStack(env, num_stack=other_kwargs["obs_horizon"]) if video_dir: - env = RecordEpisode(env, output_dir=video_dir, save_trajectory=False, save_video=True, source_type="diffusion_policy", source_desc="diffusion_policy evaluation rollout", max_steps_per_video=max_episode_steps) + env = RecordEpisode( + env, + output_dir=video_dir, + save_trajectory=False, + save_video=True, + source_type="diffusion_policy", + source_desc="diffusion_policy evaluation rollout", + max_steps_per_video=max_episode_steps, + ) env = ManiSkillVectorEnv(env, ignore_terminations=True, record_metrics=True) return env diff --git a/examples/baselines/diffusion_policy/diffusion_policy/plain_conv.py b/examples/baselines/diffusion_policy/diffusion_policy/plain_conv.py new file mode 100644 index 000000000..7244469b8 --- /dev/null +++ b/examples/baselines/diffusion_policy/diffusion_policy/plain_conv.py @@ -0,0 +1,65 @@ +import torch.nn as nn + + +def make_mlp(in_channels, mlp_channels, act_builder=nn.ReLU, last_act=True): + c_in = in_channels + module_list = [] + for idx, c_out in enumerate(mlp_channels): + module_list.append(nn.Linear(c_in, c_out)) + if last_act or idx < len(mlp_channels) - 1: + module_list.append(act_builder()) + c_in = c_out + return nn.Sequential(*module_list) + + +class PlainConv(nn.Module): + def __init__( + self, + in_channels=3, + out_dim=256, + pool_feature_map=False, + last_act=True, # True for ConvBody, False for CNN + ): + super().__init__() + # assume input image size is 64x64 + + self.out_dim = out_dim + self.cnn = nn.Sequential( + nn.Conv2d(in_channels, 16, 3, padding=1, bias=True), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), # [32, 32] + nn.Conv2d(16, 32, 3, padding=1, bias=True), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), # [16, 16] + nn.Conv2d(32, 64, 3, padding=1, bias=True), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), # [8, 8] + nn.Conv2d(64, 128, 3, padding=1, bias=True), + nn.ReLU(inplace=True), + nn.MaxPool2d(2, 2), # [4, 4] + nn.Conv2d(128, 128, 1, padding=0, bias=True), + nn.ReLU(inplace=True), + ) + + if pool_feature_map: + self.pool = nn.AdaptiveMaxPool2d((1, 1)) + self.fc = make_mlp(128, [out_dim], last_act=last_act) + else: + self.pool = None + self.fc = make_mlp(128 * 4 * 4 * 4, [out_dim], last_act=last_act) + + self.reset_parameters() + + def reset_parameters(self): + for name, module in self.named_modules(): + if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d)): + if module.bias is not None: + nn.init.zeros_(module.bias) + + def forward(self, image): + x = self.cnn(image) + if self.pool is not None: + x = self.pool(x) + x = x.flatten(1) + x = self.fc(x) + return x diff --git a/examples/baselines/diffusion_policy/diffusion_policy/utils.py b/examples/baselines/diffusion_policy/diffusion_policy/utils.py index af3876fca..e3398ccee 100644 --- a/examples/baselines/diffusion_policy/diffusion_policy/utils.py +++ b/examples/baselines/diffusion_policy/diffusion_policy/utils.py @@ -1,7 +1,11 @@ -from torch.utils.data.sampler import Sampler import numpy as np import torch -from h5py import File, Group, Dataset +import torch.nn as nn +import torch.nn.functional as F +from gymnasium import spaces +from h5py import Dataset, File, Group +from torch.utils.data.sampler import Sampler + class IterationBasedBatchSampler(Sampler): """Wraps a BatchSampler. @@ -45,15 +49,18 @@ def worker_init_fn(worker_id, base_seed=None): # print(worker_id, base_seed) np.random.seed(base_seed + worker_id) + TARGET_KEY_TO_SOURCE_KEY = { - 'states': 'env_states', - 'observations': 'obs', - 'success': 'success', - 'next_observations': 'obs', + "states": "env_states", + "observations": "obs", + "success": "success", + "next_observations": "obs", # 'dones': 'dones', # 'rewards': 'rewards', - 'actions': 'actions', + "actions": "actions", } + + def load_content_from_h5_file(file): if isinstance(file, (File, Group)): return {key: load_content_from_h5_file(file[key]) for key in list(file.keys())} @@ -62,34 +69,40 @@ def load_content_from_h5_file(file): else: raise NotImplementedError(f"Unspported h5 file type: {type(file)}") -def load_hdf5(path, ): - print('Loading HDF5 file', path) - file = File(path, 'r') + +def load_hdf5( + path, +): + print("Loading HDF5 file", path) + file = File(path, "r") ret = load_content_from_h5_file(file) file.close() - print('Loaded') + print("Loaded") return ret + def load_traj_hdf5(path, num_traj=None): - print('Loading HDF5 file', path) - file = File(path, 'r') + print("Loading HDF5 file", path) + file = File(path, "r") keys = list(file.keys()) if num_traj is not None: assert num_traj <= len(keys), f"num_traj: {num_traj} > len(keys): {len(keys)}" - keys = sorted(keys, key=lambda x: int(x.split('_')[-1])) + keys = sorted(keys, key=lambda x: int(x.split("_")[-1])) keys = keys[:num_traj] - ret = { - key: load_content_from_h5_file(file[key]) for key in keys - } + ret = {key: load_content_from_h5_file(file[key]) for key in keys} file.close() - print('Loaded') + print("Loaded") return ret -def load_demo_dataset(path, keys=['observations', 'actions'], num_traj=None, concat=True): + + +def load_demo_dataset( + path, keys=["observations", "actions"], num_traj=None, concat=True +): # assert num_traj is None raw_data = load_traj_hdf5(path, num_traj) # raw_data has keys like: ['traj_0', 'traj_1', ...] # raw_data['traj_0'] has keys like: ['actions', 'dones', 'env_states', 'infos', ...] - _traj = raw_data['traj_0'] + _traj = raw_data["traj_0"] for key in keys: source_key = TARGET_KEY_TO_SOURCE_KEY[key] assert source_key in _traj, f"key: {source_key} not in traj_0: {_traj.keys()}" @@ -98,22 +111,99 @@ def load_demo_dataset(path, keys=['observations', 'actions'], num_traj=None, con # if 'next' in target_key: # raise NotImplementedError('Please carefully deal with the length of trajectory') source_key = TARGET_KEY_TO_SOURCE_KEY[target_key] - dataset[target_key] = [ raw_data[idx][source_key] for idx in raw_data ] + dataset[target_key] = [raw_data[idx][source_key] for idx in raw_data] if isinstance(dataset[target_key][0], np.ndarray) and concat: - if target_key in ['observations', 'states'] and \ - len(dataset[target_key][0]) > len(raw_data['traj_0']['actions']): - dataset[target_key] = np.concatenate([ - t[:-1] for t in dataset[target_key] - ], axis=0) - elif target_key in ['next_observations', 'next_states'] and \ - len(dataset[target_key][0]) > len(raw_data['traj_0']['actions']): - dataset[target_key] = np.concatenate([ - t[1:] for t in dataset[target_key] - ], axis=0) + if target_key in ["observations", "states"] and len( + dataset[target_key][0] + ) > len(raw_data["traj_0"]["actions"]): + dataset[target_key] = np.concatenate( + [t[:-1] for t in dataset[target_key]], axis=0 + ) + elif target_key in ["next_observations", "next_states"] and len( + dataset[target_key][0] + ) > len(raw_data["traj_0"]["actions"]): + dataset[target_key] = np.concatenate( + [t[1:] for t in dataset[target_key]], axis=0 + ) else: dataset[target_key] = np.concatenate(dataset[target_key], axis=0) - print('Load', target_key, dataset[target_key].shape) + print("Load", target_key, dataset[target_key].shape) else: - print('Load', target_key, len(dataset[target_key]), type(dataset[target_key][0])) + print( + "Load", + target_key, + len(dataset[target_key]), + type(dataset[target_key][0]), + ) return dataset + + +def convert_obs(obs, concat_fn, transpose_fn, state_obs_extractor): + img_dict = obs["sensor_data"] + new_img_dict = { + key: transpose_fn( + concat_fn([v[key] for v in img_dict.values()]) + ) # (C, H, W) or (B, C, H, W) + for key in ["rgb", "depth"] + } + # if isinstance(new_img_dict['depth'], torch.Tensor): # MS2 vec env uses float16, but gym AsyncVecEnv uses float32 + # new_img_dict['depth'] = new_img_dict['depth'].to(torch.float16) + + # Unified version + states_to_stack = state_obs_extractor(obs) + for j in range(len(states_to_stack)): + if states_to_stack[j].dtype == np.float64: + states_to_stack[j] = states_to_stack[j].astype(np.float32) + try: + state = np.hstack(states_to_stack) + except: # dirty fix for concat trajectory of states + state = np.column_stack(states_to_stack) + if state.dtype == np.float64: + for x in states_to_stack: + print(x.shape, x.dtype) + import pdb + + pdb.set_trace() + + out_dict = { + "state": state, + "rgb": new_img_dict["rgb"], + "depth": new_img_dict["depth"], + } + return out_dict + + +def build_obs_space(env, depth_dtype, state_obs_extractor): + # NOTE: We have to use float32 for gym AsyncVecEnv since it does not support float16, but we can use float16 for MS2 vec env + obs_space = env.observation_space + + # Unified version + state_dim = sum([v.shape[0] for v in state_obs_extractor(obs_space)]) + + single_img_space = next(iter(env.observation_space["image"].values())) + h, w, _ = single_img_space["rgb"].shape + n_images = len(env.observation_space["image"]) + + return spaces.Dict( + { + "state": spaces.Box( + -float("inf"), float("inf"), shape=(state_dim,), dtype=np.float32 + ), + "rgb": spaces.Box(0, 255, shape=(n_images * 3, h, w), dtype=np.uint8), + "depth": spaces.Box( + -float("inf"), float("inf"), shape=(n_images, h, w), dtype=depth_dtype + ), + } + ) + + +def build_state_obs_extractor(env_id): + env_name = env_id.split("-")[0] + if env_name in ["TurnFaucet", "StackCube"]: + return lambda obs: list(obs["extra"].values()) + elif env_name == "PushChair" or env_name == "PickCube": + return lambda obs: list(obs["agent"].values()) + list(obs["extra"].values()) + else: + raise NotImplementedError(f"Please tune state obs by yourself") + diff --git a/examples/baselines/diffusion_policy/examples.sh b/examples/baselines/diffusion_policy/examples.sh index d59255b05..5739b2851 100644 --- a/examples/baselines/diffusion_policy/examples.sh +++ b/examples/baselines/diffusion_policy/examples.sh @@ -44,4 +44,16 @@ python -m mani_skill.trajectory.replay_trajectory \ python train.py --env-id PegInsertionSide-v1 \ --demo-path ~/.maniskill/demos/PegInsertionSide-v1/motionplanning/trajectory.state.pd_ee_delta_pose.cpu.h5 \ --control-mode "pd_ee_delta_pose" --sim-backend "cpu" --num-demos 100 --max_episode_steps 300 \ - --total_iters 300000 \ No newline at end of file + --total_iters 300000 + +# DrawTriangle-v1 +python -m mani_skill.trajectory.replay_trajectory \ + --traj-path ~/.maniskill/demos/DrawTriangle-v1/motionplanning/trajectory.h5 \ + --use-first-env-state -c pd_ee_delta_pose -o state \ + --save-traj --num-procs 10 -b cpu + +# RGBD based observations +python train_rgbd.py --env-id DrawTriangle-v1 \ + --demo-path ~/.maniskill/demos/DrawTriangle-v1/motionplanning/trajectory.rgbd.pd_ee_delta_pose.cpu.h5 + --control-mode "pd_ee_delta_pose" --sim-backend "cpu" --num-demos 100 --max_episode_steps 300 \ + --total-iters 300000 \ No newline at end of file diff --git a/examples/baselines/diffusion_policy/train_rgbd.py b/examples/baselines/diffusion_policy/train_rgbd.py new file mode 100644 index 000000000..ef6e317f6 --- /dev/null +++ b/examples/baselines/diffusion_policy/train_rgbd.py @@ -0,0 +1,580 @@ +ALGO_NAME = "BC_Diffusion_rgbd_UNet" + +import argparse +import os +import random +import time +from collections import defaultdict +from dataclasses import dataclass, field +from distutils.util import strtobool +from functools import partial +from typing import List, Optional + +import gymnasium as gym +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import tyro +from diffusers.optimization import get_scheduler +from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from diffusers.training_utils import EMAModel +from gymnasium import spaces +from mani_skill.utils import gym_utils +from mani_skill.utils.registration import REGISTERED_ENVS +from mani_skill.utils.wrappers.flatten import FlattenRGBDObservationWrapper +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataset import Dataset +from torch.utils.data.sampler import BatchSampler, RandomSampler +from torch.utils.tensorboard import SummaryWriter + +from diffusion_policy.conditional_unet1d import ConditionalUnet1D +from diffusion_policy.evaluate import evaluate +from diffusion_policy.make_env import make_eval_envs +from diffusion_policy.plain_conv import PlainConv +from diffusion_policy.utils import (IterationBasedBatchSampler, + build_state_obs_extractor, convert_obs, + worker_init_fn) + + +@dataclass +class Args: + exp_name: Optional[str] = None + """the name of this experiment""" + seed: int = 1 + """seed of the experiment""" + torch_deterministic: bool = True + """if toggled, `torch.backends.cudnn.deterministic=False`""" + cuda: bool = True + """if toggled, cuda will be enabled by default""" + track: bool = False + """if toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "ManiSkill" + """the wandb's project name""" + wandb_entity: Optional[str] = None + """the entity (team) of wandb's project""" + capture_video: bool = True + """whether to capture videos of the agent performances (check out `videos` folder)""" + + env_id: str = "PegInsertionSide-v0" + """the id of the environment""" + demo_path: str = ( + "data/ms2_official_demos/rigid_body/PegInsertionSide-v0/trajectory.state.pd_ee_delta_pose.h5" + ) + """the path of demo dataset (pkl or h5)""" + num_demos: Optional[int] = None + """number of trajectories to load from the demo dataset""" + total_iters: int = 1_000_000 + """total timesteps of the experiment""" + batch_size: int = 256 + """the batch size of sample from the replay memory""" + + # Diffusion Policy specific arguments + lr: float = 1e-4 + """the learning rate of the diffusion policy""" + obs_horizon: int = 2 # Seems not very important in ManiSkill, 1, 2, 4 work well + act_horizon: int = 8 # Seems not very important in ManiSkill, 4, 8, 15 work well + pred_horizon: int = ( + 16 # 16->8 leads to worse performance, maybe it is like generate a half image; 16->32, improvement is very marginal + ) + diffusion_step_embed_dim: int = 64 # not very important + unet_dims: List[int] = field( + default_factory=lambda: [64, 128, 256] + ) # default setting is about ~4.5M params + n_groups: int = ( + 8 # jigu says it is better to let each group have at least 8 channels; it seems 4 and 8 are simila + ) + depth: bool = True + """use depth to train""" + + # Environment/experiment specific arguments + max_episode_steps: Optional[int] = None + """Change the environments' max_episode_steps to this value. Sometimes necessary if the demonstrations being imitated are too short. Typically the default + max episode steps of environments in ManiSkill are tuned lower so reinforcement learning agents can learn faster.""" + log_freq: int = 1000 + """the frequency of logging the training metrics""" + eval_freq: int = 5000 + """the frequency of evaluating the agent on the evaluation environments""" + save_freq: Optional[int] = None + """the frequency of saving the model checkpoints. By default this is None and will only save checkpoints based on the best evaluation metrics.""" + num_eval_episodes: int = 100 + """the number of episodes to evaluate the agent on""" + num_eval_envs: int = 10 + """the number of parallel environments to evaluate the agent on""" + sim_backend: str = "cpu" + """the simulation backend to use for evaluation environments. can be "cpu" or "gpu""" + num_dataload_workers: int = 0 + """the number of workers to use for loading the training data in the torch dataloader""" + control_mode: str = "pd_joint_delta_pos" + """the control mode to use for the evaluation environments. Must match the control mode of the demonstration dataset.""" + + # additional tags/configs for logging purposes to wandb and shared comparisons with other algorithms + demo_type: Optional[str] = None + + +def reorder_keys(d, ref_dict): + out = dict() + for k, v in ref_dict.items(): + if isinstance(v, dict) or isinstance(v, spaces.Dict): + out[k] = reorder_keys(d[k], ref_dict[k]) + else: + out[k] = d[k] + return out + + +class SmallDemoDataset_DiffusionPolicy(Dataset): # Load everything into memory + def __init__(self, data_path, obs_process_fn, obs_space, num_traj): + if data_path[-4:] == ".pkl": + raise NotImplementedError() + else: + from diffusion_policy.utils import load_demo_dataset + + trajectories = load_demo_dataset(data_path, num_traj=num_traj, concat=False) + # trajectories['observations'] is a list of dict, each dict is a traj, with keys in obs_space, values with length L+1 + # trajectories['actions'] is a list of np.ndarray (L, act_dim) + + print("Raw trajectory loaded, start to pre-process the observations...") + + # Pre-process the observations, make them align with the obs returned by the obs_wrapper + obs_traj_dict_list = [] + for obs_traj_dict in trajectories["observations"]: + _obs_traj_dict = reorder_keys( + obs_traj_dict, obs_space + ) # key order in demo is different from key order in env obs + _obs_traj_dict = obs_process_fn(_obs_traj_dict) + _obs_traj_dict["depth"] = torch.Tensor( + _obs_traj_dict["depth"].astype(np.float32) / 1024 + ).to(torch.float16) + _obs_traj_dict["rgb"] = torch.from_numpy( + _obs_traj_dict["rgb"] + ) # still uint8 + _obs_traj_dict["state"] = torch.from_numpy(_obs_traj_dict["state"]) + obs_traj_dict_list.append(_obs_traj_dict) + trajectories["observations"] = obs_traj_dict_list + self.obs_keys = list(_obs_traj_dict.keys()) + # Pre-process the actions + for i in range(len(trajectories["actions"])): + trajectories["actions"][i] = torch.Tensor(trajectories["actions"][i]) + print( + "Obs/action pre-processing is done, start to pre-compute the slice indices..." + ) + + # Pre-compute all possible (traj_idx, start, end) tuples, this is very specific to Diffusion Policy + if ( + "delta_pos" in args.control_mode + or args.control_mode == "base_pd_joint_vel_arm_pd_joint_vel" + ): + self.pad_action_arm = torch.zeros( + (trajectories["actions"][0].shape[1] - 1,) + ) + # to make the arm stay still, we pad the action with 0 in 'delta_pos' control mode + # gripper action needs to be copied from the last action + else: + raise NotImplementedError(f"Control Mode {args.control_mode} not supported") + self.obs_horizon, self.pred_horizon = obs_horizon, pred_horizon = ( + args.obs_horizon, + args.pred_horizon, + ) + self.slices = [] + num_traj = len(trajectories["actions"]) + total_transitions = 0 + for traj_idx in range(num_traj): + L = trajectories["actions"][traj_idx].shape[0] + assert trajectories["observations"][traj_idx]["state"].shape[0] == L + 1 + total_transitions += L + + # |o|o| observations: 2 + # | |a|a|a|a|a|a|a|a| actions executed: 8 + # |p|p|p|p|p|p|p|p|p|p|p|p|p|p|p|p| actions predicted: 16 + pad_before = obs_horizon - 1 + # Pad before the trajectory, so the first action of an episode is in "actions executed" + # obs_horizon - 1 is the number of "not used actions" + pad_after = pred_horizon - obs_horizon + # Pad after the trajectory, so all the observations are utilized in training + # Note that in the original code, pad_after = act_horizon - 1, but I think this is not the best choice + self.slices += [ + (traj_idx, start, start + pred_horizon) + for start in range(-pad_before, L - pred_horizon + pad_after) + ] # slice indices follow convention [start, end) + + print( + f"Total transitions: {total_transitions}, Total obs sequences: {len(self.slices)}" + ) + + self.trajectories = trajectories + + def __getitem__(self, index): + traj_idx, start, end = self.slices[index] + L, act_dim = self.trajectories["actions"][traj_idx].shape + + obs_traj = self.trajectories["observations"][traj_idx] + obs_seq = {} + for k, v in obs_traj.items(): + obs_seq[k] = v[ + max(0, start) : start + self.obs_horizon + ] # start+self.obs_horizon is at least 1 + if start < 0: # pad before the trajectory + pad_obs_seq = torch.stack([obs_seq[k][0]] * abs(start), dim=0) + obs_seq[k] = torch.cat((pad_obs_seq, obs_seq[k]), dim=0) + # don't need to pad obs after the trajectory, see the above char drawing + + act_seq = self.trajectories["actions"][traj_idx][max(0, start) : end] + if start < 0: # pad before the trajectory + act_seq = torch.cat([act_seq[0].repeat(-start, 1), act_seq], dim=0) + if end > L: # pad after the trajectory + gripper_action = act_seq[-1, -1] # assume gripper is with pos controller + pad_action = torch.cat((self.pad_action_arm, gripper_action[None]), dim=0) + act_seq = torch.cat([act_seq, pad_action.repeat(end - L, 1)], dim=0) + # making the robot (arm and gripper) stay still + assert ( + obs_seq["state"].shape[0] == self.obs_horizon + and act_seq.shape[0] == self.pred_horizon + ) + return { + "observations": obs_seq, + "actions": act_seq, + } + + def __len__(self): + return len(self.slices) + + +class Agent(nn.Module): + def __init__(self, env, args): + super().__init__() + self.obs_horizon = args.obs_horizon + self.act_horizon = args.act_horizon + self.pred_horizon = args.pred_horizon + assert ( + len(env.single_observation_space["state"].shape) == 2 + ) # (obs_horizon, obs_dim) + assert len(env.single_action_space.shape) == 1 # (act_dim, ) + assert (env.single_action_space.high == 1).all() and ( + env.single_action_space.low == -1 + ).all() + # denoising results will be clipped to [-1,1], so the action should be in [-1,1] as well + self.act_dim = env.single_action_space.shape[0] + obs_state_dim = env.single_observation_space["state"].shape[1] + _, H, W, C = envs.single_observation_space["rgb"].shape + + visual_feature_dim = 256 + in_c = int(C / 3 * 4) if args.depth else C + self.visual_encoder = PlainConv( + in_channels=in_c, out_dim=visual_feature_dim, pool_feature_map=True + ) + self.noise_pred_net = ConditionalUnet1D( + input_dim=self.act_dim, # act_horizon is not used (U-Net doesn't care) + global_cond_dim=self.obs_horizon * (visual_feature_dim + obs_state_dim), + diffusion_step_embed_dim=args.diffusion_step_embed_dim, + down_dims=args.unet_dims, + n_groups=args.n_groups, + ) + self.num_diffusion_iters = 100 + self.noise_scheduler = DDPMScheduler( + num_train_timesteps=self.num_diffusion_iters, + beta_schedule="squaredcos_cap_v2", # has big impact on performance, try not to change + clip_sample=True, # clip output to [-1,1] to improve stability + prediction_type="epsilon", # predict noise (instead of denoised action) + ) + + def encode_obs(self, obs_seq, eval_mode): + rgb = obs_seq["rgb"].float() / 255.0 # (B, obs_horizon, 3*k, H, W) + if args.depth: + depth = obs_seq["depth"].float() # (B, obs_horizon, 1*k, H, W) + img_seq = torch.cat([rgb, depth], dim=2) # (B, obs_horizon, C, H, W), C=4*k + else: + img_seq = rgb + img_seq = img_seq.flatten(end_dim=1) # (B*obs_horizon, C, H, W) + if hasattr(self, "aug") and not eval_mode: + img_seq = self.aug(img_seq) # (B*obs_horizon, C, H, W) + visual_feature = self.visual_encoder(img_seq) # (B*obs_horizon, D) + visual_feature = visual_feature.reshape( + rgb.shape[0], self.obs_horizon, visual_feature.shape[1] + ) # (B, obs_horizon, D) + feature = torch.cat( + (visual_feature, obs_seq["state"]), dim=-1 + ) # (B, obs_horizon, D+obs_state_dim) + return feature.flatten(start_dim=1) # (B, obs_horizon * (D+obs_state_dim)) + + def compute_loss(self, obs_seq, action_seq): + B = obs_seq["state"].shape[0] + + # observation as FiLM conditioning + obs_cond = self.encode_obs( + obs_seq, eval_mode=False + ) # (B, obs_horizon * obs_dim) + + # sample noise to add to actions + noise = torch.randn((B, self.pred_horizon, self.act_dim), device=device) + + # sample a diffusion iteration for each data point + timesteps = torch.randint( + 0, self.noise_scheduler.config.num_train_timesteps, (B,), device=device + ).long() + + # add noise to the clean images(actions) according to the noise magnitude at each diffusion iteration + # (this is the forward diffusion process) + noisy_action_seq = self.noise_scheduler.add_noise(action_seq, noise, timesteps) + + # predict the noise residual + noise_pred = self.noise_pred_net( + noisy_action_seq, timesteps, global_cond=obs_cond + ) + + return F.mse_loss(noise_pred, noise) + + def get_action(self, obs_seq): + # init scheduler + # self.noise_scheduler.set_timesteps(self.num_diffusion_iters) + # set_timesteps will change noise_scheduler.timesteps is only used in noise_scheduler.step() + # noise_scheduler.step() is only called during inference + # if we use DDPM, and inference_diffusion_steps == train_diffusion_steps, then we can skip this + + # obs_seq['state']: (B, obs_horizon, obs_state_dim) + B = obs_seq["state"].shape[0] + with torch.no_grad(): + obs_seq["rgb"] = obs_seq["rgb"].permute(0, 1, 4, 2, 3) + if args.depth: + obs_seq["depth"] = obs_seq["depth"].permute(0, 1, 4, 2, 3) / 1024 + + obs_cond = self.encode_obs( + obs_seq, eval_mode=True + ) # (B, obs_horizon * obs_dim) + + # initialize action from Guassian noise + noisy_action_seq = torch.randn( + (B, self.pred_horizon, self.act_dim), device=obs_seq["state"].device + ) + + for k in self.noise_scheduler.timesteps: + # predict noise + noise_pred = self.noise_pred_net( + sample=noisy_action_seq, + timestep=k, + global_cond=obs_cond, + ) + + # inverse diffusion step (remove noise) + noisy_action_seq = self.noise_scheduler.step( + model_output=noise_pred, + timestep=k, + sample=noisy_action_seq, + ).prev_sample + + # only take act_horizon number of actions + start = self.obs_horizon - 1 + end = start + self.act_horizon + return noisy_action_seq[:, start:end] # (B, act_horizon, act_dim) + + +if __name__ == "__main__": + args = tyro.cli(Args) + + if args.exp_name is None: + args.exp_name = os.path.basename(__file__)[: -len(".py")] + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + else: + run_name = args.exp_name + + if args.demo_path.endswith(".h5"): + import json + + json_file = args.demo_path[:-2] + "json" + with open(json_file, "r") as f: + demo_info = json.load(f) + if "control_mode" in demo_info["env_info"]["env_kwargs"]: + control_mode = demo_info["env_info"]["env_kwargs"]["control_mode"] + elif "control_mode" in demo_info["episodes"][0]: + control_mode = demo_info["episodes"][0]["control_mode"] + else: + raise Exception("Control mode not found in json") + assert ( + control_mode == args.control_mode + ), f"Control mode mismatched. Dataset has control mode {control_mode}, but args has control mode {args.control_mode}" + assert args.obs_horizon + args.act_horizon - 1 <= args.pred_horizon + assert args.obs_horizon >= 1 and args.act_horizon >= 1 and args.pred_horizon >= 1 + + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.backends.cudnn.deterministic = args.torch_deterministic + + device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu") + + env_kwargs = dict( + control_mode=args.control_mode, + reward_mode="sparse", + obs_mode="rgbd", + render_mode="rgb_array", + ) + if args.max_episode_steps is not None: + env_kwargs["max_episode_steps"] = args.max_episode_steps + other_kwargs = dict(obs_horizon=args.obs_horizon) + envs = make_eval_envs( + args.env_id, + args.num_eval_envs, + args.sim_backend, + env_kwargs, + other_kwargs, + video_dir=f"runs/{run_name}/videos" if args.capture_video else None, + wrappers=[partial(FlattenRGBDObservationWrapper, sep_depth=True)], + ) + if args.track: + import wandb + + config = vars(args) + config["eval_env_cfg"] = dict( + **env_kwargs, + num_envs=args.num_eval_envs, + env_id=args.env_id, + env_horizon=gym_utils.find_max_episode_steps_value(envs), + ) + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=config, + name=run_name, + save_code=True, + group="DiffusionPolicy", + tags=["diffusion_policy"], + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" + % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + obs_process_fn = partial( + convert_obs, + concat_fn=partial(np.concatenate, axis=-1), + transpose_fn=partial( + np.transpose, axes=(0, 3, 1, 2) + ), # (B, H, W, C) -> (B, C, H, W) + state_obs_extractor=build_state_obs_extractor(args.env_id), + ) + tmp_env = gym.make(args.env_id, obs_mode="rgbd") + orignal_obs_space = tmp_env.observation_space + tmp_env.close() + dataset = SmallDemoDataset_DiffusionPolicy( + args.demo_path, obs_process_fn, orignal_obs_space, args.num_demos + ) + sampler = RandomSampler(dataset, replacement=False) + batch_sampler = BatchSampler(sampler, batch_size=args.batch_size, drop_last=True) + batch_sampler = IterationBasedBatchSampler(batch_sampler, args.total_iters) + train_dataloader = DataLoader( + dataset, + batch_sampler=batch_sampler, + num_workers=args.num_dataload_workers, + worker_init_fn=lambda worker_id: worker_init_fn(worker_id, base_seed=args.seed), + pin_memory=True, + persistent_workers=(args.num_dataload_workers > 0), + ) + sampler = RandomSampler(dataset, replacement=False) + batch_sampler = BatchSampler(sampler, batch_size=args.batch_size, drop_last=True) + batch_sampler = IterationBasedBatchSampler(batch_sampler, args.total_iters) + train_dataloader = DataLoader( + dataset, + batch_sampler=batch_sampler, + num_workers=args.num_dataload_workers, + worker_init_fn=lambda worker_id: worker_init_fn(worker_id, base_seed=args.seed), + pin_memory=True, + persistent_workers=(args.num_dataload_workers > 0), + ) + + agent = Agent(envs, args).to(device) + + optimizer = optim.AdamW( + params=agent.parameters(), lr=args.lr, betas=(0.95, 0.999), weight_decay=1e-6 + ) + + # Cosine LR schedule with linear warmup + lr_scheduler = get_scheduler( + name="cosine", + optimizer=optimizer, + num_warmup_steps=500, + num_training_steps=args.total_iters, + ) + + # Exponential Moving Average + # accelerates training and improves stability + # holds a copy of the model weights + ema = EMAModel(parameters=agent.parameters(), power=0.75) + ema_agent = Agent(envs, args).to(device) + + # ---------------------------------------------------------------------------- # + # Training begins. + # ---------------------------------------------------------------------------- # + agent.train() + + best_eval_metrics = defaultdict(float) + timings = defaultdict(float) + + for iteration, data_batch in enumerate(train_dataloader): + # # copy data from cpu to gpu + # data_batch = {k: v.cuda(non_blocking=True) for k, v in data_batch.items()} + + # forward and compute loss + obs_batch_dict = data_batch["observations"] + obs_batch_dict = { + k: v.cuda(non_blocking=True) for k, v in obs_batch_dict.items() + } + act_batch = data_batch["actions"].cuda(non_blocking=True) + + # forward and compute loss + total_loss = agent.compute_loss( + obs_seq=obs_batch_dict, # obs_batch_dict['state'] is (B, L, obs_dim) + action_seq=act_batch, # (B, L, act_dim) + ) + + # backward + optimizer.zero_grad() + total_loss.backward() + optimizer.step() + lr_scheduler.step() # step lr scheduler every batch, this is different from standard pytorch behavior + last_tick = time.time() + + ema.step(agent.parameters()) + + if iteration % args.log_freq == 0: + print(f"Iteration {iteration}, loss: {total_loss.item()}") + writer.add_scalar( + "charts/learning_rate", optimizer.param_groups[0]["lr"], iteration + ) + writer.add_scalar("losses/total_loss", total_loss.item(), iteration) + for k, v in timings.items(): + writer.add_scalar(f"time/{k}", v, iteration) + # Evaluation + if iteration % args.eval_freq == 0: + last_tick = time.time() + + ema.copy_to(ema_agent.parameters()) + # def sample_fn(obs): + + eval_metrics = evaluate( + args.num_eval_episodes, ema_agent, envs, device, args.sim_backend + ) + timings["eval"] += time.time() - last_tick + + print(f"Evaluated {len(eval_metrics['success_at_end'])} episodes") + for k in eval_metrics.keys(): + eval_metrics[k] = np.mean(eval_metrics[k]) + writer.add_scalar(f"eval/{k}", eval_metrics[k], iteration) + print(f"{k}: {eval_metrics[k]:.4f}") + + save_on_best_metrics = ["success_once", "success_at_end"] + for k in save_on_best_metrics: + if k in eval_metrics and eval_metrics[k] > best_eval_metrics[k]: + best_eval_metrics[k] = eval_metrics[k] + save_ckpt(run_name, f"best_eval_{k}") + print( + f"New best {k}_rate: {eval_metrics[k]:.4f}. Saving checkpoint." + ) + + # Checkpoint + if args.save_freq is not None and iteration % args.save_freq == 0: + save_ckpt(run_name, str(iteration)) + envs.close() + writer.close() diff --git a/figures/environment_demos/DrawTriangle-v1_rt.mp4 b/figures/environment_demos/DrawTriangle-v1_rt.mp4 new file mode 100644 index 000000000..40d1fa784 Binary files /dev/null and b/figures/environment_demos/DrawTriangle-v1_rt.mp4 differ diff --git a/mani_skill/envs/tasks/drawing/__init__.py b/mani_skill/envs/tasks/drawing/__init__.py index c6df1f738..f90b9a236 100644 --- a/mani_skill/envs/tasks/drawing/__init__.py +++ b/mani_skill/envs/tasks/drawing/__init__.py @@ -1 +1,2 @@ from .draw import * +from .draw_triangle import * \ No newline at end of file diff --git a/mani_skill/envs/tasks/drawing/draw_triangle.py b/mani_skill/envs/tasks/drawing/draw_triangle.py new file mode 100644 index 000000000..9bd1c10f0 --- /dev/null +++ b/mani_skill/envs/tasks/drawing/draw_triangle.py @@ -0,0 +1,335 @@ +import math +import random +from typing import Dict + +import mani_skill.envs.utils.randomization as randomization +import numpy as np +import sapien +import torch +from mani_skill.agents.robots.panda.panda_stick import PandaStick +from mani_skill.envs.sapien_env import BaseEnv +from mani_skill.sensors.camera import CameraConfig +from mani_skill.utils import sapien_utils +from mani_skill.utils.geometry.rotation_conversions import quaternion_to_matrix +from mani_skill.utils.registration import register_env +from mani_skill.utils.scene_builder.table.scene_builder import \ + TableSceneBuilder +from mani_skill.utils.structs.actor import Actor +from mani_skill.utils.structs.pose import Pose +from mani_skill.utils.structs.types import SceneConfig, SimConfig +from transforms3d.euler import euler2quat, quat2euler +from transforms3d.quaternions import quat2mat + + +@register_env("DrawTriangle-v1", max_episode_steps=300) +class DrawTriangle(BaseEnv): + MAX_DOTS = 300 + """ + The total "ink" available to use and draw with before you need to call env.reset. NOTE that on GPU simulation it is not recommended to have a very high value for this as it can slow down rendering + when too many objects are being rendered in many scenes. + """ + DOT_THICKNESS = 0.003 + """thickness of the paint drawn on to the canvas""" + CANVAS_THICKNESS = 0.02 + """How thick the canvas on the table is""" + BRUSH_RADIUS = 0.01 + """The brushes radius""" + BRUSH_COLORS = [[0.8, 0.2, 0.2, 1]] + """The colors of the brushes. If there is more than one color, each parallel environment will have a randomly sampled color.""" + THRESHOLD = 0.05 + + SUPPORTED_REWARD_MODES = ["sparse"] + + SUPPORTED_ROBOTS: ["panda_stick"] # type: ignore + agent: PandaStick + + def __init__(self, *args, robot_uids="panda_stick", **kwargs): + super().__init__(*args, robot_uids=robot_uids, **kwargs) + + @property + def _default_sim_config(self): + # we set contact_offset to a small value as we are not expecting to make any contacts really apart from the brush hitting the canvas too hard. + # We set solver iterations very low as this environment is not doing a ton of manipulation (the brush is attached to the robot after all) + return SimConfig( + sim_freq=100, + control_freq=20, + scene_config=SceneConfig( + contact_offset=0.01, + solver_position_iterations=4, + solver_velocity_iterations=0, + ), + ) + + @property + def _default_sensor_configs(self): + pose = sapien_utils.look_at(eye=[0.3, 0, 0.8], target=[0, 0, 0.1]) + return [ + CameraConfig( + "base_camera", + pose=pose, + width=320, + height=240, + fov=1.2, + near=0.01, + far=100, + ) + ] + + @property + def _default_human_render_camera_configs(self): + pose = sapien_utils.look_at(eye=[0.3, 0, 0.8], target=[0, 0, 0.1]) + return CameraConfig( + "render_camera", + pose=pose, + width=1280, + height=960, + fov=1.2, + near=0.01, + far=100, + ) + + def _load_scene(self, options: dict): + + self.table_scene = TableSceneBuilder(self, robot_init_qpos_noise=0) + self.table_scene.build() + + def create_goal_triangle(name="tri", base_color=None): + + box1_half_w = 0.3 / 2 + box1_half_h = 0.01 / 2 + half_thickness = 0.001 / 2 + + radius = (box1_half_w) / math.sqrt(3) + + theta = np.pi / 2 + + # define centers and compute verticies, might need to adjust how centers are calculated or add a theta arg for variation + c1 = np.array([radius * math.cos(theta), radius * math.sin(theta), 0.01]) + c2 = np.array( + [ + radius * math.cos(theta + (2 * np.pi / 3)), + radius * math.sin(theta + (2 * np.pi / 3)), + 0.01, + ] + ) + c3 = np.array( + [ + radius * math.cos((theta + (4 * np.pi / 3))), + radius * math.sin(theta + (4 * np.pi / 3)), + 0.01, + ] + ) + self.original_verts = np.array( + [(c1 + c3) - c2, (c1 + c2) - c3, (c2 + c3) - c1] + ) + + builder = self.scene.create_actor_builder() + first_block_pose = sapien.Pose( + list(c1), euler2quat(0, 0, theta - (np.pi / 2)) + ) + first_block_size = [box1_half_w, box1_half_h, half_thickness] + builder.add_box_visual( + pose=first_block_pose, + half_size=first_block_size, + material=sapien.render.RenderMaterial( + base_color=base_color, + ), + ) + + second_block_pose = sapien.Pose( + list(c2), euler2quat(0, 0, theta - (5 * np.pi / 6)) + ) + second_block_size = [box1_half_w, box1_half_h, half_thickness] + # builder.add_box_collision(pose=second_block_pose, half_size=second_block_size) + builder.add_box_visual( + pose=second_block_pose, + half_size=second_block_size, + material=sapien.render.RenderMaterial( + base_color=base_color, + ), + ) + + third_block_pose = sapien.Pose( + list(c3), euler2quat(0, 0, theta - (np.pi / 6)) + ) + third_block_size = [box1_half_w, box1_half_h, half_thickness] + # builder.add_box_collision(pose=second_block_pose, half_size=second_block_size) + builder.add_box_visual( + pose=third_block_pose, + half_size=third_block_size, + material=sapien.render.RenderMaterial( + base_color=base_color, + ), + ) + return builder.build_kinematic(name=name) + + # build a white canvas on the table + self.canvas = self.scene.create_actor_builder() + self.canvas.add_box_visual( + half_size=[0.4, 0.6, self.CANVAS_THICKNESS / 2], + material=sapien.render.RenderMaterial(base_color=[1, 1, 1, 1]), + ) + self.canvas.add_box_collision( + half_size=[0.4, 0.6, self.CANVAS_THICKNESS / 2], + ) + self.canvas.initial_pose = sapien.Pose(p=[-0.1, 0, self.CANVAS_THICKNESS / 2]) + self.canvas = self.canvas.build_static(name="canvas") + + self.dots = [] + self.dot_pos = None + color_choices = torch.randint(0, len(self.BRUSH_COLORS), (self.num_envs,)) + for i in range(self.MAX_DOTS): + actors = [] + if len(self.BRUSH_COLORS) > 1: + for env_idx in range(self.num_envs): + builder = self.scene.create_actor_builder() + builder.add_cylinder_visual( + radius=self.BRUSH_RADIUS, + half_length=self.DOT_THICKNESS / 2, + material=sapien.render.RenderMaterial( + base_color=self.BRUSH_COLORS[color_choices[env_idx]] + ), + ) + builder.set_scene_idxs([env_idx]) + actor = builder.build_kinematic(name=f"dot_{i}_{env_idx}") + actors.append(actor) + self.dots.append(Actor.merge(actors)) + else: + builder = self.scene.create_actor_builder() + builder.add_cylinder_visual( + radius=self.BRUSH_RADIUS, + half_length=self.DOT_THICKNESS / 2, + material=sapien.render.RenderMaterial( + base_color=self.BRUSH_COLORS[0] + ), + ) + actor = builder.build_kinematic(name=f"dot_{i}") + self.dots.append(actor) + self.goal_tri = create_goal_triangle( + name="goal_tri", + base_color=np.array([10, 10, 10,255]) / 255, + ) + + def _initialize_episode(self, env_idx: torch.Tensor, options: dict): + self.draw_step = 0 + with torch.device(self.device): + b = len(env_idx) + self.table_scene.initialize(env_idx) + target_pos = torch.zeros((b, 3)) + + target_pos[:, :2] = torch.rand((b, 2)) * 0.02 - 0.1 + target_pos[:, -1] = 0.01 + qs = randomization.random_quaternions(b, lock_x=True, lock_y=True) + mats = quaternion_to_matrix(qs).to(self.device) + self.goal_tri.set_pose(Pose.create_from_pq(p=target_pos, q=qs)) + + self.vertices = torch.from_numpy( + np.tile(self.original_verts, (b, 1, 1)) + ).to( + self.device + ) # b, 3, 3 + self.vertices = ( + mats.double() @ self.vertices.transpose(-1, -2).double() + ).transpose( + -1, -2 + ) # apply rotation matrix + self.vertices += target_pos.unsqueeze(1) + + self.triangles = self.generate_triangle_with_points( + 100, self.vertices[:, :, :-1] + ) + + for dot in self.dots: + # initially spawn dots in the table so they aren't seen + dot.set_pose( + sapien.Pose( + p=[0, 0, -self.DOT_THICKNESS], q=euler2quat(0, np.pi / 2, 0) + ) + ) + + def _after_control_step(self): + if self.gpu_sim_enabled: + self.scene._gpu_fetch_all() + + # This is the actual, GPU parallelized, drawing code. + # This is not real drawing but seeks to mimic drawing by placing dots on the canvas whenever the robot is close enough to the canvas surface + # We do not actually check if the robot contacts the table (although that is possible) and instead use a fast method to check. + # We add a 0.005 meter of leeway to make it easier for the robot to get close to the canvas and start drawing instead of having to be super close to the table. + robot_touching_table = ( + self.agent.tcp.pose.p[:, 2] + < self.CANVAS_THICKNESS + self.DOT_THICKNESS + 0.005 + ) + robot_brush_pos = torch.zeros((self.num_envs, 3), device=self.device) + robot_brush_pos[:, 2] = -self.DOT_THICKNESS + robot_brush_pos[robot_touching_table, :2] = self.agent.tcp.pose.p[ + robot_touching_table, :2 + ] + robot_brush_pos[robot_touching_table, 2] = ( + self.DOT_THICKNESS / 2 + self.CANVAS_THICKNESS + ) + # move the next unused dot to the robot's brush position. All unused dots are initialized inside the table so they aren't visible + new_dot_pos = Pose.create_from_pq(robot_brush_pos, euler2quat(0, np.pi / 2, 0)) + self.dots[self.draw_step].set_pose(new_dot_pos) + if new_dot_pos.get_p()[:, -1] > 0: + if self.dot_pos == None: + self.dot_pos = new_dot_pos.get_p()[:, None, :] + self.dot_pos = torch.cat( + (self.dot_pos, new_dot_pos.get_p()[:, None, :]), dim=1 + ) + + self.draw_step += 1 + + # on GPU sim we have to call _gpu_apply_all() to apply the changes we make to object poses. + if self.gpu_sim_enabled: + self.scene._gpu_apply_all() + + def evaluate(self): + out = self.success_check() + return {"success": out} + + def _get_obs_extra(self, info: Dict): + obs = dict( + tcp_pose=self.agent.tcp.pose.raw_pose, + ) + + if "state" in self.obs_mode: + obs.update( + goal_pose = self.goal_tri.pose.raw_pose, + tcp_to_verts_pos = self.vertices - self.agent.tcp.pose.p.unsqueeze(1), + goal_pos=self.goal_tri.pose.p, + vertices = self.vertices + ) + + return obs + + def generate_triangle_with_points(self, n, vertices): + batch_size = vertices.shape[0] + + all_points = [] + + for i in range(vertices.shape[1]): + start_vertex = vertices[:, i, :] + end_vertex = vertices[:, (i + 1) % vertices.shape[1], :] + t = torch.linspace(0, 1, n + 2, device=vertices.device)[:-1] + t = t.view(1, -1, 1).repeat(batch_size, 1, 2) + intermediate_points = ( + start_vertex.unsqueeze(1) * (1 - t) + end_vertex.unsqueeze(1) * t + ) + all_points.append(intermediate_points) + all_points = torch.cat(all_points, dim=1) + + return all_points + + def success_check(self): + if self.dot_pos == None or len(self.dot_pos) == 0: + return torch.Tensor([False]).to(bool) + drawn_pts = self.dot_pos[:, :, :-1] + + distance_matrix = torch.sqrt( + torch.sum( + (drawn_pts[:, :, None, :] - self.triangles[:, None, :, :]) ** 2, axis=-1 + ) + ) + + Y_closeness = torch.min(distance_matrix, dim=1).values < self.THRESHOLD + return torch.Tensor([torch.all(Y_closeness)]).to(bool) diff --git a/mani_skill/examples/.gitignore b/mani_skill/examples/.gitignore index f3e6bd028..d29c740d4 100644 --- a/mani_skill/examples/.gitignore +++ b/mani_skill/examples/.gitignore @@ -1 +1,2 @@ -videos/ \ No newline at end of file +videos/ +demos/ \ No newline at end of file diff --git a/mani_skill/examples/motionplanning/panda_stick/motionplanner.py b/mani_skill/examples/motionplanning/panda_stick/motionplanner.py index 097ad1748..5a423884b 100644 --- a/mani_skill/examples/motionplanning/panda_stick/motionplanner.py +++ b/mani_skill/examples/motionplanning/panda_stick/motionplanner.py @@ -2,12 +2,12 @@ import numpy as np import sapien import trimesh - from mani_skill.agents.base_agent import BaseAgent from mani_skill.envs.sapien_env import BaseEnv from mani_skill.envs.scene import ManiSkillScene from mani_skill.utils.structs.pose import to_sapien_pose + class PandaStickMotionPlanningSolver: def __init__( self, @@ -113,7 +113,7 @@ def move_to_pose_with_screw( ): pose = to_sapien_pose(pose) # try screw two times before giving up - pose = sapien.Pose(p=pose.p , q=pose.q) + pose = sapien.Pose(p=pose.p, q=pose.q) result = self.planner.plan_screw( np.concatenate([pose.p, pose.q]), self.robot.get_qpos().cpu().numpy()[0], diff --git a/mani_skill/examples/motionplanning/panda_stick/run.py b/mani_skill/examples/motionplanning/panda_stick/run.py new file mode 100644 index 000000000..a0373cbfc --- /dev/null +++ b/mani_skill/examples/motionplanning/panda_stick/run.py @@ -0,0 +1,154 @@ +import argparse +import os.path as osp + +import gymnasium as gym +import numpy as np +from mani_skill.examples.motionplanning.panda_stick.solutions import \ + solveDrawTriangle +from mani_skill.utils.wrappers.record import RecordEpisode +from tqdm import tqdm + +MP_SOLUTIONS = {"DrawTriangle-v1": solveDrawTriangle} + + +def parse_args(args=None): + parser = argparse.ArgumentParser() + parser.add_argument( + "-e", + "--env-id", + type=str, + default="DrawTriangle-v1", + help=f"Environment to run motion planning solver on. Available options are {list(MP_SOLUTIONS.keys())}", + ) + parser.add_argument( + "-o", + "--obs-mode", + type=str, + default="none", + help="Observation mode to use. Usually this is kept as 'none' as observations are not necesary to be stored, they can be replayed later via the mani_skill.trajectory.replay_trajectory script.", + ) + parser.add_argument( + "-n", + "--num-traj", + type=int, + default=10, + help="Number of trajectories to generate.", + ) + parser.add_argument( + "--only-count-success", + action="store_true", + help="If true, generates trajectories until num_traj of them are successful and only saves the successful trajectories/videos", + ) + parser.add_argument("--reward-mode", type=str) + parser.add_argument( + "-b", + "--sim-backend", + type=str, + default="auto", + help="Which simulation backend to use. Can be 'auto', 'cpu', 'gpu'", + ) + parser.add_argument( + "--render-mode", + type=str, + default="rgb_array", + help="can be 'sensors' or 'rgb_array' which only affect what is saved to videos", + ) + parser.add_argument( + "--vis", + action="store_true", + help="whether or not to open a GUI to visualize the solution live", + ) + parser.add_argument( + "--save-video", + action="store_true", + help="whether or not to save videos locally", + ) + parser.add_argument( + "--traj-name", + type=str, + help="The name of the trajectory .h5 file that will be created.", + ) + parser.add_argument( + "--shader", + default="default", + type=str, + help="Change shader used for rendering. Default is 'default' which is very fast. Can also be 'rt' for ray tracing and generating photo-realistic renders. Can also be 'rt-fast' for a faster but lower quality ray-traced renderer", + ) + parser.add_argument( + "--record-dir", + type=str, + default="demos", + help="where to save the recorded trajectories", + ) + return parser.parse_args() + + +def main(args): + env_id = args.env_id + env = gym.make( + env_id, + obs_mode=args.obs_mode, + control_mode="pd_joint_pos", + render_mode="rgb_array", + ) + if env_id not in MP_SOLUTIONS: + raise RuntimeError( + f"No already written motion planning solutions for {env_id}. Available options are {list(MP_SOLUTIONS.keys())}" + ) + env = RecordEpisode( + env, + output_dir=osp.join(args.record_dir, env_id, "motionplanning"), + trajectory_name=args.traj_name, + save_video=args.save_video, + source_type="motionplanning", + source_desc="official motion planning solution from ManiSkill contributors", + video_fps=30, + save_on_reset=False, + ) + solve = MP_SOLUTIONS[env_id] + print(f"Motion Planning Running on {env_id}") + pbar = tqdm(range(args.num_traj)) + seed = 0 + successes = [] + solution_episode_lengths = [] + failed_motion_plans = 0 + passed = 0 + while True: + res = solve(env, seed=seed, debug=False, vis=True if args.vis else False) + if res == -1: + success = False + failed_motion_plans += 1 + else: + success = res[-1]["success"].item() + elapsed_steps = res[-1]["elapsed_steps"].item() + solution_episode_lengths.append(elapsed_steps) + successes.append(success) + if args.only_count_success and not success: + seed += 1 + env.flush_trajectory(save=False) + if args.save_video: + env.flush_video(save=False) + continue + else: + env.flush_trajectory() + if args.save_video: + env.flush_video() + pbar.update(1) + pbar.set_postfix( + dict( + success_rate=np.mean(successes), + failed_motion_plan_rate=failed_motion_plans / (seed + 1), + avg_episode_length=np.mean(solution_episode_lengths), + max_episode_length=np.max(solution_episode_lengths), + # min_episode_length=np.min(solution_episode_lengths) + ) + ) + seed += 1 + passed += 1 + if passed == args.num_traj: + break + env.close() + + +if __name__ == "__main__": + main(parse_args()) diff --git a/mani_skill/examples/motionplanning/panda_stick/solutions/__init__.py b/mani_skill/examples/motionplanning/panda_stick/solutions/__init__.py new file mode 100644 index 000000000..8ad0262c7 --- /dev/null +++ b/mani_skill/examples/motionplanning/panda_stick/solutions/__init__.py @@ -0,0 +1 @@ +from .draw_triangle import solve as solveDrawTriangle diff --git a/mani_skill/examples/motionplanning/panda_stick/solutions/draw_triangle.py b/mani_skill/examples/motionplanning/panda_stick/solutions/draw_triangle.py new file mode 100644 index 000000000..18cbdc594 --- /dev/null +++ b/mani_skill/examples/motionplanning/panda_stick/solutions/draw_triangle.py @@ -0,0 +1,53 @@ +import sapien +from mani_skill.envs.tasks import PushCubeEnv +from mani_skill.examples.motionplanning.panda_stick.motionplanner import \ + PandaStickMotionPlanningSolver + + +def solve(env: PushCubeEnv, seed=None, debug=False, vis=False): + env.reset(seed=seed) + planner = PandaStickMotionPlanningSolver( + env, + debug=debug, + vis=vis, + base_pose=env.unwrapped.agent.robot.pose, + visualize_target_grasp_pose=vis, + print_env_info=False, + joint_vel_limits = 0.3 + ) + + FINGER_LENGTH = 0.025 + env = env.unwrapped + + rot = list(env.agent.tcp.pose.get_q()[0].cpu().numpy()) + + # -------------------------------------------------------------------------- # + # Move to first vertex + # -------------------------------------------------------------------------- # + + reach_pose = sapien.Pose(p=list(env.vertices[0, 0].numpy()), q=rot) + res = planner.move_to_pose_with_screw(reach_pose) + + # -------------------------------------------------------------------------- # + # Move to second vertex + # -------------------------------------------------------------------------- # + + reach_pose = sapien.Pose(p=list(env.vertices[0, 1]), q=rot) + res = planner.move_to_pose_with_screw(reach_pose) + + # -------------------------------------------------------------------------- # + # Move to third vertex + # -------------------------------------------------------------------------- # + + reach_pose = sapien.Pose(p=list(env.vertices[0, 2]), q=rot) + res = planner.move_to_pose_with_screw(reach_pose) + + # -------------------------------------------------------------------------- # + # Move back to first vertex + # -------------------------------------------------------------------------- # + + reach_pose = sapien.Pose(p=list(env.vertices[0, 0]), q=rot) + res = planner.move_to_pose_with_screw(reach_pose) + + planner.close() + return res diff --git a/mani_skill/utils/wrappers/flatten.py b/mani_skill/utils/wrappers/flatten.py index 64be9d0c4..989b95125 100644 --- a/mani_skill/utils/wrappers/flatten.py +++ b/mani_skill/utils/wrappers/flatten.py @@ -23,11 +23,12 @@ class FlattenRGBDObservationWrapper(gym.ObservationWrapper): Note that the returned observations will have a "rgbd" or "rgb" or "depth" key depending on the rgb/depth bool flags. """ - def __init__(self, env, rgb=True, depth=True, state=True) -> None: + def __init__(self, env, rgb=True, depth=True, state=True, sep_depth=False) -> None: self.base_env: BaseEnv = env.unwrapped super().__init__(env) self.include_rgb = rgb self.include_depth = depth + self.sep_depth = sep_depth self.include_state = state new_obs = self.observation(self.base_env._init_raw_obs) self.base_env.update_obs_space(new_obs) @@ -50,7 +51,11 @@ def observation(self, observation: Dict): if self.include_rgb and not self.include_depth: ret["rgb"] = images elif self.include_rgb and self.include_depth: - ret["rgbd"] = images + if self.sep_depth: + ret["rgb"] = images[...,:-1] + ret["depth"] = images[...,-1:] + else: + ret["rgbd"] = images elif self.include_depth and not self.include_rgb: ret["depth"] = images return ret diff --git a/mani_skill/utils/wrappers/frame_stack.py b/mani_skill/utils/wrappers/frame_stack.py index 3a3e5c4d2..fb0aaf9ce 100644 --- a/mani_skill/utils/wrappers/frame_stack.py +++ b/mani_skill/utils/wrappers/frame_stack.py @@ -76,6 +76,14 @@ def __getitem__(self, int_or_slice: Union[int, slice]): [self._check_decompress(f) for f in self._frames[int_or_slice]], axis=0 ) + def _check_decompress(self, frame): + if self.lz4_compress: + from lz4.block import decompress + return torch.frombuffer(decompress(frame), dtype=self.dtype).reshape( + self.frame_shape + ) + return frame + def __eq__(self, other): """Checks that the current frames are equal to the other object.""" return self.__array__() == other @@ -98,6 +106,7 @@ def __init__( self, env: gym.Env, num_stack: int, + lz4_compress: bool = False ): """Observation wrapper that stacks the observations in a rolling manner. @@ -110,6 +119,7 @@ def __init__( self.num_stack = num_stack self.frames = deque(maxlen=num_stack) + self.lz4_compress = lz4_compress [self.frames.append(self.base_env._init_raw_obs) for _ in range(self.num_stack)] new_obs = self.observation(self.base_env._init_raw_obs) self.base_env.update_obs_space(new_obs)