mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-27 06:40:16 +08:00
Update model_multitalk.py
This commit is contained in:
parent
af4d412e67
commit
64a984177f
@ -16,30 +16,30 @@ def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks, split_num=8):
|
|||||||
|
|
||||||
x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype)
|
x_ref_attnmap = torch.zeros(B, H, x_seqlens, device=visual_q.device, dtype=visual_q.dtype)
|
||||||
chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens)
|
chunk_size = min(max(x_seqlens // split_num, 1), x_seqlens)
|
||||||
|
|
||||||
for i in range(0, x_seqlens, chunk_size):
|
for i in range(0, x_seqlens, chunk_size):
|
||||||
end_i = min(i + chunk_size, x_seqlens)
|
end_i = min(i + chunk_size, x_seqlens)
|
||||||
|
|
||||||
attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens
|
attn_chunk = visual_q[:, :, i:end_i] @ ref_k.permute(0, 2, 3, 1) # B, H, chunk, ref_seqlens
|
||||||
|
|
||||||
# Apply softmax
|
# Apply softmax
|
||||||
attn_max = attn_chunk.max(dim=-1, keepdim=True).values
|
attn_max = attn_chunk.max(dim=-1, keepdim=True).values
|
||||||
attn_chunk = (attn_chunk - attn_max).exp()
|
attn_chunk = (attn_chunk - attn_max).exp()
|
||||||
attn_sum = attn_chunk.sum(dim=-1, keepdim=True)
|
attn_sum = attn_chunk.sum(dim=-1, keepdim=True)
|
||||||
attn_chunk = attn_chunk / (attn_sum + 1e-8)
|
attn_chunk = attn_chunk / (attn_sum + 1e-8)
|
||||||
|
|
||||||
# Apply mask and sum
|
# Apply mask and sum
|
||||||
masked_attn = attn_chunk * ref_target_mask
|
masked_attn = attn_chunk * ref_target_mask
|
||||||
x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8)
|
x_ref_attnmap[:, :, i:end_i] = masked_attn.sum(-1) / (ref_target_mask.sum() + 1e-8)
|
||||||
|
|
||||||
del attn_chunk, masked_attn
|
del attn_chunk, masked_attn
|
||||||
|
|
||||||
# Average across heads
|
# Average across heads
|
||||||
x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens
|
x_ref_attnmap = x_ref_attnmap.mean(dim=1) # B, x_seqlens
|
||||||
x_ref_attn_maps.append(x_ref_attnmap)
|
x_ref_attn_maps.append(x_ref_attnmap)
|
||||||
|
|
||||||
del visual_q, ref_k
|
del visual_q, ref_k
|
||||||
|
|
||||||
return torch.cat(x_ref_attn_maps, dim=0)
|
return torch.cat(x_ref_attn_maps, dim=0)
|
||||||
|
|
||||||
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2):
|
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user