Skip to content

Commit

Permalink
update changes
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn committed Aug 1, 2024
1 parent b6408e7 commit 7295d6c
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 27 deletions.
38 changes: 19 additions & 19 deletions rejection_sampling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ To run through the entire dataset you would need a lot more GPUs to finish the g

```bash
# debug job submission
python mason.py \
--cluster ai2/allennlp-cirrascale ai2/general-cirrascale-a5000 ai2/general-cirrascale-a5000 ai2/general-cirrascale-a100-80g-ib \
--priority low \
--budget ai2/allennlp \
--gpus 1 -- which python

chmod -R 777 /net/nfs.cirrascale/allennlp/.cache/hub/
python mason.py \
--cluster ai2/allennlp-cirrascale ai2/general-cirrascale-a5000 ai2/general-cirrascale-a5000 ai2/general-cirrascale-a100-80g-ib \
--priority low \
Expand All @@ -47,68 +54,61 @@ python mason.py \

# prod generations
bash rejection_sampling/batch_generation.bash
# Running shard 1 of 10 (indices 0 to 32615)

# Running shard 1 of 10 (indices 0 to 32615)
# full_command=['python', 'rejection_sampling/generation.py', '--dataset_name', 'allenai/tulu-v2-sft-mixture', '--model_name_or_path', 'allenai/llama-3-tulu-2-8b', '--dataset_start_idx', '0', '--dataset_end_idx', '32615', '--save_filename', 'rejection_sampling/shards/rejection_sampled_completions_0.jsonl', '--n', '5']
# Kicked off Beaker job. https://beaker.org/ex/01J45EDCVMJ3JZP4QPBDGENCBJ
# Kicked off Beaker job. https://beaker.org/ex/01J45EJ6GFY327Q3GMH81T7X9J
# Finished shard 1 of 10

# Running shard 2 of 10 (indices 32615 to 65230)
# full_command=['python', 'rejection_sampling/generation.py', '--dataset_name', 'allenai/tulu-v2-sft-mixture', '--model_name_or_path', 'allenai/llama-3-tulu-2-8b', '--dataset_start_idx', '32615', '--dataset_end_idx', '65230', '--save_filename', 'rejection_sampling/shards/rejection_sampled_completions_1.jsonl', '--n', '5']
# Kicked off Beaker job. https://beaker.org/ex/01J45EDDME4ZP5MCZ16RQNPTGC
# Kicked off Beaker job. https://beaker.org/ex/01J45EJ797SXMFNPCNNFF278Z5
# Finished shard 2 of 10

# Running shard 3 of 10 (indices 65230 to 97845)
# full_command=['python', 'rejection_sampling/generation.py', '--dataset_name', 'allenai/tulu-v2-sft-mixture', '--model_name_or_path', 'allenai/llama-3-tulu-2-8b', '--dataset_start_idx', '65230', '--dataset_end_idx', '97845', '--save_filename', 'rejection_sampling/shards/rejection_sampled_completions_2.jsonl', '--n', '5']
# Kicked off Beaker job. https://beaker.org/ex/01J45EDED3YQ0XEE6QNY7C8GXD
# Kicked off Beaker job. https://beaker.org/ex/01J45EJ814RVA87G38YG949NK5
# Finished shard 3 of 10

# Running shard 4 of 10 (indices 97845 to 130460)
# full_command=['python', 'rejection_sampling/generation.py', '--dataset_name', 'allenai/tulu-v2-sft-mixture', '--model_name_or_path', 'allenai/llama-3-tulu-2-8b', '--dataset_start_idx', '97845', '--dataset_end_idx', '130460', '--save_filename', 'rejection_sampling/shards/rejection_sampled_completions_3.jsonl', '--n', '5']
# Kicked off Beaker job. https://beaker.org/ex/01J45EDF5P9SV1ARKY1EF9MBRP
# Kicked off Beaker job. https://beaker.org/ex/01J45EJ8SAV2R1ZYBVZM1X4VP0
# Finished shard 4 of 10

# Running shard 5 of 10 (indices 130460 to 163075)
# full_command=['python', 'rejection_sampling/generation.py', '--dataset_name', 'allenai/tulu-v2-sft-mixture', '--model_name_or_path', 'allenai/llama-3-tulu-2-8b', '--dataset_start_idx', '130460', '--dataset_end_idx', '163075', '--save_filename', 'rejection_sampling/shards/rejection_sampled_completions_4.jsonl', '--n', '5']
# Kicked off Beaker job. https://beaker.org/ex/01J45EDFZ6MED7PVGVPX6MVS2M
# Kicked off Beaker job. https://beaker.org/ex/01J45EJ9HHEMZRREPKSBENWJB1
# Finished shard 5 of 10

# Running shard 6 of 10 (indices 163075 to 195690)
# full_command=['python', 'rejection_sampling/generation.py', '--dataset_name', 'allenai/tulu-v2-sft-mixture', '--model_name_or_path', 'allenai/llama-3-tulu-2-8b', '--dataset_start_idx', '163075', '--dataset_end_idx', '195690', '--save_filename', 'rejection_sampling/shards/rejection_sampled_completions_5.jsonl', '--n', '5']
# Kicked off Beaker job. https://beaker.org/ex/01J45EDGQDNVHMMJ8NR57T7NDB
# Kicked off Beaker job. https://beaker.org/ex/01J45EJA9B3GHJ3CTPATTV1342
# Finished shard 6 of 10

# Running shard 7 of 10 (indices 195690 to 228305)
# full_command=['python', 'rejection_sampling/generation.py', '--dataset_name', 'allenai/tulu-v2-sft-mixture', '--model_name_or_path', 'allenai/llama-3-tulu-2-8b', '--dataset_start_idx', '195690', '--dataset_end_idx', '228305', '--save_filename', 'rejection_sampling/shards/rejection_sampled_completions_6.jsonl', '--n', '5']
# Kicked off Beaker job. https://beaker.org/ex/01J45EDHFC4QRSAVD3FH15ZSAE
# Kicked off Beaker job. https://beaker.org/ex/01J45EJB1Y6KY564923RJRK39B
# Finished shard 7 of 10

# Running shard 8 of 10 (indices 228305 to 260920)
# full_command=['python', 'rejection_sampling/generation.py', '--dataset_name', 'allenai/tulu-v2-sft-mixture', '--model_name_or_path', 'allenai/llama-3-tulu-2-8b', '--dataset_start_idx', '228305', '--dataset_end_idx', '260920', '--save_filename', 'rejection_sampling/shards/rejection_sampled_completions_7.jsonl', '--n', '5']
# Kicked off Beaker job. https://beaker.org/ex/01J45EDJ89EAV0V4DMGWTG0TK8
# Kicked off Beaker job. https://beaker.org/ex/01J45EJBSTVEXV8WJM6R09R9CT
# Finished shard 8 of 10

# Running shard 9 of 10 (indices 260920 to 293535)
# full_command=['python', 'rejection_sampling/generation.py', '--dataset_name', 'allenai/tulu-v2-sft-mixture', '--model_name_or_path', 'allenai/llama-3-tulu-2-8b', '--dataset_start_idx', '260920', '--dataset_end_idx', '293535', '--save_filename', 'rejection_sampling/shards/rejection_sampled_completions_8.jsonl', '--n', '5']
# Kicked off Beaker job. https://beaker.org/ex/01J45EDK0XE1HSF953EMT3TSQ3
# Kicked off Beaker job. https://beaker.org/ex/01J45EJCJ5HEVVB0TMSVKFEBAQ
# Finished shard 9 of 10

# Running shard 10 of 10 (indices 293535 to 326154)
# full_command=['python', 'rejection_sampling/generation.py', '--dataset_name', 'allenai/tulu-v2-sft-mixture', '--model_name_or_path', 'allenai/llama-3-tulu-2-8b', '--dataset_start_idx', '293535', '--dataset_end_idx', '326154', '--save_filename', 'rejection_sampling/shards/rejection_sampled_completions_9.jsonl', '--n', '5']
# Kicked off Beaker job. https://beaker.org/ex/01J45EDKTM1YBK3YRV9V0BQT9R
# Kicked off Beaker job. https://beaker.org/ex/01J45EJDAAKXEW084RF71QYMFP
# Finished shard 10 of 10

# All shards submitted
```


# 2. tokenize them and run a reward model to filter them
python rejection_sampling.py \
--input_filename completions.jsonl \
--save_filename rejection_sampled_completions.jsonl \
--n 3 \
--num_gpus 2 \
--push_to_hub
bash rejection_sampling/batch_rejection_sampling.bash
```


Expand Down
1 change: 1 addition & 0 deletions rejection_sampling/batch_generation.bash
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mkdir -p rejection_sampling/shards
total_items=326154
num_shards=10
items_per_shard=$((total_items / num_shards))
shared_hf_repo_id=rejection_sampling_$RANDOM

# Loop through shards
for ((i=0; i<num_shards; i++))
Expand Down
38 changes: 38 additions & 0 deletions rejection_sampling/batch_rejection_sampling.bash
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
mkdir -p rejection_sampling/shards
total_items=326154
num_shards=10
items_per_shard=$((total_items / num_shards))
shared_hf_repo_id=rejection_sampling_$RANDOM

# Loop through shards
for ((i=0; i<num_shards; i++))
do
# Calculate start and end indices for this shard
start_idx=$((i * items_per_shard))
end_idx=$(((i + 1) * items_per_shard))

# Adjust the end index for the last shard to include any remaining items
if [ $i -eq $((num_shards - 1)) ]; then
end_idx=$((total_items))
fi

# Run the command for this shard
echo "Running shard $((i+1)) of $num_shards (indices $start_idx to $end_idx)"
python mason.py \
--cluster ai2/allennlp-cirrascale ai2/general-cirrascale-a5000 ai2/general-cirrascale-a5000 ai2/general-cirrascale-a100-80g-ib \
--budget ai2/allennlp \
--priority low \
--gpus 1 -- python rejection_sampling/rejection_sampling.py \
--input_filename rejection_sampling/shards/rejection_sampled_completions_$i.jsonl \
--model_name_or_path allenai/llama-3-tulu-2-8b-uf-mean-rm \
--save_filename rejection_sampling/shards/rejection_sampled_completions_scores_$i.jsonl \
--hf_repo_id $shared_hf_repo_id \
--no_add_timestamp \
--n 5 \
--push_to_hub \
--num_gpus 1

echo "Finished shard $((i+1)) of $num_shards"
echo
done
echo "All shards submitted"
3 changes: 2 additions & 1 deletion rejection_sampling/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def main(args: Args, dataset_args: DatasetArgs, gen_args: GenerationArgs):
ds[key] = ds[key].select(range(min(dataset_args.sanity_check_size, len(ds[key]))))
if dataset_args.dataset_end_idx is None:
dataset_args.dataset_end_idx = len(ds[dataset_args.dataset_train_split])
ds[key] = ds[key].select(range(dataset_args.dataset_start_idx, dataset_args.dataset_end_idx))
for key in ds:
ds[key] = ds[key].select(range(dataset_args.dataset_start_idx, dataset_args.dataset_end_idx))
pprint([dataset_args, args, gen_args])

## DATASET specific logic: in this dataset the prompt is simply just a list of strings
Expand Down
27 changes: 20 additions & 7 deletions rejection_sampling/rejection_sampling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

import time
import torch
import torch.multiprocessing as mp
Expand All @@ -10,6 +13,7 @@
DataCollatorWithPadding,
AutoTokenizer,
)
from tqdm import tqdm
from datasets import Dataset
import json
from torch.utils.data import DataLoader
Expand All @@ -21,13 +25,14 @@
class Args:
model_name_or_path: str = "cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr"
input_filename: str = "completions.jsonl"
save_filename_prefix: str = "rejected_sampling_completions"
save_filename: str = "rejected_sampling_completions.jsonl"
n: int = 1
forward_batch_size: int = 10
forward_batch_size: int = 8
num_gpus: int = 1 # New argument for specifying the number of GPUs
push_to_hub: bool = False
hf_entity: Optional[str] = None
hf_repo_id: str = "rejection_sampling"
add_timestamp: bool = True


def first_true_indices(bools: torch.Tensor, dtype=torch.long):
Expand Down Expand Up @@ -78,7 +83,7 @@ def process_shard(rank: int, args: Args, shard: List[str]):
dataloader = DataLoader(ds, batch_size=args.forward_batch_size, collate_fn=data_collator, pin_memory=True)
scores = []
with torch.no_grad():
for data in dataloader:
for data in tqdm(dataloader):
input_ids = data["input_ids"].to(device)
_, score, _ = get_reward(model, input_ids, tokenizer.pad_token_id, 0)
scores.append(score.cpu())
Expand Down Expand Up @@ -130,13 +135,21 @@ def main(args: Args):
assert worst_completions[i]["messages"][:-1] == best_completions[i]["messages"][:-1]
table["chosen_score"].append(best_completions[i]["score"])
table["rejected_score"].append(worst_completions[i]["score"])
ds = Dataset.from_dict(table)
first_key = list(table.keys())[0]
print(f"{len(table[first_key])=}")
with open(args.save_filename, 'w') as outfile:
for i in range(len(table[first_key])):
json.dump({key: table[key][i] for key in table}, outfile)
outfile.write('\n')

if args.push_to_hub:
if args.hf_entity is None:
args.hf_entity = api.whoami()["name"]
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}_{int(time.time())}"
ds.push_to_hub(full_repo_id)
for f in [__file__, args.input_filename]:
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
if args.add_timestamp:
full_repo_id += f"_{int(time.time())}"
api.create_repo(full_repo_id, repo_type="dataset")
for f in [__file__, args.save_filename]:
api.upload_file(
path_or_fileobj=f,
path_in_repo=f.split("/")[-1],
Expand Down

0 comments on commit 7295d6c

Please sign in to comment.