mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-22 04:10:15 +08:00
Chunk attention map calculation for multiple speakers to reduce peak VRAM usage
This commit is contained in:
parent
b4d3f4e567
commit
af4d412e67
@ -88,8 +88,8 @@ class WanSelfAttention(nn.Module):
|
|||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
)
|
)
|
||||||
|
|
||||||
if "self_attn" in patches:
|
if "attn1_patch" in patches:
|
||||||
for p in patches["self_attn"]:
|
for p in patches["attn1_patch"]:
|
||||||
x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
|
x = p({"x": x, "q": q, "k": k, "transformer_options": transformer_options})
|
||||||
|
|
||||||
x = self.o(x)
|
x = self.o(x)
|
||||||
@ -251,8 +251,8 @@ class WanAttentionBlock(nn.Module):
|
|||||||
# cross-attention & ffn
|
# cross-attention & ffn
|
||||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||||
|
|
||||||
if "cross_attn" in patches:
|
if "attn2_patch" in patches:
|
||||||
for p in patches["cross_attn"]:
|
for p in patches["attn2_patch"]:
|
||||||
x = p({"x": x, "transformer_options": transformer_options})
|
x = p({"x": x, "transformer_options": transformer_options})
|
||||||
|
|
||||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||||
|
|||||||
@ -4,27 +4,42 @@ import comfy
|
|||||||
from comfy.ldm.modules.attention import optimized_attention
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
|
||||||
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks):
|
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=8):
|
||||||
scale = 1.0 / visual_q.shape[-1] ** 0.5
|
scale = 1.0 / visual_q.shape[-1] ** 0.5
|
||||||
visual_q = visual_q.transpose(1, 2) * scale
|
visual_q = visual_q.transpose(1, 2) * scale
|
||||||
|
|
||||||
attn = visual_q @ ref_k.permute(0, 2, 3, 1).to(visual_q)
|
B, H, x_seqlens, K = visual_q.shape
|
||||||
|
|
||||||
x_ref_attn_map_source = attn.softmax(-1).to(visual_q.dtype) # B, H, x_seqlens, ref_seqlens
|
|
||||||
del attn
|
|
||||||
|
|
||||||
x_ref_attn_maps = []
|
x_ref_attn_maps = []
|
||||||
|
|
||||||
for class_idx, ref_target_mask in enumerate(ref_target_masks):
|
for class_idx, ref_target_mask in enumerate(ref_target_masks):
|
||||||
ref_target_mask = ref_target_mask.view(1, 1, 1, *ref_target_mask.shape)
|
ref_target_mask = ref_target_mask.view(1, 1, 1, -1)
|
||||||
x_ref_attnmap = x_ref_attn_map_source * ref_target_mask
|
|
||||||
x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens
|
x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype)
|
||||||
x_ref_attnmap = x_ref_attnmap.transpose(1, 2) # B, x_seqlens, H
|
chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens)
|
||||||
x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens
|
|
||||||
|
for i in range(0, x_seqlens, chunk_size):
|
||||||
|
end_i = min(i + chunk_size, x_seqlens)
|
||||||
|
|
||||||
|
attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens
|
||||||
|
|
||||||
|
# Apply softmax
|
||||||
|
attn_max = attn_chunk.max(dim=-1, keepdim=True).values
|
||||||
|
attn_chunk = (attn_chunk - attn_max).exp()
|
||||||
|
attn_sum = attn_chunk.sum(dim=-1, keepdim=True)
|
||||||
|
attn_chunk = attn_chunk / (attn_sum + 1e-8)
|
||||||
|
|
||||||
|
# Apply mask and sum
|
||||||
|
masked_attn = attn_chunk * ref_target_mask
|
||||||
|
x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8)
|
||||||
|
|
||||||
|
del attn_chunk, masked_attn
|
||||||
|
|
||||||
|
# Average across heads
|
||||||
|
x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens
|
||||||
x_ref_attn_maps.append(x_ref_attnmap)
|
x_ref_attn_maps.append(x_ref_attnmap)
|
||||||
|
|
||||||
del x_ref_attn_map_source
|
del visual_q, ref_k
|
||||||
|
|
||||||
return torch.cat(x_ref_attn_maps, dim=0)
|
return torch.cat(x_ref_attn_maps, dim=0)
|
||||||
|
|
||||||
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2):
|
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2):
|
||||||
|
|||||||
@ -1429,9 +1429,9 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
|
|||||||
is_extend=previous_frames is not None,
|
is_extend=previous_frames is not None,
|
||||||
))
|
))
|
||||||
# add cross-attention patch
|
# add cross-attention patch
|
||||||
model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "cross_attn")
|
model_patched.set_model_patch(MultiTalkCrossAttnPatch(model_patch, audio_scale), "attn2_patch")
|
||||||
if token_ref_target_masks is not None:
|
if token_ref_target_masks is not None:
|
||||||
model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "self_attn")
|
model_patched.set_model_patch(MultiTalkGetAttnMapPatch(token_ref_target_masks), "attn1_patch")
|
||||||
|
|
||||||
out_latent = {}
|
out_latent = {}
|
||||||
out_latent["samples"] = latent
|
out_latent["samples"] = latent
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user