mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-14 20:42:31 +08:00
Initial CogVideoX and SparkVSR support
This commit is contained in:
parent
5410ed34f5
commit
3e4d15a4f7
@ -783,3 +783,10 @@ class ZImagePixelSpace(ChromaRadiance):
|
||||
No VAE encoding/decoding — the model operates directly on RGB pixels.
|
||||
"""
|
||||
pass
|
||||
|
||||
class CogVideoX(LatentFormat):
|
||||
latent_channels = 16
|
||||
latent_dimensions = 3
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.15258426
|
||||
|
||||
0
comfy/ldm/cogvideo/__init__.py
Normal file
0
comfy/ldm/cogvideo/__init__.py
Normal file
571
comfy/ldm/cogvideo/model.py
Normal file
571
comfy/ldm/cogvideo/model.py
Normal file
@ -0,0 +1,571 @@
|
||||
# CogVideoX 3D Transformer - ported to ComfyUI native ops
|
||||
# Architecture reference: diffusers CogVideoXTransformer3DModel
|
||||
# Style reference: comfy/ldm/wan/model.py
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.patcher_extension
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
def _get_1d_rotary_pos_embed(dim, pos, theta=10000.0):
|
||||
"""Returns (cos, sin) each with shape [seq_len, dim].
|
||||
|
||||
Frequencies are computed at dim//2 resolution then repeat_interleaved
|
||||
to full dim, matching CogVideoX's interleaved (real, imag) pair format.
|
||||
"""
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim))
|
||||
angles = torch.outer(pos.float(), freqs.float())
|
||||
cos = angles.cos().repeat_interleave(2, dim=-1).float()
|
||||
sin = angles.sin().repeat_interleave(2, dim=-1).float()
|
||||
return (cos, sin)
|
||||
|
||||
|
||||
def apply_rotary_emb(x, freqs_cos_sin):
|
||||
"""Apply CogVideoX rotary embedding to query or key tensor.
|
||||
|
||||
x: [B, heads, seq_len, head_dim]
|
||||
freqs_cos_sin: (cos, sin) each [seq_len, head_dim//2]
|
||||
|
||||
Uses interleaved pair rotation (same as diffusers CogVideoX/Flux).
|
||||
head_dim is reshaped to (-1, 2) pairs, rotated, then flattened back.
|
||||
"""
|
||||
cos, sin = freqs_cos_sin
|
||||
cos = cos[None, None, :, :].to(x.device)
|
||||
sin = sin[None, None, :, :].to(x.device)
|
||||
|
||||
# Interleaved pairs: [B, H, S, D] -> [B, H, S, D//2, 2] -> (real, imag)
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half)
|
||||
args = timesteps[:, None].float() * freqs[None] * scale
|
||||
embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
|
||||
if flip_sin_to_cos:
|
||||
embedding = torch.cat([embedding[:, half:], embedding[:, :half]], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
def get_3d_sincos_pos_embed(embed_dim, spatial_size, temporal_size, spatial_interpolation_scale=1.0, temporal_interpolation_scale=1.0, device=None):
|
||||
if isinstance(spatial_size, int):
|
||||
spatial_size = (spatial_size, spatial_size)
|
||||
|
||||
grid_w = torch.arange(spatial_size[0], dtype=torch.float32, device=device) / spatial_interpolation_scale
|
||||
grid_h = torch.arange(spatial_size[1], dtype=torch.float32, device=device) / spatial_interpolation_scale
|
||||
grid_t = torch.arange(temporal_size, dtype=torch.float32, device=device) / temporal_interpolation_scale
|
||||
|
||||
grid_t, grid_h, grid_w = torch.meshgrid(grid_t, grid_h, grid_w, indexing="ij")
|
||||
|
||||
embed_dim_spatial = 2 * (embed_dim // 3)
|
||||
embed_dim_temporal = embed_dim // 3
|
||||
|
||||
pos_embed_spatial = _get_2d_sincos_pos_embed(embed_dim_spatial, grid_h, grid_w, device=device)
|
||||
pos_embed_temporal = _get_1d_sincos_pos_embed(embed_dim_temporal, grid_t[:, 0, 0], device=device)
|
||||
|
||||
T, H, W = grid_t.shape
|
||||
pos_embed_temporal = pos_embed_temporal.unsqueeze(1).unsqueeze(1).expand(-1, H, W, -1)
|
||||
pos_embed = torch.cat([pos_embed_temporal, pos_embed_spatial], dim=-1)
|
||||
|
||||
return pos_embed
|
||||
|
||||
|
||||
def _get_2d_sincos_pos_embed(embed_dim, grid_h, grid_w, device=None):
|
||||
T, H, W = grid_h.shape
|
||||
half_dim = embed_dim // 2
|
||||
pos_h = _get_1d_sincos_pos_embed(half_dim, grid_h.reshape(-1), device=device).reshape(T, H, W, half_dim)
|
||||
pos_w = _get_1d_sincos_pos_embed(half_dim, grid_w.reshape(-1), device=device).reshape(T, H, W, half_dim)
|
||||
return torch.cat([pos_h, pos_w], dim=-1)
|
||||
|
||||
|
||||
def _get_1d_sincos_pos_embed(embed_dim, pos, device=None):
|
||||
half = embed_dim // 2
|
||||
freqs = torch.exp(-math.log(10000.0) * torch.arange(start=0, end=half, dtype=torch.float32, device=device) / half)
|
||||
args = pos.float().reshape(-1)[:, None] * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if embed_dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
|
||||
class CogVideoXPatchEmbed(nn.Module):
|
||||
def __init__(self, patch_size=2, patch_size_t=None, in_channels=16, dim=1920,
|
||||
text_dim=4096, bias=True, sample_width=90, sample_height=60,
|
||||
sample_frames=49, temporal_compression_ratio=4,
|
||||
max_text_seq_length=226, spatial_interpolation_scale=1.875,
|
||||
temporal_interpolation_scale=1.0, use_positional_embeddings=True,
|
||||
use_learned_positional_embeddings=True,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.patch_size_t = patch_size_t
|
||||
self.dim = dim
|
||||
self.sample_height = sample_height
|
||||
self.sample_width = sample_width
|
||||
self.sample_frames = sample_frames
|
||||
self.temporal_compression_ratio = temporal_compression_ratio
|
||||
self.max_text_seq_length = max_text_seq_length
|
||||
self.spatial_interpolation_scale = spatial_interpolation_scale
|
||||
self.temporal_interpolation_scale = temporal_interpolation_scale
|
||||
self.use_positional_embeddings = use_positional_embeddings
|
||||
self.use_learned_positional_embeddings = use_learned_positional_embeddings
|
||||
|
||||
if patch_size_t is None:
|
||||
self.proj = operations.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size, bias=bias, device=device, dtype=dtype)
|
||||
else:
|
||||
self.proj = operations.Linear(in_channels * patch_size * patch_size * patch_size_t, dim, device=device, dtype=dtype)
|
||||
|
||||
self.text_proj = operations.Linear(text_dim, dim, device=device, dtype=dtype)
|
||||
|
||||
if use_positional_embeddings or use_learned_positional_embeddings:
|
||||
persistent = use_learned_positional_embeddings
|
||||
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
|
||||
self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
|
||||
|
||||
def _get_positional_embeddings(self, sample_height, sample_width, sample_frames, device=None):
|
||||
post_patch_height = sample_height // self.patch_size
|
||||
post_patch_width = sample_width // self.patch_size
|
||||
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
|
||||
if self.patch_size_t is not None:
|
||||
post_time_compression_frames = post_time_compression_frames // self.patch_size_t
|
||||
num_patches = post_patch_height * post_patch_width * post_time_compression_frames
|
||||
|
||||
pos_embedding = get_3d_sincos_pos_embed(
|
||||
self.dim,
|
||||
(post_patch_width, post_patch_height),
|
||||
post_time_compression_frames,
|
||||
self.spatial_interpolation_scale,
|
||||
self.temporal_interpolation_scale,
|
||||
device=device,
|
||||
)
|
||||
pos_embedding = pos_embedding.reshape(-1, self.dim)
|
||||
joint_pos_embedding = pos_embedding.new_zeros(
|
||||
1, self.max_text_seq_length + num_patches, self.dim, requires_grad=False
|
||||
)
|
||||
joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding)
|
||||
return joint_pos_embedding
|
||||
|
||||
def forward(self, text_embeds, image_embeds):
|
||||
text_embeds = self.text_proj(text_embeds)
|
||||
batch_size, num_frames, channels, height, width = image_embeds.shape
|
||||
|
||||
if self.patch_size_t is None:
|
||||
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
||||
image_embeds = self.proj(image_embeds)
|
||||
image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
|
||||
image_embeds = image_embeds.flatten(3).transpose(2, 3)
|
||||
image_embeds = image_embeds.flatten(1, 2)
|
||||
else:
|
||||
p = self.patch_size
|
||||
p_t = self.patch_size_t
|
||||
image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
|
||||
image_embeds = image_embeds.reshape(
|
||||
batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
|
||||
)
|
||||
image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
|
||||
image_embeds = self.proj(image_embeds)
|
||||
|
||||
embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous()
|
||||
|
||||
if self.use_positional_embeddings or self.use_learned_positional_embeddings:
|
||||
text_seq_length = text_embeds.shape[1]
|
||||
num_image_patches = image_embeds.shape[1]
|
||||
|
||||
# Compute sincos pos embedding for image patches
|
||||
pos_embedding = get_3d_sincos_pos_embed(
|
||||
self.dim,
|
||||
(width // self.patch_size, height // self.patch_size),
|
||||
num_image_patches // ((height // self.patch_size) * (width // self.patch_size)),
|
||||
self.spatial_interpolation_scale,
|
||||
self.temporal_interpolation_scale,
|
||||
device=embeds.device,
|
||||
).reshape(-1, self.dim)
|
||||
|
||||
# Build joint: zeros for text + sincos for image
|
||||
joint_pos = torch.zeros(1, text_seq_length + num_image_patches, self.dim, device=embeds.device, dtype=embeds.dtype)
|
||||
joint_pos[:, text_seq_length:] = pos_embedding.to(dtype=embeds.dtype)
|
||||
embeds = embeds + joint_pos
|
||||
|
||||
return embeds
|
||||
|
||||
|
||||
class CogVideoXLayerNormZero(nn.Module):
|
||||
def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5, bias=True,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(time_dim, 6 * dim, bias=bias, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states, temb):
|
||||
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
|
||||
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
|
||||
|
||||
|
||||
class CogVideoXAdaLayerNorm(nn.Module):
|
||||
def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = operations.Linear(time_dim, 2 * dim, device=device, dtype=dtype)
|
||||
self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, temb):
|
||||
temb = self.linear(self.silu(temb))
|
||||
shift, scale = temb.chunk(2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
class CogVideoXBlock(nn.Module):
|
||||
def __init__(self, dim, num_heads, head_dim, time_dim,
|
||||
eps=1e-5, ff_inner_dim=None, ff_bias=True,
|
||||
device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.norm1 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
# Self-attention (joint text + latent)
|
||||
self.q = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
|
||||
self.k = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
|
||||
self.v = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
|
||||
self.norm_q = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype)
|
||||
self.norm_k = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype)
|
||||
self.attn_out = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
self.norm2 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
# Feed-forward (GELU approximate)
|
||||
inner_dim = ff_inner_dim or dim * 4
|
||||
self.ff_proj = operations.Linear(dim, inner_dim, bias=ff_bias, device=device, dtype=dtype)
|
||||
self.ff_out = operations.Linear(inner_dim, dim, bias=ff_bias, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=None, transformer_options={}):
|
||||
text_seq_length = encoder_hidden_states.size(1)
|
||||
|
||||
# Norm & modulate
|
||||
norm_hidden, norm_encoder, gate_msa, enc_gate_msa = self.norm1(hidden_states, encoder_hidden_states, temb)
|
||||
|
||||
# Joint self-attention
|
||||
qkv_input = torch.cat([norm_encoder, norm_hidden], dim=1)
|
||||
b, s, _ = qkv_input.shape
|
||||
n, d = self.num_heads, self.head_dim
|
||||
|
||||
q = self.q(qkv_input).view(b, s, n, d)
|
||||
k = self.k(qkv_input).view(b, s, n, d)
|
||||
v = self.v(qkv_input)
|
||||
|
||||
q = self.norm_q(q).view(b, s, n, d)
|
||||
k = self.norm_k(k).view(b, s, n, d)
|
||||
|
||||
# Apply rotary embeddings to image tokens only (diffusers format: [B, heads, seq, head_dim])
|
||||
if image_rotary_emb is not None:
|
||||
q_img = q[:, text_seq_length:].transpose(1, 2) # [B, heads, img_seq, head_dim]
|
||||
k_img = k[:, text_seq_length:].transpose(1, 2)
|
||||
q_img = apply_rotary_emb(q_img, image_rotary_emb)
|
||||
k_img = apply_rotary_emb(k_img, image_rotary_emb)
|
||||
q = torch.cat([q[:, :text_seq_length], q_img.transpose(1, 2)], dim=1)
|
||||
k = torch.cat([k[:, :text_seq_length], k_img.transpose(1, 2)], dim=1)
|
||||
|
||||
attn_out = optimized_attention(
|
||||
q.reshape(b, s, n * d),
|
||||
k.reshape(b, s, n * d),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
attn_out = self.attn_out(attn_out)
|
||||
|
||||
attn_encoder, attn_hidden = attn_out.split([text_seq_length, s - text_seq_length], dim=1)
|
||||
|
||||
hidden_states = hidden_states + gate_msa * attn_hidden
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder
|
||||
|
||||
# Norm & modulate for FF
|
||||
norm_hidden, norm_encoder, gate_ff, enc_gate_ff = self.norm2(hidden_states, encoder_hidden_states, temb)
|
||||
|
||||
# Feed-forward (GELU on concatenated text + latent)
|
||||
ff_input = torch.cat([norm_encoder, norm_hidden], dim=1)
|
||||
ff_output = self.ff_out(F.gelu(self.ff_proj(ff_input), approximate="tanh"))
|
||||
|
||||
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
||||
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class CogVideoXTransformer3DModel(nn.Module):
|
||||
def __init__(self,
|
||||
num_attention_heads=30,
|
||||
attention_head_dim=64,
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
flip_sin_to_cos=True,
|
||||
freq_shift=0,
|
||||
time_embed_dim=512,
|
||||
ofs_embed_dim=None,
|
||||
text_embed_dim=4096,
|
||||
num_layers=30,
|
||||
dropout=0.0,
|
||||
attention_bias=True,
|
||||
sample_width=90,
|
||||
sample_height=60,
|
||||
sample_frames=49,
|
||||
patch_size=2,
|
||||
patch_size_t=None,
|
||||
temporal_compression_ratio=4,
|
||||
max_text_seq_length=226,
|
||||
spatial_interpolation_scale=1.875,
|
||||
temporal_interpolation_scale=1.0,
|
||||
use_rotary_positional_embeddings=False,
|
||||
use_learned_positional_embeddings=False,
|
||||
patch_bias=True,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
dim = num_attention_heads * attention_head_dim
|
||||
self.dim = dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.patch_size = patch_size
|
||||
self.patch_size_t = patch_size_t
|
||||
self.max_text_seq_length = max_text_seq_length
|
||||
self.use_rotary_positional_embeddings = use_rotary_positional_embeddings
|
||||
|
||||
# 1. Patch embedding
|
||||
self.patch_embed = CogVideoXPatchEmbed(
|
||||
patch_size=patch_size,
|
||||
patch_size_t=patch_size_t,
|
||||
in_channels=in_channels,
|
||||
dim=dim,
|
||||
text_dim=text_embed_dim,
|
||||
bias=patch_bias,
|
||||
sample_width=sample_width,
|
||||
sample_height=sample_height,
|
||||
sample_frames=sample_frames,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
max_text_seq_length=max_text_seq_length,
|
||||
spatial_interpolation_scale=spatial_interpolation_scale,
|
||||
temporal_interpolation_scale=temporal_interpolation_scale,
|
||||
use_positional_embeddings=not use_rotary_positional_embeddings,
|
||||
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
||||
device=device, dtype=torch.float32, operations=operations,
|
||||
)
|
||||
|
||||
# 2. Time embedding
|
||||
self.time_proj_dim = dim
|
||||
self.time_proj_flip = flip_sin_to_cos
|
||||
self.time_proj_shift = freq_shift
|
||||
self.time_embedding_linear_1 = operations.Linear(dim, time_embed_dim, device=device, dtype=dtype)
|
||||
self.time_embedding_act = nn.SiLU()
|
||||
self.time_embedding_linear_2 = operations.Linear(time_embed_dim, time_embed_dim, device=device, dtype=dtype)
|
||||
|
||||
# Optional OFS embedding (CogVideoX 1.5 I2V)
|
||||
self.ofs_proj_dim = ofs_embed_dim
|
||||
if ofs_embed_dim:
|
||||
self.ofs_embedding_linear_1 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype)
|
||||
self.ofs_embedding_act = nn.SiLU()
|
||||
self.ofs_embedding_linear_2 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype)
|
||||
else:
|
||||
self.ofs_embedding_linear_1 = None
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
CogVideoXBlock(
|
||||
dim=dim,
|
||||
num_heads=num_attention_heads,
|
||||
head_dim=attention_head_dim,
|
||||
time_dim=time_embed_dim,
|
||||
eps=1e-5,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.norm_final = operations.LayerNorm(dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype)
|
||||
|
||||
# 4. Output
|
||||
self.norm_out = CogVideoXAdaLayerNorm(
|
||||
time_dim=time_embed_dim, dim=dim, eps=1e-5,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
if patch_size_t is None:
|
||||
output_dim = patch_size * patch_size * out_channels
|
||||
else:
|
||||
output_dim = patch_size * patch_size * patch_size_t * out_channels
|
||||
|
||||
self.proj_out = operations.Linear(dim, output_dim, device=device, dtype=dtype)
|
||||
|
||||
self.spatial_interpolation_scale = spatial_interpolation_scale
|
||||
self.temporal_interpolation_scale = temporal_interpolation_scale
|
||||
self.temporal_compression_ratio = temporal_compression_ratio
|
||||
|
||||
def forward(self, x, timestep, context, ofs=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, ofs, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, ofs=None, transformer_options={}, **kwargs):
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
if x.shape[1] > 16:
|
||||
lq_part = x[:, :16]
|
||||
ref_part = x[:, 16:]
|
||||
logger.warning(f"[CogVideoX] x: {x.shape}, t: {timestep.item():.0f}, ofs: {ofs}, LQ: mean={lq_part.float().mean():.4f} std={lq_part.float().std():.4f}, REF: mean={ref_part.float().mean():.4f} std={ref_part.float().std():.4f} nonzero={ref_part.count_nonzero().item()}")
|
||||
else:
|
||||
logger.warning(f"[CogVideoX] x: {x.shape}, t: {timestep.item():.0f}, ofs: {ofs}")
|
||||
|
||||
# ComfyUI passes [B, C, T, H, W]
|
||||
batch_size, channels, t, h, w = x.shape
|
||||
|
||||
# Pad to patch size (temporal + spatial), same pattern as WAN
|
||||
p_t = self.patch_size_t if self.patch_size_t is not None else 1
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (p_t, self.patch_size, self.patch_size))
|
||||
|
||||
# CogVideoX expects [B, T, C, H, W]
|
||||
x = x.permute(0, 2, 1, 3, 4)
|
||||
batch_size, num_frames, channels, height, width = x.shape
|
||||
|
||||
# Time embedding
|
||||
t_emb = get_timestep_embedding(timestep, self.time_proj_dim, self.time_proj_flip, self.time_proj_shift)
|
||||
t_emb = t_emb.to(dtype=x.dtype)
|
||||
emb = self.time_embedding_linear_2(self.time_embedding_act(self.time_embedding_linear_1(t_emb)))
|
||||
|
||||
if self.ofs_embedding_linear_1 is not None and ofs is not None:
|
||||
ofs_emb = get_timestep_embedding(ofs, self.ofs_proj_dim, self.time_proj_flip, self.time_proj_shift)
|
||||
ofs_emb = ofs_emb.to(dtype=x.dtype)
|
||||
ofs_emb = self.ofs_embedding_linear_2(self.ofs_embedding_act(self.ofs_embedding_linear_1(ofs_emb)))
|
||||
emb = emb + ofs_emb
|
||||
|
||||
# Patch embedding
|
||||
hidden_states = self.patch_embed(context, x)
|
||||
|
||||
text_seq_length = context.shape[1]
|
||||
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
||||
hidden_states = hidden_states[:, text_seq_length:]
|
||||
|
||||
# Rotary embeddings (if used)
|
||||
image_rotary_emb = None
|
||||
if self.use_rotary_positional_embeddings:
|
||||
post_patch_height = height // self.patch_size
|
||||
post_patch_width = width // self.patch_size
|
||||
if self.patch_size_t is None:
|
||||
post_time = num_frames
|
||||
else:
|
||||
post_time = num_frames // self.patch_size_t
|
||||
image_rotary_emb = self._get_rotary_emb(post_patch_height, post_patch_width, post_time, device=x.device)
|
||||
|
||||
# Transformer blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=emb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
|
||||
# Output projection
|
||||
hidden_states = self.norm_out(hidden_states, temb=emb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# Unpatchify
|
||||
p = self.patch_size
|
||||
p_t = self.patch_size_t
|
||||
|
||||
if p_t is None:
|
||||
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
||||
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
||||
else:
|
||||
output = hidden_states.reshape(
|
||||
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
||||
)
|
||||
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
||||
|
||||
# Back to ComfyUI format [B, C, T, H, W] and crop padding
|
||||
output = output.permute(0, 2, 1, 3, 4)[:, :, :t, :h, :w]
|
||||
logger.warning(f"[CogVideoX] output: {output.shape}, mean={output.float().mean():.4f}, std={output.float().std():.4f}, min={output.float().min():.4f}, max={output.float().max():.4f}")
|
||||
return output
|
||||
|
||||
def _get_rotary_emb(self, h, w, t, device):
|
||||
"""Compute CogVideoX 3D rotary positional embeddings.
|
||||
|
||||
For CogVideoX 1.5 (patch_size_t != None): uses "slice" mode — grid positions
|
||||
are integer arange computed at max_size, then sliced to actual size.
|
||||
For CogVideoX 1.0 (patch_size_t == None): uses "linspace" mode with crop coords
|
||||
scaled by spatial_interpolation_scale.
|
||||
"""
|
||||
d = self.attention_head_dim
|
||||
dim_t = d // 4
|
||||
dim_h = d // 8 * 3
|
||||
dim_w = d // 8 * 3
|
||||
|
||||
if self.patch_size_t is not None:
|
||||
# CogVideoX 1.5: "slice" mode — positions are simple integer indices
|
||||
# Compute at max(sample_size, actual_size) then slice to actual
|
||||
base_h = self.patch_embed.sample_height // self.patch_size
|
||||
base_w = self.patch_embed.sample_width // self.patch_size
|
||||
max_h = max(base_h, h)
|
||||
max_w = max(base_w, w)
|
||||
|
||||
grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
|
||||
grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
|
||||
grid_t = torch.arange(t, device=device, dtype=torch.float32)
|
||||
else:
|
||||
# CogVideoX 1.0: "linspace" mode with interpolation scale
|
||||
grid_h = torch.linspace(0, h - 1, h, device=device, dtype=torch.float32) * self.spatial_interpolation_scale
|
||||
grid_w = torch.linspace(0, w - 1, w, device=device, dtype=torch.float32) * self.spatial_interpolation_scale
|
||||
grid_t = torch.arange(t, device=device, dtype=torch.float32)
|
||||
|
||||
freqs_t = _get_1d_rotary_pos_embed(dim_t, grid_t)
|
||||
freqs_h = _get_1d_rotary_pos_embed(dim_h, grid_h)
|
||||
freqs_w = _get_1d_rotary_pos_embed(dim_w, grid_w)
|
||||
|
||||
t_cos, t_sin = freqs_t
|
||||
h_cos, h_sin = freqs_h
|
||||
w_cos, w_sin = freqs_w
|
||||
|
||||
# Slice to actual size (for "slice" mode where grids may be larger)
|
||||
t_cos, t_sin = t_cos[:t], t_sin[:t]
|
||||
h_cos, h_sin = h_cos[:h], h_sin[:h]
|
||||
w_cos, w_sin = w_cos[:w], w_sin[:w]
|
||||
|
||||
# Broadcast and concatenate into [T*H*W, head_dim]
|
||||
t_cos = t_cos[:, None, None, :].expand(-1, h, w, -1)
|
||||
t_sin = t_sin[:, None, None, :].expand(-1, h, w, -1)
|
||||
h_cos = h_cos[None, :, None, :].expand(t, -1, w, -1)
|
||||
h_sin = h_sin[None, :, None, :].expand(t, -1, w, -1)
|
||||
w_cos = w_cos[None, None, :, :].expand(t, h, -1, -1)
|
||||
w_sin = w_sin[None, None, :, :].expand(t, h, -1, -1)
|
||||
|
||||
cos = torch.cat([t_cos, h_cos, w_cos], dim=-1).reshape(t * h * w, -1)
|
||||
sin = torch.cat([t_sin, h_sin, w_sin], dim=-1).reshape(t * h * w, -1)
|
||||
return (cos, sin)
|
||||
570
comfy/ldm/cogvideo/vae.py
Normal file
570
comfy/ldm/cogvideo/vae.py
Normal file
@ -0,0 +1,570 @@
|
||||
# CogVideoX VAE - ported to ComfyUI native ops
|
||||
# Architecture reference: diffusers AutoencoderKLCogVideoX
|
||||
# Style reference: comfy/ldm/wan/vae.py
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
"""Causal 3D convolution with temporal padding.
|
||||
|
||||
Uses comfy.ops.Conv3d with autopad='causal_zero' fast path: when input has
|
||||
a single temporal frame and no cache, the 3D conv weight is sliced to act
|
||||
as a 2D conv, avoiding computation on zero-padded temporal dimensions.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, pad_mode="constant"):
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size,) * 3
|
||||
|
||||
time_kernel, height_kernel, width_kernel = kernel_size
|
||||
self.time_kernel_size = time_kernel
|
||||
self.pad_mode = pad_mode
|
||||
|
||||
height_pad = (height_kernel - 1) // 2
|
||||
width_pad = (width_kernel - 1) // 2
|
||||
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_kernel - 1, 0)
|
||||
|
||||
stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
|
||||
dilation = (dilation, 1, 1)
|
||||
self.conv = ops.Conv3d(
|
||||
in_channels, out_channels, kernel_size,
|
||||
stride=stride, dilation=dilation,
|
||||
padding=(0, height_pad, width_pad),
|
||||
)
|
||||
|
||||
def forward(self, x, conv_cache=None):
|
||||
if self.pad_mode == "replicate":
|
||||
x = F.pad(x, self.time_causal_padding, mode="replicate")
|
||||
conv_cache = None
|
||||
else:
|
||||
kernel_t = self.time_kernel_size
|
||||
if kernel_t > 1:
|
||||
if conv_cache is None and x.shape[2] == 1:
|
||||
# Fast path: single frame, no cache. All temporal padding
|
||||
# frames are copies of the input (replicate-style), so the
|
||||
# 3D conv reduces to a 2D conv with summed temporal kernel.
|
||||
w = comfy.ops.cast_to_input(self.conv.weight, x)
|
||||
b = comfy.ops.cast_to_input(self.conv.bias, x) if self.conv.bias is not None else None
|
||||
w2d = w.sum(dim=2, keepdim=True)
|
||||
out = F.conv3d(x, w2d, b,
|
||||
self.conv.stride, self.conv.padding,
|
||||
self.conv.dilation, self.conv.groups)
|
||||
return out, None
|
||||
cached = [conv_cache] if conv_cache is not None else [x[:, :, :1]] * (kernel_t - 1)
|
||||
x = torch.cat(cached + [x], dim=2)
|
||||
conv_cache = x[:, :, -self.time_kernel_size + 1:].clone() if self.time_kernel_size > 1 else None
|
||||
|
||||
out = self.conv(x)
|
||||
return out, conv_cache
|
||||
|
||||
|
||||
def _interpolate_zq(zq, target_size):
|
||||
"""Interpolate latent z to target (T, H, W), matching CogVideoX's first-frame-special handling."""
|
||||
t = target_size[0]
|
||||
if t > 1 and t % 2 == 1:
|
||||
z_first = F.interpolate(zq[:, :, :1], size=(1, target_size[1], target_size[2]))
|
||||
z_rest = F.interpolate(zq[:, :, 1:], size=(t - 1, target_size[1], target_size[2]))
|
||||
return torch.cat([z_first, z_rest], dim=2)
|
||||
return F.interpolate(zq, size=target_size)
|
||||
|
||||
|
||||
class SpatialNorm3D(nn.Module):
|
||||
"""Spatially conditioned normalization."""
|
||||
def __init__(self, f_channels, zq_channels, groups=32):
|
||||
super().__init__()
|
||||
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
||||
self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
|
||||
def forward(self, f, zq, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
if zq.shape[-3:] != f.shape[-3:]:
|
||||
zq = _interpolate_zq(zq, f.shape[-3:])
|
||||
|
||||
conv_y, new_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
|
||||
conv_b, new_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
|
||||
|
||||
return self.norm_layer(f) * conv_y + conv_b, new_cache
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
"""3D ResNet block with optional spatial norm."""
|
||||
def __init__(self, in_channels, out_channels=None, temb_channels=512, groups=32,
|
||||
eps=1e-6, act_fn="silu", spatial_norm_dim=None, pad_mode="first"):
|
||||
super().__init__()
|
||||
out_channels = out_channels or in_channels
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.spatial_norm_dim = spatial_norm_dim
|
||||
|
||||
if act_fn == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
elif act_fn == "swish":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
else:
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
if spatial_norm_dim is None:
|
||||
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
||||
self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
|
||||
else:
|
||||
self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups)
|
||||
self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups)
|
||||
|
||||
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
||||
|
||||
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = ops.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.conv_shortcut = None
|
||||
|
||||
def forward(self, x, temb=None, zq=None, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
residual = x
|
||||
|
||||
if zq is not None:
|
||||
x, new_cache["norm1"] = self.norm1(x, zq, conv_cache=conv_cache.get("norm1"))
|
||||
else:
|
||||
x = self.norm1(x)
|
||||
|
||||
x = self.nonlinearity(x)
|
||||
x, new_cache["conv1"] = self.conv1(x, conv_cache=conv_cache.get("conv1"))
|
||||
|
||||
if temb is not None and hasattr(self, "temb_proj"):
|
||||
x = x + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
if zq is not None:
|
||||
x, new_cache["norm2"] = self.norm2(x, zq, conv_cache=conv_cache.get("norm2"))
|
||||
else:
|
||||
x = self.norm2(x)
|
||||
|
||||
x = self.nonlinearity(x)
|
||||
x, new_cache["conv2"] = self.conv2(x, conv_cache=conv_cache.get("conv2"))
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
residual = self.conv_shortcut(residual)
|
||||
|
||||
return x + residual, new_cache
|
||||
|
||||
|
||||
class Downsample3D(nn.Module):
|
||||
"""3D downsampling with optional temporal compression."""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
if self.compress_time:
|
||||
b, c, t, h, w = x.shape
|
||||
x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t)
|
||||
if t % 2 == 1:
|
||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||
if x_rest.shape[-1] > 0:
|
||||
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||
x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
else:
|
||||
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
|
||||
pad = (0, 1, 0, 1)
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
b, c, t, h, w = x.shape
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = self.conv(x)
|
||||
x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample3D(nn.Module):
|
||||
"""3D upsampling with optional temporal decompression."""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
if self.compress_time:
|
||||
if x.shape[2] > 1 and x.shape[2] % 2 == 1:
|
||||
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
|
||||
x_first = F.interpolate(x_first, scale_factor=2.0)
|
||||
x_rest = F.interpolate(x_rest, scale_factor=2.0)
|
||||
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
|
||||
elif x.shape[2] > 1:
|
||||
x = F.interpolate(x, scale_factor=2.0)
|
||||
else:
|
||||
x = x.squeeze(2)
|
||||
x = F.interpolate(x, scale_factor=2.0)
|
||||
x = x[:, :, None, :, :]
|
||||
else:
|
||||
b, c, t, h, w = x.shape
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = F.interpolate(x, scale_factor=2.0)
|
||||
x = x.reshape(b, t, c, *x.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
b, c, t, h, w = x.shape
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = self.conv(x)
|
||||
x = x.reshape(b, t, *x.shape[1:]).permute(0, 2, 1, 3, 4)
|
||||
return x
|
||||
|
||||
|
||||
class DownBlock3D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1,
|
||||
eps=1e-6, act_fn="silu", groups=32, add_downsample=True,
|
||||
compress_time=False, pad_mode="first"):
|
||||
super().__init__()
|
||||
self.resnets = nn.ModuleList([
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels if i == 0 else out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
groups=groups, eps=eps, act_fn=act_fn, pad_mode=pad_mode,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
self.downsamplers = nn.ModuleList([Downsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_downsample else None
|
||||
|
||||
def forward(self, x, temb=None, zq=None, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
|
||||
if self.downsamplers is not None:
|
||||
for ds in self.downsamplers:
|
||||
x = ds(x)
|
||||
return x, new_cache
|
||||
|
||||
|
||||
class MidBlock3D(nn.Module):
|
||||
def __init__(self, in_channels, temb_channels=0, num_layers=1,
|
||||
eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=None, pad_mode="first"):
|
||||
super().__init__()
|
||||
self.resnets = nn.ModuleList([
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels, out_channels=in_channels,
|
||||
temb_channels=temb_channels, groups=groups, eps=eps,
|
||||
act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, x, temb=None, zq=None, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
|
||||
return x, new_cache
|
||||
|
||||
|
||||
class UpBlock3D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1,
|
||||
eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=16,
|
||||
add_upsample=True, compress_time=False, pad_mode="first"):
|
||||
super().__init__()
|
||||
self.resnets = nn.ModuleList([
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels if i == 0 else out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels, groups=groups, eps=eps,
|
||||
act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_upsample else None
|
||||
|
||||
def forward(self, x, temb=None, zq=None, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
|
||||
if self.upsamplers is not None:
|
||||
for us in self.upsamplers:
|
||||
x = us(x)
|
||||
return x, new_cache
|
||||
|
||||
|
||||
class Encoder3D(nn.Module):
|
||||
def __init__(self, in_channels=3, out_channels=16,
|
||||
block_out_channels=(128, 256, 256, 512),
|
||||
layers_per_block=3, act_fn="silu",
|
||||
eps=1e-6, groups=32, pad_mode="first",
|
||||
temporal_compression_ratio=4):
|
||||
super().__init__()
|
||||
temporal_compress_level = int(np.log2(temporal_compression_ratio))
|
||||
|
||||
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
self.down_blocks = nn.ModuleList()
|
||||
output_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final = i == len(block_out_channels) - 1
|
||||
compress_time = i < temporal_compress_level
|
||||
|
||||
self.down_blocks.append(DownBlock3D(
|
||||
in_channels=input_channel, out_channels=output_channel,
|
||||
temb_channels=0, num_layers=layers_per_block,
|
||||
eps=eps, act_fn=act_fn, groups=groups,
|
||||
add_downsample=not is_final, compress_time=compress_time,
|
||||
))
|
||||
|
||||
self.mid_block = MidBlock3D(
|
||||
in_channels=block_out_channels[-1], temb_channels=0,
|
||||
num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
self.norm_out = nn.GroupNorm(groups, block_out_channels[-1], eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, x, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
x, new_cache["conv_in"] = self.conv_in(x, conv_cache=conv_cache.get("conv_in"))
|
||||
|
||||
for i, block in enumerate(self.down_blocks):
|
||||
key = f"down_block_{i}"
|
||||
x, new_cache[key] = block(x, None, None, conv_cache.get(key))
|
||||
|
||||
x, new_cache["mid_block"] = self.mid_block(x, None, None, conv_cache=conv_cache.get("mid_block"))
|
||||
|
||||
x = self.norm_out(x)
|
||||
x = self.conv_act(x)
|
||||
x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out"))
|
||||
|
||||
return x, new_cache
|
||||
|
||||
|
||||
class Decoder3D(nn.Module):
|
||||
def __init__(self, in_channels=16, out_channels=3,
|
||||
block_out_channels=(128, 256, 256, 512),
|
||||
layers_per_block=3, act_fn="silu",
|
||||
eps=1e-6, groups=32, pad_mode="first",
|
||||
temporal_compression_ratio=4):
|
||||
super().__init__()
|
||||
reversed_channels = list(reversed(block_out_channels))
|
||||
temporal_compress_level = int(np.log2(temporal_compression_ratio))
|
||||
|
||||
self.conv_in = CausalConv3d(in_channels, reversed_channels[0], kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
self.mid_block = MidBlock3D(
|
||||
in_channels=reversed_channels[0], temb_channels=0,
|
||||
num_layers=2, eps=eps, act_fn=act_fn, groups=groups,
|
||||
spatial_norm_dim=in_channels, pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
self.up_blocks = nn.ModuleList()
|
||||
output_channel = reversed_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
prev_channel = output_channel
|
||||
output_channel = reversed_channels[i]
|
||||
is_final = i == len(block_out_channels) - 1
|
||||
compress_time = i < temporal_compress_level
|
||||
|
||||
self.up_blocks.append(UpBlock3D(
|
||||
in_channels=prev_channel, out_channels=output_channel,
|
||||
temb_channels=0, num_layers=layers_per_block + 1,
|
||||
eps=eps, act_fn=act_fn, groups=groups,
|
||||
spatial_norm_dim=in_channels,
|
||||
add_upsample=not is_final, compress_time=compress_time,
|
||||
))
|
||||
|
||||
self.norm_out = SpatialNorm3D(reversed_channels[-1], in_channels, groups=groups)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(reversed_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, sample, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
x, new_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
||||
|
||||
x, new_cache["mid_block"] = self.mid_block(x, None, sample, conv_cache=conv_cache.get("mid_block"))
|
||||
|
||||
for i, block in enumerate(self.up_blocks):
|
||||
key = f"up_block_{i}"
|
||||
x, new_cache[key] = block(x, None, sample, conv_cache=conv_cache.get(key))
|
||||
|
||||
x, new_cache["norm_out"] = self.norm_out(x, sample, conv_cache=conv_cache.get("norm_out"))
|
||||
x = self.conv_act(x)
|
||||
x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out"))
|
||||
|
||||
return x, new_cache
|
||||
|
||||
|
||||
|
||||
class AutoencoderKLCogVideoX(nn.Module):
|
||||
"""CogVideoX VAE. Spatial tiling/slicing handled by ComfyUI's VAE wrapper.
|
||||
|
||||
Uses rolling temporal decode: conv_in + mid_block + temporal up_blocks run
|
||||
on the full (low-res) tensor, then the expensive spatial-only up_blocks +
|
||||
norm_out + conv_out are processed in small temporal chunks with conv_cache
|
||||
carrying causal state between chunks. This keeps peak VRAM proportional to
|
||||
chunk_size rather than total frame count.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3, out_channels=3,
|
||||
block_out_channels=(128, 256, 256, 512),
|
||||
latent_channels=16, layers_per_block=3,
|
||||
act_fn="silu", eps=1e-6, groups=32,
|
||||
temporal_compression_ratio=4,
|
||||
):
|
||||
super().__init__()
|
||||
self.latent_channels = latent_channels
|
||||
self.temporal_compression_ratio = temporal_compression_ratio
|
||||
|
||||
self.encoder = Encoder3D(
|
||||
in_channels=in_channels, out_channels=latent_channels,
|
||||
block_out_channels=block_out_channels, layers_per_block=layers_per_block,
|
||||
act_fn=act_fn, eps=eps, groups=groups,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
)
|
||||
self.decoder = Decoder3D(
|
||||
in_channels=latent_channels, out_channels=out_channels,
|
||||
block_out_channels=block_out_channels, layers_per_block=layers_per_block,
|
||||
act_fn=act_fn, eps=eps, groups=groups,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
)
|
||||
|
||||
self.num_latent_frames_batch_size = 2
|
||||
self.num_sample_frames_batch_size = 8
|
||||
|
||||
def encode(self, x):
|
||||
t = x.shape[2]
|
||||
frame_batch = self.num_sample_frames_batch_size
|
||||
num_batches = max(t // frame_batch, 1)
|
||||
conv_cache = None
|
||||
enc = []
|
||||
for i in range(num_batches):
|
||||
remaining = t % frame_batch
|
||||
start = frame_batch * i + (0 if i == 0 else remaining)
|
||||
end = frame_batch * (i + 1) + remaining
|
||||
chunk, conv_cache = self.encoder(x[:, :, start:end], conv_cache=conv_cache)
|
||||
enc.append(chunk.to(x.device))
|
||||
enc = torch.cat(enc, dim=2)
|
||||
mean, _ = enc.chunk(2, dim=1)
|
||||
return mean
|
||||
|
||||
def decode(self, z):
|
||||
return self._decode_rolling(z)
|
||||
|
||||
def _decode_batched(self, z):
|
||||
"""Original batched decode - processes 2 latent frames through full decoder."""
|
||||
t = z.shape[2]
|
||||
frame_batch = self.num_latent_frames_batch_size
|
||||
num_batches = max(t // frame_batch, 1)
|
||||
conv_cache = None
|
||||
dec = []
|
||||
for i in range(num_batches):
|
||||
remaining = t % frame_batch
|
||||
start = frame_batch * i + (0 if i == 0 else remaining)
|
||||
end = frame_batch * (i + 1) + remaining
|
||||
chunk, conv_cache = self.decoder(z[:, :, start:end], conv_cache=conv_cache)
|
||||
dec.append(chunk.cpu())
|
||||
return torch.cat(dec, dim=2).to(z.device)
|
||||
|
||||
def _decode_rolling(self, z):
|
||||
"""Rolling decode - processes low-res layers on full tensor, then rolls
|
||||
through expensive high-res layers in temporal chunks."""
|
||||
decoder = self.decoder
|
||||
device = z.device
|
||||
|
||||
# Determine which up_blocks have temporal upsample vs spatial-only.
|
||||
# Temporal up_blocks are cheap (low res), spatial-only are expensive.
|
||||
temporal_compress_level = int(np.log2(self.temporal_compression_ratio))
|
||||
split_at = temporal_compress_level # first N up_blocks do temporal upsample
|
||||
|
||||
# Phase 1: conv_in + mid_block + temporal up_blocks on full tensor (low/medium res)
|
||||
x, _ = decoder.conv_in(z)
|
||||
x, _ = decoder.mid_block(x, None, z)
|
||||
|
||||
for i in range(split_at):
|
||||
x, _ = decoder.up_blocks[i](x, None, z)
|
||||
|
||||
# Phase 2: remaining spatial-only up_blocks + norm_out + conv_out in temporal chunks
|
||||
remaining_blocks = list(range(split_at, len(decoder.up_blocks)))
|
||||
chunk_size = 4 # pixel frames per chunk through high-res layers
|
||||
t_expanded = x.shape[2]
|
||||
|
||||
if t_expanded <= chunk_size or len(remaining_blocks) == 0:
|
||||
# Small enough to process in one go
|
||||
for i in remaining_blocks:
|
||||
x, _ = decoder.up_blocks[i](x, None, z)
|
||||
x, _ = decoder.norm_out(x, z)
|
||||
x = decoder.conv_act(x)
|
||||
x, _ = decoder.conv_out(x)
|
||||
return x
|
||||
|
||||
# Pre-interpolate z to each spatial resolution used by Phase 2 blocks.
|
||||
# Uses the exact same interpolation logic as SpatialNorm3D so chunked
|
||||
# output is identical to non-chunked.
|
||||
# Determine spatial sizes: run a dummy pass to find feature map sizes,
|
||||
# or compute from block structure. Simpler: compute from x's current size
|
||||
# and the known upsample factor (2x per block with upsample).
|
||||
z_at_res = {} # keyed by (h, w) → pre-interpolated z [B, C, t_expanded, h, w]
|
||||
h, w = x.shape[3], x.shape[4]
|
||||
for i in remaining_blocks:
|
||||
block = decoder.up_blocks[i]
|
||||
# Resnets operate at current h, w
|
||||
target = (t_expanded, h, w)
|
||||
if target not in z_at_res:
|
||||
z_at_res[target] = _interpolate_zq(z, target)
|
||||
# If block has upsample, next block's input is 2x spatial
|
||||
if block.upsamplers is not None:
|
||||
h, w = h * 2, w * 2
|
||||
# norm_out operates at final resolution
|
||||
target = (t_expanded, h, w)
|
||||
if target not in z_at_res:
|
||||
z_at_res[target] = _interpolate_zq(z, target)
|
||||
|
||||
# Process in temporal chunks
|
||||
dec_out = []
|
||||
conv_caches = {}
|
||||
|
||||
for chunk_start in range(0, t_expanded, chunk_size):
|
||||
chunk_end = min(chunk_start + chunk_size, t_expanded)
|
||||
x_chunk = x[:, :, chunk_start:chunk_end]
|
||||
|
||||
for i in remaining_blocks:
|
||||
block = decoder.up_blocks[i]
|
||||
cache_key = f"up_block_{i}"
|
||||
# Get pre-interpolated z at the block's input spatial resolution
|
||||
res_key = (t_expanded, x_chunk.shape[3], x_chunk.shape[4])
|
||||
z_chunk = z_at_res[res_key][:, :, chunk_start:chunk_end]
|
||||
x_chunk, new_cache = block(x_chunk, None, z_chunk, conv_cache=conv_caches.get(cache_key))
|
||||
conv_caches[cache_key] = new_cache
|
||||
|
||||
# norm_out at final resolution
|
||||
res_key = (t_expanded, x_chunk.shape[3], x_chunk.shape[4])
|
||||
z_chunk = z_at_res[res_key][:, :, chunk_start:chunk_end]
|
||||
x_chunk, new_cache = decoder.norm_out(x_chunk, z_chunk, conv_cache=conv_caches.get("norm_out"))
|
||||
conv_caches["norm_out"] = new_cache
|
||||
x_chunk = decoder.conv_act(x_chunk)
|
||||
x_chunk, new_cache = decoder.conv_out(x_chunk, conv_cache=conv_caches.get("conv_out"))
|
||||
conv_caches["conv_out"] = new_cache
|
||||
|
||||
dec_out.append(x_chunk.cpu())
|
||||
|
||||
del x
|
||||
return torch.cat(dec_out, dim=2).to(device)
|
||||
485
comfy/ldm/cogvideo/vae_backup.py
Normal file
485
comfy/ldm/cogvideo/vae_backup.py
Normal file
@ -0,0 +1,485 @@
|
||||
# CogVideoX VAE - ported to ComfyUI native ops
|
||||
# Architecture reference: diffusers AutoencoderKLCogVideoX
|
||||
# Style reference: comfy/ldm/wan/vae.py
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class SafeConv3d(nn.Conv3d):
|
||||
"""3D convolution that splits large inputs along temporal dim to avoid OOM."""
|
||||
def forward(self, x):
|
||||
mem = x.shape[0] * x.shape[1] * x.shape[2] * x.shape[3] * x.shape[4] * 2 / 1024**3
|
||||
if mem > 2 and x.shape[2] >= self.kernel_size[0]:
|
||||
kernel_t = self.kernel_size[0]
|
||||
parts = int(mem / 2) + 1
|
||||
# Ensure each chunk has at least kernel_t frames
|
||||
max_parts = max(1, x.shape[2] // kernel_t)
|
||||
parts = min(parts, max_parts)
|
||||
if parts <= 1:
|
||||
return super().forward(x)
|
||||
chunks = torch.chunk(x, parts, dim=2)
|
||||
if kernel_t > 1:
|
||||
chunks = [chunks[0]] + [
|
||||
torch.cat((chunks[i - 1][:, :, -kernel_t + 1:], chunks[i]), dim=2)
|
||||
for i in range(1, len(chunks))
|
||||
]
|
||||
out = []
|
||||
for chunk in chunks:
|
||||
out.append(super().forward(chunk))
|
||||
return torch.cat(out, dim=2)
|
||||
return super().forward(x)
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
"""Causal 3D convolution with temporal padding."""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, pad_mode="constant"):
|
||||
super().__init__()
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size = (kernel_size,) * 3
|
||||
|
||||
time_kernel, height_kernel, width_kernel = kernel_size
|
||||
time_pad = time_kernel - 1
|
||||
height_pad = (height_kernel - 1) // 2
|
||||
width_pad = (width_kernel - 1) // 2
|
||||
|
||||
self.pad_mode = pad_mode
|
||||
self.time_pad = time_pad
|
||||
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
|
||||
self.const_padding = (0, width_pad, height_pad)
|
||||
self.time_kernel_size = time_kernel
|
||||
|
||||
stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
|
||||
dilation = (dilation, 1, 1)
|
||||
self.conv = SafeConv3d(
|
||||
in_channels, out_channels, kernel_size,
|
||||
stride=stride, dilation=dilation,
|
||||
padding=0 if pad_mode == "replicate" else self.const_padding,
|
||||
)
|
||||
|
||||
def forward(self, x, conv_cache=None):
|
||||
if self.pad_mode == "replicate":
|
||||
x = F.pad(x, self.time_causal_padding, mode="replicate")
|
||||
conv_cache = None
|
||||
else:
|
||||
kernel_t = self.time_kernel_size
|
||||
if kernel_t > 1:
|
||||
cached = [conv_cache] if conv_cache is not None else [x[:, :, :1]] * (kernel_t - 1)
|
||||
x = torch.cat(cached + [x], dim=2)
|
||||
conv_cache = x[:, :, -self.time_kernel_size + 1:].clone() if self.time_kernel_size > 1 else None
|
||||
|
||||
out = self.conv(x)
|
||||
return out, conv_cache
|
||||
|
||||
|
||||
class SpatialNorm3D(nn.Module):
|
||||
"""Spatially conditioned normalization."""
|
||||
def __init__(self, f_channels, zq_channels, groups=32):
|
||||
super().__init__()
|
||||
self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
||||
self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
|
||||
def forward(self, f, zq, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
|
||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||
z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
|
||||
z_first = F.interpolate(z_first, size=f_first.shape[-3:])
|
||||
z_rest = F.interpolate(z_rest, size=f_rest.shape[-3:])
|
||||
zq = torch.cat([z_first, z_rest], dim=2)
|
||||
else:
|
||||
zq = F.interpolate(zq, size=f.shape[-3:])
|
||||
|
||||
conv_y, new_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
|
||||
conv_b, new_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
|
||||
|
||||
return self.norm_layer(f) * conv_y + conv_b, new_cache
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
"""3D ResNet block with optional spatial norm."""
|
||||
def __init__(self, in_channels, out_channels=None, temb_channels=512, groups=32,
|
||||
eps=1e-6, act_fn="silu", spatial_norm_dim=None, pad_mode="first"):
|
||||
super().__init__()
|
||||
out_channels = out_channels or in_channels
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.spatial_norm_dim = spatial_norm_dim
|
||||
|
||||
if act_fn == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
elif act_fn == "swish":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
else:
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
if spatial_norm_dim is None:
|
||||
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
||||
self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
|
||||
else:
|
||||
self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups)
|
||||
self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups)
|
||||
|
||||
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
||||
|
||||
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = SafeConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.conv_shortcut = None
|
||||
|
||||
def forward(self, x, temb=None, zq=None, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
residual = x
|
||||
|
||||
if zq is not None:
|
||||
x, new_cache["norm1"] = self.norm1(x, zq, conv_cache=conv_cache.get("norm1"))
|
||||
else:
|
||||
x = self.norm1(x)
|
||||
|
||||
x = self.nonlinearity(x)
|
||||
x, new_cache["conv1"] = self.conv1(x, conv_cache=conv_cache.get("conv1"))
|
||||
|
||||
if temb is not None and hasattr(self, "temb_proj"):
|
||||
x = x + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
if zq is not None:
|
||||
x, new_cache["norm2"] = self.norm2(x, zq, conv_cache=conv_cache.get("norm2"))
|
||||
else:
|
||||
x = self.norm2(x)
|
||||
|
||||
x = self.nonlinearity(x)
|
||||
x, new_cache["conv2"] = self.conv2(x, conv_cache=conv_cache.get("conv2"))
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
residual = self.conv_shortcut(residual)
|
||||
|
||||
return x + residual, new_cache
|
||||
|
||||
|
||||
class Downsample3D(nn.Module):
|
||||
"""3D downsampling with optional temporal compression."""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
if self.compress_time:
|
||||
b, c, t, h, w = x.shape
|
||||
x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t)
|
||||
if t % 2 == 1:
|
||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||
if x_rest.shape[-1] > 0:
|
||||
x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||
x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
else:
|
||||
x = F.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
|
||||
pad = (0, 1, 0, 1)
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
b, c, t, h, w = x.shape
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = self.conv(x)
|
||||
x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample3D(nn.Module):
|
||||
"""3D upsampling with optional temporal decompression."""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x):
|
||||
if self.compress_time:
|
||||
if x.shape[2] > 1 and x.shape[2] % 2 == 1:
|
||||
x_first, x_rest = x[:, :, 0], x[:, :, 1:]
|
||||
x_first = F.interpolate(x_first, scale_factor=2.0)
|
||||
x_rest = F.interpolate(x_rest, scale_factor=2.0)
|
||||
x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2)
|
||||
elif x.shape[2] > 1:
|
||||
x = F.interpolate(x, scale_factor=2.0)
|
||||
else:
|
||||
x = x.squeeze(2)
|
||||
x = F.interpolate(x, scale_factor=2.0)
|
||||
x = x[:, :, None, :, :]
|
||||
else:
|
||||
b, c, t, h, w = x.shape
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = F.interpolate(x, scale_factor=2.0)
|
||||
x = x.reshape(b, t, c, *x.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
b, c, t, h, w = x.shape
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
x = self.conv(x)
|
||||
x = x.reshape(b, t, *x.shape[1:]).permute(0, 2, 1, 3, 4)
|
||||
return x
|
||||
|
||||
|
||||
class DownBlock3D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1,
|
||||
eps=1e-6, act_fn="silu", groups=32, add_downsample=True,
|
||||
compress_time=False, pad_mode="first"):
|
||||
super().__init__()
|
||||
self.resnets = nn.ModuleList([
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels if i == 0 else out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
groups=groups, eps=eps, act_fn=act_fn, pad_mode=pad_mode,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
self.downsamplers = nn.ModuleList([Downsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_downsample else None
|
||||
|
||||
def forward(self, x, temb=None, zq=None, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
|
||||
if self.downsamplers is not None:
|
||||
for ds in self.downsamplers:
|
||||
x = ds(x)
|
||||
return x, new_cache
|
||||
|
||||
|
||||
class MidBlock3D(nn.Module):
|
||||
def __init__(self, in_channels, temb_channels=0, num_layers=1,
|
||||
eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=None, pad_mode="first"):
|
||||
super().__init__()
|
||||
self.resnets = nn.ModuleList([
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels, out_channels=in_channels,
|
||||
temb_channels=temb_channels, groups=groups, eps=eps,
|
||||
act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, x, temb=None, zq=None, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
|
||||
return x, new_cache
|
||||
|
||||
|
||||
class UpBlock3D(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1,
|
||||
eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=16,
|
||||
add_upsample=True, compress_time=False, pad_mode="first"):
|
||||
super().__init__()
|
||||
self.resnets = nn.ModuleList([
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels if i == 0 else out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels, groups=groups, eps=eps,
|
||||
act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_upsample else None
|
||||
|
||||
def forward(self, x, temb=None, zq=None, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
for i, resnet in enumerate(self.resnets):
|
||||
x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}"))
|
||||
if self.upsamplers is not None:
|
||||
for us in self.upsamplers:
|
||||
x = us(x)
|
||||
return x, new_cache
|
||||
|
||||
|
||||
class Encoder3D(nn.Module):
|
||||
def __init__(self, in_channels=3, out_channels=16,
|
||||
block_out_channels=(128, 256, 256, 512),
|
||||
layers_per_block=3, act_fn="silu",
|
||||
eps=1e-6, groups=32, pad_mode="first",
|
||||
temporal_compression_ratio=4):
|
||||
super().__init__()
|
||||
temporal_compress_level = int(np.log2(temporal_compression_ratio))
|
||||
|
||||
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
self.down_blocks = nn.ModuleList()
|
||||
output_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final = i == len(block_out_channels) - 1
|
||||
compress_time = i < temporal_compress_level
|
||||
|
||||
self.down_blocks.append(DownBlock3D(
|
||||
in_channels=input_channel, out_channels=output_channel,
|
||||
temb_channels=0, num_layers=layers_per_block,
|
||||
eps=eps, act_fn=act_fn, groups=groups,
|
||||
add_downsample=not is_final, compress_time=compress_time,
|
||||
))
|
||||
|
||||
self.mid_block = MidBlock3D(
|
||||
in_channels=block_out_channels[-1], temb_channels=0,
|
||||
num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
self.norm_out = nn.GroupNorm(groups, block_out_channels[-1], eps=1e-6)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, x, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
x, new_cache["conv_in"] = self.conv_in(x, conv_cache=conv_cache.get("conv_in"))
|
||||
|
||||
for i, block in enumerate(self.down_blocks):
|
||||
key = f"down_block_{i}"
|
||||
x, new_cache[key] = block(x, None, None, conv_cache.get(key))
|
||||
|
||||
x, new_cache["mid_block"] = self.mid_block(x, None, None, conv_cache=conv_cache.get("mid_block"))
|
||||
|
||||
x = self.norm_out(x)
|
||||
x = self.conv_act(x)
|
||||
x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out"))
|
||||
|
||||
return x, new_cache
|
||||
|
||||
|
||||
class Decoder3D(nn.Module):
|
||||
def __init__(self, in_channels=16, out_channels=3,
|
||||
block_out_channels=(128, 256, 256, 512),
|
||||
layers_per_block=3, act_fn="silu",
|
||||
eps=1e-6, groups=32, pad_mode="first",
|
||||
temporal_compression_ratio=4):
|
||||
super().__init__()
|
||||
reversed_channels = list(reversed(block_out_channels))
|
||||
temporal_compress_level = int(np.log2(temporal_compression_ratio))
|
||||
|
||||
self.conv_in = CausalConv3d(in_channels, reversed_channels[0], kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
self.mid_block = MidBlock3D(
|
||||
in_channels=reversed_channels[0], temb_channels=0,
|
||||
num_layers=2, eps=eps, act_fn=act_fn, groups=groups,
|
||||
spatial_norm_dim=in_channels, pad_mode=pad_mode,
|
||||
)
|
||||
|
||||
self.up_blocks = nn.ModuleList()
|
||||
output_channel = reversed_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
prev_channel = output_channel
|
||||
output_channel = reversed_channels[i]
|
||||
is_final = i == len(block_out_channels) - 1
|
||||
compress_time = i < temporal_compress_level
|
||||
|
||||
self.up_blocks.append(UpBlock3D(
|
||||
in_channels=prev_channel, out_channels=output_channel,
|
||||
temb_channels=0, num_layers=layers_per_block + 1,
|
||||
eps=eps, act_fn=act_fn, groups=groups,
|
||||
spatial_norm_dim=in_channels,
|
||||
add_upsample=not is_final, compress_time=compress_time,
|
||||
))
|
||||
|
||||
self.norm_out = SpatialNorm3D(reversed_channels[-1], in_channels, groups=groups)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(reversed_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, sample, conv_cache=None):
|
||||
new_cache = {}
|
||||
conv_cache = conv_cache or {}
|
||||
|
||||
x, new_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
|
||||
|
||||
x, new_cache["mid_block"] = self.mid_block(x, None, sample, conv_cache=conv_cache.get("mid_block"))
|
||||
|
||||
for i, block in enumerate(self.up_blocks):
|
||||
key = f"up_block_{i}"
|
||||
x, new_cache[key] = block(x, None, sample, conv_cache=conv_cache.get(key))
|
||||
|
||||
x, new_cache["norm_out"] = self.norm_out(x, sample, conv_cache=conv_cache.get("norm_out"))
|
||||
x = self.conv_act(x)
|
||||
x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out"))
|
||||
|
||||
return x, new_cache
|
||||
|
||||
|
||||
|
||||
class AutoencoderKLCogVideoX(nn.Module):
|
||||
"""CogVideoX VAE. Spatial tiling/slicing handled by ComfyUI's VAE wrapper.
|
||||
|
||||
Temporal frame batching with conv_cache is kept here since the causal 3D
|
||||
convolutions need state passed between temporal chunks.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3, out_channels=3,
|
||||
block_out_channels=(128, 256, 256, 512),
|
||||
latent_channels=16, layers_per_block=3,
|
||||
act_fn="silu", eps=1e-6, groups=32,
|
||||
temporal_compression_ratio=4,
|
||||
):
|
||||
super().__init__()
|
||||
self.latent_channels = latent_channels
|
||||
|
||||
self.encoder = Encoder3D(
|
||||
in_channels=in_channels, out_channels=latent_channels,
|
||||
block_out_channels=block_out_channels, layers_per_block=layers_per_block,
|
||||
act_fn=act_fn, eps=eps, groups=groups,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
)
|
||||
self.decoder = Decoder3D(
|
||||
in_channels=latent_channels, out_channels=out_channels,
|
||||
block_out_channels=block_out_channels, layers_per_block=layers_per_block,
|
||||
act_fn=act_fn, eps=eps, groups=groups,
|
||||
temporal_compression_ratio=temporal_compression_ratio,
|
||||
)
|
||||
|
||||
self.num_latent_frames_batch_size = 2
|
||||
self.num_sample_frames_batch_size = 8
|
||||
|
||||
def encode(self, x):
|
||||
t = x.shape[2]
|
||||
frame_batch = self.num_sample_frames_batch_size
|
||||
num_batches = max(t // frame_batch, 1)
|
||||
conv_cache = None
|
||||
enc = []
|
||||
for i in range(num_batches):
|
||||
remaining = t % frame_batch
|
||||
start = frame_batch * i + (0 if i == 0 else remaining)
|
||||
end = frame_batch * (i + 1) + remaining
|
||||
chunk, conv_cache = self.encoder(x[:, :, start:end], conv_cache=conv_cache)
|
||||
enc.append(chunk.to(x.device))
|
||||
enc = torch.cat(enc, dim=2)
|
||||
mean, _ = enc.chunk(2, dim=1)
|
||||
return mean
|
||||
|
||||
def decode(self, z):
|
||||
t = z.shape[2]
|
||||
frame_batch = self.num_latent_frames_batch_size
|
||||
num_batches = max(t // frame_batch, 1)
|
||||
conv_cache = None
|
||||
dec = []
|
||||
for i in range(num_batches):
|
||||
remaining = t % frame_batch
|
||||
start = frame_batch * i + (0 if i == 0 else remaining)
|
||||
end = frame_batch * (i + 1) + remaining
|
||||
chunk, conv_cache = self.decoder(z[:, :, start:end], conv_cache=conv_cache)
|
||||
dec.append(chunk.cpu())
|
||||
return torch.cat(dec, dim=2).to(z.device)
|
||||
@ -52,6 +52,7 @@ import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
import comfy.ldm.cogvideo.model
|
||||
import comfy.ldm.rt_detr.rtdetr_v4
|
||||
|
||||
import comfy.model_management
|
||||
@ -79,6 +80,7 @@ class ModelType(Enum):
|
||||
IMG_TO_IMG = 9
|
||||
FLOW_COSMOS = 10
|
||||
IMG_TO_IMG_FLOW = 11
|
||||
V_PREDICTION_DDPM = 12
|
||||
|
||||
|
||||
def model_sampling(model_config, model_type):
|
||||
@ -113,6 +115,8 @@ def model_sampling(model_config, model_type):
|
||||
s = comfy.model_sampling.ModelSamplingCosmosRFlow
|
||||
elif model_type == ModelType.IMG_TO_IMG_FLOW:
|
||||
c = comfy.model_sampling.IMG_TO_IMG_FLOW
|
||||
elif model_type == ModelType.V_PREDICTION_DDPM:
|
||||
c = comfy.model_sampling.V_PREDICTION_DDPM
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@ -1962,3 +1966,41 @@ class Kandinsky5Image(Kandinsky5):
|
||||
class RT_DETR_v4(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
|
||||
|
||||
class CogVideoX(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_DDPM, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cogvideo.model.CogVideoXTransformer3DModel)
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
# Detect extra channels needed (e.g. 32 - 16 = 16 for ref latent)
|
||||
extra_channels = self.diffusion_model.in_channels - noise.shape[1]
|
||||
if extra_channels == 0:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
shape = list(noise.shape)
|
||||
shape[1] = extra_channels
|
||||
return torch.zeros(shape, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
||||
|
||||
latent_dim = self.latent_format.latent_channels
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
for i in range(0, image.shape[1], latent_dim):
|
||||
image[:, i:i + latent_dim] = self.process_latent_in(image[:, i:i + latent_dim])
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
return image
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
# OFS embedding (CogVideoX 1.5 I2V), default 2.0 as used by SparkVSR
|
||||
if self.diffusion_model.ofs_proj_dim is not None:
|
||||
ofs = kwargs.get("ofs", None)
|
||||
if ofs is None:
|
||||
noise = kwargs.get("noise", None)
|
||||
ofs = torch.full((noise.shape[0],), 2.0, device=noise.device, dtype=noise.dtype)
|
||||
out['ofs'] = comfy.conds.CONDRegular(ofs)
|
||||
return out
|
||||
|
||||
@ -490,6 +490,55 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}blocks.0.norm1.linear.weight'.format(key_prefix) in state_dict_keys: # CogVideoX
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "cogvideox"
|
||||
|
||||
# Extract config from weight shapes
|
||||
norm1_weight = state_dict['{}blocks.0.norm1.linear.weight'.format(key_prefix)]
|
||||
time_embed_dim = norm1_weight.shape[1]
|
||||
dim = norm1_weight.shape[0] // 6
|
||||
|
||||
dit_config["num_attention_heads"] = dim // 64
|
||||
dit_config["attention_head_dim"] = 64
|
||||
dit_config["time_embed_dim"] = time_embed_dim
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||
|
||||
# Detect in_channels from patch_embed
|
||||
patch_proj_key = '{}patch_embed.proj.weight'.format(key_prefix)
|
||||
patch_proj_linear_key = '{}patch_embed.proj.weight'.format(key_prefix)
|
||||
if patch_proj_key in state_dict_keys:
|
||||
w = state_dict[patch_proj_key]
|
||||
if w.ndim == 4:
|
||||
# Conv2d: [out, in, kh, kw] — CogVideoX 1.0
|
||||
dit_config["in_channels"] = w.shape[1]
|
||||
dit_config["patch_size"] = w.shape[2]
|
||||
elif w.ndim == 2:
|
||||
# Linear: [out, in_channels * patch_size * patch_size * patch_size_t] — CogVideoX 1.5
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["patch_size_t"] = 2
|
||||
dit_config["in_channels"] = w.shape[1] // (2 * 2 * 2) # 256 // 8 = 32
|
||||
|
||||
text_proj_key = '{}patch_embed.text_proj.weight'.format(key_prefix)
|
||||
if text_proj_key in state_dict_keys:
|
||||
dit_config["text_embed_dim"] = state_dict[text_proj_key].shape[1]
|
||||
|
||||
# Detect OFS embedding
|
||||
ofs_key = '{}ofs_embedding_linear_1.weight'.format(key_prefix)
|
||||
if ofs_key in state_dict_keys:
|
||||
dit_config["ofs_embed_dim"] = state_dict[ofs_key].shape[1]
|
||||
|
||||
# Detect positional embedding type
|
||||
pos_key = '{}patch_embed.pos_embedding'.format(key_prefix)
|
||||
if pos_key in state_dict_keys:
|
||||
dit_config["use_learned_positional_embeddings"] = True
|
||||
dit_config["use_rotary_positional_embeddings"] = False
|
||||
else:
|
||||
dit_config["use_learned_positional_embeddings"] = False
|
||||
dit_config["use_rotary_positional_embeddings"] = True
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "wan2.1"
|
||||
|
||||
@ -54,6 +54,30 @@ class V_PREDICTION(EPS):
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
class V_PREDICTION_DDPM:
|
||||
"""CogVideoX v-prediction: model receives raw x_t (unscaled), predicts velocity v.
|
||||
x_0 = sqrt(alpha) * x_t - sqrt(1-alpha) * v
|
||||
= x_t / sqrt(sigma^2 + 1) - v * sigma / sqrt(sigma^2 + 1)
|
||||
"""
|
||||
def calculate_input(self, sigma, noise):
|
||||
return noise
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
return model_input / (sigma ** 2 + 1.0) ** 0.5 - model_output * sigma / (sigma ** 2 + 1.0) ** 0.5
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
sigma = reshape_sigma(sigma, noise.ndim)
|
||||
if max_denoise:
|
||||
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
|
||||
else:
|
||||
noise = noise * sigma
|
||||
noise += latent_image
|
||||
return noise
|
||||
|
||||
def inverse_noise_scaling(self, sigma, latent):
|
||||
return latent
|
||||
|
||||
class EDM(V_PREDICTION):
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = reshape_sigma(sigma, model_output.ndim)
|
||||
|
||||
12
comfy/sd.py
12
comfy/sd.py
@ -650,6 +650,18 @@ class VAE:
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||
elif "decoder.conv_in.conv.weight" in sd and "decoder.mid_block.resnets.0.norm1.norm_layer.weight" in sd: # CogVideoX VAE
|
||||
import comfy.ldm.cogvideo.vae
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = sd["encoder.conv_out.conv.weight"].shape[0] // 2
|
||||
self.first_stage_model = comfy.ldm.cogvideo.vae.AutoencoderKLCogVideoX(latent_channels=self.latent_channels)
|
||||
self.memory_used_decode = lambda shape, dtype: (2800 * max(2, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * max(1, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
elif "decoder.conv_in.conv.weight" in sd:
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
ddconfig["conv3d"] = True
|
||||
|
||||
@ -1749,6 +1749,55 @@ class RT_DETR_v4(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4]
|
||||
class CogVideoX_T2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "cogvideox",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"linear_start": 0.00085,
|
||||
"linear_end": 0.012,
|
||||
"beta_schedule": "linear",
|
||||
"zsnr": True,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.CogVideoX
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
# CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE
|
||||
if self.unet_config.get("patch_size_t") is not None:
|
||||
self.unet_config.setdefault("sample_height", 96)
|
||||
self.unet_config.setdefault("sample_width", 170)
|
||||
self.unet_config.setdefault("sample_frames", 81)
|
||||
out = model_base.CogVideoX(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
|
||||
class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226)
|
||||
|
||||
return supported_models_base.ClipTarget(CogVideoXT5Tokenizer, comfy.text_encoders.sd3_clip.T5XXLModel)
|
||||
|
||||
class CogVideoX_I2V(CogVideoX_T2V):
|
||||
unet_config = {
|
||||
"image_model": "cogvideox",
|
||||
"in_channels": 32,
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.CogVideoX(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, CogVideoX_I2V, CogVideoX_T2V]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
137
comfy_extras/nodes_cogvideox.py
Normal file
137
comfy_extras/nodes_cogvideox.py
Normal file
@ -0,0 +1,137 @@
|
||||
import nodes
|
||||
import node_helpers
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
from comfy_api.latest import io, ComfyExtension
|
||||
from typing_extensions import override
|
||||
|
||||
class SparkVSRConditioning(io.ComfyNode):
|
||||
"""Conditioning node for SparkVSR video super-resolution.
|
||||
|
||||
Encodes LQ video and optional HR reference frames through the VAE,
|
||||
builds the concat conditioning for the CogVideoX I2V model.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SparkVSRConditioning",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("lq_video"),
|
||||
io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=8),
|
||||
io.Int.Input("length", default=49, min=1, max=nodes.MAX_RESOLUTION, step=1),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=64),
|
||||
io.Image.Input("ref_frames", optional=True),
|
||||
io.Combo.Input("ref_mode", options=["auto", "manual"], default="auto", optional=True),
|
||||
io.String.Input("ref_indices", default="", optional=True),
|
||||
io.Float.Input("ref_guidance_scale", default=1.0, min=0.0, max=10.0, step=0.1, optional=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, lq_video, width, height, length,
|
||||
batch_size, ref_frames=None, ref_mode="auto", ref_indices="",
|
||||
ref_guidance_scale=1.0) -> io.NodeOutput:
|
||||
|
||||
temporal_compression = 4
|
||||
latent_t = ((length - 1) // temporal_compression) + 1
|
||||
latent_h = height // 8
|
||||
latent_w = width // 8
|
||||
|
||||
# Base latent (noise will be added by KSampler)
|
||||
latent = torch.zeros(
|
||||
[batch_size, 16, latent_t, latent_h, latent_w],
|
||||
device=comfy.model_management.intermediate_device()
|
||||
)
|
||||
|
||||
# Encode LQ video → this becomes the base latent (KSampler adds noise to this)
|
||||
lq = lq_video[:length]
|
||||
lq = comfy.utils.common_upscale(lq.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
lq_latent = vae.encode(lq[:, :, :, :3])
|
||||
|
||||
# Ensure temporal dim matches
|
||||
if lq_latent.shape[2] > latent_t:
|
||||
lq_latent = lq_latent[:, :, :latent_t]
|
||||
elif lq_latent.shape[2] < latent_t:
|
||||
pad = latent_t - lq_latent.shape[2]
|
||||
lq_latent = torch.cat([lq_latent, lq_latent[:, :, -1:].repeat(1, 1, pad, 1, 1)], dim=2)
|
||||
|
||||
# Build reference latent (16ch) — goes as concat_latent_image
|
||||
# concat_cond in model_base will concatenate this with the noise (16ch) → 32ch total
|
||||
ref_latent = torch.zeros_like(lq_latent)
|
||||
|
||||
if ref_frames is not None:
|
||||
num_video_frames = lq_video.shape[0]
|
||||
|
||||
# Determine reference indices
|
||||
if ref_mode == "manual" and ref_indices.strip():
|
||||
indices = [int(x.strip()) for x in ref_indices.split(",") if x.strip()]
|
||||
else:
|
||||
indices = _select_indices(num_video_frames)
|
||||
|
||||
# Encode each reference frame and place at its temporal position.
|
||||
# SparkVSR places refs at specific latent indices, rest stays zeros.
|
||||
for ref_idx in indices:
|
||||
if ref_idx >= ref_frames.shape[0]:
|
||||
continue
|
||||
|
||||
frame = ref_frames[ref_idx:ref_idx + 1]
|
||||
frame = comfy.utils.common_upscale(frame.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
frame_latent = vae.encode(frame[:, :, :, :3])
|
||||
|
||||
target_lat_idx = ref_idx // temporal_compression
|
||||
if target_lat_idx < latent_t:
|
||||
ref_latent[:, :, target_lat_idx] = frame_latent[:, :, 0]
|
||||
|
||||
# Set ref latent as concat conditioning (16ch, model_base.concat_cond adds it to noise)
|
||||
if ref_guidance_scale != 1.0 and ref_frames is not None:
|
||||
# CFG: positive gets real refs, negative gets zero refs
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": ref_latent})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": torch.zeros_like(ref_latent)})
|
||||
else:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": ref_latent})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": ref_latent})
|
||||
|
||||
# LQ latent is the base — KSampler will noise it and denoise
|
||||
out_latent = {"samples": lq_latent}
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
def _select_indices(num_frames, max_refs=None):
|
||||
"""Auto-select reference frame indices (first, evenly spaced, last)."""
|
||||
if max_refs is None:
|
||||
max_refs = (num_frames - 1) // 4 + 1
|
||||
max_refs = min(max_refs, 3)
|
||||
|
||||
if num_frames <= 1:
|
||||
return [0]
|
||||
if max_refs == 1:
|
||||
return [0]
|
||||
if max_refs == 2:
|
||||
return [0, num_frames - 1]
|
||||
|
||||
mid = num_frames // 2
|
||||
return [0, mid, num_frames - 1]
|
||||
|
||||
|
||||
class CogVideoXExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
SparkVSRConditioning,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> CogVideoXExtension:
|
||||
return CogVideoXExtension()
|
||||
144
convert_sparkvsr_to_comfy.py
Normal file
144
convert_sparkvsr_to_comfy.py
Normal file
@ -0,0 +1,144 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Convert SparkVSR/CogVideoX diffusers checkpoint to ComfyUI format.
|
||||
|
||||
Usage:
|
||||
python convert_sparkvsr_to_comfy.py --model_dir path/to/sparkvsr-checkpoint \
|
||||
--output_dir ComfyUI/models/
|
||||
|
||||
This creates two files:
|
||||
- diffusion_models/cogvideox_sparkvsr.safetensors (transformer)
|
||||
- vae/cogvideox_vae.safetensors (VAE)
|
||||
|
||||
T5-XXL text encoder does not need conversion — use existing ComfyUI T5 weights.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
def remap_transformer_keys(state_dict):
|
||||
"""Remap diffusers transformer keys to ComfyUI CogVideoX naming."""
|
||||
new_sd = {}
|
||||
for k, v in state_dict.items():
|
||||
new_key = k
|
||||
|
||||
# Patch embedding
|
||||
new_key = new_key.replace("patch_embed.proj.", "patch_embed.proj.")
|
||||
new_key = new_key.replace("patch_embed.text_proj.", "patch_embed.text_proj.")
|
||||
new_key = new_key.replace("patch_embed.pos_embedding", "patch_embed.pos_embedding")
|
||||
|
||||
# Time embedding: diffusers uses time_embedding.linear_1/2, we use time_embedding_linear_1/2
|
||||
new_key = new_key.replace("time_embedding.linear_1.", "time_embedding_linear_1.")
|
||||
new_key = new_key.replace("time_embedding.linear_2.", "time_embedding_linear_2.")
|
||||
|
||||
# OFS embedding
|
||||
new_key = new_key.replace("ofs_embedding.linear_1.", "ofs_embedding_linear_1.")
|
||||
new_key = new_key.replace("ofs_embedding.linear_2.", "ofs_embedding_linear_2.")
|
||||
|
||||
# Transformer blocks: diffusers uses transformer_blocks, we use blocks
|
||||
new_key = new_key.replace("transformer_blocks.", "blocks.")
|
||||
|
||||
# Attention: diffusers uses attn1.to_q/k/v/out, we use q/k/v/attn_out
|
||||
new_key = new_key.replace(".attn1.to_q.", ".q.")
|
||||
new_key = new_key.replace(".attn1.to_k.", ".k.")
|
||||
new_key = new_key.replace(".attn1.to_v.", ".v.")
|
||||
new_key = new_key.replace(".attn1.to_out.0.", ".attn_out.")
|
||||
new_key = new_key.replace(".attn1.norm_q.", ".norm_q.")
|
||||
new_key = new_key.replace(".attn1.norm_k.", ".norm_k.")
|
||||
|
||||
# Feed-forward: diffusers uses ff.net.0.proj/ff.net.2, we use ff_proj/ff_out
|
||||
new_key = new_key.replace(".ff.net.0.proj.", ".ff_proj.")
|
||||
new_key = new_key.replace(".ff.net.2.", ".ff_out.")
|
||||
|
||||
# Output norms
|
||||
new_key = new_key.replace("norm_final.", "norm_final.")
|
||||
new_key = new_key.replace("norm_out.linear.", "norm_out.linear.")
|
||||
new_key = new_key.replace("norm_out.norm.", "norm_out.norm.")
|
||||
|
||||
new_sd[new_key] = v
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
def remap_vae_keys(state_dict):
|
||||
"""Remap diffusers VAE keys to ComfyUI CogVideoX naming.
|
||||
|
||||
The VAE architecture is directly ported so most keys should match.
|
||||
Main differences are in block naming.
|
||||
"""
|
||||
new_sd = {}
|
||||
for k, v in state_dict.items():
|
||||
new_key = k
|
||||
|
||||
# Encoder blocks
|
||||
new_key = new_key.replace("encoder.down_blocks.", "encoder.down_blocks.")
|
||||
new_key = new_key.replace("encoder.mid_block.", "encoder.mid_block.")
|
||||
|
||||
# Decoder blocks
|
||||
new_key = new_key.replace("decoder.up_blocks.", "decoder.up_blocks.")
|
||||
new_key = new_key.replace("decoder.mid_block.", "decoder.mid_block.")
|
||||
|
||||
# Resnet blocks within down/up/mid
|
||||
new_key = new_key.replace(".resnets.", ".resnets.")
|
||||
|
||||
# CausalConv3d: diffusers stores as .conv.weight inside CausalConv3d
|
||||
# Our CausalConv3d also has .conv.weight, so this should match
|
||||
|
||||
# Downsamplers/Upsamplers
|
||||
new_key = new_key.replace(".downsamplers.0.", ".downsamplers.0.")
|
||||
new_key = new_key.replace(".upsamplers.0.", ".upsamplers.0.")
|
||||
|
||||
new_sd[new_key] = v
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Convert SparkVSR/CogVideoX to ComfyUI format")
|
||||
parser.add_argument("--model_dir", type=str, required=True,
|
||||
help="Path to diffusers pipeline directory (contains transformer/, vae/, etc.)")
|
||||
parser.add_argument("--output_dir", type=str, default=".",
|
||||
help="Output base directory (will create diffusion_models/ and vae/ subdirs)")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load transformer
|
||||
transformer_dir = os.path.join(args.model_dir, "transformer")
|
||||
print(f"Loading transformer from {transformer_dir}...")
|
||||
transformer_sd = {}
|
||||
for f in sorted(os.listdir(transformer_dir)):
|
||||
if f.endswith(".safetensors"):
|
||||
sd = load_file(os.path.join(transformer_dir, f))
|
||||
transformer_sd.update(sd)
|
||||
|
||||
transformer_sd = remap_transformer_keys(transformer_sd)
|
||||
|
||||
out_dir = os.path.join(args.output_dir, "diffusion_models")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
out_path = os.path.join(out_dir, "cogvideox_sparkvsr.safetensors")
|
||||
print(f"Saving transformer to {out_path} ({len(transformer_sd)} keys)")
|
||||
save_file(transformer_sd, out_path)
|
||||
|
||||
# Load VAE
|
||||
vae_dir = os.path.join(args.model_dir, "vae")
|
||||
print(f"Loading VAE from {vae_dir}...")
|
||||
vae_sd = {}
|
||||
for f in sorted(os.listdir(vae_dir)):
|
||||
if f.endswith(".safetensors"):
|
||||
sd = load_file(os.path.join(vae_dir, f))
|
||||
vae_sd.update(sd)
|
||||
|
||||
vae_sd = remap_vae_keys(vae_sd)
|
||||
|
||||
out_dir = os.path.join(args.output_dir, "vae")
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
out_path = os.path.join(out_dir, "cogvideox_vae.safetensors")
|
||||
print(f"Saving VAE to {out_path} ({len(vae_sd)} keys)")
|
||||
save_file(vae_sd, out_path)
|
||||
|
||||
print("Done! T5-XXL text encoder does not need conversion.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user