Skip to content

Commit

Permalink
Liner fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mzukowski-reef committed Sep 23, 2024
1 parent 664c7df commit d1d56aa
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 25 deletions.
6 changes: 2 additions & 4 deletions src/compute_horde_prompt_gen/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@ def random_select_str(self, arr: list[str], num: int = 5) -> str:
def generate_prompt(self, short=True) -> str:
if short:
theme = self.random_select(THEMES, num=1)[0]
return (
f"{theme}"
)
themes = self.random_select_str(arr, num=3)
return f"{theme}"
themes = self.random_select_str(THEMES, num=3)

relevance_level = random.randint(5, 20)
complexity_level = random.randint(5, 20)
Expand Down
23 changes: 7 additions & 16 deletions src/compute_horde_prompt_gen/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def generate_prompts(

# remove any duplicates
new_prompts = list(set(new_prompts))
log.info(f"{i=} generation took {seconds_taken:.2f}s; generated {len(new_prompts)} prompts")
log.info(
f"{i=} generation took {seconds_taken:.2f}s; generated {len(new_prompts)} prompts"
)
if total_prompts - len(new_prompts) < 0:
# one might want to optimize here and save additional prompts for next batch,
# but it is so parametrized that it produces on average additional 10 prompts
Expand All @@ -74,19 +76,19 @@ def generate_prompts(
parser.add_argument(
"--batch_size",
type=int,
default=20,
default=262, # on A6000 we want 240 prompts generated in single file, but not all results are valid
help="Batch size - number of prompts given as input per generation request",
)
parser.add_argument(
"--num_return_sequences",
type=int,
default=5,
default=1, # better to generate as many as possible prompts on different themes
help="Number of return sequences outputted for each prompt given as input",
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=500,
default=40, # 40 new tokens is enough for reasonable length prompt - 30 caused too much cut off prompts
help="Max new tokens",
)
parser.add_argument(
Expand All @@ -108,16 +110,10 @@ def generate_prompts(
default="./saved_models/",
help="Path to load the model and tokenizer from",
)
parser.add_argument(
"--number_of_batches",
type=int,
default=None,
help="Number of batches to generate",
)
parser.add_argument(
"--number_of_prompts_per_batch",
type=int,
required=True,
default=240,
help="Number of prompts per uuid batch",
)
parser.add_argument(
Expand All @@ -137,11 +133,6 @@ def generate_prompts(

uuids = args.uuids.split(",")

if args.number_of_batches:
assert (
len(uuids) == args.number_of_batches
), "Number of uuids should be equal to number of batches requested"

model_path = os.path.join(args.model_path, args.model_name)
if args.model_name == "mock":
model = MockModel()
Expand Down
10 changes: 5 additions & 5 deletions src/compute_horde_prompt_gen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,27 @@

def clean_line(line: str) -> str:
line = line.strip()
head, sep, tail = line.partition('<|')
head, sep, tail = line.partition("<|")
if head:
line = head.strip()
else:
# if we started with a tag we assume that inside we find our prompt
line = tail.partition('|>')[2].partition('<|')[0].strip()
line = tail.partition("|>")[2].partition("<|")[0].strip()
# remove list numbering if present
line = re.sub(r"^\s*\d+\.?\s*", "", line)
# strip quotations
line = line.strip('"\'')
line = line.strip("\"'")
return line


def parse_output(output: str) -> list[str]:
# split into lines and clean them
lines = output.split("\n")
for line in lines:
clean_line = clean_line(line)
cleaned_line = clean_line(line)
# we skip if line is too short or too long and not ends with ?
# in most cases it would be just first line
if len(clean_line) > 10 and len(clean_line) < 300 and line.endswith('?'):
if len(cleaned_line) > 10 and len(cleaned_line) < 300 and line.endswith("?"):
return [line]

return []
Expand Down

0 comments on commit d1d56aa

Please sign in to comment.