forked from ArrowM/auto1111-improved-prompt-matrix
-
Notifications
You must be signed in to change notification settings - Fork 1
/
improved_prompt_matrix.py
64 lines (53 loc) · 2.31 KB
/
improved_prompt_matrix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import re
import gradio as gr
import modules.shared as shared
import modules.scripts as scripts
import modules.sd_samplers
from modules.processing import process_images, StableDiffusionProcessingTxt2Img
class Script(scripts.Script):
def title(self):
return "Improved prompt matrix"
def ui(self, is_img2img):
dummy = gr.Checkbox(label="Usage: a <corgi|cat> wearing <goggles|a hat>")
return [dummy]
def run(self, p, dummy):
modules.processing.fix_seed(p)
original_prompt = p.prompt[0] if type(p.prompt) == list else p.prompt
matrix_count = 0
prompt_matrix_parts = []
for data in re.finditer(r'(<([^>]+)>)', original_prompt):
if data:
matrix_count += 1
span = data.span(1)
items = data.group(2).split("|")
prompt_matrix_parts.extend(items)
all_prompts = [original_prompt]
while True:
found_matrix = False
for this_prompt in all_prompts:
for data in re.finditer(r'(<(?!(lora|hypernet):)([^>]+)>)', this_prompt):
if data:
found_matrix = True
# Remove last prompt as it has a found_matrix
all_prompts.remove(this_prompt)
span = data.span(1)
items = data.group(3).split("|")
for item in items:
new_prompt = this_prompt[:span[0]] + item.strip() + this_prompt[span[1]:]
all_prompts.append(new_prompt.strip())
break
if found_matrix:
break
if not found_matrix:
break
total_images = len(all_prompts) * p.n_iter
print(f"Prompt matrix will create {total_images} images")
total_steps = p.steps * total_images
if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
total_steps *= 2
shared.total_tqdm.updateTotal(total_steps)
p.prompt = all_prompts * p.n_iter
p.seed = [item for item in range(int(p.seed), int(p.seed) + p.n_iter) for _ in range(len(all_prompts))]
p.n_iter = total_images
p.prompt_for_display = original_prompt
return process_images(p)