diff --git a/ldm_patched/ldm/modules/attention.py b/ldm_patched/ldm/modules/attention.py index 6bfa64f97..13183e5f5 100644 --- a/ldm_patched/ldm/modules/attention.py +++ b/ldm_patched/ldm/modules/attention.py @@ -385,7 +385,7 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) - def forward(self, x, context=None, value=None, mask=None): + def forward(self, x, context=None, value=None, mask=None, transformer_options=None): q = self.to_q(x) context = default(context, x) k = self.to_k(context) @@ -504,7 +504,7 @@ def _forward(self, x, context=None, transformer_options={}): n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options) n = self.attn1.to_out(n) else: - n = self.attn1(n, context=context_attn1, value=value_attn1) + n = self.attn1(n, context=context_attn1, value=value_attn1, transformer_options=extra_options) if "attn1_output_patch" in transformer_patches: patch = transformer_patches["attn1_output_patch"] @@ -544,7 +544,7 @@ def _forward(self, x, context=None, transformer_options={}): n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options) n = self.attn2.to_out(n) else: - n = self.attn2(n, context=context_attn2, value=value_attn2) + n = self.attn2(n, context=context_attn2, value=value_attn2, transformer_options=extra_options) if "attn2_output_patch" in transformer_patches: patch = transformer_patches["attn2_output_patch"]