Skip to content

Commit

Permalink
Lora unload (#55)
Browse files Browse the repository at this point in the history
Unloads loras when folks are making a prediction w/out one.
  • Loading branch information
daanelson authored May 21, 2024
1 parent ad34452 commit b4aaa90
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 155 deletions.
2 changes: 1 addition & 1 deletion cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ build:
- "libxext6"
- "wget"
python_packages:
- "diffusers==0.19.3"
- "diffusers<=0.25"
- "torch==2.0.1"
- "transformers==4.31.0"
- "invisible-watermark==0.2.0"
Expand Down
19 changes: 18 additions & 1 deletion predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def load_trained_weights(self, weights, pipe):
name_rank_map = {}
for tk, tv in tensors.items():
# up is N, d
tensors[tk] = tv.half()
if tk.endswith("up.weight"):
proc_name = ".".join(tk.split(".")[:-3])
r = tv.shape[1]
Expand All @@ -140,7 +141,7 @@ def load_trained_weights(self, weights, pipe):
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=name_rank_map[name],
)
).half()
unet_lora_attn_procs[name] = module.to("cuda", non_blocking=True)

unet.set_attn_processor(unet_lora_attn_procs)
Expand All @@ -160,6 +161,20 @@ def load_trained_weights(self, weights, pipe):
self.tuned_weights = weights
self.tuned_model = True

def unload_trained_weights(self, pipe: DiffusionPipeline):
print("unloading loras")

def _recursive_unset_lora(module: torch.nn.Module):
if hasattr(module, "lora_layer"):
module.lora_layer = None

for _, child in module.named_children():
_recursive_unset_lora(child)

_recursive_unset_lora(pipe.unet)
self.tuned_weights = None
self.tuned_model = False

def setup(self, weights: Optional[Path] = None):
"""Load the model into memory to make running multiple predictions efficient"""

Expand Down Expand Up @@ -350,6 +365,8 @@ def predict(

if replicate_weights:
self.load_trained_weights(replicate_weights, self.txt2img_pipe)
elif self.tuned_model:
self.unload_trained_weights(self.txt2img_pipe)

# OOMs can leave vae in bad state
if self.txt2img_pipe.vae.dtype == torch.float32:
Expand Down
5 changes: 5 additions & 0 deletions requirements_test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
numpy
pytest
replicate
requests
Pillow
Loading

0 comments on commit b4aaa90

Please sign in to comment.