mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-28 23:30:16 +08:00
re-init
This commit is contained in:
parent
4ffea0e864
commit
efe83f5a36
@ -87,7 +87,7 @@ class WanSelfAttention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
x = self.o(x)
|
x = self.o(x)
|
||||||
return x
|
return x, q, k
|
||||||
|
|
||||||
|
|
||||||
class WanT2VCrossAttention(WanSelfAttention):
|
class WanT2VCrossAttention(WanSelfAttention):
|
||||||
@ -178,7 +178,8 @@ class WanAttentionBlock(nn.Module):
|
|||||||
window_size=(-1, -1),
|
window_size=(-1, -1),
|
||||||
qk_norm=True,
|
qk_norm=True,
|
||||||
cross_attn_norm=False,
|
cross_attn_norm=False,
|
||||||
eps=1e-6, operation_settings={}):
|
eps=1e-6, operation_settings={},
|
||||||
|
block_idx=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.ffn_dim = ffn_dim
|
self.ffn_dim = ffn_dim
|
||||||
@ -187,6 +188,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
self.qk_norm = qk_norm
|
self.qk_norm = qk_norm
|
||||||
self.cross_attn_norm = cross_attn_norm
|
self.cross_attn_norm = cross_attn_norm
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
self.block_idx = block_idx
|
||||||
|
|
||||||
# layers
|
# layers
|
||||||
self.norm1 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.norm1 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
@ -225,6 +227,8 @@ class WanAttentionBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# assert e.dtype == torch.float32
|
# assert e.dtype == torch.float32
|
||||||
|
|
||||||
|
patches = transformer_options.get("patches", {})
|
||||||
|
|
||||||
if e.ndim < 4:
|
if e.ndim < 4:
|
||||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||||
else:
|
else:
|
||||||
@ -232,7 +236,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
# assert e[0].dtype == torch.float32
|
# assert e[0].dtype == torch.float32
|
||||||
|
|
||||||
# self-attention
|
# self-attention
|
||||||
y = self.self_attn(
|
y, q, k = self.self_attn(
|
||||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||||
freqs, transformer_options=transformer_options)
|
freqs, transformer_options=transformer_options)
|
||||||
|
|
||||||
@ -241,6 +245,11 @@ class WanAttentionBlock(nn.Module):
|
|||||||
|
|
||||||
# cross-attention & ffn
|
# cross-attention & ffn
|
||||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
if "cross_attn" in patches:
|
||||||
|
for p in patches["cross_attn"]:
|
||||||
|
x = x + p({"x": x, "q": q, "k": k, "block_idx": self.block_idx, "transformer_options": transformer_options})
|
||||||
|
|
||||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||||
return x
|
return x
|
||||||
@ -262,6 +271,7 @@ class VaceWanAttentionBlock(WanAttentionBlock):
|
|||||||
):
|
):
|
||||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
||||||
self.block_id = block_id
|
self.block_id = block_id
|
||||||
|
self.block_idx = None
|
||||||
if block_id == 0:
|
if block_id == 0:
|
||||||
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||||
@ -486,8 +496,8 @@ class WanModel(torch.nn.Module):
|
|||||||
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
|
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
|
||||||
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
|
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings, block_idx=i)
|
||||||
for _ in range(num_layers)
|
for i in range(num_layers)
|
||||||
])
|
])
|
||||||
|
|
||||||
# head
|
# head
|
||||||
@ -540,6 +550,7 @@ class WanModel(torch.nn.Module):
|
|||||||
# embeddings
|
# embeddings
|
||||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
grid_sizes = x.shape[2:]
|
grid_sizes = x.shape[2:]
|
||||||
|
transformer_options["grid_sizes"] = grid_sizes
|
||||||
x = x.flatten(2).transpose(1, 2)
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
# time embeddings
|
# time embeddings
|
||||||
@ -722,6 +733,7 @@ class VaceWanModel(WanModel):
|
|||||||
# embeddings
|
# embeddings
|
||||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
grid_sizes = x.shape[2:]
|
grid_sizes = x.shape[2:]
|
||||||
|
transformer_options["grid_sizes"] = grid_sizes
|
||||||
x = x.flatten(2).transpose(1, 2)
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
# time embeddings
|
# time embeddings
|
||||||
|
|||||||
593
comfy/ldm/wan/model_multitalk.py
Normal file
593
comfy/ldm/wan/model_multitalk.py
Normal file
@ -0,0 +1,593 @@
|
|||||||
|
import torch
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
import math
|
||||||
|
import comfy
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
import latent_preview
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_x_ref_attn_map(visual_q, ref_k, ref_target_masks):
|
||||||
|
scale = 1.0 / visual_q.shape[-1] ** 0.5
|
||||||
|
visual_q = visual_q.transpose(1, 2) * scale
|
||||||
|
|
||||||
|
attn = visual_q @ ref_k.permute(0, 2, 3, 1).to(visual_q)
|
||||||
|
|
||||||
|
x_ref_attn_map_source = attn.softmax(-1).to(visual_q.dtype) # B, H, x_seqlens, ref_seqlens
|
||||||
|
del attn
|
||||||
|
|
||||||
|
x_ref_attn_maps = []
|
||||||
|
|
||||||
|
for class_idx, ref_target_mask in enumerate(ref_target_masks):
|
||||||
|
ref_target_mask = ref_target_mask.view(1, 1, 1, *ref_target_mask.shape)
|
||||||
|
x_ref_attnmap = x_ref_attn_map_source * ref_target_mask
|
||||||
|
x_ref_attnmap = x_ref_attnmap.sum(-1) / ref_target_mask.sum() # B, H, x_seqlens, ref_seqlens --> B, H, x_seqlens
|
||||||
|
x_ref_attnmap = x_ref_attnmap.transpose(1, 2) # B, x_seqlens, H
|
||||||
|
x_ref_attnmap = x_ref_attnmap.mean(-1) # B, x_seqlens
|
||||||
|
x_ref_attn_maps.append(x_ref_attnmap)
|
||||||
|
|
||||||
|
del x_ref_attn_map_source
|
||||||
|
|
||||||
|
return torch.cat(x_ref_attn_maps, dim=0)
|
||||||
|
|
||||||
|
def get_attn_map_with_target(visual_q, ref_k, shape, ref_target_masks=None, split_num=2):
|
||||||
|
"""Args:
|
||||||
|
query (torch.tensor): B M H K
|
||||||
|
key (torch.tensor): B M H K
|
||||||
|
shape (tuple): (N_t, N_h, N_w)
|
||||||
|
ref_target_masks: [B, N_h * N_w]
|
||||||
|
"""
|
||||||
|
|
||||||
|
N_t, N_h, N_w = shape
|
||||||
|
|
||||||
|
x_seqlens = N_h * N_w
|
||||||
|
ref_k = ref_k[:, :x_seqlens]
|
||||||
|
_, seq_lens, heads, _ = visual_q.shape
|
||||||
|
class_num, _ = ref_target_masks.shape
|
||||||
|
x_ref_attn_maps = torch.zeros(class_num, seq_lens).to(visual_q)
|
||||||
|
|
||||||
|
split_chunk = heads // split_num
|
||||||
|
|
||||||
|
for i in range(split_num):
|
||||||
|
x_ref_attn_maps_perhead = calculate_x_ref_attn_map(
|
||||||
|
visual_q[:, :, i*split_chunk:(i+1)*split_chunk, :],
|
||||||
|
ref_k[:, :, i*split_chunk:(i+1)*split_chunk, :],
|
||||||
|
ref_target_masks
|
||||||
|
)
|
||||||
|
x_ref_attn_maps += x_ref_attn_maps_perhead
|
||||||
|
|
||||||
|
return x_ref_attn_maps / split_num
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_and_scale(column, source_range, target_range, epsilon=1e-8):
|
||||||
|
source_min, source_max = source_range
|
||||||
|
new_min, new_max = target_range
|
||||||
|
normalized = (column - source_min) / (source_max - source_min + epsilon)
|
||||||
|
scaled = normalized * (new_max - new_min) + new_min
|
||||||
|
return scaled
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
||||||
|
x1, x2 = x.unbind(dim=-1)
|
||||||
|
x = torch.stack((-x2, x1), dim=-1)
|
||||||
|
return rearrange(x, "... d r -> ... (d r)")
|
||||||
|
|
||||||
|
|
||||||
|
def get_audio_embeds(encoded_audio, audio_start, audio_end):
|
||||||
|
audio_embs = []
|
||||||
|
human_num = len(encoded_audio)
|
||||||
|
audio_frames = encoded_audio[0].shape[0]
|
||||||
|
|
||||||
|
indices = (torch.arange(4 + 1) - 2) * 1
|
||||||
|
|
||||||
|
for human_idx in range(human_num):
|
||||||
|
if audio_end > audio_frames: # in case of not enough audio for current window, pad with first audio frame as that's most likely silence
|
||||||
|
pad_len = audio_end - audio_frames
|
||||||
|
pad_shape = list(encoded_audio[human_idx].shape)
|
||||||
|
pad_shape[0] = pad_len
|
||||||
|
pad_tensor = encoded_audio[human_idx][:1].repeat(pad_len, *([1] * (encoded_audio[human_idx].dim() - 1)))
|
||||||
|
encoded_audio_in = torch.cat([encoded_audio[human_idx], pad_tensor], dim=0)
|
||||||
|
else:
|
||||||
|
encoded_audio_in = encoded_audio[human_idx]
|
||||||
|
center_indices = torch.arange(audio_start, audio_end, 1).unsqueeze(1) + indices.unsqueeze(0)
|
||||||
|
center_indices = torch.clamp(center_indices, min=0, max=encoded_audio_in.shape[0] - 1)
|
||||||
|
audio_emb = encoded_audio_in[center_indices].unsqueeze(0)
|
||||||
|
audio_embs.append(audio_emb)
|
||||||
|
|
||||||
|
return torch.cat(audio_embs, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
def project_audio_features(audio_proj, encoded_audio, audio_start, audio_end):
|
||||||
|
audio_embs = get_audio_embeds(encoded_audio, audio_start, audio_end)
|
||||||
|
|
||||||
|
first_frame_audio_emb_s = audio_embs[:, :1, ...]
|
||||||
|
latter_frame_audio_emb = audio_embs[:, 1:, ...]
|
||||||
|
latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=4)
|
||||||
|
|
||||||
|
middle_index = audio_proj.seq_len // 2
|
||||||
|
|
||||||
|
latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...]
|
||||||
|
latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||||
|
latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
|
||||||
|
latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||||
|
latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...]
|
||||||
|
latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
|
||||||
|
latter_frame_audio_emb_s = torch.cat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
|
||||||
|
|
||||||
|
audio_emb = audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s)
|
||||||
|
audio_emb = torch.cat(audio_emb.split(1), dim=2)
|
||||||
|
|
||||||
|
return audio_emb
|
||||||
|
|
||||||
|
|
||||||
|
class RotaryPositionalEmbedding1D(torch.nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
head_dim,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.base = 10000
|
||||||
|
|
||||||
|
def precompute_freqs_cis_1d(self, pos_indices):
|
||||||
|
freqs = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2)[: (self.head_dim // 2)].float() / self.head_dim))
|
||||||
|
freqs = freqs.to(pos_indices.device)
|
||||||
|
freqs = torch.einsum("..., f -> ... f", pos_indices.float(), freqs)
|
||||||
|
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
||||||
|
return freqs
|
||||||
|
|
||||||
|
def forward(self, x, pos_indices):
|
||||||
|
freqs_cis = self.precompute_freqs_cis_1d(pos_indices)
|
||||||
|
|
||||||
|
x_ = x.float()
|
||||||
|
|
||||||
|
freqs_cis = freqs_cis.float().to(x.device)
|
||||||
|
cos, sin = freqs_cis.cos(), freqs_cis.sin()
|
||||||
|
cos, sin = rearrange(cos, 'n d -> 1 1 n d'), rearrange(sin, 'n d -> 1 1 n d')
|
||||||
|
x_ = (x_ * cos) + (rotate_half(x_) * sin)
|
||||||
|
|
||||||
|
return x_.type_as(x)
|
||||||
|
|
||||||
|
class SingleStreamAttention(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
encoder_hidden_states_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
qkv_bias: bool,
|
||||||
|
device=None, dtype=None, operations=None
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.encoder_hidden_states_dim = encoder_hidden_states_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
|
||||||
|
self.q_linear = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
|
||||||
|
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||||
|
self.kv_linear = operations.Linear(encoder_hidden_states_dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, encoder_hidden_states: torch.Tensor, shape=None) -> torch.Tensor:
|
||||||
|
N_t, N_h, N_w = shape
|
||||||
|
|
||||||
|
expected_tokens = N_t * N_h * N_w
|
||||||
|
actual_tokens = x.shape[1]
|
||||||
|
x_extra = None
|
||||||
|
|
||||||
|
if actual_tokens != expected_tokens:
|
||||||
|
x_extra = x[:, -N_h * N_w:, :]
|
||||||
|
x = x[:, :-N_h * N_w, :]
|
||||||
|
N_t = N_t - 1
|
||||||
|
|
||||||
|
B = x.shape[0]
|
||||||
|
S = N_h * N_w
|
||||||
|
x = x.view(B * N_t, S, self.dim)
|
||||||
|
|
||||||
|
# get q for hidden_state
|
||||||
|
q = self.q_linear(x).view(B * N_t, S, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
# get kv from encoder_hidden_states # shape: (B, N, num_heads, head_dim)
|
||||||
|
kv = self.kv_linear(encoder_hidden_states)
|
||||||
|
encoder_k, encoder_v = kv.view(B * N_t, encoder_hidden_states.shape[1], 2, self.num_heads, self.head_dim).unbind(2)
|
||||||
|
|
||||||
|
#print("q.shape", q.shape) #torch.Size([21, 1024, 40, 128])
|
||||||
|
x = optimized_attention(
|
||||||
|
q.transpose(1, 2),
|
||||||
|
encoder_k.transpose(1, 2),
|
||||||
|
encoder_v.transpose(1, 2),
|
||||||
|
heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2)
|
||||||
|
|
||||||
|
# linear transform
|
||||||
|
x = self.proj(x.reshape(B * N_t, S, self.dim))
|
||||||
|
x = x.view(B, N_t * S, self.dim)
|
||||||
|
|
||||||
|
if x_extra is not None:
|
||||||
|
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class SingleStreamMultiAttention(SingleStreamAttention):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
encoder_hidden_states_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
qkv_bias: bool,
|
||||||
|
class_range: int = 24,
|
||||||
|
class_interval: int = 4,
|
||||||
|
device=None, dtype=None, operations=None
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
dim=dim,
|
||||||
|
encoder_hidden_states_dim=encoder_hidden_states_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rotary-embedding layout parameters
|
||||||
|
self.class_interval = class_interval
|
||||||
|
self.class_range = class_range
|
||||||
|
self.max_humans = self.class_range // self.class_interval
|
||||||
|
|
||||||
|
# Constant bucket used for background tokens
|
||||||
|
self.rope_bak = int(self.class_range // 2)
|
||||||
|
|
||||||
|
self.rope_1d = RotaryPositionalEmbedding1D(self.head_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
shape=None,
|
||||||
|
x_ref_attn_map=None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
encoder_hidden_states = encoder_hidden_states.squeeze(0).to(x.device)
|
||||||
|
human_num = x_ref_attn_map.shape[0] if x_ref_attn_map is not None else 1
|
||||||
|
# Single-speaker fall-through
|
||||||
|
if human_num <= 1:
|
||||||
|
return super().forward(x, encoder_hidden_states, shape)
|
||||||
|
|
||||||
|
N_t, N_h, N_w = shape
|
||||||
|
|
||||||
|
x_extra = None
|
||||||
|
if x.shape[0] * N_t != encoder_hidden_states.shape[0]:
|
||||||
|
x_extra = x[:, -N_h * N_w:, :]
|
||||||
|
x = x[:, :-N_h * N_w, :]
|
||||||
|
N_t = N_t - 1
|
||||||
|
x = rearrange(x, "B (N_t S) C -> (B N_t) S C", N_t=N_t)
|
||||||
|
|
||||||
|
# Query projection
|
||||||
|
B, N, C = x.shape
|
||||||
|
q = self.q_linear(x)
|
||||||
|
q = q.view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
# Use `class_range` logic for 2 speakers
|
||||||
|
rope_h1 = (0, self.class_interval)
|
||||||
|
rope_h2 = (self.class_range - self.class_interval, self.class_range)
|
||||||
|
rope_bak = int(self.class_range // 2)
|
||||||
|
|
||||||
|
# Normalize and scale attention maps for each speaker
|
||||||
|
max_values = x_ref_attn_map.max(1).values[:, None, None]
|
||||||
|
min_values = x_ref_attn_map.min(1).values[:, None, None]
|
||||||
|
max_min_values = torch.cat([max_values, min_values], dim=2)
|
||||||
|
|
||||||
|
human1_max_value, human1_min_value = max_min_values[0, :, 0].max(), max_min_values[0, :, 1].min()
|
||||||
|
human2_max_value, human2_min_value = max_min_values[1, :, 0].max(), max_min_values[1, :, 1].min()
|
||||||
|
|
||||||
|
human1 = normalize_and_scale(x_ref_attn_map[0], (human1_min_value, human1_max_value), rope_h1)
|
||||||
|
human2 = normalize_and_scale(x_ref_attn_map[1], (human2_min_value, human2_max_value), rope_h2)
|
||||||
|
back = torch.full((x_ref_attn_map.size(1),), rope_bak, dtype=human1.dtype, device=human1.device)
|
||||||
|
|
||||||
|
# Token-wise speaker dominance
|
||||||
|
max_indices = x_ref_attn_map.argmax(dim=0)
|
||||||
|
normalized_map = torch.stack([human1, human2, back], dim=1)
|
||||||
|
normalized_pos = normalized_map[torch.arange(x_ref_attn_map.size(1)), max_indices]
|
||||||
|
|
||||||
|
# Apply rotary to Q
|
||||||
|
q = rearrange(q, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
||||||
|
q = self.rope_1d(q, normalized_pos)
|
||||||
|
q = rearrange(q, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
||||||
|
|
||||||
|
# Keys / Values
|
||||||
|
_, N_a, _ = encoder_hidden_states.shape
|
||||||
|
encoder_kv = self.kv_linear(encoder_hidden_states)
|
||||||
|
encoder_kv = encoder_kv.view(B, N_a, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||||
|
encoder_k, encoder_v = encoder_kv.unbind(0)
|
||||||
|
|
||||||
|
# Rotary for keys – assign centre of each speaker bucket to its context tokens
|
||||||
|
per_frame = torch.zeros(N_a, dtype=encoder_k.dtype, device=encoder_k.device)
|
||||||
|
per_frame[: per_frame.size(0) // 2] = (rope_h1[0] + rope_h1[1]) / 2
|
||||||
|
per_frame[per_frame.size(0) // 2 :] = (rope_h2[0] + rope_h2[1]) / 2
|
||||||
|
encoder_pos = torch.cat([per_frame] * N_t, dim=0)
|
||||||
|
|
||||||
|
encoder_k = rearrange(encoder_k, "(B N_t) H S C -> B H (N_t S) C", N_t=N_t)
|
||||||
|
encoder_k = self.rope_1d(encoder_k, encoder_pos)
|
||||||
|
encoder_k = rearrange(encoder_k, "B H (N_t S) C -> (B N_t) H S C", N_t=N_t)
|
||||||
|
|
||||||
|
# Final attention
|
||||||
|
q = rearrange(q, "B H M K -> B M H K")
|
||||||
|
encoder_k = rearrange(encoder_k, "B H M K -> B M H K")
|
||||||
|
encoder_v = rearrange(encoder_v, "B H M K -> B M H K")
|
||||||
|
|
||||||
|
x = optimized_attention(
|
||||||
|
q.transpose(1, 2),
|
||||||
|
encoder_k.transpose(1, 2),
|
||||||
|
encoder_v.transpose(1, 2),
|
||||||
|
heads=self.num_heads, skip_reshape=True, skip_output_reshape=True).transpose(1, 2)
|
||||||
|
|
||||||
|
# Linear projection
|
||||||
|
x = x.reshape(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
|
||||||
|
# Restore original layout
|
||||||
|
x = rearrange(x, "(B N_t) S C -> B (N_t S) C", N_t=N_t)
|
||||||
|
if x_extra is not None:
|
||||||
|
x = torch.cat([x, torch.zeros_like(x_extra)], dim=1)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MultiTalkAudioProjModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
seq_len: int = 5,
|
||||||
|
seq_len_vf: int = 12,
|
||||||
|
blocks: int = 12,
|
||||||
|
channels: int = 768,
|
||||||
|
intermediate_dim: int = 512,
|
||||||
|
out_dim: int = 768,
|
||||||
|
context_tokens: int = 32,
|
||||||
|
device=None, dtype=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.seq_len = seq_len
|
||||||
|
self.blocks = blocks
|
||||||
|
self.channels = channels
|
||||||
|
self.input_dim = seq_len * blocks * channels
|
||||||
|
self.input_dim_vf = seq_len_vf * blocks * channels
|
||||||
|
self.intermediate_dim = intermediate_dim
|
||||||
|
self.context_tokens = context_tokens
|
||||||
|
self.out_dim = out_dim
|
||||||
|
|
||||||
|
# define multiple linear layers
|
||||||
|
self.proj1 = operations.Linear(self.input_dim, intermediate_dim, device=device, dtype=dtype)
|
||||||
|
self.proj1_vf = operations.Linear(self.input_dim_vf, intermediate_dim, device=device, dtype=dtype)
|
||||||
|
self.proj2 = operations.Linear(intermediate_dim, intermediate_dim, device=device, dtype=dtype)
|
||||||
|
self.proj3 = operations.Linear(intermediate_dim, context_tokens * out_dim, device=device, dtype=dtype)
|
||||||
|
self.norm = operations.LayerNorm(out_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, audio_embeds, audio_embeds_vf):
|
||||||
|
video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
|
||||||
|
B, _, _, S, C = audio_embeds.shape
|
||||||
|
|
||||||
|
# process audio of first frame
|
||||||
|
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
||||||
|
batch_size, window_size, blocks, channels = audio_embeds.shape
|
||||||
|
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
||||||
|
|
||||||
|
# process audio of latter frame
|
||||||
|
audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
|
||||||
|
batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
|
||||||
|
audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
|
||||||
|
|
||||||
|
# first projection
|
||||||
|
audio_embeds = torch.relu(self.proj1(audio_embeds))
|
||||||
|
audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
|
||||||
|
audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
|
||||||
|
audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
|
||||||
|
audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
|
||||||
|
batch_size_c, N_t, C_a = audio_embeds_c.shape
|
||||||
|
audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
|
||||||
|
|
||||||
|
# second projection
|
||||||
|
audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
|
||||||
|
|
||||||
|
context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.out_dim)
|
||||||
|
|
||||||
|
# normalization and reshape
|
||||||
|
context_tokens = self.norm(context_tokens)
|
||||||
|
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
|
||||||
|
|
||||||
|
return context_tokens
|
||||||
|
|
||||||
|
class WanMultiTalkAttentionBlock(torch.nn.Module):
|
||||||
|
def __init__(self, in_dim=5120, out_dim=768, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.audio_cross_attn = SingleStreamMultiAttention(in_dim, out_dim, num_heads=40, qkv_bias=True, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.norm_x = operations.LayerNorm(in_dim, device=device, dtype=dtype, elementwise_affine=True)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiTalkCrossAttnPatch:
|
||||||
|
def __init__(self, model_patch, audio_scale=1.0, ref_target_masks=None):
|
||||||
|
self.model_patch = model_patch
|
||||||
|
self.audio_scale = audio_scale
|
||||||
|
self.ref_target_masks = ref_target_masks
|
||||||
|
|
||||||
|
def __call__(self, kwargs):
|
||||||
|
x = kwargs["x"]
|
||||||
|
block_idx = kwargs.get("block_idx", 0)
|
||||||
|
if block_idx is None:
|
||||||
|
return torch.zeros_like(x)
|
||||||
|
|
||||||
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
|
audio_embeds = transformer_options.get("audio_embeds")
|
||||||
|
|
||||||
|
x_ref_attn_map = None
|
||||||
|
if self.ref_target_masks is not None:
|
||||||
|
x_ref_attn_map = get_attn_map_with_target(kwargs["q"], kwargs["k"], transformer_options["grid_sizes"], ref_target_masks=self.ref_target_masks.to(x.device))
|
||||||
|
norm_x = self.model_patch.model.blocks[block_idx].norm_x(x)
|
||||||
|
x_audio = self.model_patch.model.blocks[block_idx].audio_cross_attn(
|
||||||
|
norm_x, audio_embeds.to(x.dtype),
|
||||||
|
shape=transformer_options["grid_sizes"],
|
||||||
|
x_ref_attn_map=x_ref_attn_map
|
||||||
|
)
|
||||||
|
return x_audio * self.audio_scale
|
||||||
|
|
||||||
|
def models(self):
|
||||||
|
return [self.model_patch]
|
||||||
|
|
||||||
|
class MultiTalkApplyModelWrapper:
|
||||||
|
def __init__(self, init_latents):
|
||||||
|
self.init_latents = init_latents
|
||||||
|
|
||||||
|
def __call__(self, executor, x, *args, **kwargs):
|
||||||
|
x[:, :, :self.init_latents.shape[2]] = self.init_latents.to(x)
|
||||||
|
samples = executor(x, *args, **kwargs)
|
||||||
|
return samples
|
||||||
|
|
||||||
|
|
||||||
|
class InfiniteTalkOuterSampleLoopingWrapper:
|
||||||
|
def __init__(self, init_previous_frames, encoded_audio, model_patch, audio_scale, max_frames, frame_window_size, motion_frame_count=9, vae=None, ref_target_masks=None):
|
||||||
|
self.init_previous_frames = init_previous_frames
|
||||||
|
self.encoded_audio = encoded_audio
|
||||||
|
self.total_audio_frames = encoded_audio[0].shape[0]
|
||||||
|
self.max_frames = max_frames
|
||||||
|
self.frame_window_size = frame_window_size
|
||||||
|
self.latent_window_size = (frame_window_size - 1) // 4 + 1
|
||||||
|
self.model_patch = model_patch
|
||||||
|
self.audio_scale = audio_scale
|
||||||
|
self.motion_frame_count = motion_frame_count
|
||||||
|
self.vae = vae
|
||||||
|
self.ref_target_masks = ref_target_masks
|
||||||
|
|
||||||
|
def __call__(self, executor, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, **kwargs):
|
||||||
|
# init variables
|
||||||
|
previous_frames = motion_frames_latent = None
|
||||||
|
init_from_cond = False
|
||||||
|
frame_offset = audio_start = latent_frame_offset = latent_start_idx = 0
|
||||||
|
audio_end = self.frame_window_size
|
||||||
|
latent_end_idx = self.latent_window_size
|
||||||
|
decoded_results = []
|
||||||
|
|
||||||
|
model_patcher = executor.class_obj.model_patcher
|
||||||
|
model_options = executor.class_obj.model_options
|
||||||
|
process_latent_in = model_patcher.model.process_latent_in
|
||||||
|
dtype = model_patcher.model_dtype()
|
||||||
|
|
||||||
|
# when extending from previous frames
|
||||||
|
if self.init_previous_frames is not None:
|
||||||
|
previous_frames = self.init_previous_frames
|
||||||
|
if previous_frames.shape[0] < self.motion_frame_count:
|
||||||
|
previous_frames = torch.cat([previous_frames[:1].repeat(self.motion_frame_count - previous_frames.shape[0], 1, 1, 1), previous_frames], dim=0)
|
||||||
|
motion_frames = previous_frames[-self.motion_frame_count:]
|
||||||
|
frame_offset = previous_frames.shape[0] - self.motion_frame_count
|
||||||
|
|
||||||
|
# add/replace current cross-attention patch to model options
|
||||||
|
model_options["transformer_options"].setdefault("patches", {}).setdefault("cross_attn", []).append(
|
||||||
|
MultiTalkCrossAttnPatch(self.model_patch, self.audio_scale, ref_target_masks=self.ref_target_masks)
|
||||||
|
)
|
||||||
|
|
||||||
|
frames_needed = math.ceil(min(self.max_frames, self.total_audio_frames) / 81) * 81
|
||||||
|
estimated_iterations = frames_needed // (self.frame_window_size - self.motion_frame_count)
|
||||||
|
total_steps = (sigmas.shape[-1] - 1) * estimated_iterations
|
||||||
|
logging.info(f"InfiniteTalk estimated loop iterations: {estimated_iterations}, Total steps: {total_steps}")
|
||||||
|
|
||||||
|
# custom previewer callback for full loop progress bar
|
||||||
|
x0_output = {}
|
||||||
|
previewer = latent_preview.get_previewer(model_patcher.load_device, model_patcher.model.latent_format)
|
||||||
|
pbar = comfy.utils.ProgressBar(total_steps)
|
||||||
|
def custom_callback(step, x0, x, total_steps):
|
||||||
|
if x0_output is not None:
|
||||||
|
x0_output["x0"] = x0
|
||||||
|
|
||||||
|
preview_bytes = None
|
||||||
|
if previewer:
|
||||||
|
preview_bytes = previewer.decode_latent_to_preview_image("JPEG", x0)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# outer loop start for multiple frame windows
|
||||||
|
for i in range(estimated_iterations):
|
||||||
|
|
||||||
|
# first frame to InfinityTalk always has to be noise free encoded image
|
||||||
|
# if no previous samples provided, try to get I2V cond latent from positive cond
|
||||||
|
|
||||||
|
if previous_frames is None:
|
||||||
|
concat_latent_image = executor.class_obj.conds["positive"][0].get("concat_latent_image", None)
|
||||||
|
if concat_latent_image is not None:
|
||||||
|
motion_frames_latent = concat_latent_image[:, :, :1]
|
||||||
|
overlap = 1
|
||||||
|
init_from_cond = True
|
||||||
|
# else, use previous samples' last frames as first frame
|
||||||
|
else:
|
||||||
|
audio_start = frame_offset
|
||||||
|
audio_end = audio_start + self.frame_window_size
|
||||||
|
latent_start_idx = latent_frame_offset
|
||||||
|
latent_end_idx = latent_start_idx + self.latent_window_size
|
||||||
|
|
||||||
|
if len(motion_frames.shape) == 5:
|
||||||
|
motion_frames = motion_frames.squeeze(0)
|
||||||
|
spacial_compression = self.vae.spacial_compression_encode()
|
||||||
|
if (motion_frames.shape[-3], motion_frames.shape[-2]) != (noise.shape[-2] * spacial_compression, noise.shape[-1] * spacial_compression):
|
||||||
|
motion_frames = comfy.utils.common_upscale(
|
||||||
|
motion_frames.movedim(-1, 1),
|
||||||
|
noise.shape[-1] * spacial_compression, noise.shape[-2] * spacial_compression,
|
||||||
|
"bilinear", "center")
|
||||||
|
|
||||||
|
motion_frames_latent = self.vae.encode(motion_frames)
|
||||||
|
overlap = motion_frames_latent.shape[2]
|
||||||
|
|
||||||
|
audio_embed = project_audio_features(self.model_patch.model.audio_proj, self.encoded_audio, audio_start, audio_end).to(dtype)
|
||||||
|
model_options["transformer_options"]["audio_embeds"] = audio_embed
|
||||||
|
|
||||||
|
# model input first latents need to always be replaced on every step
|
||||||
|
if motion_frames_latent is not None:
|
||||||
|
wrappers = model_options["transformer_options"]["wrappers"]
|
||||||
|
w = wrappers.setdefault(comfy.patcher_extension.WrappersMP.APPLY_MODEL, {})
|
||||||
|
w["MultiTalk_apply_model"] = [MultiTalkApplyModelWrapper(process_latent_in(motion_frames_latent))]
|
||||||
|
|
||||||
|
# Slice possible encoded latent_image for vid2vid
|
||||||
|
if latent_image is not None and torch.count_nonzero(latent_image) > 0:
|
||||||
|
# Check if we have enough latents
|
||||||
|
if latent_end_idx > latent_image.shape[2]:
|
||||||
|
# This window needs more frames - pad the latent_image at the end
|
||||||
|
pad_length = latent_end_idx - latent_image.shape[2]
|
||||||
|
last_frame = latent_image[:, :, -1:].repeat(1, 1, pad_length, 1, 1)
|
||||||
|
latent_image = torch.cat([latent_image, last_frame], dim=2)
|
||||||
|
new_noise_frames = torch.randn_like(latent_image[:, :, -pad_length:], device=noise.device, dtype=noise.dtype)
|
||||||
|
noise = torch.cat([noise, new_noise_frames], dim=2)
|
||||||
|
noise = noise[:, :, latent_start_idx:latent_end_idx]
|
||||||
|
latent_image = latent_image[:, :, latent_start_idx:latent_end_idx]
|
||||||
|
if denoise_mask is not None: # todo: check if denoise mask needs adjustment for latent_image changes
|
||||||
|
print("Using denoise mask with shape", denoise_mask.shape)
|
||||||
|
|
||||||
|
# run the sampling process
|
||||||
|
result = executor(noise, latent_image, sampler, sigmas, denoise_mask=denoise_mask, callback=custom_callback, disable_pbar=False, seed=seed, **kwargs)
|
||||||
|
|
||||||
|
#insert motion frames before decoding
|
||||||
|
if previous_frames is not None and not init_from_cond:
|
||||||
|
result = torch.cat([motion_frames_latent.to(result), result[:, :, overlap:]], dim=2)
|
||||||
|
|
||||||
|
previous_frames = self.vae.decode(result)
|
||||||
|
motion_frames = previous_frames[:, -self.motion_frame_count:]
|
||||||
|
|
||||||
|
# Track frame progress
|
||||||
|
new_frame_count = previous_frames.shape[1] - self.motion_frame_count
|
||||||
|
frame_offset += new_frame_count
|
||||||
|
|
||||||
|
motion_latent_count = (self.motion_frame_count - 1) // 4 + 1 if self.motion_frame_count > 0 else 0
|
||||||
|
new_latent_count = self.latent_window_size - motion_latent_count
|
||||||
|
|
||||||
|
latent_frame_offset += new_latent_count
|
||||||
|
|
||||||
|
if init_from_cond:
|
||||||
|
decoded_results.append(previous_frames)
|
||||||
|
init_from_cond = False
|
||||||
|
else:
|
||||||
|
decoded_results.append(previous_frames[:, self.motion_frame_count:])
|
||||||
|
|
||||||
|
return torch.cat(decoded_results, dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def to(self, device_or_dtype):
|
||||||
|
if isinstance(device_or_dtype, torch.device):
|
||||||
|
if self.init_previous_frames is not None:
|
||||||
|
self.init_previous_frames = self.init_previous_frames.to(device_or_dtype)
|
||||||
|
if self.encoded_audio is not None:
|
||||||
|
self.encoded_audio = [ea.to(device_or_dtype) for ea in self.encoded_audio]
|
||||||
|
if self.ref_target_masks is not None:
|
||||||
|
self.ref_target_masks = self.ref_target_masks.to(device_or_dtype)
|
||||||
|
return self
|
||||||
@ -735,7 +735,7 @@ class AnyType(ComfyTypeIO):
|
|||||||
Type = Any
|
Type = Any
|
||||||
|
|
||||||
@comfytype(io_type="MODEL_PATCH")
|
@comfytype(io_type="MODEL_PATCH")
|
||||||
class MODEL_PATCH(ComfyTypeIO):
|
class ModelPatch(ComfyTypeIO):
|
||||||
Type = Any
|
Type = Any
|
||||||
|
|
||||||
@comfytype(io_type="AUDIO_ENCODER")
|
@comfytype(io_type="AUDIO_ENCODER")
|
||||||
@ -1603,6 +1603,7 @@ class _IO:
|
|||||||
ControlNet = ControlNet
|
ControlNet = ControlNet
|
||||||
Vae = Vae
|
Vae = Vae
|
||||||
Model = Model
|
Model = Model
|
||||||
|
ModelPatch = ModelPatch
|
||||||
ClipVision = ClipVision
|
ClipVision = ClipVision
|
||||||
ClipVisionOutput = ClipVisionOutput
|
ClipVisionOutput = ClipVisionOutput
|
||||||
AudioEncoder = AudioEncoder
|
AudioEncoder = AudioEncoder
|
||||||
|
|||||||
@ -844,6 +844,45 @@ class SamplerCustomAdvanced:
|
|||||||
out_denoised = out
|
out_denoised = out
|
||||||
return (out, out_denoised)
|
return (out, out_denoised)
|
||||||
|
|
||||||
|
|
||||||
|
class LoopingSamplerCustomAdvanced:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required":
|
||||||
|
{
|
||||||
|
"noise": ("NOISE", ),
|
||||||
|
"guider": ("GUIDER", ),
|
||||||
|
"sampler": ("SAMPLER", ),
|
||||||
|
"sigmas": ("SIGMAS", ),
|
||||||
|
"latent_image": ("LATENT", ),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
RETURN_NAMES = ("output",)
|
||||||
|
|
||||||
|
FUNCTION = "sample"
|
||||||
|
|
||||||
|
CATEGORY = "sampling/custom_sampling"
|
||||||
|
DESCRIPTION = "SamplerCustomAdvanced for models that alredy have decode latents in a loop generation such as InfiniteTalk"
|
||||||
|
|
||||||
|
def sample(self, noise, guider, sampler, sigmas, latent_image):
|
||||||
|
latent = latent_image
|
||||||
|
latent_image = latent["samples"]
|
||||||
|
latent = latent.copy()
|
||||||
|
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image)
|
||||||
|
latent["samples"] = latent_image
|
||||||
|
|
||||||
|
noise_mask = None
|
||||||
|
if "noise_mask" in latent:
|
||||||
|
noise_mask = latent["noise_mask"]
|
||||||
|
|
||||||
|
samples = guider.sample(noise.generate_noise(latent), latent_image, sampler, sigmas, denoise_mask=noise_mask, callback=None, disable_pbar=False, seed=noise.seed)
|
||||||
|
result = samples.to(comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
|
return (result[0].cpu().float(), )
|
||||||
|
|
||||||
|
|
||||||
class AddNoise:
|
class AddNoise:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
@ -925,6 +964,7 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"DisableNoise": DisableNoise,
|
"DisableNoise": DisableNoise,
|
||||||
"AddNoise": AddNoise,
|
"AddNoise": AddNoise,
|
||||||
"SamplerCustomAdvanced": SamplerCustomAdvanced,
|
"SamplerCustomAdvanced": SamplerCustomAdvanced,
|
||||||
|
"LoopingSamplerCustomAdvanced": LoopingSamplerCustomAdvanced,
|
||||||
}
|
}
|
||||||
|
|
||||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
|||||||
@ -211,6 +211,14 @@ class ModelPatchLoader:
|
|||||||
elif 'feature_embedder.mid_layer_norm.bias' in sd:
|
elif 'feature_embedder.mid_layer_norm.bias' in sd:
|
||||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
|
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
|
||||||
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
|
||||||
|
elif "audio_proj.proj1.weight" in sd:
|
||||||
|
model = MultiTalkModelPatch(
|
||||||
|
audio_window=5, context_tokens=32, vae_scale=4,
|
||||||
|
in_dim=sd["blocks.0.audio_cross_attn.proj.weight"].shape[0],
|
||||||
|
intermediate_dim=sd["audio_proj.proj1.weight"].shape[0],
|
||||||
|
out_dim=sd["audio_proj.norm.weight"].shape[0],
|
||||||
|
device=comfy.model_management.unet_offload_device(),
|
||||||
|
operations=comfy.ops.manual_cast)
|
||||||
|
|
||||||
model.load_state_dict(sd)
|
model.load_state_dict(sd)
|
||||||
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
|
||||||
@ -336,6 +344,40 @@ class USOStyleReference:
|
|||||||
return (model_patched,)
|
return (model_patched,)
|
||||||
|
|
||||||
|
|
||||||
|
from comfy.ldm.wan.model_multitalk import WanMultiTalkAttentionBlock, MultiTalkAudioProjModel
|
||||||
|
|
||||||
|
class MultiTalkModelPatch(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
audio_window: int = 5,
|
||||||
|
intermediate_dim: int = 512,
|
||||||
|
in_dim: int = 5120,
|
||||||
|
out_dim: int = 768,
|
||||||
|
context_tokens: int = 32,
|
||||||
|
vae_scale: int = 4,
|
||||||
|
num_layers: int = 40,
|
||||||
|
|
||||||
|
device=None, dtype=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.audio_proj = MultiTalkAudioProjModel(
|
||||||
|
seq_len=audio_window,
|
||||||
|
seq_len_vf=audio_window+vae_scale-1,
|
||||||
|
intermediate_dim=intermediate_dim,
|
||||||
|
out_dim=out_dim,
|
||||||
|
context_tokens=context_tokens,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
self.blocks = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
WanMultiTalkAttentionBlock(in_dim, out_dim, device=device, dtype=dtype, operations=operations)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelPatchLoader": ModelPatchLoader,
|
"ModelPatchLoader": ModelPatchLoader,
|
||||||
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
"QwenImageDiffsynthControlnet": QwenImageDiffsynthControlnet,
|
||||||
|
|||||||
@ -1288,6 +1288,143 @@ class Wan22ImageToVideoLatent(io.ComfyNode):
|
|||||||
return io.NodeOutput(out_latent)
|
return io.NodeOutput(out_latent)
|
||||||
|
|
||||||
|
|
||||||
|
from comfy.ldm.wan.model_multitalk import InfiniteTalkOuterSampleLoopingWrapper
|
||||||
|
class WanInfiniteTalkToVideo(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="WanInfiniteTalkToVideo",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.ModelPatch.Input("model_patch"),
|
||||||
|
io.Conditioning.Input("positive"),
|
||||||
|
io.Conditioning.Input("negative"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||||
|
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||||
|
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||||
|
io.Image.Input("start_image", optional=True),
|
||||||
|
io.AudioEncoderOutput.Input("audio_encoder_output_1"),
|
||||||
|
io.AudioEncoderOutput.Input("audio_encoder_output_2", optional=True),
|
||||||
|
io.Mask.Input("mask_1", optional=True, tooltip="Mask for the first speaker, required if using two audio inputs."),
|
||||||
|
io.Mask.Input("mask_2", optional=True, tooltip="Mask for the second speaker, required if using two audio inputs."),
|
||||||
|
io.Int.Input("frame_window_size", default=81, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of frames to generate in one window."),
|
||||||
|
io.Int.Input("motion_frame_count", default=9, min=1, max=33, step=1, tooltip="Number of previous frames to use as motion context."),
|
||||||
|
io.Float.Input("audio_scale", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||||
|
io.Image.Input("previous_frames", optional=True),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(display_name="model"),
|
||||||
|
io.Conditioning.Output(display_name="positive"),
|
||||||
|
io.Conditioning.Output(display_name="negative"),
|
||||||
|
io.Latent.Output(display_name="latent"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, model_patch, positive, negative, vae, width, height, length, audio_encoder_output_1, motion_frame_count, frame_window_size,
|
||||||
|
start_image=None, previous_frames=None, audio_scale=None, clip_vision_output=None, audio_encoder_output_2=None, mask_1=None, mask_2=None) -> io.NodeOutput:
|
||||||
|
|
||||||
|
if frame_window_size > length:
|
||||||
|
frame_window_size = length
|
||||||
|
if audio_encoder_output_2 is not None:
|
||||||
|
if mask_1 is None or mask_2 is None:
|
||||||
|
raise ValueError("Masks must be provided if two audio encoder outputs are used.")
|
||||||
|
|
||||||
|
ref_masks = None
|
||||||
|
if mask_1 is not None and mask_2 is not None:
|
||||||
|
if audio_encoder_output_2 is None:
|
||||||
|
raise ValueError("Second audio encoder output must be provided if two masks are used.")
|
||||||
|
ref_masks = torch.cat([mask_1, mask_2])
|
||||||
|
|
||||||
|
latent = torch.zeros([1, 16, ((frame_window_size - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||||
|
if start_image is not None:
|
||||||
|
start_image = comfy.utils.common_upscale(start_image[:frame_window_size].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||||
|
image = torch.ones((frame_window_size, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5
|
||||||
|
image[:start_image.shape[0]] = start_image
|
||||||
|
|
||||||
|
concat_latent_image = vae.encode(image[:, :, :, :3])
|
||||||
|
concat_mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
|
||||||
|
concat_mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
|
||||||
|
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": concat_mask})
|
||||||
|
|
||||||
|
if clip_vision_output is not None:
|
||||||
|
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||||
|
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||||
|
|
||||||
|
model_patched = model.clone()
|
||||||
|
|
||||||
|
encoded_audio_list = []
|
||||||
|
seq_lengths = []
|
||||||
|
|
||||||
|
for audio_encoder_output in [audio_encoder_output_1, audio_encoder_output_2]:
|
||||||
|
if audio_encoder_output is None:
|
||||||
|
continue
|
||||||
|
all_layers = audio_encoder_output["encoded_audio_all_layers"]
|
||||||
|
encoded_audio = torch.stack(all_layers, dim=0).squeeze(1)[1:] # shape: [num_layers, T, 512]
|
||||||
|
encoded_audio = linear_interpolation(encoded_audio, input_fps=50, output_fps=25).movedim(0, 1) # shape: [T, num_layers, 512]
|
||||||
|
encoded_audio_list.append(encoded_audio)
|
||||||
|
seq_lengths.append(encoded_audio.shape[0])
|
||||||
|
|
||||||
|
# Pad / combine depending on multi_audio_type
|
||||||
|
multi_audio_type = "add"
|
||||||
|
if len(encoded_audio_list) > 1:
|
||||||
|
if multi_audio_type == "para":
|
||||||
|
max_len = max(seq_lengths)
|
||||||
|
padded = []
|
||||||
|
for emb in encoded_audio_list:
|
||||||
|
if emb.shape[0] < max_len:
|
||||||
|
pad = torch.zeros(max_len - emb.shape[0], *emb.shape[1:], dtype=emb.dtype)
|
||||||
|
emb = torch.cat([emb, pad], dim=0)
|
||||||
|
padded.append(emb)
|
||||||
|
encoded_audio_list = padded
|
||||||
|
elif multi_audio_type == "add":
|
||||||
|
total_len = sum(seq_lengths)
|
||||||
|
full_list = []
|
||||||
|
offset = 0
|
||||||
|
for emb, seq_len in zip(encoded_audio_list, seq_lengths):
|
||||||
|
full = torch.zeros(total_len, *emb.shape[1:], dtype=emb.dtype)
|
||||||
|
full[offset:offset+seq_len] = emb
|
||||||
|
full_list.append(full)
|
||||||
|
offset += seq_len
|
||||||
|
encoded_audio_list = full_list
|
||||||
|
|
||||||
|
token_ref_target_masks = None
|
||||||
|
if ref_masks is not None:
|
||||||
|
token_ref_target_masks = torch.nn.functional.interpolate(
|
||||||
|
ref_masks.unsqueeze(0), size=(latent.shape[-2] // 2, latent.shape[-1] // 2), mode='nearest')[0]
|
||||||
|
token_ref_target_masks = (token_ref_target_masks > 0).view(token_ref_target_masks.shape[0], -1)
|
||||||
|
|
||||||
|
|
||||||
|
init_previous_frames = None
|
||||||
|
if previous_frames is not None:
|
||||||
|
init_previous_frames = previous_frames[:, :, :, :3]
|
||||||
|
|
||||||
|
|
||||||
|
model_patched.add_wrapper_with_key(
|
||||||
|
comfy.patcher_extension.WrappersMP.OUTER_SAMPLE,
|
||||||
|
"infinite_talk_outer_sample",
|
||||||
|
InfiniteTalkOuterSampleLoopingWrapper(
|
||||||
|
init_previous_frames,
|
||||||
|
encoded_audio_list,
|
||||||
|
model_patch,
|
||||||
|
audio_scale,
|
||||||
|
length,
|
||||||
|
frame_window_size,
|
||||||
|
motion_frame_count,
|
||||||
|
vae=vae,
|
||||||
|
ref_target_masks=token_ref_target_masks)
|
||||||
|
)
|
||||||
|
|
||||||
|
out_latent = {}
|
||||||
|
out_latent["samples"] = latent
|
||||||
|
return io.NodeOutput(model_patched, positive, negative, out_latent)
|
||||||
|
|
||||||
|
|
||||||
class WanExtension(ComfyExtension):
|
class WanExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
@ -1307,6 +1444,7 @@ class WanExtension(ComfyExtension):
|
|||||||
WanHuMoImageToVideo,
|
WanHuMoImageToVideo,
|
||||||
WanAnimateToVideo,
|
WanAnimateToVideo,
|
||||||
Wan22ImageToVideoLatent,
|
Wan22ImageToVideoLatent,
|
||||||
|
WanInfiniteTalkToVideo,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def comfy_entrypoint() -> WanExtension:
|
async def comfy_entrypoint() -> WanExtension:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user