mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 21:42:37 +08:00
a lot of fixes + siglip2_base support
This commit is contained in:
parent
4c782e3395
commit
e684ff2505
@ -1,7 +1,28 @@
|
||||
import torch
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device, MultiheadAttentionComfyv
|
||||
import comfy.ops
|
||||
|
||||
class SiglipMultiheadAttentionPoolingHead(torch.nn.Module):
|
||||
def __init__(self, hidden_size, num_attention_heads, layer_norm_eps, intermediate_size, activation, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
self.probe = torch.nn.Parameter(torch.randn(1, 1, hidden_size, device=device, dtype=dtype))
|
||||
self.attention = MultiheadAttentionComfyv(hidden_size, num_attention_heads, batch_first=True, device=device, dtype=dtype, operations=operations)
|
||||
self.layernorm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
||||
self.mlp = CLIPMLP(hidden_size, intermediate_size, activation = activation, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward(self, hidden_state):
|
||||
batch_size = hidden_state.shape[0]
|
||||
probe = self.probe.repeat(batch_size, 1, 1)
|
||||
|
||||
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
|
||||
|
||||
residual = hidden_state
|
||||
hidden_state = self.layernorm(hidden_state)
|
||||
hidden_state = residual + self.mlp(hidden_state)
|
||||
|
||||
return hidden_state[:, 0]
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||
super().__init__()
|
||||
@ -198,6 +219,8 @@ class CLIPVision(torch.nn.Module):
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
model_type = config_dict["model_type"]
|
||||
use_head = config_dict.get("use_head", False)
|
||||
layer_norm_eps = config_dict.get("layer_norm_eps", 1e-6)
|
||||
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
|
||||
if model_type == "siglip_vision_model":
|
||||
@ -208,6 +231,11 @@ class CLIPVision(torch.nn.Module):
|
||||
self.output_layernorm = False
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.post_layernorm = operations.LayerNorm(embed_dim)
|
||||
self.use_head = use_head
|
||||
if use_head:
|
||||
self.head = SiglipMultiheadAttentionPoolingHead(
|
||||
hidden_size=embed_dim, num_attention_heads=heads, layer_norm_eps=layer_norm_eps, intermediate_size=intermediate_size, activation=intermediate_activation, device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
|
||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||
x = self.embeddings(pixel_values)
|
||||
@ -216,7 +244,10 @@ class CLIPVision(torch.nn.Module):
|
||||
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
|
||||
if self.output_layernorm:
|
||||
x = self.post_layernorm(x)
|
||||
pooled_output = x
|
||||
if self.use_head:
|
||||
pooled_output = self.head(x)
|
||||
else:
|
||||
pooled_output = x
|
||||
else:
|
||||
pooled_output = self.post_layernorm(x[:, 0, :])
|
||||
return x, i, pooled_output
|
||||
|
||||
@ -10,5 +10,6 @@
|
||||
"num_hidden_layers": 12,
|
||||
"patch_size": 16,
|
||||
"image_mean": [0.5, 0.5, 0.5],
|
||||
"image_std": [0.5, 0.5, 0.5]
|
||||
"image_std": [0.5, 0.5, 0.5],
|
||||
"use_head": true
|
||||
}
|
||||
|
||||
@ -780,6 +780,8 @@ class HunyuanVideoFoley(nn.Module):
|
||||
self.empty_clip_feat = nn.Parameter(torch.zeros(1, self.visual_in_channels, **factory_kwargs), requires_grad = False)
|
||||
self.empty_sync_feat = nn.Parameter(torch.zeros(1, self.sync_feat_dim, **factory_kwargs), requires_grad = False)
|
||||
|
||||
self.conditions = None
|
||||
|
||||
def get_empty_clip_sequence(self, bs=None, len=None) -> torch.Tensor:
|
||||
len = len if len is not None else self.clip_len
|
||||
if bs is None:
|
||||
@ -858,28 +860,33 @@ class HunyuanVideoFoley(nn.Module):
|
||||
bs, _, ol = x.shape
|
||||
tl = ol // self.patch_size
|
||||
|
||||
uncondition, condition = torch.chunk(context, 2)
|
||||
if self.conditions is None:
|
||||
|
||||
condition = condition.view(3, context.size(1) // 3, -1)
|
||||
uncondition = uncondition.view(3, context.size(1) // 3, -1)
|
||||
uncondition, condition = torch.chunk(context, 2)
|
||||
|
||||
uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3)
|
||||
clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3)
|
||||
cond_neg, clip_feat, sync_feat, cond_pos = [trim_repeats(t) for t in (cond_neg, clip_feat, sync_feat, cond_pos)]
|
||||
condition = condition.view(3, context.size(1) // 3, -1)
|
||||
uncondition = uncondition.view(3, context.size(1) // 3, -1)
|
||||
|
||||
uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)]
|
||||
uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)]
|
||||
|
||||
uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [unlock_cpu_tensor(t, device) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)]
|
||||
uncond_1, uncond_2, cond_neg = torch.chunk(uncondition, 3)
|
||||
clip_feat, sync_feat, cond_pos = torch.chunk(condition, 3)
|
||||
cond_neg, clip_feat, sync_feat, cond_pos = [trim_repeats(t) for t in (cond_neg, clip_feat, sync_feat, cond_pos)]
|
||||
|
||||
diff = cond_pos.shape[1] - cond_neg.shape[1]
|
||||
if cond_neg.shape[1] < cond_pos.shape[1]:
|
||||
cond_neg = torch.nn.functional.pad(cond_neg, (0, 0, 0, diff))
|
||||
elif diff < 0:
|
||||
cond_pos = torch.nn.functional.pad(cond_pos, (0, 0, 0, torch.abs(diff)))
|
||||
uncond_1 = uncond_1[:, :clip_feat.size(1), :clip_feat.size(2)]
|
||||
uncond_2 = uncond_2[:, :sync_feat.size(1), :sync_feat.size(2)]
|
||||
|
||||
uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos = [unlock_cpu_tensor(t, device) for t in (uncond_1, uncond_2, cond_neg, clip_feat, sync_feat, cond_pos)]
|
||||
|
||||
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])
|
||||
clip_feat = clip_feat.view(2, -1, 768)
|
||||
diff = cond_pos.shape[1] - cond_neg.shape[1]
|
||||
if cond_neg.shape[1] < cond_pos.shape[1]:
|
||||
cond_neg = torch.nn.functional.pad(cond_neg, (0, 0, 0, diff))
|
||||
elif diff < 0:
|
||||
cond_pos = torch.nn.functional.pad(cond_pos, (0, 0, 0, torch.abs(diff)))
|
||||
|
||||
clip_feat, sync_feat, cond = torch.cat([uncond_1, clip_feat]), torch.cat([uncond_2, sync_feat]), torch.cat([cond_neg, cond_pos])
|
||||
clip_feat = clip_feat.view(2, -1, 768)
|
||||
|
||||
else:
|
||||
clip_feat, sync_feat, cond = self.conditions
|
||||
|
||||
if drop_visual is not None:
|
||||
clip_feat[drop_visual] = self.get_empty_clip_sequence().to(dtype=clip_feat.dtype)
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import math
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import List
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
@ -88,15 +87,17 @@ class DACEncoder(nn.Module):
|
||||
device = None, dtype = None, operations = None
|
||||
):
|
||||
super().__init__()
|
||||
# Create first convolution
|
||||
self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3, device = device, dtype = dtype, operations = operations)]
|
||||
|
||||
# Create EncoderBlocks that double channels as they downsample by `stride`
|
||||
for stride in strides:
|
||||
d_model *= 2
|
||||
self.block += [DACEncoderBlock(d_model, stride=stride, device = device, dtype = dtype, operations = operations)]
|
||||
|
||||
# Wrap black into nn.Sequential
|
||||
self.block += [
|
||||
Snake1d(d_model, device=device, dtype=dtype),
|
||||
WNConv1d(d_model, d_latent, kernel_size=3, padding=1, device=device, dtype=dtype, operations = operations),
|
||||
]
|
||||
|
||||
self.block = nn.Sequential(*self.block)
|
||||
self.enc_dim = d_model
|
||||
|
||||
@ -145,6 +146,12 @@ class DACDecoder(nn.Module):
|
||||
output_dim = channels // 2 ** (i + 1)
|
||||
layers += [DACDecoderBlock(input_dim, output_dim, stride, device = device, dtype = dtype, operations = operations)]
|
||||
|
||||
layers += [
|
||||
Snake1d(output_dim, device=device, dtype=dtype),
|
||||
WNConv1d(output_dim, d_out, kernel_size=7, padding=3, device=device, dtype=dtype, operations = operations),
|
||||
nn.Tanh(),
|
||||
]
|
||||
|
||||
self.model = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
@ -154,11 +161,11 @@ class DAC(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_dim: int = 128,
|
||||
encoder_rates: List[int] = [2, 3, 4, 5],
|
||||
encoder_rates: List[int] = [2, 3, 4, 5, 8],
|
||||
latent_dim: int = 128,
|
||||
decoder_dim: int = 2048,
|
||||
decoder_rates: List[int] = [8, 5, 4, 3],
|
||||
sample_rate: int = 44100,
|
||||
decoder_rates: List[int] = [8, 5, 4, 3, 2],
|
||||
sample_rate: int = 48000,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -173,7 +180,6 @@ class DAC(torch.nn.Module):
|
||||
|
||||
self.latent_dim = latent_dim
|
||||
|
||||
self.hop_length = np.prod(encoder_rates)
|
||||
self.encoder = DACEncoder(encoder_dim, encoder_rates, latent_dim, operations = ops)
|
||||
|
||||
self.decoder = DACDecoder(
|
||||
@ -184,8 +190,10 @@ class DAC(torch.nn.Module):
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
self.post_quant_conv = ops.Conv1d(latent_dim, latent_dim, 1)
|
||||
|
||||
def decode(self, z: torch.Tensor):
|
||||
z = self.post_quant_conv(z)
|
||||
return self.decoder(z)
|
||||
|
||||
def forward(self):
|
||||
@ -205,17 +213,14 @@ class FoleyVae(torch.nn.Module):
|
||||
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
||||
]
|
||||
)
|
||||
self.decode_sample_rate = self.dac.sample_rate
|
||||
|
||||
def decode(self, x, vae_options = {}):
|
||||
return self.dac.decode(x)
|
||||
def encode(self, x):
|
||||
return self.synchformer(x)
|
||||
|
||||
def forward(self, x):
|
||||
try:
|
||||
return self.encode(x)
|
||||
except:
|
||||
x = x.to(next(self.parameters()).device)
|
||||
return self.encode(x)
|
||||
def encode(self, x):
|
||||
x = x.to(next(self.parameters()).device)
|
||||
return self.synchformer(x)
|
||||
|
||||
def video_encoding(self, video, step):
|
||||
video = torch.stack([self.syncformer_preprocess(t) for t in video])
|
||||
|
||||
@ -88,7 +88,8 @@ class VAEDecodeAudio:
|
||||
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
return ({"waveform": audio, "sample_rate": 44100}, )
|
||||
sample_rate = vae.first_stage_model.decode_sample_rate or 44100
|
||||
return ({"waveform": audio, "sample_rate": sample_rate}, )
|
||||
|
||||
|
||||
def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
|
||||
|
||||
@ -11,7 +11,7 @@ class EmptyLatentHunyuanFoley(io.ComfyNode):
|
||||
display_name="EmptyLatentHunyuanFoley",
|
||||
category="audio/latent",
|
||||
inputs = [
|
||||
io.Int.Input("length", min = 1, max = 15, default = 12),
|
||||
io.Float.Input("length", min = 1.0, max = 15.0, default = 12.0),
|
||||
io.Int.Input("batch_size", min = 1, max = 48_000, default = 1),
|
||||
io.Video.Input("video", optional=True),
|
||||
],
|
||||
|
||||
@ -72,7 +72,7 @@ class EncodeVideo(io.ComfyNode):
|
||||
model = vae.first_stage_model if vae is not None else clip_vision.model
|
||||
vae = vae if vae is not None else clip_vision
|
||||
|
||||
# should be the offload device
|
||||
|
||||
if hasattr(model, "video_encoding"):
|
||||
data, num_segments, output_fn = model.video_encoding(video, step_size)
|
||||
batch_size = b * num_segments
|
||||
@ -95,7 +95,7 @@ class EncodeVideo(io.ComfyNode):
|
||||
try:
|
||||
out = vae.encode(chunk)
|
||||
except:
|
||||
out = model(chunk)
|
||||
out = model.encode(chunk)
|
||||
else:
|
||||
out = vae.encode_image(chunk)
|
||||
out = out["image_embeds"]
|
||||
@ -103,6 +103,7 @@ class EncodeVideo(io.ComfyNode):
|
||||
out_cpu = out.cpu()
|
||||
if outputs is None:
|
||||
full_shape = (total, *out_cpu.shape[1:])
|
||||
# should be the offload device
|
||||
outputs = torch.empty(full_shape, dtype=out_cpu.dtype, pin_memory=True)
|
||||
|
||||
chunk_len = out_cpu.shape[0]
|
||||
@ -141,7 +142,7 @@ class ResampleVideo(io.ComfyNode):
|
||||
if src_fps is None or target_fps > src_fps:
|
||||
for packet in container.demux(stream):
|
||||
for frame in packet.decode():
|
||||
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0
|
||||
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float()
|
||||
frames.append(arr)
|
||||
return io.NodeOutput(torch.stack(frames))
|
||||
|
||||
@ -156,7 +157,7 @@ class ResampleVideo(io.ComfyNode):
|
||||
continue
|
||||
t = frame.time
|
||||
while t >= next_time:
|
||||
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float() / 255.0
|
||||
arr = torch.from_numpy(frame.to_ndarray(format="rgb24")).float()
|
||||
frames.append(arr)
|
||||
next_time += step
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user