Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Action Chunking with Transformers (ACT) to baselines #640

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
96 changes: 96 additions & 0 deletions examples/baselines/act/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Action Chunking with Transformers (ACT)

Code for running the ACT algorithm based on ["Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware"](https://arxiv.org/pdf/2304.13705). It is adapted from the [original code](https://github.com/tonyzhaozh/act).

## Installation

To get started, we recommend using conda/mamba to create a new environment and install the dependencies

```bash
conda create -n act-ms python=3.9
conda activate act-ms
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the conda env act-ms is created and you do a local pip install. However a simple setup.py file is still missing, can you create that?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created a simple setup.py.

pip install -e .
```

## Demonstration Download and Preprocessing

By default for fast downloads and smaller file sizes, ManiSkill demonstrations are stored in a highly reduced/compressed format which includes not keeping any observation data. Run the command to download the demonstration and convert it to a format that includes observation data and the desired action space.

```bash
python -m mani_skill.utils.download_demo "PickCube-v1"
```

```bash
python -m mani_skill.trajectory.replay_trajectory \
--traj-path ~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.h5 \
--use-first-env-state -c pd_ee_delta_pos -o state \
--save-traj --num-procs 10
```

Set -o to rgbd for RGBD observations. Note that the control mode can heavily influence how well Behavior Cloning performs. In the paper, they reported a degraded performance when using delta joint positions as actions instead of target joint positions. By default, we recommend using `pd_joint_delta_pos` for control mode as all tasks can be solved with that control mode, although it is harder to learn with BC than `pd_ee_delta_pos` or `pd_ee_delta_pose` for robots that have those control modes. Finally, the type of demonstration data used can also impact performance, with typically neural network generated demonstrations being easier to learn from than human/motion planning generated demonstrations.

## Training

We provide scripts to train ACT on demonstrations. Make sure to use the same sim backend as the backend the demonstrations were collected with.


Note that some demonstrations are slow (e.g. motion planning or human teleoperated) and can exceed the default max episode steps which can be an issue as imitation learning algorithms learn to solve the task at the same speed the demonstrations solve it. In this case, you can use the `--max-episode-steps` flag to set a higher value so that the policy can solve the task in time. General recommendation is to set `--max-episode-steps` to about 2x the length of the mean demonstrations length you are using for training. We provide recommended numbers for demonstrations in the examples.sh script.

Example training, learning from 100 demonstrations generated via motionplanning in the PickCube-v1 task
```bash
python train.py --env-id PickCube-v1 \
--demo-path ~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.state.pd_ee_delta_pos.cuda.h5 \
--control-mode "pd_ee_delta_pos" --sim-backend "cpu" --num-demos 100 --max_episode_steps 100 \
--total_iters 30000
```


## Train and Evaluate with GPU Simulation

You can also choose to train on trajectories generated in the GPU simulation and evaluate much faster with the GPU simulation. However as most demonstrations are usually generated in the CPU simulation (via motionplanning or teleoperation), you may observe worse performance when evaluating on the GPU simulation vs the CPU simulation. This can be partially alleviated by using the replay trajectory tool to try and replay trajectories back in the GPU simulation.

It is also recommended to not save videos if you are using a lot of parallel environments as the video size can get very large.

To replay trajectories in the GPU simulation, you can use the following command. Note that this can be a bit slow as the replay trajectory tool is currently not optimized for GPU parallelized environments.

```bash
python -m mani_skill.trajectory.replay_trajectory \
--traj-path ~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.h5 \
--use-first-env-state -c pd_ee_delta_pos -o state \
--save-traj --num-procs 1 -b gpu --count 100 # process only 100 trajectories
```

Once our GPU backend demonstration dataset is ready, you can use the following command to train and evaluate on the GPU simulation.

```bash
python train.py --env-id PickCube-v1 \
--demo-path ~/.maniskill/demos/PickCube-v1/motionplanning/trajectory.state.pd_ee_delta_pos.cuda.h5 \
--control-mode "pd_ee_delta_pos" --sim-backend "gpu" --num-demos 100 --max_episode_steps 100 \
--total_iters 30000 \
--num-eval-envs 100 --no-capture-video
```

## Citation

If you use this baseline please cite the following
```
@inproceedings{DBLP:conf/rss/ZhaoKLF23,
author = {Tony Z. Zhao and
Vikash Kumar and
Sergey Levine and
Chelsea Finn},
editor = {Kostas E. Bekris and
Kris Hauser and
Sylvia L. Herbert and
Jingjin Yu},
title = {Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware},
booktitle = {Robotics: Science and Systems XIX, Daegu, Republic of Korea, July
10-14, 2023},
year = {2023},
url = {https://doi.org/10.15607/RSS.2023.XIX.016},
doi = {10.15607/RSS.2023.XIX.016},
timestamp = {Thu, 20 Jul 2023 15:37:49 +0200},
biburl = {https://dblp.org/rec/conf/rss/ZhaoKLF23.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```
121 changes: 121 additions & 0 deletions examples/baselines/act/act/detr/backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Backbone modules.
"""
from collections import OrderedDict

import torch
import torch.nn.functional as F
import torchvision
from torch import nn
from torchvision.models._utils import IntermediateLayerGetter
from typing import Dict, List

from examples.baselines.act.act.utils import NestedTensor, is_main_process
Copy link
Member

@StoneT2000 StoneT2000 Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imports should be absolute and relative to act (which you pip install -e .)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated my code.

from examples.baselines.act.act.detr.position_encoding import build_position_encoding

import IPython
e = IPython.embed

class FrozenBatchNorm2d(torch.nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.

Copy-paste from torchvision.misc.ops with added eps before rqsrt,
without which any other policy_models than torchvision.policy_models.resnet[18,34,50,101]
produce nans.
"""

def __init__(self, n):
super(FrozenBatchNorm2d, self).__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
num_batches_tracked_key = prefix + 'num_batches_tracked'
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]

super(FrozenBatchNorm2d, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)

def forward(self, x):
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
eps = 1e-5
scale = w * (rv + eps).rsqrt()
bias = b - rm * scale
return x * scale + bias


class BackboneBase(nn.Module):

def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
super().__init__()
# for name, parameter in backbone.named_parameters(): # only train later layers # TODO do we want this?
# if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
# parameter.requires_grad_(False)
if return_interm_layers:
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
else:
return_layers = {'layer4': "0"}
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
self.num_channels = num_channels

def forward(self, tensor):
xs = self.body(tensor)
return xs
# out: Dict[str, NestedTensor] = {}
# for name, x in xs.items():
# m = tensor_list.mask
# assert m is not None
# mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
# out[name] = NestedTensor(x, mask)
# return out


class Backbone(BackboneBase):
"""ResNet backbone with frozen BatchNorm."""
def __init__(self, name: str,
train_backbone: bool,
return_interm_layers: bool,
dilation: bool):
backbone = getattr(torchvision.models, name)(
replace_stride_with_dilation=[False, False, dilation],
pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) # pretrained # TODO do we want frozen batch_norm??
num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
super().__init__(backbone, train_backbone, num_channels, return_interm_layers)


class Joiner(nn.Sequential):
def __init__(self, backbone, position_embedding):
super().__init__(backbone, position_embedding)

def forward(self, tensor_list: NestedTensor):
xs = self[0](tensor_list)
out: List[NestedTensor] = []
pos = []
for name, x in xs.items():
out.append(x)
# position encoding
pos.append(self[1](x).to(x.dtype))

return out, pos


def build_backbone(args):
position_embedding = build_position_encoding(args)
train_backbone = args.lr_backbone > 0
return_interm_layers = args.masks
backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
model = Joiner(backbone, position_embedding)
model.num_channels = backbone.num_channels
return model
139 changes: 139 additions & 0 deletions examples/baselines/act/act/detr/detr_vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR model and criterion classes.
"""
import torch
from torch import nn
from torch.autograd import Variable
from examples.baselines.act.act.detr.transformer import build_transformer, TransformerEncoder, TransformerEncoderLayer

import numpy as np

import IPython
e = IPython.embed


def reparametrize(mu, logvar):
std = logvar.div(2).exp()
eps = Variable(std.data.new(std.size()).normal_())
return mu + std * eps


def get_sinusoid_encoding_table(n_position, d_hid):
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1

return torch.FloatTensor(sinusoid_table).unsqueeze(0)


class DETRVAE(nn.Module):
""" This is the DETR module that performs object detection """
def __init__(self, backbones, transformer, encoder, state_dim, action_dim, num_queries):
super().__init__()
self.num_queries = num_queries
self.transformer = transformer
self.encoder = encoder
hidden_dim = transformer.d_model
self.action_head = nn.Linear(hidden_dim, action_dim)
self.query_embed = nn.Embedding(num_queries, hidden_dim)
if backbones is not None:
self.input_proj = nn.Conv2d(backbones[0].num_channels, hidden_dim, kernel_size=1)
self.backbones = nn.ModuleList(backbones)
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
else:
self.input_proj_robot_state = nn.Linear(state_dim, hidden_dim)
self.backbones = None

# encoder extra parameters
self.latent_dim = 32 # size of latent z
self.cls_embed = nn.Embedding(1, hidden_dim) # extra cls token embedding
self.encoder_state_proj = nn.Linear(state_dim, hidden_dim) # project state to embedding
self.encoder_action_proj = nn.Linear(action_dim, hidden_dim) # project action to embedding
self.latent_proj = nn.Linear(hidden_dim, self.latent_dim*2) # project hidden state to latent std, var
self.register_buffer('pos_table', get_sinusoid_encoding_table(1+1+num_queries, hidden_dim)) # [CLS], state, actions

# decoder extra parameters
self.latent_out_proj = nn.Linear(self.latent_dim, hidden_dim) # project latent sample to embedding
self.additional_pos_embed = nn.Embedding(2, hidden_dim) # learned position embedding for state and proprio

def forward(self, obs, actions=None):
is_training = actions is not None
state = obs['state'] if self.backbones is not None else obs
bs = state.shape[0]

if is_training:
# project CLS token, state sequence, and action sequence to embedding dim
cls_embed = self.cls_embed.weight # (1, hidden_dim)
cls_embed = torch.unsqueeze(cls_embed, axis=0).repeat(bs, 1, 1) # (bs, 1, hidden_dim)
state_embed = self.encoder_state_proj(state) # (bs, hidden_dim)
state_embed = torch.unsqueeze(state_embed, axis=1) # (bs, 1, hidden_dim)
action_embed = self.encoder_action_proj(actions) # (bs, seq, hidden_dim)
# concat them together to form an input to the CVAE encoder
encoder_input = torch.cat([cls_embed, state_embed, action_embed], axis=1) # (bs, seq+2, hidden_dim)
encoder_input = encoder_input.permute(1, 0, 2) # (seq+2, bs, hidden_dim)
# no masking is applied to all parts of the CVAE encoder input
is_pad = torch.full((bs, encoder_input.shape[0]), False).to(state.device) # False: not a padding
# obtain position embedding
pos_embed = self.pos_table.clone().detach()
pos_embed = pos_embed.permute(1, 0, 2) # (seq+2, 1, hidden_dim)
# query CVAE encoder
encoder_output = self.encoder(encoder_input, pos=pos_embed, src_key_padding_mask=is_pad)
encoder_output = encoder_output[0] # take cls output only
latent_info = self.latent_proj(encoder_output)
mu = latent_info[:, :self.latent_dim]
logvar = latent_info[:, self.latent_dim:]
latent_sample = reparametrize(mu, logvar)
latent_input = self.latent_out_proj(latent_sample)
else:
mu = logvar = None
latent_sample = torch.zeros([bs, self.latent_dim], dtype=torch.float32).to(state.device)
latent_input = self.latent_out_proj(latent_sample)

# CVAE decoder
if self.backbones is not None:
vis_data = obs['rgb'] if "rgb" in obs else obs['rgbd']
num_cams = vis_data.shape[1]

# Image observation features and position embeddings
all_cam_features = []
all_cam_pos = []
for cam_id in range(num_cams):
features, pos = self.backbones[0](vis_data[:, cam_id]) # HARDCODED
features = features[0] # take the last layer feature # (batch, hidden_dim, H, W)
pos = pos[0] # (1, hidden_dim, H, W)
all_cam_features.append(self.input_proj(features))
all_cam_pos.append(pos)

# proprioception features (state)
proprio_input = self.input_proj_robot_state(state)
# fold camera dimension into width dimension
src = torch.cat(all_cam_features, axis=3) # (batch, hidden_dim, 4, 8)
pos = torch.cat(all_cam_pos, axis=3) # (batch, hidden_dim, 4, 8)
hs = self.transformer(src, None, self.query_embed.weight, pos, latent_input, proprio_input, self.additional_pos_embed.weight)[0] # (batch, num_queries, hidden_dim)
else:
state = self.input_proj_robot_state(state)
hs = self.transformer(None, None, self.query_embed.weight, None, latent_input, state, self.additional_pos_embed.weight)[0]

a_hat = self.action_head(hs)
return a_hat, [mu, logvar]


def build_encoder(args):
d_model = args.hidden_dim # 256
dropout = args.dropout # 0.1
nhead = args.nheads # 8
dim_feedforward = args.dim_feedforward # 2048
num_encoder_layers = args.enc_layers # 4 # TODO shared with VAE decoder
normalize_before = args.pre_norm # False
activation = "relu"

encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
dropout, activation, normalize_before)
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

return encoder
Loading