Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeffrey committed May 24, 2024
1 parent dee37d0 commit e23ec2d
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions open_lm/datapreprocess/ray/tokenize_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,21 +453,29 @@ def load_tokenizer(tokenizer, eos_overwrite=None, pad_overwrite=None):
raise ValueError(f"Unknown Tokenizer: {tokenizer}")

eos_token_id, pad_token_id = enc.eos_token_id, enc.pad_token_id

if eos_overwrite is not None:
if eos_token_id is not None and eos_overwrite != eos_token_id:
logger.warning(f"Default EOS id for {tokenizer} is {eos_token_id} and you are overriding it to be {eos_overwrite}. This may cause issues during training.")
logger.warning(
f"Default EOS id for {tokenizer} is {eos_token_id} and you are overriding it to be {eos_overwrite}. This may cause issues during training."
)
eos_token_id = eos_overwrite

if pad_overwrite is not None:
if pad_overwrite != pad_token_id:
logger.warning(f"Default PAD id for {tokenizer} is {pad_token_id} and you are overriding it to be {pad_overwrite}. This may cause issues during training.")
logger.warning(
f"Default PAD id for {tokenizer} is {pad_token_id} and you are overriding it to be {pad_overwrite}. This may cause issues during training."
)
pad_token_id = pad_overwrite

if eos_token_id is None:
raise ValueError("Tokenizer does not have a specified EOS token id. Please manually pass one in via --eos_overwrite")
raise ValueError(
"Tokenizer does not have a specified EOS token id. Please manually pass one in via --eos_overwrite"
)
if pad_token_id is None:
raise ValueError("Tokenizer does not have a specified PAD token id. Please manually pass one in via --pad_overwrite")
raise ValueError(
"Tokenizer does not have a specified PAD token id. Please manually pass one in via --pad_overwrite"
)

return (lambda x: enc(x).input_ids, eos_token_id, pad_token_id)

Expand Down Expand Up @@ -573,7 +581,7 @@ def main(args):
parser.add_argument("--content_key", type=str, default="text")
parser.add_argument("--seqlen", type=int, default=2048)
parser.add_argument("--tokenizer", type=str, default="EleutherAI/gpt-neox-20b")
parser.add_argument("--pretokenized", action='store_true') # For pre-tokenized data, don't load tokenizer
parser.add_argument("--pretokenized", action="store_true") # For pre-tokenized data, don't load tokenizer
parser.add_argument("--wds_chunk_size", type=int, default=8192)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--subset", type=int, default=None)
Expand Down Expand Up @@ -627,7 +635,6 @@ def main(args):
)
num_nodes = len(ray.nodes())


SpecialTokens = enum.Enum
Sources = enum.Enum("Sources", {item["source"]: index for index, item in enumerate(data["sources"])})

Expand Down Expand Up @@ -663,7 +670,7 @@ def main(args):

if args.pretokenized:
tokenizer = (lambda x: x, args.eos_overwrite, args.pad_overwrite)
else:
else:
tokenizer = load_tokenizer(args.tokenizer, args.eos_overwrite, args.pad_overwrite)

logger.info(f"Total number of keys = {len(input_paths)}")
Expand Down

0 comments on commit e23ec2d

Please sign in to comment.