wan: Delete the self attention before cross attention

This saves VRAM when the cross attention and FFN are in play as the
VRAM peak.
This commit is contained in:
Rattus 2025-09-27 21:38:40 +10:00
parent 8866a22dcb
commit 98ca6030f3

View File

@ -237,6 +237,7 @@ class WanAttentionBlock(nn.Module):
freqs, transformer_options=transformer_options) freqs, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x)) x = torch.addcmul(x, y, repeat_e(e[2], x))
del y
# cross-attention & ffn # cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)