Skip to content

Commit

Permalink
fix CLIP doing the unneeded normalization
Browse files Browse the repository at this point in the history
revert SD2.1 back to use the original repo
add SDXL's force_zero_embeddings to negative prompt
  • Loading branch information
AUTOMATIC1111 committed Jul 13, 2023
1 parent 21aec6f commit 594c8e7
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 8 deletions.
2 changes: 1 addition & 1 deletion modules/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def get_conds_with_caching(self, function, required_prompts, steps, caches, extr

def setup_conds(self):
prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height)
negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)

sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
Expand Down
14 changes: 10 additions & 4 deletions modules/prompt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,17 @@ class SdConditioning(list):
A list with prompts for stable diffusion's conditioner model.
Can also specify width and height of created image - SDXL needs it.
"""
def __init__(self, prompts, width=None, height=None):
def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None):
super().__init__()
self.extend(prompts)
self.width = width or getattr(prompts, 'width', None)
self.height = height or getattr(prompts, 'height', None)

if copy_from is None:
copy_from = prompts

self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False)
self.width = width or getattr(copy_from, 'width', None)
self.height = height or getattr(copy_from, 'height', None)



def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
Expand Down Expand Up @@ -153,7 +159,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):
res.append(cached)
continue

texts = [x[1] for x in prompt_schedule]
texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)
conds = model.get_learned_conditioning(texts)

cond_schedule = []
Expand Down
2 changes: 1 addition & 1 deletion modules/sd_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def hijack(self, m):
if typename == 'FrozenCLIPEmbedder':
model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(embedder, self)
m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
conditioner.embedders[i] = m.cond_stage_model
if typename == 'FrozenOpenCLIPEmbedder2':
embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
Expand Down
15 changes: 15 additions & 0 deletions modules/sd_hijack_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,18 @@ def encode_embedding_init_text(self, init_text, nvpt):
embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)

return embedded


class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords):
def __init__(self, wrapped, hijack):
super().__init__(wrapped, hijack)

def encode_with_transformers(self, tokens):
outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden")

if self.wrapped.layer == "last":
z = outputs.last_hidden_state
else:
z = outputs.hidden_states[self.wrapped.layer_idx]

return z
1 change: 0 additions & 1 deletion modules/sd_models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
config_default = shared.sd_default_config
config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")
config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")
config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml")
config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")
config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")
config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")
Expand Down
3 changes: 2 additions & 1 deletion modules/sd_models_xl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch:
"target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype),
}

c = self.conditioner(sdxl_conds)
force_zero_negative_prompt = getattr(batch, 'is_negative_prompt', False) and all(x == '' for x in batch)
c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])

return c

Expand Down

0 comments on commit 594c8e7

Please sign in to comment.