From 3e4d15a4f70f5de24f426d9681d8505d42bf19d7 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 8 Apr 2026 12:04:51 +0300 Subject: [PATCH] Initial CogVideoX and SparkVSR support --- comfy/latent_formats.py | 7 + comfy/ldm/cogvideo/__init__.py | 0 comfy/ldm/cogvideo/model.py | 571 +++++++++++++++++++++++++++++++ comfy/ldm/cogvideo/vae.py | 570 ++++++++++++++++++++++++++++++ comfy/ldm/cogvideo/vae_backup.py | 485 ++++++++++++++++++++++++++ comfy/model_base.py | 42 +++ comfy/model_detection.py | 49 +++ comfy/model_sampling.py | 24 ++ comfy/sd.py | 12 + comfy/supported_models.py | 51 ++- comfy_extras/nodes_cogvideox.py | 137 ++++++++ convert_sparkvsr_to_comfy.py | 144 ++++++++ nodes.py | 3 +- 13 files changed, 2093 insertions(+), 2 deletions(-) create mode 100644 comfy/ldm/cogvideo/__init__.py create mode 100644 comfy/ldm/cogvideo/model.py create mode 100644 comfy/ldm/cogvideo/vae.py create mode 100644 comfy/ldm/cogvideo/vae_backup.py create mode 100644 comfy_extras/nodes_cogvideox.py create mode 100644 convert_sparkvsr_to_comfy.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 6a57bca1c..0f4059ebe 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -783,3 +783,10 @@ class ZImagePixelSpace(ChromaRadiance): No VAE encoding/decoding — the model operates directly on RGB pixels. """ pass + +class CogVideoX(LatentFormat): + latent_channels = 16 + latent_dimensions = 3 + + def __init__(self): + self.scale_factor = 1.15258426 diff --git a/comfy/ldm/cogvideo/__init__.py b/comfy/ldm/cogvideo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/ldm/cogvideo/model.py b/comfy/ldm/cogvideo/model.py new file mode 100644 index 000000000..a4e737d41 --- /dev/null +++ b/comfy/ldm/cogvideo/model.py @@ -0,0 +1,571 @@ +# CogVideoX 3D Transformer - ported to ComfyUI native ops +# Architecture reference: diffusers CogVideoXTransformer3DModel +# Style reference: comfy/ldm/wan/model.py + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention +import comfy.patcher_extension +import comfy.ldm.common_dit + + +def _get_1d_rotary_pos_embed(dim, pos, theta=10000.0): + """Returns (cos, sin) each with shape [seq_len, dim]. + + Frequencies are computed at dim//2 resolution then repeat_interleaved + to full dim, matching CogVideoX's interleaved (real, imag) pair format. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim)) + angles = torch.outer(pos.float(), freqs.float()) + cos = angles.cos().repeat_interleave(2, dim=-1).float() + sin = angles.sin().repeat_interleave(2, dim=-1).float() + return (cos, sin) + + +def apply_rotary_emb(x, freqs_cos_sin): + """Apply CogVideoX rotary embedding to query or key tensor. + + x: [B, heads, seq_len, head_dim] + freqs_cos_sin: (cos, sin) each [seq_len, head_dim//2] + + Uses interleaved pair rotation (same as diffusers CogVideoX/Flux). + head_dim is reshaped to (-1, 2) pairs, rotated, then flattened back. + """ + cos, sin = freqs_cos_sin + cos = cos[None, None, :, :].to(x.device) + sin = sin[None, None, :, :].to(x.device) + + # Interleaved pairs: [B, H, S, D] -> [B, H, S, D//2, 2] -> (real, imag) + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + return (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + +def get_timestep_embedding(timesteps, dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1, max_period=10000): + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half) + args = timesteps[:, None].float() * freqs[None] * scale + embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) + if flip_sin_to_cos: + embedding = torch.cat([embedding[:, half:], embedding[:, :half]], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def get_3d_sincos_pos_embed(embed_dim, spatial_size, temporal_size, spatial_interpolation_scale=1.0, temporal_interpolation_scale=1.0, device=None): + if isinstance(spatial_size, int): + spatial_size = (spatial_size, spatial_size) + + grid_w = torch.arange(spatial_size[0], dtype=torch.float32, device=device) / spatial_interpolation_scale + grid_h = torch.arange(spatial_size[1], dtype=torch.float32, device=device) / spatial_interpolation_scale + grid_t = torch.arange(temporal_size, dtype=torch.float32, device=device) / temporal_interpolation_scale + + grid_t, grid_h, grid_w = torch.meshgrid(grid_t, grid_h, grid_w, indexing="ij") + + embed_dim_spatial = 2 * (embed_dim // 3) + embed_dim_temporal = embed_dim // 3 + + pos_embed_spatial = _get_2d_sincos_pos_embed(embed_dim_spatial, grid_h, grid_w, device=device) + pos_embed_temporal = _get_1d_sincos_pos_embed(embed_dim_temporal, grid_t[:, 0, 0], device=device) + + T, H, W = grid_t.shape + pos_embed_temporal = pos_embed_temporal.unsqueeze(1).unsqueeze(1).expand(-1, H, W, -1) + pos_embed = torch.cat([pos_embed_temporal, pos_embed_spatial], dim=-1) + + return pos_embed + + +def _get_2d_sincos_pos_embed(embed_dim, grid_h, grid_w, device=None): + T, H, W = grid_h.shape + half_dim = embed_dim // 2 + pos_h = _get_1d_sincos_pos_embed(half_dim, grid_h.reshape(-1), device=device).reshape(T, H, W, half_dim) + pos_w = _get_1d_sincos_pos_embed(half_dim, grid_w.reshape(-1), device=device).reshape(T, H, W, half_dim) + return torch.cat([pos_h, pos_w], dim=-1) + + +def _get_1d_sincos_pos_embed(embed_dim, pos, device=None): + half = embed_dim // 2 + freqs = torch.exp(-math.log(10000.0) * torch.arange(start=0, end=half, dtype=torch.float32, device=device) / half) + args = pos.float().reshape(-1)[:, None] * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if embed_dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + + +class CogVideoXPatchEmbed(nn.Module): + def __init__(self, patch_size=2, patch_size_t=None, in_channels=16, dim=1920, + text_dim=4096, bias=True, sample_width=90, sample_height=60, + sample_frames=49, temporal_compression_ratio=4, + max_text_seq_length=226, spatial_interpolation_scale=1.875, + temporal_interpolation_scale=1.0, use_positional_embeddings=True, + use_learned_positional_embeddings=True, + device=None, dtype=None, operations=None): + super().__init__() + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.dim = dim + self.sample_height = sample_height + self.sample_width = sample_width + self.sample_frames = sample_frames + self.temporal_compression_ratio = temporal_compression_ratio + self.max_text_seq_length = max_text_seq_length + self.spatial_interpolation_scale = spatial_interpolation_scale + self.temporal_interpolation_scale = temporal_interpolation_scale + self.use_positional_embeddings = use_positional_embeddings + self.use_learned_positional_embeddings = use_learned_positional_embeddings + + if patch_size_t is None: + self.proj = operations.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size, bias=bias, device=device, dtype=dtype) + else: + self.proj = operations.Linear(in_channels * patch_size * patch_size * patch_size_t, dim, device=device, dtype=dtype) + + self.text_proj = operations.Linear(text_dim, dim, device=device, dtype=dtype) + + if use_positional_embeddings or use_learned_positional_embeddings: + persistent = use_learned_positional_embeddings + pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) + self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) + + def _get_positional_embeddings(self, sample_height, sample_width, sample_frames, device=None): + post_patch_height = sample_height // self.patch_size + post_patch_width = sample_width // self.patch_size + post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 + if self.patch_size_t is not None: + post_time_compression_frames = post_time_compression_frames // self.patch_size_t + num_patches = post_patch_height * post_patch_width * post_time_compression_frames + + pos_embedding = get_3d_sincos_pos_embed( + self.dim, + (post_patch_width, post_patch_height), + post_time_compression_frames, + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + device=device, + ) + pos_embedding = pos_embedding.reshape(-1, self.dim) + joint_pos_embedding = pos_embedding.new_zeros( + 1, self.max_text_seq_length + num_patches, self.dim, requires_grad=False + ) + joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding) + return joint_pos_embedding + + def forward(self, text_embeds, image_embeds): + text_embeds = self.text_proj(text_embeds) + batch_size, num_frames, channels, height, width = image_embeds.shape + + if self.patch_size_t is None: + image_embeds = image_embeds.reshape(-1, channels, height, width) + image_embeds = self.proj(image_embeds) + image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:]) + image_embeds = image_embeds.flatten(3).transpose(2, 3) + image_embeds = image_embeds.flatten(1, 2) + else: + p = self.patch_size + p_t = self.patch_size_t + image_embeds = image_embeds.permute(0, 1, 3, 4, 2) + image_embeds = image_embeds.reshape( + batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels + ) + image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) + image_embeds = self.proj(image_embeds) + + embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous() + + if self.use_positional_embeddings or self.use_learned_positional_embeddings: + text_seq_length = text_embeds.shape[1] + num_image_patches = image_embeds.shape[1] + + # Compute sincos pos embedding for image patches + pos_embedding = get_3d_sincos_pos_embed( + self.dim, + (width // self.patch_size, height // self.patch_size), + num_image_patches // ((height // self.patch_size) * (width // self.patch_size)), + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + device=embeds.device, + ).reshape(-1, self.dim) + + # Build joint: zeros for text + sincos for image + joint_pos = torch.zeros(1, text_seq_length + num_image_patches, self.dim, device=embeds.device, dtype=embeds.dtype) + joint_pos[:, text_seq_length:] = pos_embedding.to(dtype=embeds.dtype) + embeds = embeds + joint_pos + + return embeds + + +class CogVideoXLayerNormZero(nn.Module): + def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5, bias=True, + device=None, dtype=None, operations=None): + super().__init__() + self.silu = nn.SiLU() + self.linear = operations.Linear(time_dim, 6 * dim, bias=bias, device=device, dtype=dtype) + self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) + + def forward(self, hidden_states, encoder_hidden_states, temb): + shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] + return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] + + +class CogVideoXAdaLayerNorm(nn.Module): + def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5, + device=None, dtype=None, operations=None): + super().__init__() + self.silu = nn.SiLU() + self.linear = operations.Linear(time_dim, 2 * dim, device=device, dtype=dtype) + self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) + + def forward(self, x, temb): + temb = self.linear(self.silu(temb)) + shift, scale = temb.chunk(2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +class CogVideoXBlock(nn.Module): + def __init__(self, dim, num_heads, head_dim, time_dim, + eps=1e-5, ff_inner_dim=None, ff_bias=True, + device=None, dtype=None, operations=None): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.norm1 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations) + + # Self-attention (joint text + latent) + self.q = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) + self.k = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) + self.v = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) + self.norm_q = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype) + self.norm_k = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype) + self.attn_out = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) + + self.norm2 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations) + + # Feed-forward (GELU approximate) + inner_dim = ff_inner_dim or dim * 4 + self.ff_proj = operations.Linear(dim, inner_dim, bias=ff_bias, device=device, dtype=dtype) + self.ff_out = operations.Linear(inner_dim, dim, bias=ff_bias, device=device, dtype=dtype) + + def forward(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=None, transformer_options={}): + text_seq_length = encoder_hidden_states.size(1) + + # Norm & modulate + norm_hidden, norm_encoder, gate_msa, enc_gate_msa = self.norm1(hidden_states, encoder_hidden_states, temb) + + # Joint self-attention + qkv_input = torch.cat([norm_encoder, norm_hidden], dim=1) + b, s, _ = qkv_input.shape + n, d = self.num_heads, self.head_dim + + q = self.q(qkv_input).view(b, s, n, d) + k = self.k(qkv_input).view(b, s, n, d) + v = self.v(qkv_input) + + q = self.norm_q(q).view(b, s, n, d) + k = self.norm_k(k).view(b, s, n, d) + + # Apply rotary embeddings to image tokens only (diffusers format: [B, heads, seq, head_dim]) + if image_rotary_emb is not None: + q_img = q[:, text_seq_length:].transpose(1, 2) # [B, heads, img_seq, head_dim] + k_img = k[:, text_seq_length:].transpose(1, 2) + q_img = apply_rotary_emb(q_img, image_rotary_emb) + k_img = apply_rotary_emb(k_img, image_rotary_emb) + q = torch.cat([q[:, :text_seq_length], q_img.transpose(1, 2)], dim=1) + k = torch.cat([k[:, :text_seq_length], k_img.transpose(1, 2)], dim=1) + + attn_out = optimized_attention( + q.reshape(b, s, n * d), + k.reshape(b, s, n * d), + v, + heads=self.num_heads, + transformer_options=transformer_options, + ) + + attn_out = self.attn_out(attn_out) + + attn_encoder, attn_hidden = attn_out.split([text_seq_length, s - text_seq_length], dim=1) + + hidden_states = hidden_states + gate_msa * attn_hidden + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder + + # Norm & modulate for FF + norm_hidden, norm_encoder, gate_ff, enc_gate_ff = self.norm2(hidden_states, encoder_hidden_states, temb) + + # Feed-forward (GELU on concatenated text + latent) + ff_input = torch.cat([norm_encoder, norm_hidden], dim=1) + ff_output = self.ff_out(F.gelu(self.ff_proj(ff_input), approximate="tanh")) + + hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] + + return hidden_states, encoder_hidden_states + + +class CogVideoXTransformer3DModel(nn.Module): + def __init__(self, + num_attention_heads=30, + attention_head_dim=64, + in_channels=16, + out_channels=16, + flip_sin_to_cos=True, + freq_shift=0, + time_embed_dim=512, + ofs_embed_dim=None, + text_embed_dim=4096, + num_layers=30, + dropout=0.0, + attention_bias=True, + sample_width=90, + sample_height=60, + sample_frames=49, + patch_size=2, + patch_size_t=None, + temporal_compression_ratio=4, + max_text_seq_length=226, + spatial_interpolation_scale=1.875, + temporal_interpolation_scale=1.0, + use_rotary_positional_embeddings=False, + use_learned_positional_embeddings=False, + patch_bias=True, + image_model=None, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + dim = num_attention_heads * attention_head_dim + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.max_text_seq_length = max_text_seq_length + self.use_rotary_positional_embeddings = use_rotary_positional_embeddings + + # 1. Patch embedding + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + patch_size_t=patch_size_t, + in_channels=in_channels, + dim=dim, + text_dim=text_embed_dim, + bias=patch_bias, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings, + device=device, dtype=torch.float32, operations=operations, + ) + + # 2. Time embedding + self.time_proj_dim = dim + self.time_proj_flip = flip_sin_to_cos + self.time_proj_shift = freq_shift + self.time_embedding_linear_1 = operations.Linear(dim, time_embed_dim, device=device, dtype=dtype) + self.time_embedding_act = nn.SiLU() + self.time_embedding_linear_2 = operations.Linear(time_embed_dim, time_embed_dim, device=device, dtype=dtype) + + # Optional OFS embedding (CogVideoX 1.5 I2V) + self.ofs_proj_dim = ofs_embed_dim + if ofs_embed_dim: + self.ofs_embedding_linear_1 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype) + self.ofs_embedding_act = nn.SiLU() + self.ofs_embedding_linear_2 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype) + else: + self.ofs_embedding_linear_1 = None + + # 3. Transformer blocks + self.blocks = nn.ModuleList([ + CogVideoXBlock( + dim=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + time_dim=time_embed_dim, + eps=1e-5, + device=device, dtype=dtype, operations=operations, + ) + for _ in range(num_layers) + ]) + + self.norm_final = operations.LayerNorm(dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype) + + # 4. Output + self.norm_out = CogVideoXAdaLayerNorm( + time_dim=time_embed_dim, dim=dim, eps=1e-5, + device=device, dtype=dtype, operations=operations, + ) + + if patch_size_t is None: + output_dim = patch_size * patch_size * out_channels + else: + output_dim = patch_size * patch_size * patch_size_t * out_channels + + self.proj_out = operations.Linear(dim, output_dim, device=device, dtype=dtype) + + self.spatial_interpolation_scale = spatial_interpolation_scale + self.temporal_interpolation_scale = temporal_interpolation_scale + self.temporal_compression_ratio = temporal_compression_ratio + + def forward(self, x, timestep, context, ofs=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, ofs, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, ofs=None, transformer_options={}, **kwargs): + import logging + logger = logging.getLogger(__name__) + if x.shape[1] > 16: + lq_part = x[:, :16] + ref_part = x[:, 16:] + logger.warning(f"[CogVideoX] x: {x.shape}, t: {timestep.item():.0f}, ofs: {ofs}, LQ: mean={lq_part.float().mean():.4f} std={lq_part.float().std():.4f}, REF: mean={ref_part.float().mean():.4f} std={ref_part.float().std():.4f} nonzero={ref_part.count_nonzero().item()}") + else: + logger.warning(f"[CogVideoX] x: {x.shape}, t: {timestep.item():.0f}, ofs: {ofs}") + + # ComfyUI passes [B, C, T, H, W] + batch_size, channels, t, h, w = x.shape + + # Pad to patch size (temporal + spatial), same pattern as WAN + p_t = self.patch_size_t if self.patch_size_t is not None else 1 + x = comfy.ldm.common_dit.pad_to_patch_size(x, (p_t, self.patch_size, self.patch_size)) + + # CogVideoX expects [B, T, C, H, W] + x = x.permute(0, 2, 1, 3, 4) + batch_size, num_frames, channels, height, width = x.shape + + # Time embedding + t_emb = get_timestep_embedding(timestep, self.time_proj_dim, self.time_proj_flip, self.time_proj_shift) + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embedding_linear_2(self.time_embedding_act(self.time_embedding_linear_1(t_emb))) + + if self.ofs_embedding_linear_1 is not None and ofs is not None: + ofs_emb = get_timestep_embedding(ofs, self.ofs_proj_dim, self.time_proj_flip, self.time_proj_shift) + ofs_emb = ofs_emb.to(dtype=x.dtype) + ofs_emb = self.ofs_embedding_linear_2(self.ofs_embedding_act(self.ofs_embedding_linear_1(ofs_emb))) + emb = emb + ofs_emb + + # Patch embedding + hidden_states = self.patch_embed(context, x) + + text_seq_length = context.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + # Rotary embeddings (if used) + image_rotary_emb = None + if self.use_rotary_positional_embeddings: + post_patch_height = height // self.patch_size + post_patch_width = width // self.patch_size + if self.patch_size_t is None: + post_time = num_frames + else: + post_time = num_frames // self.patch_size_t + image_rotary_emb = self._get_rotary_emb(post_patch_height, post_patch_width, post_time, device=x.device) + + # Transformer blocks + for i, block in enumerate(self.blocks): + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, + ) + + hidden_states = self.norm_final(hidden_states) + + # Output projection + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # Unpatchify + p = self.patch_size + p_t = self.patch_size_t + + if p_t is None: + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + else: + output = hidden_states.reshape( + batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p + ) + output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) + + # Back to ComfyUI format [B, C, T, H, W] and crop padding + output = output.permute(0, 2, 1, 3, 4)[:, :, :t, :h, :w] + logger.warning(f"[CogVideoX] output: {output.shape}, mean={output.float().mean():.4f}, std={output.float().std():.4f}, min={output.float().min():.4f}, max={output.float().max():.4f}") + return output + + def _get_rotary_emb(self, h, w, t, device): + """Compute CogVideoX 3D rotary positional embeddings. + + For CogVideoX 1.5 (patch_size_t != None): uses "slice" mode — grid positions + are integer arange computed at max_size, then sliced to actual size. + For CogVideoX 1.0 (patch_size_t == None): uses "linspace" mode with crop coords + scaled by spatial_interpolation_scale. + """ + d = self.attention_head_dim + dim_t = d // 4 + dim_h = d // 8 * 3 + dim_w = d // 8 * 3 + + if self.patch_size_t is not None: + # CogVideoX 1.5: "slice" mode — positions are simple integer indices + # Compute at max(sample_size, actual_size) then slice to actual + base_h = self.patch_embed.sample_height // self.patch_size + base_w = self.patch_embed.sample_width // self.patch_size + max_h = max(base_h, h) + max_w = max(base_w, w) + + grid_h = torch.arange(max_h, device=device, dtype=torch.float32) + grid_w = torch.arange(max_w, device=device, dtype=torch.float32) + grid_t = torch.arange(t, device=device, dtype=torch.float32) + else: + # CogVideoX 1.0: "linspace" mode with interpolation scale + grid_h = torch.linspace(0, h - 1, h, device=device, dtype=torch.float32) * self.spatial_interpolation_scale + grid_w = torch.linspace(0, w - 1, w, device=device, dtype=torch.float32) * self.spatial_interpolation_scale + grid_t = torch.arange(t, device=device, dtype=torch.float32) + + freqs_t = _get_1d_rotary_pos_embed(dim_t, grid_t) + freqs_h = _get_1d_rotary_pos_embed(dim_h, grid_h) + freqs_w = _get_1d_rotary_pos_embed(dim_w, grid_w) + + t_cos, t_sin = freqs_t + h_cos, h_sin = freqs_h + w_cos, w_sin = freqs_w + + # Slice to actual size (for "slice" mode where grids may be larger) + t_cos, t_sin = t_cos[:t], t_sin[:t] + h_cos, h_sin = h_cos[:h], h_sin[:h] + w_cos, w_sin = w_cos[:w], w_sin[:w] + + # Broadcast and concatenate into [T*H*W, head_dim] + t_cos = t_cos[:, None, None, :].expand(-1, h, w, -1) + t_sin = t_sin[:, None, None, :].expand(-1, h, w, -1) + h_cos = h_cos[None, :, None, :].expand(t, -1, w, -1) + h_sin = h_sin[None, :, None, :].expand(t, -1, w, -1) + w_cos = w_cos[None, None, :, :].expand(t, h, -1, -1) + w_sin = w_sin[None, None, :, :].expand(t, h, -1, -1) + + cos = torch.cat([t_cos, h_cos, w_cos], dim=-1).reshape(t * h * w, -1) + sin = torch.cat([t_sin, h_sin, w_sin], dim=-1).reshape(t * h * w, -1) + return (cos, sin) diff --git a/comfy/ldm/cogvideo/vae.py b/comfy/ldm/cogvideo/vae.py new file mode 100644 index 000000000..fe97f362c --- /dev/null +++ b/comfy/ldm/cogvideo/vae.py @@ -0,0 +1,570 @@ +# CogVideoX VAE - ported to ComfyUI native ops +# Architecture reference: diffusers AutoencoderKLCogVideoX +# Style reference: comfy/ldm/wan/vae.py + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops +ops = comfy.ops.disable_weight_init + + +class CausalConv3d(nn.Module): + """Causal 3D convolution with temporal padding. + + Uses comfy.ops.Conv3d with autopad='causal_zero' fast path: when input has + a single temporal frame and no cache, the 3D conv weight is sliced to act + as a 2D conv, avoiding computation on zero-padded temporal dimensions. + """ + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, pad_mode="constant"): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + + time_kernel, height_kernel, width_kernel = kernel_size + self.time_kernel_size = time_kernel + self.pad_mode = pad_mode + + height_pad = (height_kernel - 1) // 2 + width_pad = (width_kernel - 1) // 2 + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_kernel - 1, 0) + + stride = stride if isinstance(stride, tuple) else (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = ops.Conv3d( + in_channels, out_channels, kernel_size, + stride=stride, dilation=dilation, + padding=(0, height_pad, width_pad), + ) + + def forward(self, x, conv_cache=None): + if self.pad_mode == "replicate": + x = F.pad(x, self.time_causal_padding, mode="replicate") + conv_cache = None + else: + kernel_t = self.time_kernel_size + if kernel_t > 1: + if conv_cache is None and x.shape[2] == 1: + # Fast path: single frame, no cache. All temporal padding + # frames are copies of the input (replicate-style), so the + # 3D conv reduces to a 2D conv with summed temporal kernel. + w = comfy.ops.cast_to_input(self.conv.weight, x) + b = comfy.ops.cast_to_input(self.conv.bias, x) if self.conv.bias is not None else None + w2d = w.sum(dim=2, keepdim=True) + out = F.conv3d(x, w2d, b, + self.conv.stride, self.conv.padding, + self.conv.dilation, self.conv.groups) + return out, None + cached = [conv_cache] if conv_cache is not None else [x[:, :, :1]] * (kernel_t - 1) + x = torch.cat(cached + [x], dim=2) + conv_cache = x[:, :, -self.time_kernel_size + 1:].clone() if self.time_kernel_size > 1 else None + + out = self.conv(x) + return out, conv_cache + + +def _interpolate_zq(zq, target_size): + """Interpolate latent z to target (T, H, W), matching CogVideoX's first-frame-special handling.""" + t = target_size[0] + if t > 1 and t % 2 == 1: + z_first = F.interpolate(zq[:, :, :1], size=(1, target_size[1], target_size[2])) + z_rest = F.interpolate(zq[:, :, 1:], size=(t - 1, target_size[1], target_size[2])) + return torch.cat([z_first, z_rest], dim=2) + return F.interpolate(zq, size=target_size) + + +class SpatialNorm3D(nn.Module): + """Spatially conditioned normalization.""" + def __init__(self, f_channels, zq_channels, groups=32): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) + self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + + def forward(self, f, zq, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + + if zq.shape[-3:] != f.shape[-3:]: + zq = _interpolate_zq(zq, f.shape[-3:]) + + conv_y, new_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y")) + conv_b, new_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b")) + + return self.norm_layer(f) * conv_y + conv_b, new_cache + + +class ResnetBlock3D(nn.Module): + """3D ResNet block with optional spatial norm.""" + def __init__(self, in_channels, out_channels=None, temb_channels=512, groups=32, + eps=1e-6, act_fn="silu", spatial_norm_dim=None, pad_mode="first"): + super().__init__() + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.spatial_norm_dim = spatial_norm_dim + + if act_fn == "silu": + self.nonlinearity = nn.SiLU() + elif act_fn == "swish": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = nn.SiLU() + + if spatial_norm_dim is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) + else: + self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups) + self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups) + + self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + + if temb_channels > 0: + self.temb_proj = nn.Linear(temb_channels, out_channels) + + self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + + if in_channels != out_channels: + self.conv_shortcut = ops.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + else: + self.conv_shortcut = None + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + residual = x + + if zq is not None: + x, new_cache["norm1"] = self.norm1(x, zq, conv_cache=conv_cache.get("norm1")) + else: + x = self.norm1(x) + + x = self.nonlinearity(x) + x, new_cache["conv1"] = self.conv1(x, conv_cache=conv_cache.get("conv1")) + + if temb is not None and hasattr(self, "temb_proj"): + x = x + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if zq is not None: + x, new_cache["norm2"] = self.norm2(x, zq, conv_cache=conv_cache.get("norm2")) + else: + x = self.norm2(x) + + x = self.nonlinearity(x) + x, new_cache["conv2"] = self.conv2(x, conv_cache=conv_cache.get("conv2")) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return x + residual, new_cache + + +class Downsample3D(nn.Module): + """3D downsampling with optional temporal compression.""" + def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time: + b, c, t, h, w = x.shape + x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t) + if t % 2 == 1: + x_first, x_rest = x[..., 0], x[..., 1:] + if x_rest.shape[-1] > 0: + x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) + x = torch.cat([x_first[..., None], x_rest], dim=-1) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + else: + x = F.avg_pool1d(x, kernel_size=2, stride=2) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + return x + + +class Upsample3D(nn.Module): + """3D upsampling with optional temporal decompression.""" + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time: + if x.shape[2] > 1 and x.shape[2] % 2 == 1: + x_first, x_rest = x[:, :, 0], x[:, :, 1:] + x_first = F.interpolate(x_first, scale_factor=2.0) + x_rest = F.interpolate(x_rest, scale_factor=2.0) + x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) + elif x.shape[2] > 1: + x = F.interpolate(x, scale_factor=2.0) + else: + x = x.squeeze(2) + x = F.interpolate(x, scale_factor=2.0) + x = x[:, :, None, :, :] + else: + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = F.interpolate(x, scale_factor=2.0) + x = x.reshape(b, t, c, *x.shape[2:]).permute(0, 2, 1, 3, 4) + + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, *x.shape[1:]).permute(0, 2, 1, 3, 4) + return x + + +class DownBlock3D(nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1, + eps=1e-6, act_fn="silu", groups=32, add_downsample=True, + compress_time=False, pad_mode="first"): + super().__init__() + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + groups=groups, eps=eps, act_fn=act_fn, pad_mode=pad_mode, + ) + for i in range(num_layers) + ]) + self.downsamplers = nn.ModuleList([Downsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_downsample else None + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + for i, resnet in enumerate(self.resnets): + x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) + if self.downsamplers is not None: + for ds in self.downsamplers: + x = ds(x) + return x, new_cache + + +class MidBlock3D(nn.Module): + def __init__(self, in_channels, temb_channels=0, num_layers=1, + eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=None, pad_mode="first"): + super().__init__() + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels, out_channels=in_channels, + temb_channels=temb_channels, groups=groups, eps=eps, + act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode, + ) + for _ in range(num_layers) + ]) + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + for i, resnet in enumerate(self.resnets): + x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) + return x, new_cache + + +class UpBlock3D(nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1, + eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=16, + add_upsample=True, compress_time=False, pad_mode="first"): + super().__init__() + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + temb_channels=temb_channels, groups=groups, eps=eps, + act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode, + ) + for i in range(num_layers) + ]) + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_upsample else None + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + for i, resnet in enumerate(self.resnets): + x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) + if self.upsamplers is not None: + for us in self.upsamplers: + x = us(x) + return x, new_cache + + +class Encoder3D(nn.Module): + def __init__(self, in_channels=3, out_channels=16, + block_out_channels=(128, 256, 256, 512), + layers_per_block=3, act_fn="silu", + eps=1e-6, groups=32, pad_mode="first", + temporal_compression_ratio=4): + super().__init__() + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) + + self.down_blocks = nn.ModuleList() + output_channel = block_out_channels[0] + for i in range(len(block_out_channels)): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + self.down_blocks.append(DownBlock3D( + in_channels=input_channel, out_channels=output_channel, + temb_channels=0, num_layers=layers_per_block, + eps=eps, act_fn=act_fn, groups=groups, + add_downsample=not is_final, compress_time=compress_time, + )) + + self.mid_block = MidBlock3D( + in_channels=block_out_channels[-1], temb_channels=0, + num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode, + ) + + self.norm_out = nn.GroupNorm(groups, block_out_channels[-1], eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode) + + def forward(self, x, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + + x, new_cache["conv_in"] = self.conv_in(x, conv_cache=conv_cache.get("conv_in")) + + for i, block in enumerate(self.down_blocks): + key = f"down_block_{i}" + x, new_cache[key] = block(x, None, None, conv_cache.get(key)) + + x, new_cache["mid_block"] = self.mid_block(x, None, None, conv_cache=conv_cache.get("mid_block")) + + x = self.norm_out(x) + x = self.conv_act(x) + x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out")) + + return x, new_cache + + +class Decoder3D(nn.Module): + def __init__(self, in_channels=16, out_channels=3, + block_out_channels=(128, 256, 256, 512), + layers_per_block=3, act_fn="silu", + eps=1e-6, groups=32, pad_mode="first", + temporal_compression_ratio=4): + super().__init__() + reversed_channels = list(reversed(block_out_channels)) + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + self.conv_in = CausalConv3d(in_channels, reversed_channels[0], kernel_size=3, pad_mode=pad_mode) + + self.mid_block = MidBlock3D( + in_channels=reversed_channels[0], temb_channels=0, + num_layers=2, eps=eps, act_fn=act_fn, groups=groups, + spatial_norm_dim=in_channels, pad_mode=pad_mode, + ) + + self.up_blocks = nn.ModuleList() + output_channel = reversed_channels[0] + for i in range(len(block_out_channels)): + prev_channel = output_channel + output_channel = reversed_channels[i] + is_final = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + self.up_blocks.append(UpBlock3D( + in_channels=prev_channel, out_channels=output_channel, + temb_channels=0, num_layers=layers_per_block + 1, + eps=eps, act_fn=act_fn, groups=groups, + spatial_norm_dim=in_channels, + add_upsample=not is_final, compress_time=compress_time, + )) + + self.norm_out = SpatialNorm3D(reversed_channels[-1], in_channels, groups=groups) + self.conv_act = nn.SiLU() + self.conv_out = CausalConv3d(reversed_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode) + + def forward(self, sample, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + + x, new_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) + + x, new_cache["mid_block"] = self.mid_block(x, None, sample, conv_cache=conv_cache.get("mid_block")) + + for i, block in enumerate(self.up_blocks): + key = f"up_block_{i}" + x, new_cache[key] = block(x, None, sample, conv_cache=conv_cache.get(key)) + + x, new_cache["norm_out"] = self.norm_out(x, sample, conv_cache=conv_cache.get("norm_out")) + x = self.conv_act(x) + x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out")) + + return x, new_cache + + + +class AutoencoderKLCogVideoX(nn.Module): + """CogVideoX VAE. Spatial tiling/slicing handled by ComfyUI's VAE wrapper. + + Uses rolling temporal decode: conv_in + mid_block + temporal up_blocks run + on the full (low-res) tensor, then the expensive spatial-only up_blocks + + norm_out + conv_out are processed in small temporal chunks with conv_cache + carrying causal state between chunks. This keeps peak VRAM proportional to + chunk_size rather than total frame count. + """ + + def __init__(self, + in_channels=3, out_channels=3, + block_out_channels=(128, 256, 256, 512), + latent_channels=16, layers_per_block=3, + act_fn="silu", eps=1e-6, groups=32, + temporal_compression_ratio=4, + ): + super().__init__() + self.latent_channels = latent_channels + self.temporal_compression_ratio = temporal_compression_ratio + + self.encoder = Encoder3D( + in_channels=in_channels, out_channels=latent_channels, + block_out_channels=block_out_channels, layers_per_block=layers_per_block, + act_fn=act_fn, eps=eps, groups=groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + self.decoder = Decoder3D( + in_channels=latent_channels, out_channels=out_channels, + block_out_channels=block_out_channels, layers_per_block=layers_per_block, + act_fn=act_fn, eps=eps, groups=groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + + self.num_latent_frames_batch_size = 2 + self.num_sample_frames_batch_size = 8 + + def encode(self, x): + t = x.shape[2] + frame_batch = self.num_sample_frames_batch_size + num_batches = max(t // frame_batch, 1) + conv_cache = None + enc = [] + for i in range(num_batches): + remaining = t % frame_batch + start = frame_batch * i + (0 if i == 0 else remaining) + end = frame_batch * (i + 1) + remaining + chunk, conv_cache = self.encoder(x[:, :, start:end], conv_cache=conv_cache) + enc.append(chunk.to(x.device)) + enc = torch.cat(enc, dim=2) + mean, _ = enc.chunk(2, dim=1) + return mean + + def decode(self, z): + return self._decode_rolling(z) + + def _decode_batched(self, z): + """Original batched decode - processes 2 latent frames through full decoder.""" + t = z.shape[2] + frame_batch = self.num_latent_frames_batch_size + num_batches = max(t // frame_batch, 1) + conv_cache = None + dec = [] + for i in range(num_batches): + remaining = t % frame_batch + start = frame_batch * i + (0 if i == 0 else remaining) + end = frame_batch * (i + 1) + remaining + chunk, conv_cache = self.decoder(z[:, :, start:end], conv_cache=conv_cache) + dec.append(chunk.cpu()) + return torch.cat(dec, dim=2).to(z.device) + + def _decode_rolling(self, z): + """Rolling decode - processes low-res layers on full tensor, then rolls + through expensive high-res layers in temporal chunks.""" + decoder = self.decoder + device = z.device + + # Determine which up_blocks have temporal upsample vs spatial-only. + # Temporal up_blocks are cheap (low res), spatial-only are expensive. + temporal_compress_level = int(np.log2(self.temporal_compression_ratio)) + split_at = temporal_compress_level # first N up_blocks do temporal upsample + + # Phase 1: conv_in + mid_block + temporal up_blocks on full tensor (low/medium res) + x, _ = decoder.conv_in(z) + x, _ = decoder.mid_block(x, None, z) + + for i in range(split_at): + x, _ = decoder.up_blocks[i](x, None, z) + + # Phase 2: remaining spatial-only up_blocks + norm_out + conv_out in temporal chunks + remaining_blocks = list(range(split_at, len(decoder.up_blocks))) + chunk_size = 4 # pixel frames per chunk through high-res layers + t_expanded = x.shape[2] + + if t_expanded <= chunk_size or len(remaining_blocks) == 0: + # Small enough to process in one go + for i in remaining_blocks: + x, _ = decoder.up_blocks[i](x, None, z) + x, _ = decoder.norm_out(x, z) + x = decoder.conv_act(x) + x, _ = decoder.conv_out(x) + return x + + # Pre-interpolate z to each spatial resolution used by Phase 2 blocks. + # Uses the exact same interpolation logic as SpatialNorm3D so chunked + # output is identical to non-chunked. + # Determine spatial sizes: run a dummy pass to find feature map sizes, + # or compute from block structure. Simpler: compute from x's current size + # and the known upsample factor (2x per block with upsample). + z_at_res = {} # keyed by (h, w) → pre-interpolated z [B, C, t_expanded, h, w] + h, w = x.shape[3], x.shape[4] + for i in remaining_blocks: + block = decoder.up_blocks[i] + # Resnets operate at current h, w + target = (t_expanded, h, w) + if target not in z_at_res: + z_at_res[target] = _interpolate_zq(z, target) + # If block has upsample, next block's input is 2x spatial + if block.upsamplers is not None: + h, w = h * 2, w * 2 + # norm_out operates at final resolution + target = (t_expanded, h, w) + if target not in z_at_res: + z_at_res[target] = _interpolate_zq(z, target) + + # Process in temporal chunks + dec_out = [] + conv_caches = {} + + for chunk_start in range(0, t_expanded, chunk_size): + chunk_end = min(chunk_start + chunk_size, t_expanded) + x_chunk = x[:, :, chunk_start:chunk_end] + + for i in remaining_blocks: + block = decoder.up_blocks[i] + cache_key = f"up_block_{i}" + # Get pre-interpolated z at the block's input spatial resolution + res_key = (t_expanded, x_chunk.shape[3], x_chunk.shape[4]) + z_chunk = z_at_res[res_key][:, :, chunk_start:chunk_end] + x_chunk, new_cache = block(x_chunk, None, z_chunk, conv_cache=conv_caches.get(cache_key)) + conv_caches[cache_key] = new_cache + + # norm_out at final resolution + res_key = (t_expanded, x_chunk.shape[3], x_chunk.shape[4]) + z_chunk = z_at_res[res_key][:, :, chunk_start:chunk_end] + x_chunk, new_cache = decoder.norm_out(x_chunk, z_chunk, conv_cache=conv_caches.get("norm_out")) + conv_caches["norm_out"] = new_cache + x_chunk = decoder.conv_act(x_chunk) + x_chunk, new_cache = decoder.conv_out(x_chunk, conv_cache=conv_caches.get("conv_out")) + conv_caches["conv_out"] = new_cache + + dec_out.append(x_chunk.cpu()) + + del x + return torch.cat(dec_out, dim=2).to(device) diff --git a/comfy/ldm/cogvideo/vae_backup.py b/comfy/ldm/cogvideo/vae_backup.py new file mode 100644 index 000000000..47254b672 --- /dev/null +++ b/comfy/ldm/cogvideo/vae_backup.py @@ -0,0 +1,485 @@ +# CogVideoX VAE - ported to ComfyUI native ops +# Architecture reference: diffusers AutoencoderKLCogVideoX +# Style reference: comfy/ldm/wan/vae.py + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops +ops = comfy.ops.disable_weight_init + + +class SafeConv3d(nn.Conv3d): + """3D convolution that splits large inputs along temporal dim to avoid OOM.""" + def forward(self, x): + mem = x.shape[0] * x.shape[1] * x.shape[2] * x.shape[3] * x.shape[4] * 2 / 1024**3 + if mem > 2 and x.shape[2] >= self.kernel_size[0]: + kernel_t = self.kernel_size[0] + parts = int(mem / 2) + 1 + # Ensure each chunk has at least kernel_t frames + max_parts = max(1, x.shape[2] // kernel_t) + parts = min(parts, max_parts) + if parts <= 1: + return super().forward(x) + chunks = torch.chunk(x, parts, dim=2) + if kernel_t > 1: + chunks = [chunks[0]] + [ + torch.cat((chunks[i - 1][:, :, -kernel_t + 1:], chunks[i]), dim=2) + for i in range(1, len(chunks)) + ] + out = [] + for chunk in chunks: + out.append(super().forward(chunk)) + return torch.cat(out, dim=2) + return super().forward(x) + + +class CausalConv3d(nn.Module): + """Causal 3D convolution with temporal padding.""" + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, pad_mode="constant"): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + + time_kernel, height_kernel, width_kernel = kernel_size + time_pad = time_kernel - 1 + height_pad = (height_kernel - 1) // 2 + width_pad = (width_kernel - 1) // 2 + + self.pad_mode = pad_mode + self.time_pad = time_pad + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) + self.const_padding = (0, width_pad, height_pad) + self.time_kernel_size = time_kernel + + stride = stride if isinstance(stride, tuple) else (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = SafeConv3d( + in_channels, out_channels, kernel_size, + stride=stride, dilation=dilation, + padding=0 if pad_mode == "replicate" else self.const_padding, + ) + + def forward(self, x, conv_cache=None): + if self.pad_mode == "replicate": + x = F.pad(x, self.time_causal_padding, mode="replicate") + conv_cache = None + else: + kernel_t = self.time_kernel_size + if kernel_t > 1: + cached = [conv_cache] if conv_cache is not None else [x[:, :, :1]] * (kernel_t - 1) + x = torch.cat(cached + [x], dim=2) + conv_cache = x[:, :, -self.time_kernel_size + 1:].clone() if self.time_kernel_size > 1 else None + + out = self.conv(x) + return out, conv_cache + + +class SpatialNorm3D(nn.Module): + """Spatially conditioned normalization.""" + def __init__(self, f_channels, zq_channels, groups=32): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) + self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + + def forward(self, f, zq, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + + if f.shape[2] > 1 and f.shape[2] % 2 == 1: + f_first, f_rest = f[:, :, :1], f[:, :, 1:] + z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] + z_first = F.interpolate(z_first, size=f_first.shape[-3:]) + z_rest = F.interpolate(z_rest, size=f_rest.shape[-3:]) + zq = torch.cat([z_first, z_rest], dim=2) + else: + zq = F.interpolate(zq, size=f.shape[-3:]) + + conv_y, new_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y")) + conv_b, new_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b")) + + return self.norm_layer(f) * conv_y + conv_b, new_cache + + +class ResnetBlock3D(nn.Module): + """3D ResNet block with optional spatial norm.""" + def __init__(self, in_channels, out_channels=None, temb_channels=512, groups=32, + eps=1e-6, act_fn="silu", spatial_norm_dim=None, pad_mode="first"): + super().__init__() + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.spatial_norm_dim = spatial_norm_dim + + if act_fn == "silu": + self.nonlinearity = nn.SiLU() + elif act_fn == "swish": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = nn.SiLU() + + if spatial_norm_dim is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) + else: + self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups) + self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups) + + self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + + if temb_channels > 0: + self.temb_proj = nn.Linear(temb_channels, out_channels) + + self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + + if in_channels != out_channels: + self.conv_shortcut = SafeConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + else: + self.conv_shortcut = None + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + residual = x + + if zq is not None: + x, new_cache["norm1"] = self.norm1(x, zq, conv_cache=conv_cache.get("norm1")) + else: + x = self.norm1(x) + + x = self.nonlinearity(x) + x, new_cache["conv1"] = self.conv1(x, conv_cache=conv_cache.get("conv1")) + + if temb is not None and hasattr(self, "temb_proj"): + x = x + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if zq is not None: + x, new_cache["norm2"] = self.norm2(x, zq, conv_cache=conv_cache.get("norm2")) + else: + x = self.norm2(x) + + x = self.nonlinearity(x) + x, new_cache["conv2"] = self.conv2(x, conv_cache=conv_cache.get("conv2")) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return x + residual, new_cache + + +class Downsample3D(nn.Module): + """3D downsampling with optional temporal compression.""" + def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time: + b, c, t, h, w = x.shape + x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t) + if t % 2 == 1: + x_first, x_rest = x[..., 0], x[..., 1:] + if x_rest.shape[-1] > 0: + x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) + x = torch.cat([x_first[..., None], x_rest], dim=-1) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + else: + x = F.avg_pool1d(x, kernel_size=2, stride=2) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + return x + + +class Upsample3D(nn.Module): + """3D upsampling with optional temporal decompression.""" + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time: + if x.shape[2] > 1 and x.shape[2] % 2 == 1: + x_first, x_rest = x[:, :, 0], x[:, :, 1:] + x_first = F.interpolate(x_first, scale_factor=2.0) + x_rest = F.interpolate(x_rest, scale_factor=2.0) + x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) + elif x.shape[2] > 1: + x = F.interpolate(x, scale_factor=2.0) + else: + x = x.squeeze(2) + x = F.interpolate(x, scale_factor=2.0) + x = x[:, :, None, :, :] + else: + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = F.interpolate(x, scale_factor=2.0) + x = x.reshape(b, t, c, *x.shape[2:]).permute(0, 2, 1, 3, 4) + + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, *x.shape[1:]).permute(0, 2, 1, 3, 4) + return x + + +class DownBlock3D(nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1, + eps=1e-6, act_fn="silu", groups=32, add_downsample=True, + compress_time=False, pad_mode="first"): + super().__init__() + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + groups=groups, eps=eps, act_fn=act_fn, pad_mode=pad_mode, + ) + for i in range(num_layers) + ]) + self.downsamplers = nn.ModuleList([Downsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_downsample else None + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + for i, resnet in enumerate(self.resnets): + x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) + if self.downsamplers is not None: + for ds in self.downsamplers: + x = ds(x) + return x, new_cache + + +class MidBlock3D(nn.Module): + def __init__(self, in_channels, temb_channels=0, num_layers=1, + eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=None, pad_mode="first"): + super().__init__() + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels, out_channels=in_channels, + temb_channels=temb_channels, groups=groups, eps=eps, + act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode, + ) + for _ in range(num_layers) + ]) + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + for i, resnet in enumerate(self.resnets): + x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) + return x, new_cache + + +class UpBlock3D(nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1, + eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=16, + add_upsample=True, compress_time=False, pad_mode="first"): + super().__init__() + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + temb_channels=temb_channels, groups=groups, eps=eps, + act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode, + ) + for i in range(num_layers) + ]) + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_upsample else None + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + for i, resnet in enumerate(self.resnets): + x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) + if self.upsamplers is not None: + for us in self.upsamplers: + x = us(x) + return x, new_cache + + +class Encoder3D(nn.Module): + def __init__(self, in_channels=3, out_channels=16, + block_out_channels=(128, 256, 256, 512), + layers_per_block=3, act_fn="silu", + eps=1e-6, groups=32, pad_mode="first", + temporal_compression_ratio=4): + super().__init__() + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) + + self.down_blocks = nn.ModuleList() + output_channel = block_out_channels[0] + for i in range(len(block_out_channels)): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + self.down_blocks.append(DownBlock3D( + in_channels=input_channel, out_channels=output_channel, + temb_channels=0, num_layers=layers_per_block, + eps=eps, act_fn=act_fn, groups=groups, + add_downsample=not is_final, compress_time=compress_time, + )) + + self.mid_block = MidBlock3D( + in_channels=block_out_channels[-1], temb_channels=0, + num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode, + ) + + self.norm_out = nn.GroupNorm(groups, block_out_channels[-1], eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode) + + def forward(self, x, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + + x, new_cache["conv_in"] = self.conv_in(x, conv_cache=conv_cache.get("conv_in")) + + for i, block in enumerate(self.down_blocks): + key = f"down_block_{i}" + x, new_cache[key] = block(x, None, None, conv_cache.get(key)) + + x, new_cache["mid_block"] = self.mid_block(x, None, None, conv_cache=conv_cache.get("mid_block")) + + x = self.norm_out(x) + x = self.conv_act(x) + x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out")) + + return x, new_cache + + +class Decoder3D(nn.Module): + def __init__(self, in_channels=16, out_channels=3, + block_out_channels=(128, 256, 256, 512), + layers_per_block=3, act_fn="silu", + eps=1e-6, groups=32, pad_mode="first", + temporal_compression_ratio=4): + super().__init__() + reversed_channels = list(reversed(block_out_channels)) + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + self.conv_in = CausalConv3d(in_channels, reversed_channels[0], kernel_size=3, pad_mode=pad_mode) + + self.mid_block = MidBlock3D( + in_channels=reversed_channels[0], temb_channels=0, + num_layers=2, eps=eps, act_fn=act_fn, groups=groups, + spatial_norm_dim=in_channels, pad_mode=pad_mode, + ) + + self.up_blocks = nn.ModuleList() + output_channel = reversed_channels[0] + for i in range(len(block_out_channels)): + prev_channel = output_channel + output_channel = reversed_channels[i] + is_final = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + self.up_blocks.append(UpBlock3D( + in_channels=prev_channel, out_channels=output_channel, + temb_channels=0, num_layers=layers_per_block + 1, + eps=eps, act_fn=act_fn, groups=groups, + spatial_norm_dim=in_channels, + add_upsample=not is_final, compress_time=compress_time, + )) + + self.norm_out = SpatialNorm3D(reversed_channels[-1], in_channels, groups=groups) + self.conv_act = nn.SiLU() + self.conv_out = CausalConv3d(reversed_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode) + + def forward(self, sample, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + + x, new_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) + + x, new_cache["mid_block"] = self.mid_block(x, None, sample, conv_cache=conv_cache.get("mid_block")) + + for i, block in enumerate(self.up_blocks): + key = f"up_block_{i}" + x, new_cache[key] = block(x, None, sample, conv_cache=conv_cache.get(key)) + + x, new_cache["norm_out"] = self.norm_out(x, sample, conv_cache=conv_cache.get("norm_out")) + x = self.conv_act(x) + x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out")) + + return x, new_cache + + + +class AutoencoderKLCogVideoX(nn.Module): + """CogVideoX VAE. Spatial tiling/slicing handled by ComfyUI's VAE wrapper. + + Temporal frame batching with conv_cache is kept here since the causal 3D + convolutions need state passed between temporal chunks. + """ + + def __init__(self, + in_channels=3, out_channels=3, + block_out_channels=(128, 256, 256, 512), + latent_channels=16, layers_per_block=3, + act_fn="silu", eps=1e-6, groups=32, + temporal_compression_ratio=4, + ): + super().__init__() + self.latent_channels = latent_channels + + self.encoder = Encoder3D( + in_channels=in_channels, out_channels=latent_channels, + block_out_channels=block_out_channels, layers_per_block=layers_per_block, + act_fn=act_fn, eps=eps, groups=groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + self.decoder = Decoder3D( + in_channels=latent_channels, out_channels=out_channels, + block_out_channels=block_out_channels, layers_per_block=layers_per_block, + act_fn=act_fn, eps=eps, groups=groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + + self.num_latent_frames_batch_size = 2 + self.num_sample_frames_batch_size = 8 + + def encode(self, x): + t = x.shape[2] + frame_batch = self.num_sample_frames_batch_size + num_batches = max(t // frame_batch, 1) + conv_cache = None + enc = [] + for i in range(num_batches): + remaining = t % frame_batch + start = frame_batch * i + (0 if i == 0 else remaining) + end = frame_batch * (i + 1) + remaining + chunk, conv_cache = self.encoder(x[:, :, start:end], conv_cache=conv_cache) + enc.append(chunk.to(x.device)) + enc = torch.cat(enc, dim=2) + mean, _ = enc.chunk(2, dim=1) + return mean + + def decode(self, z): + t = z.shape[2] + frame_batch = self.num_latent_frames_batch_size + num_batches = max(t // frame_batch, 1) + conv_cache = None + dec = [] + for i in range(num_batches): + remaining = t % frame_batch + start = frame_batch * i + (0 if i == 0 else remaining) + end = frame_batch * (i + 1) + remaining + chunk, conv_cache = self.decoder(z[:, :, start:end], conv_cache=conv_cache) + dec.append(chunk.cpu()) + return torch.cat(dec, dim=2).to(z.device) diff --git a/comfy/model_base.py b/comfy/model_base.py index c2ae646aa..791116436 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -52,6 +52,7 @@ import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model import comfy.ldm.ace.ace_step15 +import comfy.ldm.cogvideo.model import comfy.ldm.rt_detr.rtdetr_v4 import comfy.model_management @@ -79,6 +80,7 @@ class ModelType(Enum): IMG_TO_IMG = 9 FLOW_COSMOS = 10 IMG_TO_IMG_FLOW = 11 + V_PREDICTION_DDPM = 12 def model_sampling(model_config, model_type): @@ -113,6 +115,8 @@ def model_sampling(model_config, model_type): s = comfy.model_sampling.ModelSamplingCosmosRFlow elif model_type == ModelType.IMG_TO_IMG_FLOW: c = comfy.model_sampling.IMG_TO_IMG_FLOW + elif model_type == ModelType.V_PREDICTION_DDPM: + c = comfy.model_sampling.V_PREDICTION_DDPM class ModelSampling(s, c): pass @@ -1962,3 +1966,41 @@ class Kandinsky5Image(Kandinsky5): class RT_DETR_v4(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4) + +class CogVideoX(BaseModel): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION_DDPM, image_to_video=False, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cogvideo.model.CogVideoXTransformer3DModel) + self.image_to_video = image_to_video + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + # Detect extra channels needed (e.g. 32 - 16 = 16 for ref latent) + extra_channels = self.diffusion_model.in_channels - noise.shape[1] + if extra_channels == 0: + return None + + image = kwargs.get("concat_latent_image", None) + device = kwargs["device"] + + if image is None: + shape = list(noise.shape) + shape[1] = extra_channels + return torch.zeros(shape, dtype=noise.dtype, layout=noise.layout, device=noise.device) + + latent_dim = self.latent_format.latent_channels + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + for i in range(0, image.shape[1], latent_dim): + image[:, i:i + latent_dim] = self.process_latent_in(image[:, i:i + latent_dim]) + image = utils.resize_to_batch_size(image, noise.shape[0]) + return image + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + # OFS embedding (CogVideoX 1.5 I2V), default 2.0 as used by SparkVSR + if self.diffusion_model.ofs_proj_dim is not None: + ofs = kwargs.get("ofs", None) + if ofs is None: + noise = kwargs.get("noise", None) + ofs = torch.full((noise.shape[0],), 2.0, device=noise.device, dtype=noise.dtype) + out['ofs'] = comfy.conds.CONDRegular(ofs) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8bed6828d..a1df050ef 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -490,6 +490,55 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return dit_config + if '{}blocks.0.norm1.linear.weight'.format(key_prefix) in state_dict_keys: # CogVideoX + dit_config = {} + dit_config["image_model"] = "cogvideox" + + # Extract config from weight shapes + norm1_weight = state_dict['{}blocks.0.norm1.linear.weight'.format(key_prefix)] + time_embed_dim = norm1_weight.shape[1] + dim = norm1_weight.shape[0] // 6 + + dit_config["num_attention_heads"] = dim // 64 + dit_config["attention_head_dim"] = 64 + dit_config["time_embed_dim"] = time_embed_dim + dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') + + # Detect in_channels from patch_embed + patch_proj_key = '{}patch_embed.proj.weight'.format(key_prefix) + patch_proj_linear_key = '{}patch_embed.proj.weight'.format(key_prefix) + if patch_proj_key in state_dict_keys: + w = state_dict[patch_proj_key] + if w.ndim == 4: + # Conv2d: [out, in, kh, kw] — CogVideoX 1.0 + dit_config["in_channels"] = w.shape[1] + dit_config["patch_size"] = w.shape[2] + elif w.ndim == 2: + # Linear: [out, in_channels * patch_size * patch_size * patch_size_t] — CogVideoX 1.5 + dit_config["patch_size"] = 2 + dit_config["patch_size_t"] = 2 + dit_config["in_channels"] = w.shape[1] // (2 * 2 * 2) # 256 // 8 = 32 + + text_proj_key = '{}patch_embed.text_proj.weight'.format(key_prefix) + if text_proj_key in state_dict_keys: + dit_config["text_embed_dim"] = state_dict[text_proj_key].shape[1] + + # Detect OFS embedding + ofs_key = '{}ofs_embedding_linear_1.weight'.format(key_prefix) + if ofs_key in state_dict_keys: + dit_config["ofs_embed_dim"] = state_dict[ofs_key].shape[1] + + # Detect positional embedding type + pos_key = '{}patch_embed.pos_embedding'.format(key_prefix) + if pos_key in state_dict_keys: + dit_config["use_learned_positional_embeddings"] = True + dit_config["use_rotary_positional_embeddings"] = False + else: + dit_config["use_learned_positional_embeddings"] = False + dit_config["use_rotary_positional_embeddings"] = True + + return dit_config + if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 dit_config = {} dit_config["image_model"] = "wan2.1" diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 13860e6a2..cf2b5db5f 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -54,6 +54,30 @@ class V_PREDICTION(EPS): sigma = reshape_sigma(sigma, model_output.ndim) return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 +class V_PREDICTION_DDPM: + """CogVideoX v-prediction: model receives raw x_t (unscaled), predicts velocity v. + x_0 = sqrt(alpha) * x_t - sqrt(1-alpha) * v + = x_t / sqrt(sigma^2 + 1) - v * sigma / sqrt(sigma^2 + 1) + """ + def calculate_input(self, sigma, noise): + return noise + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = reshape_sigma(sigma, model_output.ndim) + return model_input / (sigma ** 2 + 1.0) ** 0.5 - model_output * sigma / (sigma ** 2 + 1.0) ** 0.5 + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + sigma = reshape_sigma(sigma, noise.ndim) + if max_denoise: + noise = noise * torch.sqrt(1.0 + sigma ** 2.0) + else: + noise = noise * sigma + noise += latent_image + return noise + + def inverse_noise_scaling(self, sigma, latent): + return latent + class EDM(V_PREDICTION): def calculate_denoised(self, sigma, model_output, model_input): sigma = reshape_sigma(sigma, model_output.ndim) diff --git a/comfy/sd.py b/comfy/sd.py index f331feefb..6cbd1eae9 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -650,6 +650,18 @@ class VAE: self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) + elif "decoder.conv_in.conv.weight" in sd and "decoder.mid_block.resnets.0.norm1.norm_layer.weight" in sd: # CogVideoX VAE + import comfy.ldm.cogvideo.vae + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.latent_dim = 3 + self.latent_channels = sd["encoder.conv_out.conv.weight"].shape[0] // 2 + self.first_stage_model = comfy.ldm.cogvideo.vae.AutoencoderKLCogVideoX(latent_channels=self.latent_channels) + self.memory_used_decode = lambda shape, dtype: (2800 * max(2, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1400 * max(1, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 9a5612716..e6fbe2209 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1749,6 +1749,55 @@ class RT_DETR_v4(supported_models_base.BASE): def clip_target(self, state_dict={}): return None -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4] +class CogVideoX_T2V(supported_models_base.BASE): + unet_config = { + "image_model": "cogvideox", + } + + sampling_settings = { + "linear_start": 0.00085, + "linear_end": 0.012, + "beta_schedule": "linear", + "zsnr": True, + } + + unet_extra_config = {} + latent_format = latent_formats.CogVideoX + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + # CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE + if self.unet_config.get("patch_size_t") is not None: + self.unet_config.setdefault("sample_height", 96) + self.unet_config.setdefault("sample_width", 170) + self.unet_config.setdefault("sample_frames", 81) + out = model_base.CogVideoX(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + + class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226) + + return supported_models_base.ClipTarget(CogVideoXT5Tokenizer, comfy.text_encoders.sd3_clip.T5XXLModel) + +class CogVideoX_I2V(CogVideoX_T2V): + unet_config = { + "image_model": "cogvideox", + "in_channels": 32, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.CogVideoX(self, image_to_video=True, device=device) + return out + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, CogVideoX_I2V, CogVideoX_T2V] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_cogvideox.py b/comfy_extras/nodes_cogvideox.py new file mode 100644 index 000000000..59aa74cee --- /dev/null +++ b/comfy_extras/nodes_cogvideox.py @@ -0,0 +1,137 @@ +import nodes +import node_helpers +import torch +import comfy.model_management +import comfy.utils +from comfy_api.latest import io, ComfyExtension +from typing_extensions import override + +class SparkVSRConditioning(io.ComfyNode): + """Conditioning node for SparkVSR video super-resolution. + + Encodes LQ video and optional HR reference frames through the VAE, + builds the concat conditioning for the CogVideoX I2V model. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SparkVSRConditioning", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Image.Input("lq_video"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("length", default=49, min=1, max=nodes.MAX_RESOLUTION, step=1), + io.Int.Input("batch_size", default=1, min=1, max=64), + io.Image.Input("ref_frames", optional=True), + io.Combo.Input("ref_mode", options=["auto", "manual"], default="auto", optional=True), + io.String.Input("ref_indices", default="", optional=True), + io.Float.Input("ref_guidance_scale", default=1.0, min=0.0, max=10.0, step=0.1, optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, lq_video, width, height, length, + batch_size, ref_frames=None, ref_mode="auto", ref_indices="", + ref_guidance_scale=1.0) -> io.NodeOutput: + + temporal_compression = 4 + latent_t = ((length - 1) // temporal_compression) + 1 + latent_h = height // 8 + latent_w = width // 8 + + # Base latent (noise will be added by KSampler) + latent = torch.zeros( + [batch_size, 16, latent_t, latent_h, latent_w], + device=comfy.model_management.intermediate_device() + ) + + # Encode LQ video → this becomes the base latent (KSampler adds noise to this) + lq = lq_video[:length] + lq = comfy.utils.common_upscale(lq.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + lq_latent = vae.encode(lq[:, :, :, :3]) + + # Ensure temporal dim matches + if lq_latent.shape[2] > latent_t: + lq_latent = lq_latent[:, :, :latent_t] + elif lq_latent.shape[2] < latent_t: + pad = latent_t - lq_latent.shape[2] + lq_latent = torch.cat([lq_latent, lq_latent[:, :, -1:].repeat(1, 1, pad, 1, 1)], dim=2) + + # Build reference latent (16ch) — goes as concat_latent_image + # concat_cond in model_base will concatenate this with the noise (16ch) → 32ch total + ref_latent = torch.zeros_like(lq_latent) + + if ref_frames is not None: + num_video_frames = lq_video.shape[0] + + # Determine reference indices + if ref_mode == "manual" and ref_indices.strip(): + indices = [int(x.strip()) for x in ref_indices.split(",") if x.strip()] + else: + indices = _select_indices(num_video_frames) + + # Encode each reference frame and place at its temporal position. + # SparkVSR places refs at specific latent indices, rest stays zeros. + for ref_idx in indices: + if ref_idx >= ref_frames.shape[0]: + continue + + frame = ref_frames[ref_idx:ref_idx + 1] + frame = comfy.utils.common_upscale(frame.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + frame_latent = vae.encode(frame[:, :, :, :3]) + + target_lat_idx = ref_idx // temporal_compression + if target_lat_idx < latent_t: + ref_latent[:, :, target_lat_idx] = frame_latent[:, :, 0] + + # Set ref latent as concat conditioning (16ch, model_base.concat_cond adds it to noise) + if ref_guidance_scale != 1.0 and ref_frames is not None: + # CFG: positive gets real refs, negative gets zero refs + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": ref_latent}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": torch.zeros_like(ref_latent)}) + else: + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": ref_latent}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": ref_latent}) + + # LQ latent is the base — KSampler will noise it and denoise + out_latent = {"samples": lq_latent} + return io.NodeOutput(positive, negative, out_latent) + + +def _select_indices(num_frames, max_refs=None): + """Auto-select reference frame indices (first, evenly spaced, last).""" + if max_refs is None: + max_refs = (num_frames - 1) // 4 + 1 + max_refs = min(max_refs, 3) + + if num_frames <= 1: + return [0] + if max_refs == 1: + return [0] + if max_refs == 2: + return [0, num_frames - 1] + + mid = num_frames // 2 + return [0, mid, num_frames - 1] + + +class CogVideoXExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SparkVSRConditioning, + ] + + +async def comfy_entrypoint() -> CogVideoXExtension: + return CogVideoXExtension() diff --git a/convert_sparkvsr_to_comfy.py b/convert_sparkvsr_to_comfy.py new file mode 100644 index 000000000..891c14428 --- /dev/null +++ b/convert_sparkvsr_to_comfy.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +"""Convert SparkVSR/CogVideoX diffusers checkpoint to ComfyUI format. + +Usage: + python convert_sparkvsr_to_comfy.py --model_dir path/to/sparkvsr-checkpoint \ + --output_dir ComfyUI/models/ + +This creates two files: + - diffusion_models/cogvideox_sparkvsr.safetensors (transformer) + - vae/cogvideox_vae.safetensors (VAE) + +T5-XXL text encoder does not need conversion — use existing ComfyUI T5 weights. +""" + +import argparse +import os +import torch +from safetensors.torch import load_file, save_file + + +def remap_transformer_keys(state_dict): + """Remap diffusers transformer keys to ComfyUI CogVideoX naming.""" + new_sd = {} + for k, v in state_dict.items(): + new_key = k + + # Patch embedding + new_key = new_key.replace("patch_embed.proj.", "patch_embed.proj.") + new_key = new_key.replace("patch_embed.text_proj.", "patch_embed.text_proj.") + new_key = new_key.replace("patch_embed.pos_embedding", "patch_embed.pos_embedding") + + # Time embedding: diffusers uses time_embedding.linear_1/2, we use time_embedding_linear_1/2 + new_key = new_key.replace("time_embedding.linear_1.", "time_embedding_linear_1.") + new_key = new_key.replace("time_embedding.linear_2.", "time_embedding_linear_2.") + + # OFS embedding + new_key = new_key.replace("ofs_embedding.linear_1.", "ofs_embedding_linear_1.") + new_key = new_key.replace("ofs_embedding.linear_2.", "ofs_embedding_linear_2.") + + # Transformer blocks: diffusers uses transformer_blocks, we use blocks + new_key = new_key.replace("transformer_blocks.", "blocks.") + + # Attention: diffusers uses attn1.to_q/k/v/out, we use q/k/v/attn_out + new_key = new_key.replace(".attn1.to_q.", ".q.") + new_key = new_key.replace(".attn1.to_k.", ".k.") + new_key = new_key.replace(".attn1.to_v.", ".v.") + new_key = new_key.replace(".attn1.to_out.0.", ".attn_out.") + new_key = new_key.replace(".attn1.norm_q.", ".norm_q.") + new_key = new_key.replace(".attn1.norm_k.", ".norm_k.") + + # Feed-forward: diffusers uses ff.net.0.proj/ff.net.2, we use ff_proj/ff_out + new_key = new_key.replace(".ff.net.0.proj.", ".ff_proj.") + new_key = new_key.replace(".ff.net.2.", ".ff_out.") + + # Output norms + new_key = new_key.replace("norm_final.", "norm_final.") + new_key = new_key.replace("norm_out.linear.", "norm_out.linear.") + new_key = new_key.replace("norm_out.norm.", "norm_out.norm.") + + new_sd[new_key] = v + + return new_sd + + +def remap_vae_keys(state_dict): + """Remap diffusers VAE keys to ComfyUI CogVideoX naming. + + The VAE architecture is directly ported so most keys should match. + Main differences are in block naming. + """ + new_sd = {} + for k, v in state_dict.items(): + new_key = k + + # Encoder blocks + new_key = new_key.replace("encoder.down_blocks.", "encoder.down_blocks.") + new_key = new_key.replace("encoder.mid_block.", "encoder.mid_block.") + + # Decoder blocks + new_key = new_key.replace("decoder.up_blocks.", "decoder.up_blocks.") + new_key = new_key.replace("decoder.mid_block.", "decoder.mid_block.") + + # Resnet blocks within down/up/mid + new_key = new_key.replace(".resnets.", ".resnets.") + + # CausalConv3d: diffusers stores as .conv.weight inside CausalConv3d + # Our CausalConv3d also has .conv.weight, so this should match + + # Downsamplers/Upsamplers + new_key = new_key.replace(".downsamplers.0.", ".downsamplers.0.") + new_key = new_key.replace(".upsamplers.0.", ".upsamplers.0.") + + new_sd[new_key] = v + + return new_sd + + +def main(): + parser = argparse.ArgumentParser(description="Convert SparkVSR/CogVideoX to ComfyUI format") + parser.add_argument("--model_dir", type=str, required=True, + help="Path to diffusers pipeline directory (contains transformer/, vae/, etc.)") + parser.add_argument("--output_dir", type=str, default=".", + help="Output base directory (will create diffusion_models/ and vae/ subdirs)") + args = parser.parse_args() + + # Load transformer + transformer_dir = os.path.join(args.model_dir, "transformer") + print(f"Loading transformer from {transformer_dir}...") + transformer_sd = {} + for f in sorted(os.listdir(transformer_dir)): + if f.endswith(".safetensors"): + sd = load_file(os.path.join(transformer_dir, f)) + transformer_sd.update(sd) + + transformer_sd = remap_transformer_keys(transformer_sd) + + out_dir = os.path.join(args.output_dir, "diffusion_models") + os.makedirs(out_dir, exist_ok=True) + out_path = os.path.join(out_dir, "cogvideox_sparkvsr.safetensors") + print(f"Saving transformer to {out_path} ({len(transformer_sd)} keys)") + save_file(transformer_sd, out_path) + + # Load VAE + vae_dir = os.path.join(args.model_dir, "vae") + print(f"Loading VAE from {vae_dir}...") + vae_sd = {} + for f in sorted(os.listdir(vae_dir)): + if f.endswith(".safetensors"): + sd = load_file(os.path.join(vae_dir, f)) + vae_sd.update(sd) + + vae_sd = remap_vae_keys(vae_sd) + + out_dir = os.path.join(args.output_dir, "vae") + os.makedirs(out_dir, exist_ok=True) + out_path = os.path.join(out_dir, "cogvideox_vae.safetensors") + print(f"Saving VAE to {out_path} ({len(vae_sd)} keys)") + save_file(vae_sd, out_path) + + print("Done! T5-XXL text encoder does not need conversion.") + + +if __name__ == "__main__": + main() diff --git a/nodes.py b/nodes.py index 299b3d758..f90cee732 100644 --- a/nodes.py +++ b/nodes.py @@ -2457,7 +2457,8 @@ async def init_builtin_extra_nodes(): "nodes_number_convert.py", "nodes_painter.py", "nodes_curve.py", - "nodes_rtdetr.py" + "nodes_rtdetr.py", + "nodes_cogvideox.py", ] import_failed = []