Skip to content

Commit

Permalink
pass options to cross attention class
Browse files Browse the repository at this point in the history
  • Loading branch information
lllyasviel committed Mar 8, 2024
1 parent 10b5ca2 commit 29be1da
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions ldm_patched/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.

self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))

def forward(self, x, context=None, value=None, mask=None):
def forward(self, x, context=None, value=None, mask=None, transformer_options=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
Expand Down Expand Up @@ -504,7 +504,7 @@ def _forward(self, x, context=None, transformer_options={}):
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
n = self.attn1.to_out(n)
else:
n = self.attn1(n, context=context_attn1, value=value_attn1)
n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=extra_options)

if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"]
Expand Down Expand Up @@ -544,7 +544,7 @@ def _forward(self, x, context=None, transformer_options={}):
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
n = self.attn2.to_out(n)
else:
n = self.attn2(n, context=context_attn2, value=value_attn2)
n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=extra_options)

if "attn2_output_patch" in transformer_patches:
patch = transformer_patches["attn2_output_patch"]
Expand Down

0 comments on commit 29be1da

Please sign in to comment.