This codebase is used for experimenting with inferring reward functions from videos. From a passive state trajectory, we learn a reward function based on the notion that ranking video frames encodes task completion information and should provide a dense guidance signal (i.e., reward) to an agent trying to learn the task. We also incorporate data collected by the agent during training as negatives to better shape the reward landscape over the full state space.
It is built around a fork of DrQ-v2, a model-free off-policy algorithm for image-based continuous control. We utilize environments from metaworld for benchmarking. We also run simple experiments in a 2D pointmass maze environment using an implementation of SAC.
We also include implementations of baselines
- GAIL [Ho & Ermon, 2016]
- VICE [Fu & Singh et al, 2018]
- AIRL [Fu et al, 2018]
- SOIL [Radosavovic et al., 2021]
- Time Contrastive Networks [Sermanet et al, 2017]
- Watch and Match: Supercharging Imitation with Regularized Optimal Transport [Haldar et al, 2022]
Note we assume mujoco 2.1.0 is installed on the computer. Code tested on a computer running Ubuntu 20.04 with nvidia driver version 525.60.13 and cuda toolkit 12.0 installed. Logging is done through weights and biases.
# clone the repo
[email protected]:anonymous/rewardlearning-vid.git
git submodule init
git submodule update
# create a virtual environment (note this repo contains .python-version file specifying the name of a virtual environment to look for)
pyenv virtualenv 3.8.10 rewardlearning-vid-py38
pip install --upgrade pip
# install pytorch (system dependent)
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
# standard pip dependencies
pip install -r requirements.txt
# install submodule dependencies
cd drqv2; pip install -e .
cd r3m; pip install -e .
cd metaworld; pip install -e .
ROT
- git submodule. fork of official implementation adapted to our codebasedrqv2
- git submodule. fork of official implementation. we use the drqv2 agent from here. other parts of the codebase also pull from the replay buffermetaworld
- git submodule. fork of metaworld. some improvements to camera rendering and initialization of environments.policy_learning
- main folder. includes code that wraps metaworld environments with learned reward functions as well as the main training script that sets up drqv2 with our framework.pytorch_sac
- git submodule. fork of this SAC implementation used with 2d pointmass environment defined here as well for quick experimentation in a simple domainr3m
- git submodule. vanilla copy of official implementation. useful for extracting features from images.reward_extraction
- main folder. learned reward function model and training code. some metaworld helper code. some expert data saving code.scripts
- misc scripts. data egress from wandb and plotting results code.tcn
- implementation of time contrastive networks adapted from here
valid env strings: assembly
, drawer-open
, hammer
, door-close
, push
, reach
, button-press-topdown
, door-open
# Rank2Reward
python -m policy_learning.train_v2 --env_str hammer --use_online_lrf --seed 42
# GAIL
python -m policy_learning.train_v2 --env_str hammer --use_online_lrf --train_gail
# AIRL
python -m policy_learning.train_v2 --env_str hammer --use_online_lrf --train_airl
# VICE
python -m policy_learning.train_v2 --env_str hammer --use_online_lrf --train_vice
# SOIL
python -m policy_learning.train_v2 --env_str hammer --train_soil
# TCN
python -m policy_learning.train_v2 --env_str hammer --train_tcn
# ROT
cd ROT/ROT
python train.py task_name=hammer
cd pytorch_sac
python train.py