From af4d412e67bb72af64bf223385acff7f67ed3220 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 26 Nov 2025 18:14:54 +0200 Subject: [PATCH] Chunk attention map calculation for multiple speakers to reduce peak VRAM usage --- comfy/ldm/wan/model.py | 8 +++--- comfy/ldm/wan/model_multitalk.py | 43 +++++++++++++++++++++----------- comfy_extras/nodes_wan.py | 4 +-- 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index a34df1dbd..23feb9a7b 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -88,8 +88,8 @@ class WanSelfAttention(nn.Module): transformer_options=transformer_options, ) - if "self_attn" in patches: - for p in patches["self_attn"]: + if "attn1_patch" in patches: + for p in patches["attn1_patch"]: x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options}) x = self.o(x) @@ -251,8 +251,8 @@ class WanAttentionBlock(nn.Module): # cross-attention & ffn x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options) - if "cross_attn" in patches: - for p in patches["cross_attn"]: + if "attn2_patch" in patches: + for p in patches["attn2_patch"]: x = p({"x": x, "transformer_options": transformer_options}) y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x))) diff --git a/comfy/ldm/wan/model_multitalk.py b/comfy/ldm/wan/model_multitalk.py index a651618c8..7e7dda7d5 100644 --- a/comfy/ldm/wan/model_multitalk.py +++ b/comfy/ldm/wan/model_multitalk.py @@ -4,27 +4,42 @@ import comfy from comfy.ldm.modules.attention import optimized_attention -def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks): +def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=8): scale = 1.0 / visual_q.shape[-1] ** 0.5 visual_q = visual_q.transpose(1, 2) * scale - attn = visual_q @ ref_k.permute(0, 2, 3, 1).to(visual_q) - - x_ref_attn_map_source = attn.softmax(-1).to(visual_q.dtype) # B, H, x_seqlens, ref_seqlens - del attn + B, H, x_seqlens, K = visual_q.shape x_ref_attn_maps = [] - for class_idx, ref_target_mask in enumerate(ref_target_masks): - ref_target_mask = ref_target_mask.view(1, 1, 1, *ref_target_mask.shape) - x_ref_attnmap = x_ref_attn_map_source * ref_target_mask - x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens - x_ref_attnmap = x_ref_attnmap.transpose(1, 2) # B, x_seqlens, H - x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens + ref_target_mask = ref_target_mask.view(1, 1, 1, -1) + + x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype) + chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens) + + for i in range(0, x_seqlens, chunk_size): + end_i = min(i + chunk_size, x_seqlens) + + attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens + + # Apply softmax + attn_max = attn_chunk.max(dim=-1, keepdim=True).values + attn_chunk = (attn_chunk - attn_max).exp() + attn_sum = attn_chunk.sum(dim=-1, keepdim=True) + attn_chunk = attn_chunk / (attn_sum + 1e-8) + + # Apply mask and sum + masked_attn = attn_chunk * ref_target_mask + x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8) + + del attn_chunk, masked_attn + + # Average across heads + x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens x_ref_attn_maps.append(x_ref_attnmap) - - del x_ref_attn_map_source - + + del visual_q, ref_k + return torch.cat(x_ref_attn_maps, dim=0) def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2): diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index 409025121..8e0f8287b 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -1429,9 +1429,9 @@ class WanInfiniteTalkToVideo(io.ComfyNode): is_extend=previous_frames is not None, )) # add cross-attention patch - model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "cross_attn") + model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "attn2_patch") if token_ref_target_masks is not None: - model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "self_attn") + model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "attn1_patch") out_latent = {} out_latent["samples"] = latent