From e684ff2505df66e11ac7891594998fbc6080dd9e Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 10 Oct 2025 19:11:50 +0300 Subject: [PATCH] a lot of fixes + siglip2_base support --- comfy/clip_model.py | 35 +++++++++++++++++++-- comfy/clip_vision_siglip2_base_512.json | 3 +- comfy/ldm/hunyuan_foley/model.py | 41 +++++++++++++++---------- comfy/ldm/hunyuan_foley/vae.py | 37 ++++++++++++---------- comfy_extras/nodes_audio.py | 3 +- comfy_extras/nodes_hunyuan_foley.py | 2 +- comfy_extras/nodes_video.py | 9 +++--- 7 files changed, 88 insertions(+), 42 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 7c0cadab5..63d80d17f 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -1,7 +1,28 @@ import torch -from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.ldm.modules.attention import optimized_attention_for_device, MultiheadAttentionComfyv import comfy.ops +class SiglipMultiheadAttentionPoolingHead(torch.nn.Module): + def __init__(self, hidden_size, num_attention_heads, layer_norm_eps, intermediate_size, activation, device=None, dtype=None, operations=None): + super().__init__() + + self.probe = torch.nn.Parameter(torch.randn(1, 1, hidden_size, device=device, dtype=dtype)) + self.attention = MultiheadAttentionComfyv(hidden_size, num_attention_heads, batch_first=True, device=device, dtype=dtype, operations=operations) + self.layernorm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) + self.mlp = CLIPMLP(hidden_size, intermediate_size, activation = activation, device=device, dtype=dtype, operations=operations) + + def forward(self, hidden_state): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention(probe, hidden_state, hidden_state)[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + class CLIPAttention(torch.nn.Module): def __init__(self, embed_dim, heads, dtype, device, operations): super().__init__() @@ -198,6 +219,8 @@ class CLIPVision(torch.nn.Module): intermediate_size = config_dict["intermediate_size"] intermediate_activation = config_dict["hidden_act"] model_type = config_dict["model_type"] + use_head = config_dict.get("use_head", False) + layer_norm_eps = config_dict.get("layer_norm_eps", 1e-6) self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations) if model_type == "siglip_vision_model": @@ -208,6 +231,11 @@ class CLIPVision(torch.nn.Module): self.output_layernorm = False self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.post_layernorm = operations.LayerNorm(embed_dim) + self.use_head = use_head + if use_head: + self.head = SiglipMultiheadAttentionPoolingHead( + hidden_size=embed_dim, num_attention_heads=heads, layer_norm_eps=layer_norm_eps, intermediate_size=intermediate_size, activation=intermediate_activation, device=device, dtype=dtype, operations=operations + ) def forward(self, pixel_values, attention_mask=None, intermediate_output=None): x = self.embeddings(pixel_values) @@ -216,7 +244,10 @@ class CLIPVision(torch.nn.Module): x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output) if self.output_layernorm: x = self.post_layernorm(x) - pooled_output = x + if self.use_head: + pooled_output = self.head(x) + else: + pooled_output = x else: pooled_output = self.post_layernorm(x[:, 0, :]) return x, i, pooled_output diff --git a/comfy/clip_vision_siglip2_base_512.json b/comfy/clip_vision_siglip2_base_512.json index 4324857e4..f67598ca7 100644 --- a/comfy/clip_vision_siglip2_base_512.json +++ b/comfy/clip_vision_siglip2_base_512.json @@ -10,5 +10,6 @@ "num_hidden_layers": 12, "patch_size": 16, "image_mean": [0.5, 0.5, 0.5], - "image_std": [0.5, 0.5, 0.5] + "image_std": [0.5, 0.5, 0.5], + "use_head": true } diff --git a/comfy/ldm/hunyuan_foley/model.py b/comfy/ldm/hunyuan_foley/model.py index 5f85c99f4..90ef95179 100644 --- a/comfy/ldm/hunyuan_foley/model.py +++ b/comfy/ldm/hunyuan_foley/model.py @@ -780,6 +780,8 @@ class HunyuanVideoFoley(nn.Module): self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels, **factory_kwargs), requires_grad = False) self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim, **factory_kwargs), requires_grad = False) + self.conditions = None + def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor: len = len if len is not None else self.clip_len if bs is None: @@ -858,28 +860,33 @@ class HunyuanVideoFoley(nn.Module): bs, _, ol = x.shape tl = ol // self.patch_size - uncondition, condition = torch.chunk(context, 2) + if self.conditions is None: - condition = condition.view(3, context.size(1) // 3, -1) - uncondition = uncondition.view(3, context.size(1) // 3, -1) + uncondition, condition = torch.chunk(context, 2) - uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3) - clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3) - cond_neg, clip_feat, sync_feat, cond_pos = [trim_repeats(t) for t in (cond_neg, clip_feat, sync_feat, cond_pos)] + condition = condition.view(3, context.size(1) // 3, -1) + uncondition = uncondition.view(3, context.size(1) // 3, -1) - uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)] - uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)] - - uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [unlock_cpu_tensor(t, device) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)] + uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3) + clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3) + cond_neg, clip_feat, sync_feat, cond_pos = [trim_repeats(t) for t in (cond_neg, clip_feat, sync_feat, cond_pos)] - diff = cond_pos.shape[1] - cond_neg.shape[1] - if cond_neg.shape[1] < cond_pos.shape[1]: - cond_neg = torch.nn.functional.pad(cond_neg, (0, 0, 0, diff)) - elif diff < 0: - cond_pos = torch.nn.functional.pad(cond_pos, (0, 0, 0, torch.abs(diff))) + uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)] + uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)] + + uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [unlock_cpu_tensor(t, device) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)] - clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos]) - clip_feat = clip_feat.view(2, -1, 768) + diff = cond_pos.shape[1] - cond_neg.shape[1] + if cond_neg.shape[1] < cond_pos.shape[1]: + cond_neg = torch.nn.functional.pad(cond_neg, (0, 0, 0, diff)) + elif diff < 0: + cond_pos = torch.nn.functional.pad(cond_pos, (0, 0, 0, torch.abs(diff))) + + clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos]) + clip_feat = clip_feat.view(2, -1, 768) + + else: + clip_feat, sync_feat, cond = self.conditions if drop_visual is not None: clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype) diff --git a/comfy/ldm/hunyuan_foley/vae.py b/comfy/ldm/hunyuan_foley/vae.py index 58c9bfdb4..68635bec6 100644 --- a/comfy/ldm/hunyuan_foley/vae.py +++ b/comfy/ldm/hunyuan_foley/vae.py @@ -1,6 +1,5 @@ import math import torch -import numpy as np from typing import List import torch.nn as nn from einops import rearrange @@ -88,15 +87,17 @@ class DACEncoder(nn.Module): device = None, dtype = None, operations = None ): super().__init__() - # Create first convolution self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations)] - # Create EncoderBlocks that double channels as they downsample by `stride` for stride in strides: d_model *= 2 self.block += [DACEncoderBlock(d_model, stride=stride, device = device, dtype = dtype, operations = operations)] - # Wrap black into nn.Sequential + self.block += [ + Snake1d(d_model, device=device, dtype=dtype), + WNConv1d(d_model, d_latent, kernel_size=3, padding=1, device=device, dtype=dtype, operations = operations), + ] + self.block = nn.Sequential(*self.block) self.enc_dim = d_model @@ -145,6 +146,12 @@ class DACDecoder(nn.Module): output_dim = channels // 2 ** (i + 1) layers += [DACDecoderBlock(input_dim, output_dim, stride, device = device, dtype = dtype, operations = operations)] + layers += [ + Snake1d(output_dim, device=device, dtype=dtype), + WNConv1d(output_dim, d_out, kernel_size=7, padding=3, device=device, dtype=dtype, operations = operations), + nn.Tanh(), + ] + self.model = nn.Sequential(*layers) def forward(self, x): @@ -154,11 +161,11 @@ class DAC(torch.nn.Module): def __init__( self, encoder_dim: int = 128, - encoder_rates: List[int] = [2, 3, 4, 5], + encoder_rates: List[int] = [2, 3, 4, 5, 8], latent_dim: int = 128, decoder_dim: int = 2048, - decoder_rates: List[int] = [8, 5, 4, 3], - sample_rate: int = 44100, + decoder_rates: List[int] = [8, 5, 4, 3, 2], + sample_rate: int = 48000, ): super().__init__() @@ -173,7 +180,6 @@ class DAC(torch.nn.Module): self.latent_dim = latent_dim - self.hop_length = np.prod(encoder_rates) self.encoder = DACEncoder(encoder_dim, encoder_rates, latent_dim, operations = ops) self.decoder = DACDecoder( @@ -184,8 +190,10 @@ class DAC(torch.nn.Module): ) self.sample_rate = sample_rate + self.post_quant_conv = ops.Conv1d(latent_dim, latent_dim, 1) def decode(self, z: torch.Tensor): + z = self.post_quant_conv(z) return self.decoder(z) def forward(self): @@ -205,17 +213,14 @@ class FoleyVae(torch.nn.Module): v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ] ) + self.decode_sample_rate = self.dac.sample_rate + def decode(self, x, vae_options = {}): return self.dac.decode(x) - def encode(self, x): - return self.synchformer(x) - def forward(self, x): - try: - return self.encode(x) - except: - x = x.to(next(self.parameters()).device) - return self.encode(x) + def encode(self, x): + x = x.to(next(self.parameters()).device) + return self.synchformer(x) def video_encoding(self, video, step): video = torch.stack([self.syncformer_preprocess(t) for t in video]) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 51c8b9dd9..9acfde78b 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -88,7 +88,8 @@ class VAEDecodeAudio: std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - return ({"waveform": audio, "sample_rate": 44100}, ) + sample_rate = vae.first_stage_model.decode_sample_rate or 44100 + return ({"waveform": audio, "sample_rate": sample_rate}, ) def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"): diff --git a/comfy_extras/nodes_hunyuan_foley.py b/comfy_extras/nodes_hunyuan_foley.py index e5f168c53..9057eaa4b 100644 --- a/comfy_extras/nodes_hunyuan_foley.py +++ b/comfy_extras/nodes_hunyuan_foley.py @@ -11,7 +11,7 @@ class EmptyLatentHunyuanFoley(io.ComfyNode): display_name="EmptyLatentHunyuanFoley", category="audio/latent", inputs = [ - io.Int.Input("length", min = 1, max = 15, default = 12), + io.Float.Input("length", min = 1.0, max = 15.0, default = 12.0), io.Int.Input("batch_size", min = 1, max = 48_000, default = 1), io.Video.Input("video", optional=True), ], diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 56a1457a1..fd3964b35 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -72,7 +72,7 @@ class EncodeVideo(io.ComfyNode): model = vae.first_stage_model if vae is not None else clip_vision.model vae = vae if vae is not None else clip_vision - # should be the offload device + if hasattr(model, "video_encoding"): data, num_segments, output_fn = model.video_encoding(video, step_size) batch_size = b * num_segments @@ -95,7 +95,7 @@ class EncodeVideo(io.ComfyNode): try: out = vae.encode(chunk) except: - out = model(chunk) + out = model.encode(chunk) else: out = vae.encode_image(chunk) out = out["image_embeds"] @@ -103,6 +103,7 @@ class EncodeVideo(io.ComfyNode): out_cpu = out.cpu() if outputs is None: full_shape = (total, *out_cpu.shape[1:]) + # should be the offload device outputs = torch.empty(full_shape, dtype=out_cpu.dtype, pin_memory=True) chunk_len = out_cpu.shape[0] @@ -141,7 +142,7 @@ class ResampleVideo(io.ComfyNode): if src_fps is None or target_fps > src_fps: for packet in container.demux(stream): for frame in packet.decode(): - arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0 + arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() frames.append(arr) return io.NodeOutput(torch.stack(frames)) @@ -156,7 +157,7 @@ class ResampleVideo(io.ComfyNode): continue t = frame.time while t >= next_time: - arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0 + arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() frames.append(arr) next_time += step