diff --git a/open_instruct/utils.py b/open_instruct/utils.py index f11103898..a58863e60 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -67,6 +67,7 @@ def instruction_output_to_messages(example): example["messages"] = messages return example + def query_answer_to_messages(example): """ Convert a query-answer pair to a list of messages. @@ -78,6 +79,7 @@ def query_answer_to_messages(example): example["messages"] = messages return example + def query_response_to_messages(example): """ Convert a query-response pair to a list of messages. @@ -89,6 +91,7 @@ def query_response_to_messages(example): example["messages"] = messages return example + def prompt_completion_to_messages(example): """ Convert a prompt-completion pair to a list of messages. @@ -240,9 +243,11 @@ def get_datasets( dataset = dataset.add_column("id", id_col) # Remove redundant columns to avoid schema conflicts on load - dataset = dataset.remove_columns([col for col in dataset.column_names if col not in (columns_to_keep+["id"])]) + dataset = dataset.remove_columns( + [col for col in dataset.column_names if col not in (columns_to_keep + ["id"])] + ) - # add tag to the dataset corresponding to where it was sourced from, for + # add tag to the dataset corresponding to where it was sourced from, for if "train" in split: raw_train_datasets.append(dataset) elif "test" in split: