-
Notifications
You must be signed in to change notification settings - Fork 2
/
txt2img_webui.py
115 lines (101 loc) · 3.33 KB
/
txt2img_webui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from __future__ import generators
import argparse
import torch
import random
import gradio as gr
from sdxl import pipeline
class WebUI:
def __init__(self, low_vram=False):
self.pipeline = pipeline(
model="stabilityai/stable-diffusion-xl-base-0.9", low_vram=low_vram
)
self.refiner = pipeline(
model="stabilityai/stable-diffusion-xl-refiner-0.9", low_vram=low_vram
)
inputs = [
gr.Textbox(), # prompt
gr.Textbox(), # negative prompt
gr.Textbox(value=-1), # seed
gr.Slider(5, 1024 * 2, value=1024, step=8), # width
gr.Slider(5, 1024 * 2, value=1024, step=8), # height
gr.Slider(0, 200, value=50, step=1), # steps
gr.Slider(0, 30, value=7.5), # scale
gr.Slider(1, 30, value=1, step=1), # num of images
]
self.webui = gr.Interface(self.text_to_img, inputs, gr.Gallery())
def text_to_img(
self,
prompt,
negative_prompt,
seed,
width,
height,
steps,
guidance_scale,
num_images_per_prompt,
):
seeds, generators = self.parse_seed(seed, num_images_per_prompt)
latent_image = self.pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generators,
width=width,
height=height,
num_inference_steps=steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
output_type="latent",
).images
images = self.refiner(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generators,
num_inference_steps=steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
image=latent_image,
).images
return self.gr_gallery(seeds, images)
def launch(self, *args, **kwargs):
self.webui.launch(*args, **kwargs)
def parse_seed(self, seed, num_images):
seed = int(seed)
if seed <= -1:
seeds = [random.randint(1, 99999999999) for _ in range(0, num_images)]
generators = [
torch.Generator(device=self.pipeline.device).manual_seed(s)
for s in seeds
]
else:
seeds = [seed + i for i in range(0, num_images)]
generators = [
torch.Generator(device=self.pipeline.device).manual_seed(s)
for s in seeds
]
return (seeds, generators)
def gr_gallery(self, seeds, num_images):
return [[i[1], seeds[i[0]]] for i in enumerate(num_images)]
def arguments_parser():
parser = argparse.ArgumentParser(
description="Generate a 1024x1024 image for the given prompt to the specified output file."
)
parser.add_argument(
"-p",
"--port",
type=int,
default=7860,
help="the port to run WebUI on",
)
parser.add_argument(
"-l",
"--lowvram",
default=False,
action="store_true",
help="enable lowvram mode.",
)
return parser
def main():
args = arguments_parser().parse_args()
WebUI(low_vram=args.lowvram).launch(server_port=args.port)
if __name__ == "__main__":
main()