mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-17 10:02:59 +08:00
Fix WanI2VCrossAttention so that it expects to receive transformer_options
This commit is contained in:
parent
2d13bf1c7a
commit
1ae6fe14a7
@ -117,7 +117,7 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|||||||
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
# self.alpha = nn.Parameter(torch.zeros((1, )))
|
||||||
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, context, context_img_len):
|
def forward(self, x, context, context_img_len, transformer_options={}):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
x(Tensor): Shape [B, L1, C]
|
x(Tensor): Shape [B, L1, C]
|
||||||
@ -132,9 +132,9 @@ class WanI2VCrossAttention(WanSelfAttention):
|
|||||||
v = self.v(context)
|
v = self.v(context)
|
||||||
k_img = self.norm_k_img(self.k_img(context_img))
|
k_img = self.norm_k_img(self.k_img(context_img))
|
||||||
v_img = self.v_img(context_img)
|
v_img = self.v_img(context_img)
|
||||||
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads)
|
img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads, transformer_options=transformer_options)
|
||||||
# compute attention
|
# compute attention
|
||||||
x = optimized_attention(q, k, v, heads=self.num_heads)
|
x = optimized_attention(q, k, v, heads=self.num_heads, transformer_options=transformer_options)
|
||||||
|
|
||||||
# output
|
# output
|
||||||
x = x + img_x
|
x = x + img_x
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user