a lot of fixes + siglip2_base support

This commit is contained in:
Yousef Rafat 2025-10-10 19:11:50 +03:00
parent 4c782e3395
commit e684ff2505
7 changed files with 88 additions and 42 deletions

View File

@ -1,7 +1,28 @@
import torch
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.ldm.modules.attention import optimized_attention_for_device, MultiheadAttentionComfyv
import comfy.ops
class SiglipMultiheadAttentionPoolingHead(torch.nn.Module):
def __init__(self, hidden_size, num_attention_heads, layer_norm_eps, intermediate_size, activation, device=None, dtype=None, operations=None):
super().__init__()
self.probe = torch.nn.Parameter(torch.randn(1, 1, hidden_size, device=device, dtype=dtype))
self.attention = MultiheadAttentionComfyv(hidden_size, num_attention_heads, batch_first=True, device=device, dtype=dtype, operations=operations)
self.layernorm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
self.mlp = CLIPMLP(hidden_size, intermediate_size, activation = activation, device=device, dtype=dtype, operations=operations)
def forward(self, hidden_state):
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations):
super().__init__()
@ -198,6 +219,8 @@ class CLIPVision(torch.nn.Module):
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
model_type = config_dict["model_type"]
use_head = config_dict.get("use_head", False)
layer_norm_eps = config_dict.get("layer_norm_eps", 1e-6)
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
if model_type == "siglip_vision_model":
@ -208,6 +231,11 @@ class CLIPVision(torch.nn.Module):
self.output_layernorm = False
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.post_layernorm = operations.LayerNorm(embed_dim)
self.use_head = use_head
if use_head:
self.head = SiglipMultiheadAttentionPoolingHead(
hidden_size=embed_dim, num_attention_heads=heads, layer_norm_eps=layer_norm_eps, intermediate_size=intermediate_size, activation=intermediate_activation, device=device, dtype=dtype, operations=operations
)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
x = self.embeddings(pixel_values)
@ -216,7 +244,10 @@ class CLIPVision(torch.nn.Module):
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
if self.output_layernorm:
x = self.post_layernorm(x)
pooled_output = x
if self.use_head:
pooled_output = self.head(x)
else:
pooled_output = x
else:
pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output

View File

@ -10,5 +10,6 @@
"num_hidden_layers": 12,
"patch_size": 16,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
"image_std": [0.5, 0.5, 0.5],
"use_head": true
}

View File

@ -780,6 +780,8 @@ class HunyuanVideoFoley(nn.Module):
self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels, **factory_kwargs), requires_grad = False)
self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim, **factory_kwargs), requires_grad = False)
self.conditions = None
def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor:
len = len if len is not None else self.clip_len
if bs is None:
@ -858,28 +860,33 @@ class HunyuanVideoFoley(nn.Module):
bs, _, ol = x.shape
tl = ol // self.patch_size
uncondition, condition = torch.chunk(context, 2)
if self.conditions is None:
condition = condition.view(3, context.size(1) // 3, -1)
uncondition = uncondition.view(3, context.size(1) // 3, -1)
uncondition, condition = torch.chunk(context, 2)
uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3)
clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3)
cond_neg, clip_feat, sync_feat, cond_pos = [trim_repeats(t) for t in (cond_neg, clip_feat, sync_feat, cond_pos)]
condition = condition.view(3, context.size(1) // 3, -1)
uncondition = uncondition.view(3, context.size(1) // 3, -1)
uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)]
uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)]
uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [unlock_cpu_tensor(t, device) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)]
uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3)
clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3)
cond_neg, clip_feat, sync_feat, cond_pos = [trim_repeats(t) for t in (cond_neg, clip_feat, sync_feat, cond_pos)]
diff = cond_pos.shape[1] - cond_neg.shape[1]
if cond_neg.shape[1] < cond_pos.shape[1]:
cond_neg = torch.nn.functional.pad(cond_neg, (0, 0, 0, diff))
elif diff < 0:
cond_pos = torch.nn.functional.pad(cond_pos, (0, 0, 0, torch.abs(diff)))
uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)]
uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)]
uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [unlock_cpu_tensor(t, device) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)]
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])
clip_feat = clip_feat.view(2, -1, 768)
diff = cond_pos.shape[1] - cond_neg.shape[1]
if cond_neg.shape[1] < cond_pos.shape[1]:
cond_neg = torch.nn.functional.pad(cond_neg, (0, 0, 0, diff))
elif diff < 0:
cond_pos = torch.nn.functional.pad(cond_pos, (0, 0, 0, torch.abs(diff)))
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])
clip_feat = clip_feat.view(2, -1, 768)
else:
clip_feat, sync_feat, cond = self.conditions
if drop_visual is not None:
clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)

View File

@ -1,6 +1,5 @@
import math
import torch
import numpy as np
from typing import List
import torch.nn as nn
from einops import rearrange
@ -88,15 +87,17 @@ class DACEncoder(nn.Module):
device = None, dtype = None, operations = None
):
super().__init__()
# Create first convolution
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations)]
# Create EncoderBlocks that double channels as they downsample by `stride`
for stride in strides:
d_model *= 2
self.block += [DACEncoderBlock(d_model, stride=stride, device = device, dtype = dtype, operations = operations)]
# Wrap black into nn.Sequential
self.block += [
Snake1d(d_model, device=device, dtype=dtype),
WNConv1d(d_model, d_latent, kernel_size=3, padding=1, device=device, dtype=dtype, operations = operations),
]
self.block = nn.Sequential(*self.block)
self.enc_dim = d_model
@ -145,6 +146,12 @@ class DACDecoder(nn.Module):
output_dim = channels // 2 ** (i + 1)
layers += [DACDecoderBlock(input_dim, output_dim, stride, device = device, dtype = dtype, operations = operations)]
layers += [
Snake1d(output_dim, device=device, dtype=dtype),
WNConv1d(output_dim, d_out, kernel_size=7, padding=3, device=device, dtype=dtype, operations = operations),
nn.Tanh(),
]
self.model = nn.Sequential(*layers)
def forward(self, x):
@ -154,11 +161,11 @@ class DAC(torch.nn.Module):
def __init__(
self,
encoder_dim: int = 128,
encoder_rates: List[int] = [2, 3, 4, 5],
encoder_rates: List[int] = [2, 3, 4, 5, 8],
latent_dim: int = 128,
decoder_dim: int = 2048,
decoder_rates: List[int] = [8, 5, 4, 3],
sample_rate: int = 44100,
decoder_rates: List[int] = [8, 5, 4, 3, 2],
sample_rate: int = 48000,
):
super().__init__()
@ -173,7 +180,6 @@ class DAC(torch.nn.Module):
self.latent_dim = latent_dim
self.hop_length = np.prod(encoder_rates)
self.encoder = DACEncoder(encoder_dim, encoder_rates, latent_dim, operations = ops)
self.decoder = DACDecoder(
@ -184,8 +190,10 @@ class DAC(torch.nn.Module):
)
self.sample_rate = sample_rate
self.post_quant_conv = ops.Conv1d(latent_dim, latent_dim, 1)
def decode(self, z: torch.Tensor):
z = self.post_quant_conv(z)
return self.decoder(z)
def forward(self):
@ -205,17 +213,14 @@ class FoleyVae(torch.nn.Module):
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
]
)
self.decode_sample_rate = self.dac.sample_rate
def decode(self, x, vae_options = {}):
return self.dac.decode(x)
def encode(self, x):
return self.synchformer(x)
def forward(self, x):
try:
return self.encode(x)
except:
x = x.to(next(self.parameters()).device)
return self.encode(x)
def encode(self, x):
x = x.to(next(self.parameters()).device)
return self.synchformer(x)
def video_encoding(self, video, step):
video = torch.stack([self.syncformer_preprocess(t) for t in video])

View File

@ -88,7 +88,8 @@ class VAEDecodeAudio:
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
return ({"waveform": audio, "sample_rate": 44100}, )
sample_rate = vae.first_stage_model.decode_sample_rate or 44100
return ({"waveform": audio, "sample_rate": sample_rate}, )
def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):

View File

@ -11,7 +11,7 @@ class EmptyLatentHunyuanFoley(io.ComfyNode):
display_name="EmptyLatentHunyuanFoley",
category="audio/latent",
inputs = [
io.Int.Input("length", min = 1, max = 15, default = 12),
io.Float.Input("length", min = 1.0, max = 15.0, default = 12.0),
io.Int.Input("batch_size", min = 1, max = 48_000, default = 1),
io.Video.Input("video", optional=True),
],

View File

@ -72,7 +72,7 @@ class EncodeVideo(io.ComfyNode):
model = vae.first_stage_model if vae is not None else clip_vision.model
vae = vae if vae is not None else clip_vision
# should be the offload device
if hasattr(model, "video_encoding"):
data, num_segments, output_fn = model.video_encoding(video, step_size)
batch_size = b * num_segments
@ -95,7 +95,7 @@ class EncodeVideo(io.ComfyNode):
try:
out = vae.encode(chunk)
except:
out = model(chunk)
out = model.encode(chunk)
else:
out = vae.encode_image(chunk)
out = out["image_embeds"]
@ -103,6 +103,7 @@ class EncodeVideo(io.ComfyNode):
out_cpu = out.cpu()
if outputs is None:
full_shape = (total, *out_cpu.shape[1:])
# should be the offload device
outputs = torch.empty(full_shape, dtype=out_cpu.dtype, pin_memory=True)
chunk_len = out_cpu.shape[0]
@ -141,7 +142,7 @@ class ResampleVideo(io.ComfyNode):
if src_fps is None or target_fps > src_fps:
for packet in container.demux(stream):
for frame in packet.decode():
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float()
frames.append(arr)
return io.NodeOutput(torch.stack(frames))
@ -156,7 +157,7 @@ class ResampleVideo(io.ComfyNode):
continue
t = frame.time
while t >= next_time:
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float()
frames.append(arr)
next_time += step