Skip to content

Commit

Permalink
fix original
Browse files Browse the repository at this point in the history
  • Loading branch information
natolambert committed Jul 18, 2024
1 parent 3bb9615 commit 7e26283
Showing 1 changed file with 38 additions and 20 deletions.
58 changes: 38 additions & 20 deletions open_instruct/get_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,32 +23,35 @@
from transformers import AutoTokenizer


def get_statistics_for_messages_data(data_path, dataset=None, split="train"):
def get_statistics_for_messages_data(
data_path,
dataset=None,
split="train",
messages_key="messages",
tokenizer="/net/nfs.cirrascale/allennlp/yizhongw/hf_llama2_models/7B/",
):
if dataset is None:
# load dataset
dataset = load_dataset("json", data_files={split: data_path})
# tokenize dataset
tokenizer = AutoTokenizer.from_pretrained(
"/net/nfs.cirrascale/allennlp/yizhongw/hf_llama2_models/7B/", use_fast=False
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer, use_fast=False)
# get statistics
num_instances = len(dataset[split])

# remove any messages that have "role" == "system"
def remove_system_messages(example):
example["messages"] = [message for message in example["messages"] if message["role"] != "system"]
# instance["messages"] = [message for message in instance["messages"] if message["role"] != "system"]
example[args.messages_key] = [message for message in example[args.messages_key] if message["role"] != "system"]
return example

dataset = dataset.map(remove_system_messages)

num_of_turns = [len(instance["messages"]) for instance in dataset[split]]
num_of_turns = [len(instance[args.messages_key]) for instance in dataset[split]]
user_prompt_lengths = []
assistant_response_lengths = []
instance_lengths = []
for instance in tqdm.tqdm(dataset[split], desc="Processing instances"):
instance_length = 0
for message in instance["messages"]:
for message in instance[args.messages_key]:
if message["role"] == "user":
user_prompt_lengths.append(
len(tokenizer(message["content"], truncation=False, add_special_tokens=False)["input_ids"])
Expand Down Expand Up @@ -94,16 +97,20 @@ def remove_system_messages(example):
return result


def get_statistics_for_prompt_completion_data(data_path, dataset=None, split="train"):
def get_statistics_for_prompt_completion_data(
data_path,
dataset=None,
split="train",
response_key="completion",
tokenizer="/net/nfs.cirrascale/allennlp/yizhongw/hf_llama2_models/7B/",
):
if dataset is None:
# load dataset
dataset = load_dataset("json", data_files={split: data_path})
# load dataset
dataset = load_dataset("json", data_files={split: data_path})
prompts = [instance["prompt"] for instance in dataset[split]]
completions = [instance["completion"] for instance in dataset[split]]
completions = [instance[response_key] for instance in dataset[split]]
# tokenize dataset
tokenizer = AutoTokenizer.from_pretrained("/net/nfs.cirrascale/allennlp/yizhongw/hf_llama_models/7B")
tokenizer = AutoTokenizer.from_pretrained(tokenizer)
tokenized_prompts = tokenizer(prompts, truncation=False, add_special_tokens=False)
tokenized_completions = tokenizer(completions, truncation=False, add_special_tokens=False)
# get statistics
Expand Down Expand Up @@ -141,21 +148,32 @@ def get_statistics_for_prompt_completion_data(data_path, dataset=None, split="tr
parser.add_argument("--data_path", type=str, required=True)
parser.add_argument("--save_path", type=str, help="Path to save the statistics.")
parser.add_argument("--split", type=str, default="train")
parser.add_argument("--response_key", type=str, default="completion")
parser.add_argument("--messages_key", type=str, default="messages")
parser.add_argument("--tokenizer", type=str, default="/net/nfs.cirrascale/allennlp/yizhongw/hf_llama2_models/7B/")
args = parser.parse_args()

# Check if the data_path is a dataset id
if repo_exists(args.data_path, repo_type="dataset"):
# Check if the data_path is a dataset id, only check if /
if "json" in args.data_path:
with open(args.data_path, "r") as f:
sample = json.loads(f.readline())
dataset = None

elif repo_exists(args.data_path, repo_type="dataset"):

dataset = load_dataset(args.data_path)
sample = dataset[args.split][0]
else:
with open(args.data_path, "r") as f:
sample = json.loads(f.readline())
raise ValueError("Invalid data path - the data path should be either a dataset id or a path to a json file.")

if "messages" in sample:
statistics = get_statistics_for_messages_data(args.data_path, dataset=dataset, split=args.split)
if args.messages_key in sample:
statistics = get_statistics_for_messages_data(
args.data_path, dataset=dataset, split=args.split, messages_key=args.messages_key, tokenizer=args.tokenizer
)
elif "prompt" in sample:
statistics = get_statistics_for_prompt_completion_data(args.data_path, dataset=dataset, split=args.split)
statistics = get_statistics_for_prompt_completion_data(
args.data_path, dataset=dataset, split=args.split, response_key=args.response_key, tokenizer=args.tokenizer
)
else:
raise ValueError("Invalid data format - the data should be either prompt completion data or messages data.")

Expand Down

0 comments on commit 7e26283

Please sign in to comment.