From 64a984177fb24f4eba6e2f969052f7a521d739a4 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 26 Nov 2025 18:15:12 +0200 Subject: [PATCH] Update model_multitalk.py --- comfy/ldm/wan/model_multitalk.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/wan/model_multitalk.py b/comfy/ldm/wan/model_multitalk.py index 7e7dda7d5..a67fb5158 100644 --- a/comfy/ldm/wan/model_multitalk.py +++ b/comfy/ldm/wan/model_multitalk.py @@ -16,30 +16,30 @@ def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=8): x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype) chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens) - + for i in range(0, x_seqlens, chunk_size): end_i = min(i + chunk_size, x_seqlens) - + attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens - + # Apply softmax attn_max = attn_chunk.max(dim=-1, keepdim=True).values attn_chunk = (attn_chunk - attn_max).exp() attn_sum = attn_chunk.sum(dim=-1, keepdim=True) attn_chunk = attn_chunk / (attn_sum + 1e-8) - + # Apply mask and sum masked_attn = attn_chunk * ref_target_mask x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8) - + del attn_chunk, masked_attn - + # Average across heads x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens x_ref_attn_maps.append(x_ref_attnmap) - + del visual_q, ref_k - + return torch.cat(x_ref_attn_maps, dim=0) def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2):