This is the implementation of the ICLR 4064 submission "COPA: Certifying Robust Policies for Offline Reinforcement Learning against Poisoning Attacks". The code is adapted on the basis of the offline RL training repo https://github.com/google-research/batch_rl.
Basically, we provide two certification (per-state action certification and reward certification) for three aggregation protocols (PARL, TPARL, DPARL). Below we present the example commands for running these certifications.
- Generate the trajectory indices for each hash num in
$[50]$ :
python split.py --train-data-folder /data/common/kahlua/dqn_replay/$1/$2/replay_logs \
--output-folder /data/common/kahlua/dqn_replay/hash_split/$1_$2
With the above command in split_script.sh
, simply run the following commands, e.g., for the game Pong.
bash split_script.sh Pong 1
bash split_script.sh Pong 2
bash split_script.sh Pong 3
bash split_script.sh Pong 4
bash split_script.sh Pong 5
- For each hash number, generate the corresponding datasets
python gen_split.py --train-data-folder /data/common/kahlua/dqn_replay/$3/$1/replay_logs \
--epi-index-path /data/common/kahlua/dqn_replay/hash_split/$3_$1/partition_$2.pt \
--output-folder /data/common/kahlua/dqn_replay/hash_split/$3_$1/dataset/hash_$2 \
--start-id 0 --end-id 50
With the above command in gen_split_script.sh
, simply run the following commands, e.g., for the game Pong.
bash gen_split_script.sh 1 0 Pong
bash gen_split_script.sh 2 0 Pong
bash gen_split_script.sh 3 0 Pong
bash gen_split_script.sh 4 0 Pong
bash gen_split_script.sh 5 0 Pong
The above commands generate the
- For each hash number, merge the
$5$ Datasets
python merge_splits.py --input-folder /data/common/kahlua/dqn_replay/hash_split/Pong --hash-num 0
The above command merges the
The following command trains the model based on the datasets Pong of hash number
CUDA_VISIBLE_DEVICES=2 python -um batch_rl.fixed_replay.train \
--base_dir=/data/common/kahlua/COPA/Pong/hash_1 \
--replay_dir=/data/common/kahlua/dqn_replay/hash_split/Pong/hash_1/ \
--gin_files='batch_rl/fixed_replay/configs/dqn.gin' \
--gin_bindings='atari_lib.create_atari_environment.game_name = "Pong"' \
--gin_bindings='FixedReplayRunner.num_iterations=100'
- PARL
python -um batch_rl.fixed_replay.test \
--base_dir [base_dir] --model_dir [model_dir] \
--cert_alg tight \
--total_num 50 --max_steps_per_episode 1000 \
--agent_name dqn \ --gin_files='copa/fixed_replay/configs/dqn.gin' \ --gin_bindings='atari_lib.create_atari_environment.game_name = "Freeway"'
where base_dir
is the path for storing experimental logs and results, and model_dir
is the path of trained
- TPARL
python -um batch_rl.fixed_replay.test \
--base_dir [base_dir] --model_dir [model_dir] \
--cert_alg window --window_size 4 \
--total_num 50 --max_steps_per_episode 1000 \
--agent_name dqn \ --gin_files='copa/fixed_replay/configs/dqn.gin' \ --gin_bindings='atari_lib.create_atari_environment.game_name = "Freeway"'
For TPARL, we explicitly pass the cert_alg
option as window
and configure the predetermined window size
- DPARL
python -um batch_rl.fixed_replay.test \
--base_dir [base_dir] --model_dir [model_dir] \
--cert_alg dynamic --max_window_size 5 \
--total_num 50 --max_steps_per_episode 1000 \
--agent_name dqn \ --gin_files='copa/fixed_replay/configs/dqn.gin' \ --gin_bindings='atari_lib.create_atari_environment.game_name = "Freeway"'
For DPARL, we explicitly pass the cert_alg
option as dynamic
and configure the maximum window size
- PARL
python -um batch_rl.fixed_replay.test_reward \
--base_dir [base_dir] --model_dir [model_dir] \
--cert_alg tight \
--total_num 50 --max_steps_per_episode 1000 \
--agent_name dqn \ --gin_files='copa/fixed_replay/configs/dqn.gin' \ --gin_bindings='atari_lib.create_atari_environment.game_name = "Freeway"'
where base_dir
is the path for storing experimental logs and results, and model_dir
is the path of trained
- TPARL
python -um batch_rl.fixed_replay.test_reward \
--base_dir [base_dir] --model_dir [model_dir] \
--cert_alg window --window_size 4 \
--total_num 50 --max_steps_per_episode 1000 \
--agent_name dqn \ --gin_files='copa/fixed_replay/configs/dqn.gin' \ --gin_bindings='atari_lib.create_atari_environment.game_name = "Freeway"'
For TPARL, we explicitly pass the cert_alg
option as window
and configure the predetermined window size
- DPARL
python -um batch_rl.fixed_replay.test_reward \
--base_dir [base_dir] --model_dir [model_dir] \
--cert_alg dynamic --max_window_size 5 \
--total_num 50 --max_steps_per_episode 1000 \
--agent_name dqn \ --gin_files='copa/fixed_replay/configs/dqn.gin' \ --gin_bindings='atari_lib.create_atari_environment.game_name = "Freeway"'
For DPARL, we explicitly pass the cert_alg
option as dynamic
and configure the maximum window size