a lot of fixes + siglip2_base support

This commit is contained in:
Yousef Rafat 2025-10-10 19:11:50 +03:00
parent 4c782e3395
commit e684ff2505
7 changed files with 88 additions and 42 deletions

View File

@ -1,7 +1,28 @@
import torch import torch
from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention_for_device, MultiheadAttentionComfyv
import comfy.ops import comfy.ops
class SiglipMultiheadAttentionPoolingHead(torch.nn.Module):
def __init__(self, hidden_size, num_attention_heads, layer_norm_eps, intermediate_size, activation, device=None, dtype=None, operations=None):
super().__init__()
self.probe = torch.nn.Parameter(torch.randn(1, 1, hidden_size, device=device, dtype=dtype))
self.attention = MultiheadAttentionComfyv(hidden_size, num_attention_heads, batch_first=True, device=device, dtype=dtype, operations=operations)
self.layernorm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
self.mlp = CLIPMLP(hidden_size, intermediate_size, activation = activation, device=device, dtype=dtype, operations=operations)
def forward(self, hidden_state):
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
class CLIPAttention(torch.nn.Module): class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations): def __init__(self, embed_dim, heads, dtype, device, operations):
super().__init__() super().__init__()
@ -198,6 +219,8 @@ class CLIPVision(torch.nn.Module):
intermediate_size = config_dict["intermediate_size"] intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"] intermediate_activation = config_dict["hidden_act"]
model_type = config_dict["model_type"] model_type = config_dict["model_type"]
use_head = config_dict.get("use_head", False)
layer_norm_eps = config_dict.get("layer_norm_eps", 1e-6)
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations) self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
if model_type == "siglip_vision_model": if model_type == "siglip_vision_model":
@ -208,6 +231,11 @@ class CLIPVision(torch.nn.Module):
self.output_layernorm = False self.output_layernorm = False
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.post_layernorm = operations.LayerNorm(embed_dim) self.post_layernorm = operations.LayerNorm(embed_dim)
self.use_head = use_head
if use_head:
self.head = SiglipMultiheadAttentionPoolingHead(
hidden_size=embed_dim, num_attention_heads=heads, layer_norm_eps=layer_norm_eps, intermediate_size=intermediate_size, activation=intermediate_activation, device=device, dtype=dtype, operations=operations
)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None): def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
x = self.embeddings(pixel_values) x = self.embeddings(pixel_values)
@ -216,7 +244,10 @@ class CLIPVision(torch.nn.Module):
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output) x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
if self.output_layernorm: if self.output_layernorm:
x = self.post_layernorm(x) x = self.post_layernorm(x)
pooled_output = x if self.use_head:
pooled_output = self.head(x)
else:
pooled_output = x
else: else:
pooled_output = self.post_layernorm(x[:, 0, :]) pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output return x, i, pooled_output

View File

@ -10,5 +10,6 @@
"num_hidden_layers": 12, "num_hidden_layers": 12,
"patch_size": 16, "patch_size": 16,
"image_mean": [0.5, 0.5, 0.5], "image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5] "image_std": [0.5, 0.5, 0.5],
"use_head": true
} }

View File

@ -780,6 +780,8 @@ class HunyuanVideoFoley(nn.Module):
self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels, **factory_kwargs), requires_grad = False) self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels, **factory_kwargs), requires_grad = False)
self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim, **factory_kwargs), requires_grad = False) self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim, **factory_kwargs), requires_grad = False)
self.conditions = None
def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor: def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor:
len = len if len is not None else self.clip_len len = len if len is not None else self.clip_len
if bs is None: if bs is None:
@ -858,28 +860,33 @@ class HunyuanVideoFoley(nn.Module):
bs, _, ol = x.shape bs, _, ol = x.shape
tl = ol // self.patch_size tl = ol // self.patch_size
uncondition, condition = torch.chunk(context, 2) if self.conditions is None:
condition = condition.view(3, context.size(1) // 3, -1) uncondition, condition = torch.chunk(context, 2)
uncondition = uncondition.view(3, context.size(1) // 3, -1)
uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3) condition = condition.view(3, context.size(1) // 3, -1)
clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3) uncondition = uncondition.view(3, context.size(1) // 3, -1)
cond_neg, clip_feat, sync_feat, cond_pos = [trim_repeats(t) for t in (cond_neg, clip_feat, sync_feat, cond_pos)]
uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)] uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3)
uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)] clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3)
cond_neg, clip_feat, sync_feat, cond_pos = [trim_repeats(t) for t in (cond_neg, clip_feat, sync_feat, cond_pos)]
uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [unlock_cpu_tensor(t, device) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)] uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)]
uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)]
diff = cond_pos.shape[1] - cond_neg.shape[1] uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [unlock_cpu_tensor(t, device) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)]
if cond_neg.shape[1] < cond_pos.shape[1]:
cond_neg = torch.nn.functional.pad(cond_neg, (0, 0, 0, diff))
elif diff < 0:
cond_pos = torch.nn.functional.pad(cond_pos, (0, 0, 0, torch.abs(diff)))
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos]) diff = cond_pos.shape[1] - cond_neg.shape[1]
clip_feat = clip_feat.view(2, -1, 768) if cond_neg.shape[1] < cond_pos.shape[1]:
cond_neg = torch.nn.functional.pad(cond_neg, (0, 0, 0, diff))
elif diff < 0:
cond_pos = torch.nn.functional.pad(cond_pos, (0, 0, 0, torch.abs(diff)))
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])
clip_feat = clip_feat.view(2, -1, 768)
else:
clip_feat, sync_feat, cond = self.conditions
if drop_visual is not None: if drop_visual is not None:
clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype) clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)

View File

@ -1,6 +1,5 @@
import math import math
import torch import torch
import numpy as np
from typing import List from typing import List
import torch.nn as nn import torch.nn as nn
from einops import rearrange from einops import rearrange
@ -88,15 +87,17 @@ class DACEncoder(nn.Module):
device = None, dtype = None, operations = None device = None, dtype = None, operations = None
): ):
super().__init__() super().__init__()
# Create first convolution
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations)] self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in strides: for stride in strides:
d_model *= 2 d_model *= 2
self.block += [DACEncoderBlock(d_model, stride=stride, device = device, dtype = dtype, operations = operations)] self.block += [DACEncoderBlock(d_model, stride=stride, device = device, dtype = dtype, operations = operations)]
# Wrap black into nn.Sequential self.block += [
Snake1d(d_model, device=device, dtype=dtype),
WNConv1d(d_model, d_latent, kernel_size=3, padding=1, device=device, dtype=dtype, operations = operations),
]
self.block = nn.Sequential(*self.block) self.block = nn.Sequential(*self.block)
self.enc_dim = d_model self.enc_dim = d_model
@ -145,6 +146,12 @@ class DACDecoder(nn.Module):
output_dim = channels // 2 ** (i + 1) output_dim = channels // 2 ** (i + 1)
layers += [DACDecoderBlock(input_dim, output_dim, stride, device = device, dtype = dtype, operations = operations)] layers += [DACDecoderBlock(input_dim, output_dim, stride, device = device, dtype = dtype, operations = operations)]
layers += [
Snake1d(output_dim, device=device, dtype=dtype),
WNConv1d(output_dim, d_out, kernel_size=7, padding=3, device=device, dtype=dtype, operations = operations),
nn.Tanh(),
]
self.model = nn.Sequential(*layers) self.model = nn.Sequential(*layers)
def forward(self, x): def forward(self, x):
@ -154,11 +161,11 @@ class DAC(torch.nn.Module):
def __init__( def __init__(
self, self,
encoder_dim: int = 128, encoder_dim: int = 128,
encoder_rates: List[int] = [2, 3, 4, 5], encoder_rates: List[int] = [2, 3, 4, 5, 8],
latent_dim: int = 128, latent_dim: int = 128,
decoder_dim: int = 2048, decoder_dim: int = 2048,
decoder_rates: List[int] = [8, 5, 4, 3], decoder_rates: List[int] = [8, 5, 4, 3, 2],
sample_rate: int = 44100, sample_rate: int = 48000,
): ):
super().__init__() super().__init__()
@ -173,7 +180,6 @@ class DAC(torch.nn.Module):
self.latent_dim = latent_dim self.latent_dim = latent_dim
self.hop_length = np.prod(encoder_rates)
self.encoder = DACEncoder(encoder_dim, encoder_rates, latent_dim, operations = ops) self.encoder = DACEncoder(encoder_dim, encoder_rates, latent_dim, operations = ops)
self.decoder = DACDecoder( self.decoder = DACDecoder(
@ -184,8 +190,10 @@ class DAC(torch.nn.Module):
) )
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.post_quant_conv = ops.Conv1d(latent_dim, latent_dim, 1)
def decode(self, z: torch.Tensor): def decode(self, z: torch.Tensor):
z = self.post_quant_conv(z)
return self.decoder(z) return self.decoder(z)
def forward(self): def forward(self):
@ -205,17 +213,14 @@ class FoleyVae(torch.nn.Module):
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
] ]
) )
self.decode_sample_rate = self.dac.sample_rate
def decode(self, x, vae_options = {}): def decode(self, x, vae_options = {}):
return self.dac.decode(x) return self.dac.decode(x)
def encode(self, x):
return self.synchformer(x)
def forward(self, x): def encode(self, x):
try: x = x.to(next(self.parameters()).device)
return self.encode(x) return self.synchformer(x)
except:
x = x.to(next(self.parameters()).device)
return self.encode(x)
def video_encoding(self, video, step): def video_encoding(self, video, step):
video = torch.stack([self.syncformer_preprocess(t) for t in video]) video = torch.stack([self.syncformer_preprocess(t) for t in video])

View File

@ -88,7 +88,8 @@ class VAEDecodeAudio:
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0 std[std < 1.0] = 1.0
audio /= std audio /= std
return ({"waveform": audio, "sample_rate": 44100}, ) sample_rate = vae.first_stage_model.decode_sample_rate or 44100
return ({"waveform": audio, "sample_rate": sample_rate}, )
def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"): def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):

View File

@ -11,7 +11,7 @@ class EmptyLatentHunyuanFoley(io.ComfyNode):
display_name="EmptyLatentHunyuanFoley", display_name="EmptyLatentHunyuanFoley",
category="audio/latent", category="audio/latent",
inputs = [ inputs = [
io.Int.Input("length", min = 1, max = 15, default = 12), io.Float.Input("length", min = 1.0, max = 15.0, default = 12.0),
io.Int.Input("batch_size", min = 1, max = 48_000, default = 1), io.Int.Input("batch_size", min = 1, max = 48_000, default = 1),
io.Video.Input("video", optional=True), io.Video.Input("video", optional=True),
], ],

View File

@ -72,7 +72,7 @@ class EncodeVideo(io.ComfyNode):
model = vae.first_stage_model if vae is not None else clip_vision.model model = vae.first_stage_model if vae is not None else clip_vision.model
vae = vae if vae is not None else clip_vision vae = vae if vae is not None else clip_vision
# should be the offload device
if hasattr(model, "video_encoding"): if hasattr(model, "video_encoding"):
data, num_segments, output_fn = model.video_encoding(video, step_size) data, num_segments, output_fn = model.video_encoding(video, step_size)
batch_size = b * num_segments batch_size = b * num_segments
@ -95,7 +95,7 @@ class EncodeVideo(io.ComfyNode):
try: try:
out = vae.encode(chunk) out = vae.encode(chunk)
except: except:
out = model(chunk) out = model.encode(chunk)
else: else:
out = vae.encode_image(chunk) out = vae.encode_image(chunk)
out = out["image_embeds"] out = out["image_embeds"]
@ -103,6 +103,7 @@ class EncodeVideo(io.ComfyNode):
out_cpu = out.cpu() out_cpu = out.cpu()
if outputs is None: if outputs is None:
full_shape = (total, *out_cpu.shape[1:]) full_shape = (total, *out_cpu.shape[1:])
# should be the offload device
outputs = torch.empty(full_shape, dtype=out_cpu.dtype, pin_memory=True) outputs = torch.empty(full_shape, dtype=out_cpu.dtype, pin_memory=True)
chunk_len = out_cpu.shape[0] chunk_len = out_cpu.shape[0]
@ -141,7 +142,7 @@ class ResampleVideo(io.ComfyNode):
if src_fps is None or target_fps > src_fps: if src_fps is None or target_fps > src_fps:
for packet in container.demux(stream): for packet in container.demux(stream):
for frame in packet.decode(): for frame in packet.decode():
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0 arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float()
frames.append(arr) frames.append(arr)
return io.NodeOutput(torch.stack(frames)) return io.NodeOutput(torch.stack(frames))
@ -156,7 +157,7 @@ class ResampleVideo(io.ComfyNode):
continue continue
t = frame.time t = frame.time
while t >= next_time: while t >= next_time:
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0 arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float()
frames.append(arr) frames.append(arr)
next_time += step next_time += step