Merge branch 'master' into yousef-higgsv2

This commit is contained in:
Yousef R. Gamaleldin 2025-09-18 01:29:53 +03:00 committed by GitHub
commit acdb10a092
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 644 additions and 60 deletions

View File

@ -1,4 +1,5 @@
from .wav2vec2 import Wav2Vec2Model
from .whisper import WhisperLargeV3
import comfy.model_management
import comfy.ops
import comfy.utils
@ -11,13 +12,18 @@ class AudioEncoderModel():
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
model_type = config.pop("model_type")
model_config = dict(config)
model_config.update({
"dtype": self.dtype,
"device": offload_device,
"operations": comfy.ops.manual_cast
})
self.model = Wav2Vec2Model(**model_config)
if model_type == "wav2vec2":
self.model = Wav2Vec2Model(**model_config)
elif model_type == "whisper3":
self.model = WhisperLargeV3(**model_config)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000
@ -35,38 +41,51 @@ class AudioEncoderModel():
outputs = {}
outputs["encoded_audio"] = out
outputs["encoded_audio_all_layers"] = all_layers
outputs["audio_samples"] = audio.shape[2]
return outputs
def load_audio_encoder_from_sd(sd, prefix=""):
sd = comfy.utils.state_dict_prefix_replace(sd, {"wav2vec2.": ""})
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
if embed_dim == 1024:# large
config = {
"embed_dim": 1024,
"num_heads": 16,
"num_layers": 24,
"conv_norm": True,
"conv_bias": True,
"do_normalize": True,
"do_stable_layer_norm": True
if "encoder.layer_norm.bias" in sd: #wav2vec2
embed_dim = sd["encoder.layer_norm.bias"].shape[0]
if embed_dim == 1024:# large
config = {
"model_type": "wav2vec2",
"embed_dim": 1024,
"num_heads": 16,
"num_layers": 24,
"conv_norm": True,
"conv_bias": True,
"do_normalize": True,
"do_stable_layer_norm": True
}
elif embed_dim == 768: # base
config = {
"model_type": "wav2vec2",
"embed_dim": 768,
"num_heads": 12,
"num_layers": 12,
"conv_norm": False,
"conv_bias": False,
"do_normalize": False, # chinese-wav2vec2-base has this False
"do_stable_layer_norm": False
}
elif embed_dim == 768: # base
else:
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
elif "model.encoder.embed_positions.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"model.": ""})
config = {
"embed_dim": 768,
"num_heads": 12,
"num_layers": 12,
"conv_norm": False,
"conv_bias": False,
"do_normalize": False, # chinese-wav2vec2-base has this False
"do_stable_layer_norm": False
"model_type": "whisper3",
}
else:
raise RuntimeError("ERROR: audio encoder file is invalid or unsupported embed_dim: {}".format(embed_dim))
raise RuntimeError("ERROR: audio encoder not supported.")
audio_encoder = AudioEncoderModel(config)
m, u = audio_encoder.load_sd(sd)
if len(m) > 0:
logging.warning("missing audio encoder: {}".format(m))
if len(u) > 0:
logging.warning("unexpected audio encoder: {}".format(u))
return audio_encoder

186
comfy/audio_encoders/whisper.py Executable file
View File

@ -0,0 +1,186 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from typing import Optional
from comfy.ldm.modules.attention import optimized_attention_masked
import comfy.ops
class WhisperFeatureExtractor(nn.Module):
def __init__(self, n_mels=128, device=None):
super().__init__()
self.sample_rate = 16000
self.n_fft = 400
self.hop_length = 160
self.n_mels = n_mels
self.chunk_length = 30
self.n_samples = 480000
self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=self.sample_rate,
n_fft=self.n_fft,
hop_length=self.hop_length,
n_mels=self.n_mels,
f_min=0,
f_max=8000,
norm="slaney",
mel_scale="slaney",
).to(device)
def __call__(self, audio):
audio = torch.mean(audio, dim=1)
batch_size = audio.shape[0]
processed_audio = []
for i in range(batch_size):
aud = audio[i]
if aud.shape[0] > self.n_samples:
aud = aud[:self.n_samples]
elif aud.shape[0] < self.n_samples:
aud = F.pad(aud, (0, self.n_samples - aud.shape[0]))
processed_audio.append(aud)
audio = torch.stack(processed_audio)
mel_spec = self.mel_spectrogram(audio.to(self.mel_spectrogram.spectrogram.window.device))[:, :, :-1].to(audio.device)
log_mel_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_mel_spec = torch.maximum(log_mel_spec, log_mel_spec.max() - 8.0)
log_mel_spec = (log_mel_spec + 4.0) / 4.0
return log_mel_spec
class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, dtype=None, device=None, operations=None):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.q_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
self.k_proj = operations.Linear(d_model, d_model, bias=False, dtype=dtype, device=device)
self.v_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, seq_len, _ = query.shape
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
attn_output = optimized_attention_masked(q, k, v, self.n_heads, mask)
attn_output = self.out_proj(attn_output)
return attn_output
class EncoderLayer(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int, dtype=None, device=None, operations=None):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads, dtype=dtype, device=device, operations=operations)
self.self_attn_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
self.fc1 = operations.Linear(d_model, d_ff, dtype=dtype, device=device)
self.fc2 = operations.Linear(d_ff, d_model, dtype=dtype, device=device)
self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
residual = x
x = self.self_attn_layer_norm(x)
x = self.self_attn(x, x, x, attention_mask)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = self.fc1(x)
x = F.gelu(x)
x = self.fc2(x)
x = residual + x
return x
class AudioEncoder(nn.Module):
def __init__(
self,
n_mels: int = 128,
n_ctx: int = 1500,
n_state: int = 1280,
n_head: int = 20,
n_layer: int = 32,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.conv1 = operations.Conv1d(n_mels, n_state, kernel_size=3, padding=1, dtype=dtype, device=device)
self.conv2 = operations.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1, dtype=dtype, device=device)
self.embed_positions = operations.Embedding(n_ctx, n_state, dtype=dtype, device=device)
self.layers = nn.ModuleList([
EncoderLayer(n_state, n_head, n_state * 4, dtype=dtype, device=device, operations=operations)
for _ in range(n_layer)
])
self.layer_norm = operations.LayerNorm(n_state, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.transpose(1, 2)
x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)
all_x = ()
for layer in self.layers:
all_x += (x,)
x = layer(x)
x = self.layer_norm(x)
all_x += (x,)
return x, all_x
class WhisperLargeV3(nn.Module):
def __init__(
self,
n_mels: int = 128,
n_audio_ctx: int = 1500,
n_audio_state: int = 1280,
n_audio_head: int = 20,
n_audio_layer: int = 32,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.feature_extractor = WhisperFeatureExtractor(n_mels=n_mels, device=device)
self.encoder = AudioEncoder(
n_mels, n_audio_ctx, n_audio_state, n_audio_head, n_audio_layer,
dtype=dtype, device=device, operations=operations
)
def forward(self, audio):
mel = self.feature_extractor(audio)
x, all_x = self.encoder(mel)
return x, all_x

View File

@ -86,24 +86,24 @@ class BatchedBrownianTree:
"""A wrapper around torchsde.BrownianTree that enables batches of entropy."""
def __init__(self, x, t0, t1, seed=None, **kwargs):
self.cpu_tree = True
if "cpu" in kwargs:
self.cpu_tree = kwargs.pop("cpu")
self.cpu_tree = kwargs.pop("cpu", True)
t0, t1, self.sign = self.sort(t0, t1)
w0 = kwargs.get('w0', torch.zeros_like(x))
w0 = kwargs.pop('w0', None)
if w0 is None:
w0 = torch.zeros_like(x)
self.batched = False
if seed is None:
seed = torch.randint(0, 2 ** 63 - 1, []).item()
self.batched = True
try:
assert len(seed) == x.shape[0]
seed = (torch.randint(0, 2 ** 63 - 1, ()).item(),)
elif isinstance(seed, (tuple, list)):
if len(seed) != x.shape[0]:
raise ValueError("Passing a list or tuple of seeds to BatchedBrownianTree requires a length matching the batch size.")
self.batched = True
w0 = w0[0]
except TypeError:
seed = [seed]
self.batched = False
if self.cpu_tree:
self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
else:
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
seed = (seed,)
if self.cpu_tree:
t0, w0, t1 = t0.detach().cpu(), w0.detach().cpu(), t1.detach().cpu()
self.trees = tuple(torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed)
@staticmethod
def sort(a, b):
@ -111,11 +111,10 @@ class BatchedBrownianTree:
def __call__(self, t0, t1):
t0, t1, sign = self.sort(t0, t1)
device, dtype = t0.device, t0.dtype
if self.cpu_tree:
w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
else:
w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
t0, t1 = t0.detach().cpu().float(), t1.detach().cpu().float()
w = torch.stack([tree(t0, t1) for tree in self.trees]).to(device=device, dtype=dtype) * (self.sign * sign)
return w if self.batched else w[0]

View File

@ -35,11 +35,10 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.to(dtype=torch.float32, device=pos.device)
def apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
x_out = freqs_cis[..., 0] * x_[..., 0] + freqs_cis[..., 1] * x_[..., 1]
return x_out.reshape(*x.shape).type_as(x)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)

View File

@ -600,7 +600,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
mask = mask.unsqueeze(1)
try:
assert mask is None
if mask is not None:
raise RuntimeError("Mask must not be set for Flash attention")
out = flash_attn_wrapper(
q.transpose(1, 2),
k.transpose(1, 2),

View File

@ -8,7 +8,7 @@ from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.flux.math import apply_rope1
import comfy.ldm.common_dit
import comfy.model_management
import comfy.patcher_extension
@ -34,7 +34,9 @@ class WanSelfAttention(nn.Module):
num_heads,
window_size=(-1, -1),
qk_norm=True,
eps=1e-6, operation_settings={}):
eps=1e-6,
kv_dim=None,
operation_settings={}):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
@ -43,11 +45,13 @@ class WanSelfAttention(nn.Module):
self.window_size = window_size
self.qk_norm = qk_norm
self.eps = eps
if kv_dim is None:
kv_dim = dim
# layers
self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.k = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.v = operation_settings.get("operations").Linear(kv_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
@ -60,20 +64,24 @@ class WanSelfAttention(nn.Module):
"""
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
# query, key, value function
def qkv_fn(x):
def qkv_fn_q(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n * d)
return q, k, v
return apply_rope1(q, freqs)
q, k, v = qkv_fn(x)
q, k = apply_rope(q, k, freqs)
def qkv_fn_k(x):
k = self.norm_k(self.k(x)).view(b, s, n, d)
return apply_rope1(k, freqs)
#These two are VRAM hogs, so we want to do all of q computation and
#have pytorch garbage collect the intermediates on the sub function
#return before we touch k
q = qkv_fn_q(x)
k = qkv_fn_k(x)
x = optimized_attention(
q.view(b, s, n * d),
k.view(b, s, n * d),
v,
self.v(x).view(b, s, n * d),
heads=self.num_heads,
transformer_options=transformer_options,
)
@ -398,6 +406,7 @@ class WanModel(torch.nn.Module):
eps=1e-6,
flf_pos_embed_token_number=None,
in_dim_ref_conv=None,
wan_attn_block_class=WanAttentionBlock,
image_model=None,
device=None,
dtype=None,
@ -475,8 +484,8 @@ class WanModel(torch.nn.Module):
# blocks
cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
self.blocks = nn.ModuleList([
WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
wan_attn_block_class(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
for _ in range(num_layers)
])
@ -1321,3 +1330,247 @@ class WanModel_S2V(WanModel):
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
class WanT2VCrossAttentionGather(WanSelfAttention):
def forward(self, x, context, transformer_options={}, **kwargs):
r"""
Args:
x(Tensor): Shape [B, L1, C] - video tokens
context(Tensor): Shape [B, L2, C] - audio tokens with shape [B, frames*16, 1536]
"""
b, n, d = x.size(0), self.num_heads, self.head_dim
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(context))
v = self.v(context)
# Handle audio temporal structure (16 tokens per frame)
k = k.reshape(-1, 16, n, d).transpose(1, 2)
v = v.reshape(-1, 16, n, d).transpose(1, 2)
# Handle video spatial structure
q = q.reshape(k.shape[0], -1, n, d).transpose(1, 2)
x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
x = x.transpose(1, 2).view(b, -1, n, d).flatten(2)
x = self.o(x)
return x
class AudioCrossAttentionWrapper(nn.Module):
def __init__(self, dim, kv_dim, num_heads, qk_norm=True, eps=1e-6, operation_settings={}):
super().__init__()
self.audio_cross_attn = WanT2VCrossAttentionGather(dim, num_heads, qk_norm, kv_dim, eps, operation_settings=operation_settings)
self.norm1_audio = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, x, audio, transformer_options={}):
x = x + self.audio_cross_attn(self.norm1_audio(x), audio, transformer_options=transformer_options)
return x
class WanAttentionBlockAudio(WanAttentionBlock):
def __init__(self,
cross_attn_type,
dim,
ffn_dim,
num_heads,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=False,
eps=1e-6, operation_settings={}):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings)
self.audio_cross_attn_wrapper = AudioCrossAttentionWrapper(dim, 1536, num_heads, qk_norm, eps, operation_settings=operation_settings)
def forward(
self,
x,
e,
freqs,
context,
context_img_len=257,
audio=None,
transformer_options={},
):
r"""
Args:
x(Tensor): Shape [B, L, C]
e(Tensor): Shape [B, 6, C]
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
# assert e.dtype == torch.float32
if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
# assert e[0].dtype == torch.float32
# self-attention
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x))
# cross-attention & ffn
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
if audio is not None:
x = self.audio_cross_attn_wrapper(x, audio, transformer_options=transformer_options)
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
class DummyAdapterLayer(nn.Module):
def __init__(self, layer):
super().__init__()
self.layer = layer
def forward(self, *args, **kwargs):
return self.layer(*args, **kwargs)
class AudioProjModel(nn.Module):
def __init__(
self,
seq_len=5,
blocks=13, # add a new parameter blocks
channels=768, # add a new parameter channels
intermediate_dim=512,
output_dim=1536,
context_tokens=16,
device=None,
dtype=None,
operations=None,
):
super().__init__()
self.seq_len = seq_len
self.blocks = blocks
self.channels = channels
self.input_dim = seq_len * blocks * channels # update input_dim to be the product of blocks and channels.
self.intermediate_dim = intermediate_dim
self.context_tokens = context_tokens
self.output_dim = output_dim
# define multiple linear layers
self.audio_proj_glob_1 = DummyAdapterLayer(operations.Linear(self.input_dim, intermediate_dim, dtype=dtype, device=device))
self.audio_proj_glob_2 = DummyAdapterLayer(operations.Linear(intermediate_dim, intermediate_dim, dtype=dtype, device=device))
self.audio_proj_glob_3 = DummyAdapterLayer(operations.Linear(intermediate_dim, context_tokens * output_dim, dtype=dtype, device=device))
self.audio_proj_glob_norm = DummyAdapterLayer(operations.LayerNorm(output_dim, dtype=dtype, device=device))
def forward(self, audio_embeds):
video_length = audio_embeds.shape[1]
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
batch_size, window_size, blocks, channels = audio_embeds.shape
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
audio_embeds = torch.relu(self.audio_proj_glob_1(audio_embeds))
audio_embeds = torch.relu(self.audio_proj_glob_2(audio_embeds))
context_tokens = self.audio_proj_glob_3(audio_embeds).reshape(batch_size, self.context_tokens, self.output_dim)
context_tokens = self.audio_proj_glob_norm(context_tokens)
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
return context_tokens
class HumoWanModel(WanModel):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
def __init__(self,
model_type='humo',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
flf_pos_embed_token_number=None,
image_model=None,
audio_token_num=16,
device=None,
dtype=None,
operations=None,
):
super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, wan_attn_block_class=WanAttentionBlockAudio, image_model=image_model, device=device, dtype=dtype, operations=operations)
self.audio_proj = AudioProjModel(seq_len=8, blocks=5, channels=1280, intermediate_dim=512, output_dim=1536, context_tokens=audio_token_num, dtype=dtype, device=device, operations=operations)
def forward_orig(
self,
x,
t,
context,
freqs=None,
audio_embed=None,
reference_latent=None,
transformer_options={},
**kwargs,
):
bs, _, time, height, width = x.shape
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t.flatten()).to(dtype=x[0].dtype))
e = e.reshape(t.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
if reference_latent is not None:
ref = self.patch_embedding(reference_latent.float()).to(x.dtype)
ref = ref.flatten(2).transpose(1, 2)
freqs_ref = self.rope_encode(reference_latent.shape[-3], reference_latent.shape[-2], reference_latent.shape[-1], t_start=time, device=x.device, dtype=x.dtype)
x = torch.cat([x, ref], dim=1)
freqs = torch.cat([freqs, freqs_ref], dim=1)
del ref, freqs_ref
# context
context = self.text_embedding(context)
context_img_len = None
if audio_embed is not None:
audio = self.audio_proj(audio_embed).permute(0, 3, 1, 2).flatten(2).transpose(1, 2)
else:
audio = None
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, audio=audio, transformer_options=args["transformer_options"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len, audio=audio, transformer_options=transformer_options)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x

View File

@ -1214,6 +1214,23 @@ class WAN21_Camera(WAN21):
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
return out
class WAN21_HuMo(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.HumoWanModel)
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
audio_embed = kwargs.get("audio_embed", None)
if audio_embed is not None:
out['audio_embed'] = comfy.conds.CONDRegular(audio_embed)
reference_latents = kwargs.get("reference_latents", None)
if reference_latents is not None:
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1]))
return out
class WAN22_S2V(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)

View File

@ -402,6 +402,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "camera_2.2"
elif '{}casual_audio_encoder.encoder.final_linear.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "s2v"
elif '{}audio_proj.audio_proj_glob_1.layer.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "humo"
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"

View File

@ -1074,6 +1074,16 @@ class WAN21_Vace(WAN21_T2V):
out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
return out
class WAN21_HuMo(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "humo",
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_HuMo(self, image_to_video=False, device=device)
return out
class WAN22_S2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@ -1368,6 +1378,6 @@ class HunyuanImage21Refiner(HunyuanVideo):
out = model_base.HunyuanImage21Refiner(self, device=device)
return out
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Higgsv2]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Higgsv2]
models += [SVD_img2vid]

View File

@ -1015,6 +1015,103 @@ class WanSoundImageToVideoExtend(io.ComfyNode):
return io.NodeOutput(positive, negative, out_latent)
def get_audio_emb_window(audio_emb, frame_num, frame0_idx, audio_shift=2):
zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device
iter_ = 1 + (frame_num - 1) // 4
audio_emb_wind = []
for lt_i in range(iter_):
if lt_i == 0:
st = frame0_idx + lt_i - 2
ed = frame0_idx + lt_i + 3
wind_feat = torch.stack([
audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
for i in range(st, ed)
], dim=0)
wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0)
else:
st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
ed = frame0_idx + 1 + 4 * lt_i + audio_shift
wind_feat = torch.stack([
audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
for i in range(st, ed)
], dim=0)
audio_emb_wind.append(wind_feat)
audio_emb_wind = torch.stack(audio_emb_wind, dim=0)
return audio_emb_wind, ed - audio_shift
class WanHuMoImageToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanHuMoImageToVideo",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("length", default=97, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.AudioEncoderOutput.Input("audio_encoder_output", optional=True),
io.Image.Input("ref_image", optional=True),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
is_experimental=True,
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, ref_image=None, audio_encoder_output=None) -> io.NodeOutput:
latent_t = ((length - 1) // 4) + 1
latent = torch.zeros([batch_size, 16, latent_t, height // 8, width // 8], device=comfy.model_management.intermediate_device())
if ref_image is not None:
ref_image = comfy.utils.common_upscale(ref_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
ref_latent = vae.encode(ref_image[:, :, :, :3])
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True)
else:
zero_latent = torch.zeros([batch_size, 16, 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [zero_latent]}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [zero_latent]}, append=True)
if audio_encoder_output is not None:
audio_emb = torch.stack(audio_encoder_output["encoded_audio_all_layers"], dim=2)
audio_len = audio_encoder_output["audio_samples"] // 640
audio_emb = audio_emb[:, :audio_len * 2]
feat0 = linear_interpolation(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
feat1 = linear_interpolation(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
feat2 = linear_interpolation(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
feat3 = linear_interpolation(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
feat4 = linear_interpolation(audio_emb[:, :, 32], 50, 25)
audio_emb = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280]
audio_emb, _ = get_audio_emb_window(audio_emb, length, frame0_idx=0)
# pad for ref latent
zero_audio_pad = torch.zeros(ref_latent.shape[2], *audio_emb.shape[1:], device=audio_emb.device, dtype=audio_emb.dtype)
audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0)
audio_emb = audio_emb.unsqueeze(0)
audio_emb_neg = torch.zeros_like(audio_emb)
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": audio_emb})
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": audio_emb_neg})
else:
zero_audio = torch.zeros([batch_size, latent_t + 1, 8, 5, 1280], device=comfy.model_management.intermediate_device())
positive = node_helpers.conditioning_set_values(positive, {"audio_embed": zero_audio})
negative = node_helpers.conditioning_set_values(negative, {"audio_embed": zero_audio})
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class Wan22ImageToVideoLatent(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -1075,6 +1172,7 @@ class WanExtension(ComfyExtension):
WanPhantomSubjectToVideo,
WanSoundImageToVideo,
WanSoundImageToVideoExtend,
WanHuMoImageToVideo,
Wan22ImageToVideoLatent,
]