From 1112d597c8c3c3f7d916e663c78fee846761ebbc Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Wed, 8 Apr 2026 12:04:51 +0300 Subject: [PATCH 01/15] Initial CogVideoX and SparkVSR support --- comfy/latent_formats.py | 7 + comfy/ldm/cogvideo/__init__.py | 0 comfy/ldm/cogvideo/model.py | 571 +++++++++++++++++++++++++++++++ comfy/ldm/cogvideo/vae.py | 570 ++++++++++++++++++++++++++++++ comfy/ldm/cogvideo/vae_backup.py | 485 ++++++++++++++++++++++++++ comfy/model_base.py | 42 +++ comfy/model_detection.py | 49 +++ comfy/model_sampling.py | 24 ++ comfy/sd.py | 12 + comfy/supported_models.py | 51 ++- comfy_extras/nodes_cogvideox.py | 137 ++++++++ convert_sparkvsr_to_comfy.py | 144 ++++++++ nodes.py | 3 +- 13 files changed, 2093 insertions(+), 2 deletions(-) create mode 100644 comfy/ldm/cogvideo/__init__.py create mode 100644 comfy/ldm/cogvideo/model.py create mode 100644 comfy/ldm/cogvideo/vae.py create mode 100644 comfy/ldm/cogvideo/vae_backup.py create mode 100644 comfy_extras/nodes_cogvideox.py create mode 100644 convert_sparkvsr_to_comfy.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 6a57bca1c..0f4059ebe 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -783,3 +783,10 @@ class ZImagePixelSpace(ChromaRadiance): No VAE encoding/decoding — the model operates directly on RGB pixels. """ pass + +class CogVideoX(LatentFormat): + latent_channels = 16 + latent_dimensions = 3 + + def __init__(self): + self.scale_factor = 1.15258426 diff --git a/comfy/ldm/cogvideo/__init__.py b/comfy/ldm/cogvideo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/ldm/cogvideo/model.py b/comfy/ldm/cogvideo/model.py new file mode 100644 index 000000000..a4e737d41 --- /dev/null +++ b/comfy/ldm/cogvideo/model.py @@ -0,0 +1,571 @@ +# CogVideoX 3D Transformer - ported to ComfyUI native ops +# Architecture reference: diffusers CogVideoXTransformer3DModel +# Style reference: comfy/ldm/wan/model.py + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention +import comfy.patcher_extension +import comfy.ldm.common_dit + + +def _get_1d_rotary_pos_embed(dim, pos, theta=10000.0): + """Returns (cos, sin) each with shape [seq_len, dim]. + + Frequencies are computed at dim//2 resolution then repeat_interleaved + to full dim, matching CogVideoX's interleaved (real, imag) pair format. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim)) + angles = torch.outer(pos.float(), freqs.float()) + cos = angles.cos().repeat_interleave(2, dim=-1).float() + sin = angles.sin().repeat_interleave(2, dim=-1).float() + return (cos, sin) + + +def apply_rotary_emb(x, freqs_cos_sin): + """Apply CogVideoX rotary embedding to query or key tensor. + + x: [B, heads, seq_len, head_dim] + freqs_cos_sin: (cos, sin) each [seq_len, head_dim//2] + + Uses interleaved pair rotation (same as diffusers CogVideoX/Flux). + head_dim is reshaped to (-1, 2) pairs, rotated, then flattened back. + """ + cos, sin = freqs_cos_sin + cos = cos[None, None, :, :].to(x.device) + sin = sin[None, None, :, :].to(x.device) + + # Interleaved pairs: [B, H, S, D] -> [B, H, S, D//2, 2] -> (real, imag) + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + return (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + +def get_timestep_embedding(timesteps, dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1, max_period=10000): + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half) + args = timesteps[:, None].float() * freqs[None] * scale + embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) + if flip_sin_to_cos: + embedding = torch.cat([embedding[:, half:], embedding[:, :half]], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def get_3d_sincos_pos_embed(embed_dim, spatial_size, temporal_size, spatial_interpolation_scale=1.0, temporal_interpolation_scale=1.0, device=None): + if isinstance(spatial_size, int): + spatial_size = (spatial_size, spatial_size) + + grid_w = torch.arange(spatial_size[0], dtype=torch.float32, device=device) / spatial_interpolation_scale + grid_h = torch.arange(spatial_size[1], dtype=torch.float32, device=device) / spatial_interpolation_scale + grid_t = torch.arange(temporal_size, dtype=torch.float32, device=device) / temporal_interpolation_scale + + grid_t, grid_h, grid_w = torch.meshgrid(grid_t, grid_h, grid_w, indexing="ij") + + embed_dim_spatial = 2 * (embed_dim // 3) + embed_dim_temporal = embed_dim // 3 + + pos_embed_spatial = _get_2d_sincos_pos_embed(embed_dim_spatial, grid_h, grid_w, device=device) + pos_embed_temporal = _get_1d_sincos_pos_embed(embed_dim_temporal, grid_t[:, 0, 0], device=device) + + T, H, W = grid_t.shape + pos_embed_temporal = pos_embed_temporal.unsqueeze(1).unsqueeze(1).expand(-1, H, W, -1) + pos_embed = torch.cat([pos_embed_temporal, pos_embed_spatial], dim=-1) + + return pos_embed + + +def _get_2d_sincos_pos_embed(embed_dim, grid_h, grid_w, device=None): + T, H, W = grid_h.shape + half_dim = embed_dim // 2 + pos_h = _get_1d_sincos_pos_embed(half_dim, grid_h.reshape(-1), device=device).reshape(T, H, W, half_dim) + pos_w = _get_1d_sincos_pos_embed(half_dim, grid_w.reshape(-1), device=device).reshape(T, H, W, half_dim) + return torch.cat([pos_h, pos_w], dim=-1) + + +def _get_1d_sincos_pos_embed(embed_dim, pos, device=None): + half = embed_dim // 2 + freqs = torch.exp(-math.log(10000.0) * torch.arange(start=0, end=half, dtype=torch.float32, device=device) / half) + args = pos.float().reshape(-1)[:, None] * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if embed_dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + + +class CogVideoXPatchEmbed(nn.Module): + def __init__(self, patch_size=2, patch_size_t=None, in_channels=16, dim=1920, + text_dim=4096, bias=True, sample_width=90, sample_height=60, + sample_frames=49, temporal_compression_ratio=4, + max_text_seq_length=226, spatial_interpolation_scale=1.875, + temporal_interpolation_scale=1.0, use_positional_embeddings=True, + use_learned_positional_embeddings=True, + device=None, dtype=None, operations=None): + super().__init__() + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.dim = dim + self.sample_height = sample_height + self.sample_width = sample_width + self.sample_frames = sample_frames + self.temporal_compression_ratio = temporal_compression_ratio + self.max_text_seq_length = max_text_seq_length + self.spatial_interpolation_scale = spatial_interpolation_scale + self.temporal_interpolation_scale = temporal_interpolation_scale + self.use_positional_embeddings = use_positional_embeddings + self.use_learned_positional_embeddings = use_learned_positional_embeddings + + if patch_size_t is None: + self.proj = operations.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size, bias=bias, device=device, dtype=dtype) + else: + self.proj = operations.Linear(in_channels * patch_size * patch_size * patch_size_t, dim, device=device, dtype=dtype) + + self.text_proj = operations.Linear(text_dim, dim, device=device, dtype=dtype) + + if use_positional_embeddings or use_learned_positional_embeddings: + persistent = use_learned_positional_embeddings + pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) + self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) + + def _get_positional_embeddings(self, sample_height, sample_width, sample_frames, device=None): + post_patch_height = sample_height // self.patch_size + post_patch_width = sample_width // self.patch_size + post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 + if self.patch_size_t is not None: + post_time_compression_frames = post_time_compression_frames // self.patch_size_t + num_patches = post_patch_height * post_patch_width * post_time_compression_frames + + pos_embedding = get_3d_sincos_pos_embed( + self.dim, + (post_patch_width, post_patch_height), + post_time_compression_frames, + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + device=device, + ) + pos_embedding = pos_embedding.reshape(-1, self.dim) + joint_pos_embedding = pos_embedding.new_zeros( + 1, self.max_text_seq_length + num_patches, self.dim, requires_grad=False + ) + joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding) + return joint_pos_embedding + + def forward(self, text_embeds, image_embeds): + text_embeds = self.text_proj(text_embeds) + batch_size, num_frames, channels, height, width = image_embeds.shape + + if self.patch_size_t is None: + image_embeds = image_embeds.reshape(-1, channels, height, width) + image_embeds = self.proj(image_embeds) + image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:]) + image_embeds = image_embeds.flatten(3).transpose(2, 3) + image_embeds = image_embeds.flatten(1, 2) + else: + p = self.patch_size + p_t = self.patch_size_t + image_embeds = image_embeds.permute(0, 1, 3, 4, 2) + image_embeds = image_embeds.reshape( + batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels + ) + image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) + image_embeds = self.proj(image_embeds) + + embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous() + + if self.use_positional_embeddings or self.use_learned_positional_embeddings: + text_seq_length = text_embeds.shape[1] + num_image_patches = image_embeds.shape[1] + + # Compute sincos pos embedding for image patches + pos_embedding = get_3d_sincos_pos_embed( + self.dim, + (width // self.patch_size, height // self.patch_size), + num_image_patches // ((height // self.patch_size) * (width // self.patch_size)), + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + device=embeds.device, + ).reshape(-1, self.dim) + + # Build joint: zeros for text + sincos for image + joint_pos = torch.zeros(1, text_seq_length + num_image_patches, self.dim, device=embeds.device, dtype=embeds.dtype) + joint_pos[:, text_seq_length:] = pos_embedding.to(dtype=embeds.dtype) + embeds = embeds + joint_pos + + return embeds + + +class CogVideoXLayerNormZero(nn.Module): + def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5, bias=True, + device=None, dtype=None, operations=None): + super().__init__() + self.silu = nn.SiLU() + self.linear = operations.Linear(time_dim, 6 * dim, bias=bias, device=device, dtype=dtype) + self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) + + def forward(self, hidden_states, encoder_hidden_states, temb): + shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] + return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] + + +class CogVideoXAdaLayerNorm(nn.Module): + def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5, + device=None, dtype=None, operations=None): + super().__init__() + self.silu = nn.SiLU() + self.linear = operations.Linear(time_dim, 2 * dim, device=device, dtype=dtype) + self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) + + def forward(self, x, temb): + temb = self.linear(self.silu(temb)) + shift, scale = temb.chunk(2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +class CogVideoXBlock(nn.Module): + def __init__(self, dim, num_heads, head_dim, time_dim, + eps=1e-5, ff_inner_dim=None, ff_bias=True, + device=None, dtype=None, operations=None): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.norm1 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations) + + # Self-attention (joint text + latent) + self.q = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) + self.k = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) + self.v = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) + self.norm_q = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype) + self.norm_k = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype) + self.attn_out = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) + + self.norm2 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations) + + # Feed-forward (GELU approximate) + inner_dim = ff_inner_dim or dim * 4 + self.ff_proj = operations.Linear(dim, inner_dim, bias=ff_bias, device=device, dtype=dtype) + self.ff_out = operations.Linear(inner_dim, dim, bias=ff_bias, device=device, dtype=dtype) + + def forward(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=None, transformer_options={}): + text_seq_length = encoder_hidden_states.size(1) + + # Norm & modulate + norm_hidden, norm_encoder, gate_msa, enc_gate_msa = self.norm1(hidden_states, encoder_hidden_states, temb) + + # Joint self-attention + qkv_input = torch.cat([norm_encoder, norm_hidden], dim=1) + b, s, _ = qkv_input.shape + n, d = self.num_heads, self.head_dim + + q = self.q(qkv_input).view(b, s, n, d) + k = self.k(qkv_input).view(b, s, n, d) + v = self.v(qkv_input) + + q = self.norm_q(q).view(b, s, n, d) + k = self.norm_k(k).view(b, s, n, d) + + # Apply rotary embeddings to image tokens only (diffusers format: [B, heads, seq, head_dim]) + if image_rotary_emb is not None: + q_img = q[:, text_seq_length:].transpose(1, 2) # [B, heads, img_seq, head_dim] + k_img = k[:, text_seq_length:].transpose(1, 2) + q_img = apply_rotary_emb(q_img, image_rotary_emb) + k_img = apply_rotary_emb(k_img, image_rotary_emb) + q = torch.cat([q[:, :text_seq_length], q_img.transpose(1, 2)], dim=1) + k = torch.cat([k[:, :text_seq_length], k_img.transpose(1, 2)], dim=1) + + attn_out = optimized_attention( + q.reshape(b, s, n * d), + k.reshape(b, s, n * d), + v, + heads=self.num_heads, + transformer_options=transformer_options, + ) + + attn_out = self.attn_out(attn_out) + + attn_encoder, attn_hidden = attn_out.split([text_seq_length, s - text_seq_length], dim=1) + + hidden_states = hidden_states + gate_msa * attn_hidden + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder + + # Norm & modulate for FF + norm_hidden, norm_encoder, gate_ff, enc_gate_ff = self.norm2(hidden_states, encoder_hidden_states, temb) + + # Feed-forward (GELU on concatenated text + latent) + ff_input = torch.cat([norm_encoder, norm_hidden], dim=1) + ff_output = self.ff_out(F.gelu(self.ff_proj(ff_input), approximate="tanh")) + + hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] + + return hidden_states, encoder_hidden_states + + +class CogVideoXTransformer3DModel(nn.Module): + def __init__(self, + num_attention_heads=30, + attention_head_dim=64, + in_channels=16, + out_channels=16, + flip_sin_to_cos=True, + freq_shift=0, + time_embed_dim=512, + ofs_embed_dim=None, + text_embed_dim=4096, + num_layers=30, + dropout=0.0, + attention_bias=True, + sample_width=90, + sample_height=60, + sample_frames=49, + patch_size=2, + patch_size_t=None, + temporal_compression_ratio=4, + max_text_seq_length=226, + spatial_interpolation_scale=1.875, + temporal_interpolation_scale=1.0, + use_rotary_positional_embeddings=False, + use_learned_positional_embeddings=False, + patch_bias=True, + image_model=None, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + dim = num_attention_heads * attention_head_dim + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.max_text_seq_length = max_text_seq_length + self.use_rotary_positional_embeddings = use_rotary_positional_embeddings + + # 1. Patch embedding + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + patch_size_t=patch_size_t, + in_channels=in_channels, + dim=dim, + text_dim=text_embed_dim, + bias=patch_bias, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings, + device=device, dtype=torch.float32, operations=operations, + ) + + # 2. Time embedding + self.time_proj_dim = dim + self.time_proj_flip = flip_sin_to_cos + self.time_proj_shift = freq_shift + self.time_embedding_linear_1 = operations.Linear(dim, time_embed_dim, device=device, dtype=dtype) + self.time_embedding_act = nn.SiLU() + self.time_embedding_linear_2 = operations.Linear(time_embed_dim, time_embed_dim, device=device, dtype=dtype) + + # Optional OFS embedding (CogVideoX 1.5 I2V) + self.ofs_proj_dim = ofs_embed_dim + if ofs_embed_dim: + self.ofs_embedding_linear_1 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype) + self.ofs_embedding_act = nn.SiLU() + self.ofs_embedding_linear_2 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype) + else: + self.ofs_embedding_linear_1 = None + + # 3. Transformer blocks + self.blocks = nn.ModuleList([ + CogVideoXBlock( + dim=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + time_dim=time_embed_dim, + eps=1e-5, + device=device, dtype=dtype, operations=operations, + ) + for _ in range(num_layers) + ]) + + self.norm_final = operations.LayerNorm(dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype) + + # 4. Output + self.norm_out = CogVideoXAdaLayerNorm( + time_dim=time_embed_dim, dim=dim, eps=1e-5, + device=device, dtype=dtype, operations=operations, + ) + + if patch_size_t is None: + output_dim = patch_size * patch_size * out_channels + else: + output_dim = patch_size * patch_size * patch_size_t * out_channels + + self.proj_out = operations.Linear(dim, output_dim, device=device, dtype=dtype) + + self.spatial_interpolation_scale = spatial_interpolation_scale + self.temporal_interpolation_scale = temporal_interpolation_scale + self.temporal_compression_ratio = temporal_compression_ratio + + def forward(self, x, timestep, context, ofs=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, ofs, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, ofs=None, transformer_options={}, **kwargs): + import logging + logger = logging.getLogger(__name__) + if x.shape[1] > 16: + lq_part = x[:, :16] + ref_part = x[:, 16:] + logger.warning(f"[CogVideoX] x: {x.shape}, t: {timestep.item():.0f}, ofs: {ofs}, LQ: mean={lq_part.float().mean():.4f} std={lq_part.float().std():.4f}, REF: mean={ref_part.float().mean():.4f} std={ref_part.float().std():.4f} nonzero={ref_part.count_nonzero().item()}") + else: + logger.warning(f"[CogVideoX] x: {x.shape}, t: {timestep.item():.0f}, ofs: {ofs}") + + # ComfyUI passes [B, C, T, H, W] + batch_size, channels, t, h, w = x.shape + + # Pad to patch size (temporal + spatial), same pattern as WAN + p_t = self.patch_size_t if self.patch_size_t is not None else 1 + x = comfy.ldm.common_dit.pad_to_patch_size(x, (p_t, self.patch_size, self.patch_size)) + + # CogVideoX expects [B, T, C, H, W] + x = x.permute(0, 2, 1, 3, 4) + batch_size, num_frames, channels, height, width = x.shape + + # Time embedding + t_emb = get_timestep_embedding(timestep, self.time_proj_dim, self.time_proj_flip, self.time_proj_shift) + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embedding_linear_2(self.time_embedding_act(self.time_embedding_linear_1(t_emb))) + + if self.ofs_embedding_linear_1 is not None and ofs is not None: + ofs_emb = get_timestep_embedding(ofs, self.ofs_proj_dim, self.time_proj_flip, self.time_proj_shift) + ofs_emb = ofs_emb.to(dtype=x.dtype) + ofs_emb = self.ofs_embedding_linear_2(self.ofs_embedding_act(self.ofs_embedding_linear_1(ofs_emb))) + emb = emb + ofs_emb + + # Patch embedding + hidden_states = self.patch_embed(context, x) + + text_seq_length = context.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + # Rotary embeddings (if used) + image_rotary_emb = None + if self.use_rotary_positional_embeddings: + post_patch_height = height // self.patch_size + post_patch_width = width // self.patch_size + if self.patch_size_t is None: + post_time = num_frames + else: + post_time = num_frames // self.patch_size_t + image_rotary_emb = self._get_rotary_emb(post_patch_height, post_patch_width, post_time, device=x.device) + + # Transformer blocks + for i, block in enumerate(self.blocks): + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, + ) + + hidden_states = self.norm_final(hidden_states) + + # Output projection + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # Unpatchify + p = self.patch_size + p_t = self.patch_size_t + + if p_t is None: + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + else: + output = hidden_states.reshape( + batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p + ) + output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) + + # Back to ComfyUI format [B, C, T, H, W] and crop padding + output = output.permute(0, 2, 1, 3, 4)[:, :, :t, :h, :w] + logger.warning(f"[CogVideoX] output: {output.shape}, mean={output.float().mean():.4f}, std={output.float().std():.4f}, min={output.float().min():.4f}, max={output.float().max():.4f}") + return output + + def _get_rotary_emb(self, h, w, t, device): + """Compute CogVideoX 3D rotary positional embeddings. + + For CogVideoX 1.5 (patch_size_t != None): uses "slice" mode — grid positions + are integer arange computed at max_size, then sliced to actual size. + For CogVideoX 1.0 (patch_size_t == None): uses "linspace" mode with crop coords + scaled by spatial_interpolation_scale. + """ + d = self.attention_head_dim + dim_t = d // 4 + dim_h = d // 8 * 3 + dim_w = d // 8 * 3 + + if self.patch_size_t is not None: + # CogVideoX 1.5: "slice" mode — positions are simple integer indices + # Compute at max(sample_size, actual_size) then slice to actual + base_h = self.patch_embed.sample_height // self.patch_size + base_w = self.patch_embed.sample_width // self.patch_size + max_h = max(base_h, h) + max_w = max(base_w, w) + + grid_h = torch.arange(max_h, device=device, dtype=torch.float32) + grid_w = torch.arange(max_w, device=device, dtype=torch.float32) + grid_t = torch.arange(t, device=device, dtype=torch.float32) + else: + # CogVideoX 1.0: "linspace" mode with interpolation scale + grid_h = torch.linspace(0, h - 1, h, device=device, dtype=torch.float32) * self.spatial_interpolation_scale + grid_w = torch.linspace(0, w - 1, w, device=device, dtype=torch.float32) * self.spatial_interpolation_scale + grid_t = torch.arange(t, device=device, dtype=torch.float32) + + freqs_t = _get_1d_rotary_pos_embed(dim_t, grid_t) + freqs_h = _get_1d_rotary_pos_embed(dim_h, grid_h) + freqs_w = _get_1d_rotary_pos_embed(dim_w, grid_w) + + t_cos, t_sin = freqs_t + h_cos, h_sin = freqs_h + w_cos, w_sin = freqs_w + + # Slice to actual size (for "slice" mode where grids may be larger) + t_cos, t_sin = t_cos[:t], t_sin[:t] + h_cos, h_sin = h_cos[:h], h_sin[:h] + w_cos, w_sin = w_cos[:w], w_sin[:w] + + # Broadcast and concatenate into [T*H*W, head_dim] + t_cos = t_cos[:, None, None, :].expand(-1, h, w, -1) + t_sin = t_sin[:, None, None, :].expand(-1, h, w, -1) + h_cos = h_cos[None, :, None, :].expand(t, -1, w, -1) + h_sin = h_sin[None, :, None, :].expand(t, -1, w, -1) + w_cos = w_cos[None, None, :, :].expand(t, h, -1, -1) + w_sin = w_sin[None, None, :, :].expand(t, h, -1, -1) + + cos = torch.cat([t_cos, h_cos, w_cos], dim=-1).reshape(t * h * w, -1) + sin = torch.cat([t_sin, h_sin, w_sin], dim=-1).reshape(t * h * w, -1) + return (cos, sin) diff --git a/comfy/ldm/cogvideo/vae.py b/comfy/ldm/cogvideo/vae.py new file mode 100644 index 000000000..fe97f362c --- /dev/null +++ b/comfy/ldm/cogvideo/vae.py @@ -0,0 +1,570 @@ +# CogVideoX VAE - ported to ComfyUI native ops +# Architecture reference: diffusers AutoencoderKLCogVideoX +# Style reference: comfy/ldm/wan/vae.py + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops +ops = comfy.ops.disable_weight_init + + +class CausalConv3d(nn.Module): + """Causal 3D convolution with temporal padding. + + Uses comfy.ops.Conv3d with autopad='causal_zero' fast path: when input has + a single temporal frame and no cache, the 3D conv weight is sliced to act + as a 2D conv, avoiding computation on zero-padded temporal dimensions. + """ + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, pad_mode="constant"): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + + time_kernel, height_kernel, width_kernel = kernel_size + self.time_kernel_size = time_kernel + self.pad_mode = pad_mode + + height_pad = (height_kernel - 1) // 2 + width_pad = (width_kernel - 1) // 2 + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_kernel - 1, 0) + + stride = stride if isinstance(stride, tuple) else (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = ops.Conv3d( + in_channels, out_channels, kernel_size, + stride=stride, dilation=dilation, + padding=(0, height_pad, width_pad), + ) + + def forward(self, x, conv_cache=None): + if self.pad_mode == "replicate": + x = F.pad(x, self.time_causal_padding, mode="replicate") + conv_cache = None + else: + kernel_t = self.time_kernel_size + if kernel_t > 1: + if conv_cache is None and x.shape[2] == 1: + # Fast path: single frame, no cache. All temporal padding + # frames are copies of the input (replicate-style), so the + # 3D conv reduces to a 2D conv with summed temporal kernel. + w = comfy.ops.cast_to_input(self.conv.weight, x) + b = comfy.ops.cast_to_input(self.conv.bias, x) if self.conv.bias is not None else None + w2d = w.sum(dim=2, keepdim=True) + out = F.conv3d(x, w2d, b, + self.conv.stride, self.conv.padding, + self.conv.dilation, self.conv.groups) + return out, None + cached = [conv_cache] if conv_cache is not None else [x[:, :, :1]] * (kernel_t - 1) + x = torch.cat(cached + [x], dim=2) + conv_cache = x[:, :, -self.time_kernel_size + 1:].clone() if self.time_kernel_size > 1 else None + + out = self.conv(x) + return out, conv_cache + + +def _interpolate_zq(zq, target_size): + """Interpolate latent z to target (T, H, W), matching CogVideoX's first-frame-special handling.""" + t = target_size[0] + if t > 1 and t % 2 == 1: + z_first = F.interpolate(zq[:, :, :1], size=(1, target_size[1], target_size[2])) + z_rest = F.interpolate(zq[:, :, 1:], size=(t - 1, target_size[1], target_size[2])) + return torch.cat([z_first, z_rest], dim=2) + return F.interpolate(zq, size=target_size) + + +class SpatialNorm3D(nn.Module): + """Spatially conditioned normalization.""" + def __init__(self, f_channels, zq_channels, groups=32): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) + self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + + def forward(self, f, zq, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + + if zq.shape[-3:] != f.shape[-3:]: + zq = _interpolate_zq(zq, f.shape[-3:]) + + conv_y, new_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y")) + conv_b, new_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b")) + + return self.norm_layer(f) * conv_y + conv_b, new_cache + + +class ResnetBlock3D(nn.Module): + """3D ResNet block with optional spatial norm.""" + def __init__(self, in_channels, out_channels=None, temb_channels=512, groups=32, + eps=1e-6, act_fn="silu", spatial_norm_dim=None, pad_mode="first"): + super().__init__() + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.spatial_norm_dim = spatial_norm_dim + + if act_fn == "silu": + self.nonlinearity = nn.SiLU() + elif act_fn == "swish": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = nn.SiLU() + + if spatial_norm_dim is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) + else: + self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups) + self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups) + + self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + + if temb_channels > 0: + self.temb_proj = nn.Linear(temb_channels, out_channels) + + self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + + if in_channels != out_channels: + self.conv_shortcut = ops.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + else: + self.conv_shortcut = None + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + residual = x + + if zq is not None: + x, new_cache["norm1"] = self.norm1(x, zq, conv_cache=conv_cache.get("norm1")) + else: + x = self.norm1(x) + + x = self.nonlinearity(x) + x, new_cache["conv1"] = self.conv1(x, conv_cache=conv_cache.get("conv1")) + + if temb is not None and hasattr(self, "temb_proj"): + x = x + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if zq is not None: + x, new_cache["norm2"] = self.norm2(x, zq, conv_cache=conv_cache.get("norm2")) + else: + x = self.norm2(x) + + x = self.nonlinearity(x) + x, new_cache["conv2"] = self.conv2(x, conv_cache=conv_cache.get("conv2")) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return x + residual, new_cache + + +class Downsample3D(nn.Module): + """3D downsampling with optional temporal compression.""" + def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time: + b, c, t, h, w = x.shape + x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t) + if t % 2 == 1: + x_first, x_rest = x[..., 0], x[..., 1:] + if x_rest.shape[-1] > 0: + x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) + x = torch.cat([x_first[..., None], x_rest], dim=-1) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + else: + x = F.avg_pool1d(x, kernel_size=2, stride=2) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + return x + + +class Upsample3D(nn.Module): + """3D upsampling with optional temporal decompression.""" + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time: + if x.shape[2] > 1 and x.shape[2] % 2 == 1: + x_first, x_rest = x[:, :, 0], x[:, :, 1:] + x_first = F.interpolate(x_first, scale_factor=2.0) + x_rest = F.interpolate(x_rest, scale_factor=2.0) + x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) + elif x.shape[2] > 1: + x = F.interpolate(x, scale_factor=2.0) + else: + x = x.squeeze(2) + x = F.interpolate(x, scale_factor=2.0) + x = x[:, :, None, :, :] + else: + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = F.interpolate(x, scale_factor=2.0) + x = x.reshape(b, t, c, *x.shape[2:]).permute(0, 2, 1, 3, 4) + + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, *x.shape[1:]).permute(0, 2, 1, 3, 4) + return x + + +class DownBlock3D(nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1, + eps=1e-6, act_fn="silu", groups=32, add_downsample=True, + compress_time=False, pad_mode="first"): + super().__init__() + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + groups=groups, eps=eps, act_fn=act_fn, pad_mode=pad_mode, + ) + for i in range(num_layers) + ]) + self.downsamplers = nn.ModuleList([Downsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_downsample else None + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + for i, resnet in enumerate(self.resnets): + x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) + if self.downsamplers is not None: + for ds in self.downsamplers: + x = ds(x) + return x, new_cache + + +class MidBlock3D(nn.Module): + def __init__(self, in_channels, temb_channels=0, num_layers=1, + eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=None, pad_mode="first"): + super().__init__() + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels, out_channels=in_channels, + temb_channels=temb_channels, groups=groups, eps=eps, + act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode, + ) + for _ in range(num_layers) + ]) + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + for i, resnet in enumerate(self.resnets): + x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) + return x, new_cache + + +class UpBlock3D(nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1, + eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=16, + add_upsample=True, compress_time=False, pad_mode="first"): + super().__init__() + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + temb_channels=temb_channels, groups=groups, eps=eps, + act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode, + ) + for i in range(num_layers) + ]) + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_upsample else None + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + for i, resnet in enumerate(self.resnets): + x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) + if self.upsamplers is not None: + for us in self.upsamplers: + x = us(x) + return x, new_cache + + +class Encoder3D(nn.Module): + def __init__(self, in_channels=3, out_channels=16, + block_out_channels=(128, 256, 256, 512), + layers_per_block=3, act_fn="silu", + eps=1e-6, groups=32, pad_mode="first", + temporal_compression_ratio=4): + super().__init__() + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) + + self.down_blocks = nn.ModuleList() + output_channel = block_out_channels[0] + for i in range(len(block_out_channels)): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + self.down_blocks.append(DownBlock3D( + in_channels=input_channel, out_channels=output_channel, + temb_channels=0, num_layers=layers_per_block, + eps=eps, act_fn=act_fn, groups=groups, + add_downsample=not is_final, compress_time=compress_time, + )) + + self.mid_block = MidBlock3D( + in_channels=block_out_channels[-1], temb_channels=0, + num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode, + ) + + self.norm_out = nn.GroupNorm(groups, block_out_channels[-1], eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode) + + def forward(self, x, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + + x, new_cache["conv_in"] = self.conv_in(x, conv_cache=conv_cache.get("conv_in")) + + for i, block in enumerate(self.down_blocks): + key = f"down_block_{i}" + x, new_cache[key] = block(x, None, None, conv_cache.get(key)) + + x, new_cache["mid_block"] = self.mid_block(x, None, None, conv_cache=conv_cache.get("mid_block")) + + x = self.norm_out(x) + x = self.conv_act(x) + x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out")) + + return x, new_cache + + +class Decoder3D(nn.Module): + def __init__(self, in_channels=16, out_channels=3, + block_out_channels=(128, 256, 256, 512), + layers_per_block=3, act_fn="silu", + eps=1e-6, groups=32, pad_mode="first", + temporal_compression_ratio=4): + super().__init__() + reversed_channels = list(reversed(block_out_channels)) + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + self.conv_in = CausalConv3d(in_channels, reversed_channels[0], kernel_size=3, pad_mode=pad_mode) + + self.mid_block = MidBlock3D( + in_channels=reversed_channels[0], temb_channels=0, + num_layers=2, eps=eps, act_fn=act_fn, groups=groups, + spatial_norm_dim=in_channels, pad_mode=pad_mode, + ) + + self.up_blocks = nn.ModuleList() + output_channel = reversed_channels[0] + for i in range(len(block_out_channels)): + prev_channel = output_channel + output_channel = reversed_channels[i] + is_final = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + self.up_blocks.append(UpBlock3D( + in_channels=prev_channel, out_channels=output_channel, + temb_channels=0, num_layers=layers_per_block + 1, + eps=eps, act_fn=act_fn, groups=groups, + spatial_norm_dim=in_channels, + add_upsample=not is_final, compress_time=compress_time, + )) + + self.norm_out = SpatialNorm3D(reversed_channels[-1], in_channels, groups=groups) + self.conv_act = nn.SiLU() + self.conv_out = CausalConv3d(reversed_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode) + + def forward(self, sample, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + + x, new_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) + + x, new_cache["mid_block"] = self.mid_block(x, None, sample, conv_cache=conv_cache.get("mid_block")) + + for i, block in enumerate(self.up_blocks): + key = f"up_block_{i}" + x, new_cache[key] = block(x, None, sample, conv_cache=conv_cache.get(key)) + + x, new_cache["norm_out"] = self.norm_out(x, sample, conv_cache=conv_cache.get("norm_out")) + x = self.conv_act(x) + x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out")) + + return x, new_cache + + + +class AutoencoderKLCogVideoX(nn.Module): + """CogVideoX VAE. Spatial tiling/slicing handled by ComfyUI's VAE wrapper. + + Uses rolling temporal decode: conv_in + mid_block + temporal up_blocks run + on the full (low-res) tensor, then the expensive spatial-only up_blocks + + norm_out + conv_out are processed in small temporal chunks with conv_cache + carrying causal state between chunks. This keeps peak VRAM proportional to + chunk_size rather than total frame count. + """ + + def __init__(self, + in_channels=3, out_channels=3, + block_out_channels=(128, 256, 256, 512), + latent_channels=16, layers_per_block=3, + act_fn="silu", eps=1e-6, groups=32, + temporal_compression_ratio=4, + ): + super().__init__() + self.latent_channels = latent_channels + self.temporal_compression_ratio = temporal_compression_ratio + + self.encoder = Encoder3D( + in_channels=in_channels, out_channels=latent_channels, + block_out_channels=block_out_channels, layers_per_block=layers_per_block, + act_fn=act_fn, eps=eps, groups=groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + self.decoder = Decoder3D( + in_channels=latent_channels, out_channels=out_channels, + block_out_channels=block_out_channels, layers_per_block=layers_per_block, + act_fn=act_fn, eps=eps, groups=groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + + self.num_latent_frames_batch_size = 2 + self.num_sample_frames_batch_size = 8 + + def encode(self, x): + t = x.shape[2] + frame_batch = self.num_sample_frames_batch_size + num_batches = max(t // frame_batch, 1) + conv_cache = None + enc = [] + for i in range(num_batches): + remaining = t % frame_batch + start = frame_batch * i + (0 if i == 0 else remaining) + end = frame_batch * (i + 1) + remaining + chunk, conv_cache = self.encoder(x[:, :, start:end], conv_cache=conv_cache) + enc.append(chunk.to(x.device)) + enc = torch.cat(enc, dim=2) + mean, _ = enc.chunk(2, dim=1) + return mean + + def decode(self, z): + return self._decode_rolling(z) + + def _decode_batched(self, z): + """Original batched decode - processes 2 latent frames through full decoder.""" + t = z.shape[2] + frame_batch = self.num_latent_frames_batch_size + num_batches = max(t // frame_batch, 1) + conv_cache = None + dec = [] + for i in range(num_batches): + remaining = t % frame_batch + start = frame_batch * i + (0 if i == 0 else remaining) + end = frame_batch * (i + 1) + remaining + chunk, conv_cache = self.decoder(z[:, :, start:end], conv_cache=conv_cache) + dec.append(chunk.cpu()) + return torch.cat(dec, dim=2).to(z.device) + + def _decode_rolling(self, z): + """Rolling decode - processes low-res layers on full tensor, then rolls + through expensive high-res layers in temporal chunks.""" + decoder = self.decoder + device = z.device + + # Determine which up_blocks have temporal upsample vs spatial-only. + # Temporal up_blocks are cheap (low res), spatial-only are expensive. + temporal_compress_level = int(np.log2(self.temporal_compression_ratio)) + split_at = temporal_compress_level # first N up_blocks do temporal upsample + + # Phase 1: conv_in + mid_block + temporal up_blocks on full tensor (low/medium res) + x, _ = decoder.conv_in(z) + x, _ = decoder.mid_block(x, None, z) + + for i in range(split_at): + x, _ = decoder.up_blocks[i](x, None, z) + + # Phase 2: remaining spatial-only up_blocks + norm_out + conv_out in temporal chunks + remaining_blocks = list(range(split_at, len(decoder.up_blocks))) + chunk_size = 4 # pixel frames per chunk through high-res layers + t_expanded = x.shape[2] + + if t_expanded <= chunk_size or len(remaining_blocks) == 0: + # Small enough to process in one go + for i in remaining_blocks: + x, _ = decoder.up_blocks[i](x, None, z) + x, _ = decoder.norm_out(x, z) + x = decoder.conv_act(x) + x, _ = decoder.conv_out(x) + return x + + # Pre-interpolate z to each spatial resolution used by Phase 2 blocks. + # Uses the exact same interpolation logic as SpatialNorm3D so chunked + # output is identical to non-chunked. + # Determine spatial sizes: run a dummy pass to find feature map sizes, + # or compute from block structure. Simpler: compute from x's current size + # and the known upsample factor (2x per block with upsample). + z_at_res = {} # keyed by (h, w) → pre-interpolated z [B, C, t_expanded, h, w] + h, w = x.shape[3], x.shape[4] + for i in remaining_blocks: + block = decoder.up_blocks[i] + # Resnets operate at current h, w + target = (t_expanded, h, w) + if target not in z_at_res: + z_at_res[target] = _interpolate_zq(z, target) + # If block has upsample, next block's input is 2x spatial + if block.upsamplers is not None: + h, w = h * 2, w * 2 + # norm_out operates at final resolution + target = (t_expanded, h, w) + if target not in z_at_res: + z_at_res[target] = _interpolate_zq(z, target) + + # Process in temporal chunks + dec_out = [] + conv_caches = {} + + for chunk_start in range(0, t_expanded, chunk_size): + chunk_end = min(chunk_start + chunk_size, t_expanded) + x_chunk = x[:, :, chunk_start:chunk_end] + + for i in remaining_blocks: + block = decoder.up_blocks[i] + cache_key = f"up_block_{i}" + # Get pre-interpolated z at the block's input spatial resolution + res_key = (t_expanded, x_chunk.shape[3], x_chunk.shape[4]) + z_chunk = z_at_res[res_key][:, :, chunk_start:chunk_end] + x_chunk, new_cache = block(x_chunk, None, z_chunk, conv_cache=conv_caches.get(cache_key)) + conv_caches[cache_key] = new_cache + + # norm_out at final resolution + res_key = (t_expanded, x_chunk.shape[3], x_chunk.shape[4]) + z_chunk = z_at_res[res_key][:, :, chunk_start:chunk_end] + x_chunk, new_cache = decoder.norm_out(x_chunk, z_chunk, conv_cache=conv_caches.get("norm_out")) + conv_caches["norm_out"] = new_cache + x_chunk = decoder.conv_act(x_chunk) + x_chunk, new_cache = decoder.conv_out(x_chunk, conv_cache=conv_caches.get("conv_out")) + conv_caches["conv_out"] = new_cache + + dec_out.append(x_chunk.cpu()) + + del x + return torch.cat(dec_out, dim=2).to(device) diff --git a/comfy/ldm/cogvideo/vae_backup.py b/comfy/ldm/cogvideo/vae_backup.py new file mode 100644 index 000000000..47254b672 --- /dev/null +++ b/comfy/ldm/cogvideo/vae_backup.py @@ -0,0 +1,485 @@ +# CogVideoX VAE - ported to ComfyUI native ops +# Architecture reference: diffusers AutoencoderKLCogVideoX +# Style reference: comfy/ldm/wan/vae.py + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops +ops = comfy.ops.disable_weight_init + + +class SafeConv3d(nn.Conv3d): + """3D convolution that splits large inputs along temporal dim to avoid OOM.""" + def forward(self, x): + mem = x.shape[0] * x.shape[1] * x.shape[2] * x.shape[3] * x.shape[4] * 2 / 1024**3 + if mem > 2 and x.shape[2] >= self.kernel_size[0]: + kernel_t = self.kernel_size[0] + parts = int(mem / 2) + 1 + # Ensure each chunk has at least kernel_t frames + max_parts = max(1, x.shape[2] // kernel_t) + parts = min(parts, max_parts) + if parts <= 1: + return super().forward(x) + chunks = torch.chunk(x, parts, dim=2) + if kernel_t > 1: + chunks = [chunks[0]] + [ + torch.cat((chunks[i - 1][:, :, -kernel_t + 1:], chunks[i]), dim=2) + for i in range(1, len(chunks)) + ] + out = [] + for chunk in chunks: + out.append(super().forward(chunk)) + return torch.cat(out, dim=2) + return super().forward(x) + + +class CausalConv3d(nn.Module): + """Causal 3D convolution with temporal padding.""" + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, pad_mode="constant"): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + + time_kernel, height_kernel, width_kernel = kernel_size + time_pad = time_kernel - 1 + height_pad = (height_kernel - 1) // 2 + width_pad = (width_kernel - 1) // 2 + + self.pad_mode = pad_mode + self.time_pad = time_pad + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) + self.const_padding = (0, width_pad, height_pad) + self.time_kernel_size = time_kernel + + stride = stride if isinstance(stride, tuple) else (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = SafeConv3d( + in_channels, out_channels, kernel_size, + stride=stride, dilation=dilation, + padding=0 if pad_mode == "replicate" else self.const_padding, + ) + + def forward(self, x, conv_cache=None): + if self.pad_mode == "replicate": + x = F.pad(x, self.time_causal_padding, mode="replicate") + conv_cache = None + else: + kernel_t = self.time_kernel_size + if kernel_t > 1: + cached = [conv_cache] if conv_cache is not None else [x[:, :, :1]] * (kernel_t - 1) + x = torch.cat(cached + [x], dim=2) + conv_cache = x[:, :, -self.time_kernel_size + 1:].clone() if self.time_kernel_size > 1 else None + + out = self.conv(x) + return out, conv_cache + + +class SpatialNorm3D(nn.Module): + """Spatially conditioned normalization.""" + def __init__(self, f_channels, zq_channels, groups=32): + super().__init__() + self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) + self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + + def forward(self, f, zq, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + + if f.shape[2] > 1 and f.shape[2] % 2 == 1: + f_first, f_rest = f[:, :, :1], f[:, :, 1:] + z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] + z_first = F.interpolate(z_first, size=f_first.shape[-3:]) + z_rest = F.interpolate(z_rest, size=f_rest.shape[-3:]) + zq = torch.cat([z_first, z_rest], dim=2) + else: + zq = F.interpolate(zq, size=f.shape[-3:]) + + conv_y, new_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y")) + conv_b, new_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b")) + + return self.norm_layer(f) * conv_y + conv_b, new_cache + + +class ResnetBlock3D(nn.Module): + """3D ResNet block with optional spatial norm.""" + def __init__(self, in_channels, out_channels=None, temb_channels=512, groups=32, + eps=1e-6, act_fn="silu", spatial_norm_dim=None, pad_mode="first"): + super().__init__() + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.spatial_norm_dim = spatial_norm_dim + + if act_fn == "silu": + self.nonlinearity = nn.SiLU() + elif act_fn == "swish": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = nn.SiLU() + + if spatial_norm_dim is None: + self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) + self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) + else: + self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups) + self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups) + + self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + + if temb_channels > 0: + self.temb_proj = nn.Linear(temb_channels, out_channels) + + self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + + if in_channels != out_channels: + self.conv_shortcut = SafeConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + else: + self.conv_shortcut = None + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + residual = x + + if zq is not None: + x, new_cache["norm1"] = self.norm1(x, zq, conv_cache=conv_cache.get("norm1")) + else: + x = self.norm1(x) + + x = self.nonlinearity(x) + x, new_cache["conv1"] = self.conv1(x, conv_cache=conv_cache.get("conv1")) + + if temb is not None and hasattr(self, "temb_proj"): + x = x + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if zq is not None: + x, new_cache["norm2"] = self.norm2(x, zq, conv_cache=conv_cache.get("norm2")) + else: + x = self.norm2(x) + + x = self.nonlinearity(x) + x, new_cache["conv2"] = self.conv2(x, conv_cache=conv_cache.get("conv2")) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return x + residual, new_cache + + +class Downsample3D(nn.Module): + """3D downsampling with optional temporal compression.""" + def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time: + b, c, t, h, w = x.shape + x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t) + if t % 2 == 1: + x_first, x_rest = x[..., 0], x[..., 1:] + if x_rest.shape[-1] > 0: + x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) + x = torch.cat([x_first[..., None], x_rest], dim=-1) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + else: + x = F.avg_pool1d(x, kernel_size=2, stride=2) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + return x + + +class Upsample3D(nn.Module): + """3D upsampling with optional temporal decompression.""" + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time: + if x.shape[2] > 1 and x.shape[2] % 2 == 1: + x_first, x_rest = x[:, :, 0], x[:, :, 1:] + x_first = F.interpolate(x_first, scale_factor=2.0) + x_rest = F.interpolate(x_rest, scale_factor=2.0) + x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) + elif x.shape[2] > 1: + x = F.interpolate(x, scale_factor=2.0) + else: + x = x.squeeze(2) + x = F.interpolate(x, scale_factor=2.0) + x = x[:, :, None, :, :] + else: + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = F.interpolate(x, scale_factor=2.0) + x = x.reshape(b, t, c, *x.shape[2:]).permute(0, 2, 1, 3, 4) + + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, *x.shape[1:]).permute(0, 2, 1, 3, 4) + return x + + +class DownBlock3D(nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1, + eps=1e-6, act_fn="silu", groups=32, add_downsample=True, + compress_time=False, pad_mode="first"): + super().__init__() + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + groups=groups, eps=eps, act_fn=act_fn, pad_mode=pad_mode, + ) + for i in range(num_layers) + ]) + self.downsamplers = nn.ModuleList([Downsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_downsample else None + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + for i, resnet in enumerate(self.resnets): + x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) + if self.downsamplers is not None: + for ds in self.downsamplers: + x = ds(x) + return x, new_cache + + +class MidBlock3D(nn.Module): + def __init__(self, in_channels, temb_channels=0, num_layers=1, + eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=None, pad_mode="first"): + super().__init__() + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels, out_channels=in_channels, + temb_channels=temb_channels, groups=groups, eps=eps, + act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode, + ) + for _ in range(num_layers) + ]) + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + for i, resnet in enumerate(self.resnets): + x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) + return x, new_cache + + +class UpBlock3D(nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1, + eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=16, + add_upsample=True, compress_time=False, pad_mode="first"): + super().__init__() + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + temb_channels=temb_channels, groups=groups, eps=eps, + act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode, + ) + for i in range(num_layers) + ]) + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_upsample else None + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + for i, resnet in enumerate(self.resnets): + x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) + if self.upsamplers is not None: + for us in self.upsamplers: + x = us(x) + return x, new_cache + + +class Encoder3D(nn.Module): + def __init__(self, in_channels=3, out_channels=16, + block_out_channels=(128, 256, 256, 512), + layers_per_block=3, act_fn="silu", + eps=1e-6, groups=32, pad_mode="first", + temporal_compression_ratio=4): + super().__init__() + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) + + self.down_blocks = nn.ModuleList() + output_channel = block_out_channels[0] + for i in range(len(block_out_channels)): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + self.down_blocks.append(DownBlock3D( + in_channels=input_channel, out_channels=output_channel, + temb_channels=0, num_layers=layers_per_block, + eps=eps, act_fn=act_fn, groups=groups, + add_downsample=not is_final, compress_time=compress_time, + )) + + self.mid_block = MidBlock3D( + in_channels=block_out_channels[-1], temb_channels=0, + num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode, + ) + + self.norm_out = nn.GroupNorm(groups, block_out_channels[-1], eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode) + + def forward(self, x, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + + x, new_cache["conv_in"] = self.conv_in(x, conv_cache=conv_cache.get("conv_in")) + + for i, block in enumerate(self.down_blocks): + key = f"down_block_{i}" + x, new_cache[key] = block(x, None, None, conv_cache.get(key)) + + x, new_cache["mid_block"] = self.mid_block(x, None, None, conv_cache=conv_cache.get("mid_block")) + + x = self.norm_out(x) + x = self.conv_act(x) + x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out")) + + return x, new_cache + + +class Decoder3D(nn.Module): + def __init__(self, in_channels=16, out_channels=3, + block_out_channels=(128, 256, 256, 512), + layers_per_block=3, act_fn="silu", + eps=1e-6, groups=32, pad_mode="first", + temporal_compression_ratio=4): + super().__init__() + reversed_channels = list(reversed(block_out_channels)) + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + self.conv_in = CausalConv3d(in_channels, reversed_channels[0], kernel_size=3, pad_mode=pad_mode) + + self.mid_block = MidBlock3D( + in_channels=reversed_channels[0], temb_channels=0, + num_layers=2, eps=eps, act_fn=act_fn, groups=groups, + spatial_norm_dim=in_channels, pad_mode=pad_mode, + ) + + self.up_blocks = nn.ModuleList() + output_channel = reversed_channels[0] + for i in range(len(block_out_channels)): + prev_channel = output_channel + output_channel = reversed_channels[i] + is_final = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + self.up_blocks.append(UpBlock3D( + in_channels=prev_channel, out_channels=output_channel, + temb_channels=0, num_layers=layers_per_block + 1, + eps=eps, act_fn=act_fn, groups=groups, + spatial_norm_dim=in_channels, + add_upsample=not is_final, compress_time=compress_time, + )) + + self.norm_out = SpatialNorm3D(reversed_channels[-1], in_channels, groups=groups) + self.conv_act = nn.SiLU() + self.conv_out = CausalConv3d(reversed_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode) + + def forward(self, sample, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + + x, new_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) + + x, new_cache["mid_block"] = self.mid_block(x, None, sample, conv_cache=conv_cache.get("mid_block")) + + for i, block in enumerate(self.up_blocks): + key = f"up_block_{i}" + x, new_cache[key] = block(x, None, sample, conv_cache=conv_cache.get(key)) + + x, new_cache["norm_out"] = self.norm_out(x, sample, conv_cache=conv_cache.get("norm_out")) + x = self.conv_act(x) + x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out")) + + return x, new_cache + + + +class AutoencoderKLCogVideoX(nn.Module): + """CogVideoX VAE. Spatial tiling/slicing handled by ComfyUI's VAE wrapper. + + Temporal frame batching with conv_cache is kept here since the causal 3D + convolutions need state passed between temporal chunks. + """ + + def __init__(self, + in_channels=3, out_channels=3, + block_out_channels=(128, 256, 256, 512), + latent_channels=16, layers_per_block=3, + act_fn="silu", eps=1e-6, groups=32, + temporal_compression_ratio=4, + ): + super().__init__() + self.latent_channels = latent_channels + + self.encoder = Encoder3D( + in_channels=in_channels, out_channels=latent_channels, + block_out_channels=block_out_channels, layers_per_block=layers_per_block, + act_fn=act_fn, eps=eps, groups=groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + self.decoder = Decoder3D( + in_channels=latent_channels, out_channels=out_channels, + block_out_channels=block_out_channels, layers_per_block=layers_per_block, + act_fn=act_fn, eps=eps, groups=groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + + self.num_latent_frames_batch_size = 2 + self.num_sample_frames_batch_size = 8 + + def encode(self, x): + t = x.shape[2] + frame_batch = self.num_sample_frames_batch_size + num_batches = max(t // frame_batch, 1) + conv_cache = None + enc = [] + for i in range(num_batches): + remaining = t % frame_batch + start = frame_batch * i + (0 if i == 0 else remaining) + end = frame_batch * (i + 1) + remaining + chunk, conv_cache = self.encoder(x[:, :, start:end], conv_cache=conv_cache) + enc.append(chunk.to(x.device)) + enc = torch.cat(enc, dim=2) + mean, _ = enc.chunk(2, dim=1) + return mean + + def decode(self, z): + t = z.shape[2] + frame_batch = self.num_latent_frames_batch_size + num_batches = max(t // frame_batch, 1) + conv_cache = None + dec = [] + for i in range(num_batches): + remaining = t % frame_batch + start = frame_batch * i + (0 if i == 0 else remaining) + end = frame_batch * (i + 1) + remaining + chunk, conv_cache = self.decoder(z[:, :, start:end], conv_cache=conv_cache) + dec.append(chunk.cpu()) + return torch.cat(dec, dim=2).to(z.device) diff --git a/comfy/model_base.py b/comfy/model_base.py index 5c2668ba9..0f7e9f158 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -52,6 +52,7 @@ import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model import comfy.ldm.ace.ace_step15 +import comfy.ldm.cogvideo.model import comfy.ldm.rt_detr.rtdetr_v4 import comfy.ldm.ernie.model @@ -80,6 +81,7 @@ class ModelType(Enum): IMG_TO_IMG = 9 FLOW_COSMOS = 10 IMG_TO_IMG_FLOW = 11 + V_PREDICTION_DDPM = 12 def model_sampling(model_config, model_type): @@ -114,6 +116,8 @@ def model_sampling(model_config, model_type): s = comfy.model_sampling.ModelSamplingCosmosRFlow elif model_type == ModelType.IMG_TO_IMG_FLOW: c = comfy.model_sampling.IMG_TO_IMG_FLOW + elif model_type == ModelType.V_PREDICTION_DDPM: + c = comfy.model_sampling.V_PREDICTION_DDPM class ModelSampling(s, c): pass @@ -1974,3 +1978,41 @@ class ErnieImage(BaseModel): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out + +class CogVideoX(BaseModel): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION_DDPM, image_to_video=False, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cogvideo.model.CogVideoXTransformer3DModel) + self.image_to_video = image_to_video + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + # Detect extra channels needed (e.g. 32 - 16 = 16 for ref latent) + extra_channels = self.diffusion_model.in_channels - noise.shape[1] + if extra_channels == 0: + return None + + image = kwargs.get("concat_latent_image", None) + device = kwargs["device"] + + if image is None: + shape = list(noise.shape) + shape[1] = extra_channels + return torch.zeros(shape, dtype=noise.dtype, layout=noise.layout, device=noise.device) + + latent_dim = self.latent_format.latent_channels + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + for i in range(0, image.shape[1], latent_dim): + image[:, i:i + latent_dim] = self.process_latent_in(image[:, i:i + latent_dim]) + image = utils.resize_to_batch_size(image, noise.shape[0]) + return image + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + # OFS embedding (CogVideoX 1.5 I2V), default 2.0 as used by SparkVSR + if self.diffusion_model.ofs_proj_dim is not None: + ofs = kwargs.get("ofs", None) + if ofs is None: + noise = kwargs.get("noise", None) + ofs = torch.full((noise.shape[0],), 2.0, device=noise.device, dtype=noise.dtype) + out['ofs'] = comfy.conds.CONDRegular(ofs) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index ca06cdd1e..62681881c 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -490,6 +490,55 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return dit_config + if '{}blocks.0.norm1.linear.weight'.format(key_prefix) in state_dict_keys: # CogVideoX + dit_config = {} + dit_config["image_model"] = "cogvideox" + + # Extract config from weight shapes + norm1_weight = state_dict['{}blocks.0.norm1.linear.weight'.format(key_prefix)] + time_embed_dim = norm1_weight.shape[1] + dim = norm1_weight.shape[0] // 6 + + dit_config["num_attention_heads"] = dim // 64 + dit_config["attention_head_dim"] = 64 + dit_config["time_embed_dim"] = time_embed_dim + dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') + + # Detect in_channels from patch_embed + patch_proj_key = '{}patch_embed.proj.weight'.format(key_prefix) + patch_proj_linear_key = '{}patch_embed.proj.weight'.format(key_prefix) + if patch_proj_key in state_dict_keys: + w = state_dict[patch_proj_key] + if w.ndim == 4: + # Conv2d: [out, in, kh, kw] — CogVideoX 1.0 + dit_config["in_channels"] = w.shape[1] + dit_config["patch_size"] = w.shape[2] + elif w.ndim == 2: + # Linear: [out, in_channels * patch_size * patch_size * patch_size_t] — CogVideoX 1.5 + dit_config["patch_size"] = 2 + dit_config["patch_size_t"] = 2 + dit_config["in_channels"] = w.shape[1] // (2 * 2 * 2) # 256 // 8 = 32 + + text_proj_key = '{}patch_embed.text_proj.weight'.format(key_prefix) + if text_proj_key in state_dict_keys: + dit_config["text_embed_dim"] = state_dict[text_proj_key].shape[1] + + # Detect OFS embedding + ofs_key = '{}ofs_embedding_linear_1.weight'.format(key_prefix) + if ofs_key in state_dict_keys: + dit_config["ofs_embed_dim"] = state_dict[ofs_key].shape[1] + + # Detect positional embedding type + pos_key = '{}patch_embed.pos_embedding'.format(key_prefix) + if pos_key in state_dict_keys: + dit_config["use_learned_positional_embeddings"] = True + dit_config["use_rotary_positional_embeddings"] = False + else: + dit_config["use_learned_positional_embeddings"] = False + dit_config["use_rotary_positional_embeddings"] = True + + return dit_config + if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 dit_config = {} dit_config["image_model"] = "wan2.1" diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 13860e6a2..cf2b5db5f 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -54,6 +54,30 @@ class V_PREDICTION(EPS): sigma = reshape_sigma(sigma, model_output.ndim) return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 +class V_PREDICTION_DDPM: + """CogVideoX v-prediction: model receives raw x_t (unscaled), predicts velocity v. + x_0 = sqrt(alpha) * x_t - sqrt(1-alpha) * v + = x_t / sqrt(sigma^2 + 1) - v * sigma / sqrt(sigma^2 + 1) + """ + def calculate_input(self, sigma, noise): + return noise + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = reshape_sigma(sigma, model_output.ndim) + return model_input / (sigma ** 2 + 1.0) ** 0.5 - model_output * sigma / (sigma ** 2 + 1.0) ** 0.5 + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + sigma = reshape_sigma(sigma, noise.ndim) + if max_denoise: + noise = noise * torch.sqrt(1.0 + sigma ** 2.0) + else: + noise = noise * sigma + noise += latent_image + return noise + + def inverse_noise_scaling(self, sigma, latent): + return latent + class EDM(V_PREDICTION): def calculate_denoised(self, sigma, model_output, model_input): sigma = reshape_sigma(sigma, model_output.ndim) diff --git a/comfy/sd.py b/comfy/sd.py index e573804a5..7f730c2ed 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -651,6 +651,18 @@ class VAE: self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) + elif "decoder.conv_in.conv.weight" in sd and "decoder.mid_block.resnets.0.norm1.norm_layer.weight" in sd: # CogVideoX VAE + import comfy.ldm.cogvideo.vae + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.latent_dim = 3 + self.latent_channels = sd["encoder.conv_out.conv.weight"].shape[0] // 2 + self.first_stage_model = comfy.ldm.cogvideo.vae.AutoencoderKLCogVideoX(latent_channels=self.latent_channels) + self.memory_used_decode = lambda shape, dtype: (2800 * max(2, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1400 * max(1, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 58d4ce731..8fb3967ac 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1781,6 +1781,55 @@ class ErnieImage(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage] +class CogVideoX_T2V(supported_models_base.BASE): + unet_config = { + "image_model": "cogvideox", + } + + sampling_settings = { + "linear_start": 0.00085, + "linear_end": 0.012, + "beta_schedule": "linear", + "zsnr": True, + } + + unet_extra_config = {} + latent_format = latent_formats.CogVideoX + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + # CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE + if self.unet_config.get("patch_size_t") is not None: + self.unet_config.setdefault("sample_height", 96) + self.unet_config.setdefault("sample_width", 170) + self.unet_config.setdefault("sample_frames", 81) + out = model_base.CogVideoX(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + + class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226) + + return supported_models_base.ClipTarget(CogVideoXT5Tokenizer, comfy.text_encoders.sd3_clip.T5XXLModel) + +class CogVideoX_I2V(CogVideoX_T2V): + unet_config = { + "image_model": "cogvideox", + "in_channels": 32, + } + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.CogVideoX(self, image_to_video=True, device=device) + return out + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, CogVideoX_I2V, CogVideoX_T2V] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_cogvideox.py b/comfy_extras/nodes_cogvideox.py new file mode 100644 index 000000000..59aa74cee --- /dev/null +++ b/comfy_extras/nodes_cogvideox.py @@ -0,0 +1,137 @@ +import nodes +import node_helpers +import torch +import comfy.model_management +import comfy.utils +from comfy_api.latest import io, ComfyExtension +from typing_extensions import override + +class SparkVSRConditioning(io.ComfyNode): + """Conditioning node for SparkVSR video super-resolution. + + Encodes LQ video and optional HR reference frames through the VAE, + builds the concat conditioning for the CogVideoX I2V model. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SparkVSRConditioning", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Image.Input("lq_video"), + io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=8), + io.Int.Input("length", default=49, min=1, max=nodes.MAX_RESOLUTION, step=1), + io.Int.Input("batch_size", default=1, min=1, max=64), + io.Image.Input("ref_frames", optional=True), + io.Combo.Input("ref_mode", options=["auto", "manual"], default="auto", optional=True), + io.String.Input("ref_indices", default="", optional=True), + io.Float.Input("ref_guidance_scale", default=1.0, min=0.0, max=10.0, step=0.1, optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, lq_video, width, height, length, + batch_size, ref_frames=None, ref_mode="auto", ref_indices="", + ref_guidance_scale=1.0) -> io.NodeOutput: + + temporal_compression = 4 + latent_t = ((length - 1) // temporal_compression) + 1 + latent_h = height // 8 + latent_w = width // 8 + + # Base latent (noise will be added by KSampler) + latent = torch.zeros( + [batch_size, 16, latent_t, latent_h, latent_w], + device=comfy.model_management.intermediate_device() + ) + + # Encode LQ video → this becomes the base latent (KSampler adds noise to this) + lq = lq_video[:length] + lq = comfy.utils.common_upscale(lq.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + lq_latent = vae.encode(lq[:, :, :, :3]) + + # Ensure temporal dim matches + if lq_latent.shape[2] > latent_t: + lq_latent = lq_latent[:, :, :latent_t] + elif lq_latent.shape[2] < latent_t: + pad = latent_t - lq_latent.shape[2] + lq_latent = torch.cat([lq_latent, lq_latent[:, :, -1:].repeat(1, 1, pad, 1, 1)], dim=2) + + # Build reference latent (16ch) — goes as concat_latent_image + # concat_cond in model_base will concatenate this with the noise (16ch) → 32ch total + ref_latent = torch.zeros_like(lq_latent) + + if ref_frames is not None: + num_video_frames = lq_video.shape[0] + + # Determine reference indices + if ref_mode == "manual" and ref_indices.strip(): + indices = [int(x.strip()) for x in ref_indices.split(",") if x.strip()] + else: + indices = _select_indices(num_video_frames) + + # Encode each reference frame and place at its temporal position. + # SparkVSR places refs at specific latent indices, rest stays zeros. + for ref_idx in indices: + if ref_idx >= ref_frames.shape[0]: + continue + + frame = ref_frames[ref_idx:ref_idx + 1] + frame = comfy.utils.common_upscale(frame.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + frame_latent = vae.encode(frame[:, :, :, :3]) + + target_lat_idx = ref_idx // temporal_compression + if target_lat_idx < latent_t: + ref_latent[:, :, target_lat_idx] = frame_latent[:, :, 0] + + # Set ref latent as concat conditioning (16ch, model_base.concat_cond adds it to noise) + if ref_guidance_scale != 1.0 and ref_frames is not None: + # CFG: positive gets real refs, negative gets zero refs + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": ref_latent}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": torch.zeros_like(ref_latent)}) + else: + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": ref_latent}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": ref_latent}) + + # LQ latent is the base — KSampler will noise it and denoise + out_latent = {"samples": lq_latent} + return io.NodeOutput(positive, negative, out_latent) + + +def _select_indices(num_frames, max_refs=None): + """Auto-select reference frame indices (first, evenly spaced, last).""" + if max_refs is None: + max_refs = (num_frames - 1) // 4 + 1 + max_refs = min(max_refs, 3) + + if num_frames <= 1: + return [0] + if max_refs == 1: + return [0] + if max_refs == 2: + return [0, num_frames - 1] + + mid = num_frames // 2 + return [0, mid, num_frames - 1] + + +class CogVideoXExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SparkVSRConditioning, + ] + + +async def comfy_entrypoint() -> CogVideoXExtension: + return CogVideoXExtension() diff --git a/convert_sparkvsr_to_comfy.py b/convert_sparkvsr_to_comfy.py new file mode 100644 index 000000000..891c14428 --- /dev/null +++ b/convert_sparkvsr_to_comfy.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +"""Convert SparkVSR/CogVideoX diffusers checkpoint to ComfyUI format. + +Usage: + python convert_sparkvsr_to_comfy.py --model_dir path/to/sparkvsr-checkpoint \ + --output_dir ComfyUI/models/ + +This creates two files: + - diffusion_models/cogvideox_sparkvsr.safetensors (transformer) + - vae/cogvideox_vae.safetensors (VAE) + +T5-XXL text encoder does not need conversion — use existing ComfyUI T5 weights. +""" + +import argparse +import os +import torch +from safetensors.torch import load_file, save_file + + +def remap_transformer_keys(state_dict): + """Remap diffusers transformer keys to ComfyUI CogVideoX naming.""" + new_sd = {} + for k, v in state_dict.items(): + new_key = k + + # Patch embedding + new_key = new_key.replace("patch_embed.proj.", "patch_embed.proj.") + new_key = new_key.replace("patch_embed.text_proj.", "patch_embed.text_proj.") + new_key = new_key.replace("patch_embed.pos_embedding", "patch_embed.pos_embedding") + + # Time embedding: diffusers uses time_embedding.linear_1/2, we use time_embedding_linear_1/2 + new_key = new_key.replace("time_embedding.linear_1.", "time_embedding_linear_1.") + new_key = new_key.replace("time_embedding.linear_2.", "time_embedding_linear_2.") + + # OFS embedding + new_key = new_key.replace("ofs_embedding.linear_1.", "ofs_embedding_linear_1.") + new_key = new_key.replace("ofs_embedding.linear_2.", "ofs_embedding_linear_2.") + + # Transformer blocks: diffusers uses transformer_blocks, we use blocks + new_key = new_key.replace("transformer_blocks.", "blocks.") + + # Attention: diffusers uses attn1.to_q/k/v/out, we use q/k/v/attn_out + new_key = new_key.replace(".attn1.to_q.", ".q.") + new_key = new_key.replace(".attn1.to_k.", ".k.") + new_key = new_key.replace(".attn1.to_v.", ".v.") + new_key = new_key.replace(".attn1.to_out.0.", ".attn_out.") + new_key = new_key.replace(".attn1.norm_q.", ".norm_q.") + new_key = new_key.replace(".attn1.norm_k.", ".norm_k.") + + # Feed-forward: diffusers uses ff.net.0.proj/ff.net.2, we use ff_proj/ff_out + new_key = new_key.replace(".ff.net.0.proj.", ".ff_proj.") + new_key = new_key.replace(".ff.net.2.", ".ff_out.") + + # Output norms + new_key = new_key.replace("norm_final.", "norm_final.") + new_key = new_key.replace("norm_out.linear.", "norm_out.linear.") + new_key = new_key.replace("norm_out.norm.", "norm_out.norm.") + + new_sd[new_key] = v + + return new_sd + + +def remap_vae_keys(state_dict): + """Remap diffusers VAE keys to ComfyUI CogVideoX naming. + + The VAE architecture is directly ported so most keys should match. + Main differences are in block naming. + """ + new_sd = {} + for k, v in state_dict.items(): + new_key = k + + # Encoder blocks + new_key = new_key.replace("encoder.down_blocks.", "encoder.down_blocks.") + new_key = new_key.replace("encoder.mid_block.", "encoder.mid_block.") + + # Decoder blocks + new_key = new_key.replace("decoder.up_blocks.", "decoder.up_blocks.") + new_key = new_key.replace("decoder.mid_block.", "decoder.mid_block.") + + # Resnet blocks within down/up/mid + new_key = new_key.replace(".resnets.", ".resnets.") + + # CausalConv3d: diffusers stores as .conv.weight inside CausalConv3d + # Our CausalConv3d also has .conv.weight, so this should match + + # Downsamplers/Upsamplers + new_key = new_key.replace(".downsamplers.0.", ".downsamplers.0.") + new_key = new_key.replace(".upsamplers.0.", ".upsamplers.0.") + + new_sd[new_key] = v + + return new_sd + + +def main(): + parser = argparse.ArgumentParser(description="Convert SparkVSR/CogVideoX to ComfyUI format") + parser.add_argument("--model_dir", type=str, required=True, + help="Path to diffusers pipeline directory (contains transformer/, vae/, etc.)") + parser.add_argument("--output_dir", type=str, default=".", + help="Output base directory (will create diffusion_models/ and vae/ subdirs)") + args = parser.parse_args() + + # Load transformer + transformer_dir = os.path.join(args.model_dir, "transformer") + print(f"Loading transformer from {transformer_dir}...") + transformer_sd = {} + for f in sorted(os.listdir(transformer_dir)): + if f.endswith(".safetensors"): + sd = load_file(os.path.join(transformer_dir, f)) + transformer_sd.update(sd) + + transformer_sd = remap_transformer_keys(transformer_sd) + + out_dir = os.path.join(args.output_dir, "diffusion_models") + os.makedirs(out_dir, exist_ok=True) + out_path = os.path.join(out_dir, "cogvideox_sparkvsr.safetensors") + print(f"Saving transformer to {out_path} ({len(transformer_sd)} keys)") + save_file(transformer_sd, out_path) + + # Load VAE + vae_dir = os.path.join(args.model_dir, "vae") + print(f"Loading VAE from {vae_dir}...") + vae_sd = {} + for f in sorted(os.listdir(vae_dir)): + if f.endswith(".safetensors"): + sd = load_file(os.path.join(vae_dir, f)) + vae_sd.update(sd) + + vae_sd = remap_vae_keys(vae_sd) + + out_dir = os.path.join(args.output_dir, "vae") + os.makedirs(out_dir, exist_ok=True) + out_path = os.path.join(out_dir, "cogvideox_vae.safetensors") + print(f"Saving VAE to {out_path} ({len(vae_sd)} keys)") + save_file(vae_sd, out_path) + + print("Done! T5-XXL text encoder does not need conversion.") + + +if __name__ == "__main__": + main() diff --git a/nodes.py b/nodes.py index 299b3d758..f90cee732 100644 --- a/nodes.py +++ b/nodes.py @@ -2457,7 +2457,8 @@ async def init_builtin_extra_nodes(): "nodes_number_convert.py", "nodes_painter.py", "nodes_curve.py", - "nodes_rtdetr.py" + "nodes_rtdetr.py", + "nodes_cogvideox.py", ] import_failed = [] From 6841484cde5f61ce761de256dbc24826f9512397 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Fri, 10 Apr 2026 11:09:51 +0200 Subject: [PATCH 02/15] Remove breaking code, logging etc. --- comfy/ldm/cogvideo/model.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/comfy/ldm/cogvideo/model.py b/comfy/ldm/cogvideo/model.py index a4e737d41..eaf242411 100644 --- a/comfy/ldm/cogvideo/model.py +++ b/comfy/ldm/cogvideo/model.py @@ -432,15 +432,6 @@ class CogVideoXTransformer3DModel(nn.Module): ).execute(x, timestep, context, ofs, transformer_options, **kwargs) def _forward(self, x, timestep, context, ofs=None, transformer_options={}, **kwargs): - import logging - logger = logging.getLogger(__name__) - if x.shape[1] > 16: - lq_part = x[:, :16] - ref_part = x[:, 16:] - logger.warning(f"[CogVideoX] x: {x.shape}, t: {timestep.item():.0f}, ofs: {ofs}, LQ: mean={lq_part.float().mean():.4f} std={lq_part.float().std():.4f}, REF: mean={ref_part.float().mean():.4f} std={ref_part.float().std():.4f} nonzero={ref_part.count_nonzero().item()}") - else: - logger.warning(f"[CogVideoX] x: {x.shape}, t: {timestep.item():.0f}, ofs: {ofs}") - # ComfyUI passes [B, C, T, H, W] batch_size, channels, t, h, w = x.shape @@ -512,7 +503,6 @@ class CogVideoXTransformer3DModel(nn.Module): # Back to ComfyUI format [B, C, T, H, W] and crop padding output = output.permute(0, 2, 1, 3, 4)[:, :, :t, :h, :w] - logger.warning(f"[CogVideoX] output: {output.shape}, mean={output.float().mean():.4f}, std={output.float().std():.4f}, min={output.float().min():.4f}, max={output.float().max():.4f}") return output def _get_rotary_emb(self, h, w, t, device): From 92571c7fe528a072d9a55671f12108c888e7e544 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Fri, 10 Apr 2026 15:03:05 +0200 Subject: [PATCH 03/15] Remove sparkvsr related code. --- comfy_extras/nodes_cogvideox.py | 137 ------------------------------ convert_sparkvsr_to_comfy.py | 144 -------------------------------- nodes.py | 1 - 3 files changed, 282 deletions(-) delete mode 100644 comfy_extras/nodes_cogvideox.py delete mode 100644 convert_sparkvsr_to_comfy.py diff --git a/comfy_extras/nodes_cogvideox.py b/comfy_extras/nodes_cogvideox.py deleted file mode 100644 index 59aa74cee..000000000 --- a/comfy_extras/nodes_cogvideox.py +++ /dev/null @@ -1,137 +0,0 @@ -import nodes -import node_helpers -import torch -import comfy.model_management -import comfy.utils -from comfy_api.latest import io, ComfyExtension -from typing_extensions import override - -class SparkVSRConditioning(io.ComfyNode): - """Conditioning node for SparkVSR video super-resolution. - - Encodes LQ video and optional HR reference frames through the VAE, - builds the concat conditioning for the CogVideoX I2V model. - """ - - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SparkVSRConditioning", - category="conditioning/video_models", - inputs=[ - io.Conditioning.Input("positive"), - io.Conditioning.Input("negative"), - io.Vae.Input("vae"), - io.Image.Input("lq_video"), - io.Int.Input("width", default=832, min=16, max=nodes.MAX_RESOLUTION, step=8), - io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=8), - io.Int.Input("length", default=49, min=1, max=nodes.MAX_RESOLUTION, step=1), - io.Int.Input("batch_size", default=1, min=1, max=64), - io.Image.Input("ref_frames", optional=True), - io.Combo.Input("ref_mode", options=["auto", "manual"], default="auto", optional=True), - io.String.Input("ref_indices", default="", optional=True), - io.Float.Input("ref_guidance_scale", default=1.0, min=0.0, max=10.0, step=0.1, optional=True), - ], - outputs=[ - io.Conditioning.Output(display_name="positive"), - io.Conditioning.Output(display_name="negative"), - io.Latent.Output(display_name="latent"), - ], - ) - - @classmethod - def execute(cls, positive, negative, vae, lq_video, width, height, length, - batch_size, ref_frames=None, ref_mode="auto", ref_indices="", - ref_guidance_scale=1.0) -> io.NodeOutput: - - temporal_compression = 4 - latent_t = ((length - 1) // temporal_compression) + 1 - latent_h = height // 8 - latent_w = width // 8 - - # Base latent (noise will be added by KSampler) - latent = torch.zeros( - [batch_size, 16, latent_t, latent_h, latent_w], - device=comfy.model_management.intermediate_device() - ) - - # Encode LQ video → this becomes the base latent (KSampler adds noise to this) - lq = lq_video[:length] - lq = comfy.utils.common_upscale(lq.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - lq_latent = vae.encode(lq[:, :, :, :3]) - - # Ensure temporal dim matches - if lq_latent.shape[2] > latent_t: - lq_latent = lq_latent[:, :, :latent_t] - elif lq_latent.shape[2] < latent_t: - pad = latent_t - lq_latent.shape[2] - lq_latent = torch.cat([lq_latent, lq_latent[:, :, -1:].repeat(1, 1, pad, 1, 1)], dim=2) - - # Build reference latent (16ch) — goes as concat_latent_image - # concat_cond in model_base will concatenate this with the noise (16ch) → 32ch total - ref_latent = torch.zeros_like(lq_latent) - - if ref_frames is not None: - num_video_frames = lq_video.shape[0] - - # Determine reference indices - if ref_mode == "manual" and ref_indices.strip(): - indices = [int(x.strip()) for x in ref_indices.split(",") if x.strip()] - else: - indices = _select_indices(num_video_frames) - - # Encode each reference frame and place at its temporal position. - # SparkVSR places refs at specific latent indices, rest stays zeros. - for ref_idx in indices: - if ref_idx >= ref_frames.shape[0]: - continue - - frame = ref_frames[ref_idx:ref_idx + 1] - frame = comfy.utils.common_upscale(frame.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) - frame_latent = vae.encode(frame[:, :, :, :3]) - - target_lat_idx = ref_idx // temporal_compression - if target_lat_idx < latent_t: - ref_latent[:, :, target_lat_idx] = frame_latent[:, :, 0] - - # Set ref latent as concat conditioning (16ch, model_base.concat_cond adds it to noise) - if ref_guidance_scale != 1.0 and ref_frames is not None: - # CFG: positive gets real refs, negative gets zero refs - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": ref_latent}) - negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": torch.zeros_like(ref_latent)}) - else: - positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": ref_latent}) - negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": ref_latent}) - - # LQ latent is the base — KSampler will noise it and denoise - out_latent = {"samples": lq_latent} - return io.NodeOutput(positive, negative, out_latent) - - -def _select_indices(num_frames, max_refs=None): - """Auto-select reference frame indices (first, evenly spaced, last).""" - if max_refs is None: - max_refs = (num_frames - 1) // 4 + 1 - max_refs = min(max_refs, 3) - - if num_frames <= 1: - return [0] - if max_refs == 1: - return [0] - if max_refs == 2: - return [0, num_frames - 1] - - mid = num_frames // 2 - return [0, mid, num_frames - 1] - - -class CogVideoXExtension(ComfyExtension): - @override - async def get_node_list(self) -> list[type[io.ComfyNode]]: - return [ - SparkVSRConditioning, - ] - - -async def comfy_entrypoint() -> CogVideoXExtension: - return CogVideoXExtension() diff --git a/convert_sparkvsr_to_comfy.py b/convert_sparkvsr_to_comfy.py deleted file mode 100644 index 891c14428..000000000 --- a/convert_sparkvsr_to_comfy.py +++ /dev/null @@ -1,144 +0,0 @@ -#!/usr/bin/env python3 -"""Convert SparkVSR/CogVideoX diffusers checkpoint to ComfyUI format. - -Usage: - python convert_sparkvsr_to_comfy.py --model_dir path/to/sparkvsr-checkpoint \ - --output_dir ComfyUI/models/ - -This creates two files: - - diffusion_models/cogvideox_sparkvsr.safetensors (transformer) - - vae/cogvideox_vae.safetensors (VAE) - -T5-XXL text encoder does not need conversion — use existing ComfyUI T5 weights. -""" - -import argparse -import os -import torch -from safetensors.torch import load_file, save_file - - -def remap_transformer_keys(state_dict): - """Remap diffusers transformer keys to ComfyUI CogVideoX naming.""" - new_sd = {} - for k, v in state_dict.items(): - new_key = k - - # Patch embedding - new_key = new_key.replace("patch_embed.proj.", "patch_embed.proj.") - new_key = new_key.replace("patch_embed.text_proj.", "patch_embed.text_proj.") - new_key = new_key.replace("patch_embed.pos_embedding", "patch_embed.pos_embedding") - - # Time embedding: diffusers uses time_embedding.linear_1/2, we use time_embedding_linear_1/2 - new_key = new_key.replace("time_embedding.linear_1.", "time_embedding_linear_1.") - new_key = new_key.replace("time_embedding.linear_2.", "time_embedding_linear_2.") - - # OFS embedding - new_key = new_key.replace("ofs_embedding.linear_1.", "ofs_embedding_linear_1.") - new_key = new_key.replace("ofs_embedding.linear_2.", "ofs_embedding_linear_2.") - - # Transformer blocks: diffusers uses transformer_blocks, we use blocks - new_key = new_key.replace("transformer_blocks.", "blocks.") - - # Attention: diffusers uses attn1.to_q/k/v/out, we use q/k/v/attn_out - new_key = new_key.replace(".attn1.to_q.", ".q.") - new_key = new_key.replace(".attn1.to_k.", ".k.") - new_key = new_key.replace(".attn1.to_v.", ".v.") - new_key = new_key.replace(".attn1.to_out.0.", ".attn_out.") - new_key = new_key.replace(".attn1.norm_q.", ".norm_q.") - new_key = new_key.replace(".attn1.norm_k.", ".norm_k.") - - # Feed-forward: diffusers uses ff.net.0.proj/ff.net.2, we use ff_proj/ff_out - new_key = new_key.replace(".ff.net.0.proj.", ".ff_proj.") - new_key = new_key.replace(".ff.net.2.", ".ff_out.") - - # Output norms - new_key = new_key.replace("norm_final.", "norm_final.") - new_key = new_key.replace("norm_out.linear.", "norm_out.linear.") - new_key = new_key.replace("norm_out.norm.", "norm_out.norm.") - - new_sd[new_key] = v - - return new_sd - - -def remap_vae_keys(state_dict): - """Remap diffusers VAE keys to ComfyUI CogVideoX naming. - - The VAE architecture is directly ported so most keys should match. - Main differences are in block naming. - """ - new_sd = {} - for k, v in state_dict.items(): - new_key = k - - # Encoder blocks - new_key = new_key.replace("encoder.down_blocks.", "encoder.down_blocks.") - new_key = new_key.replace("encoder.mid_block.", "encoder.mid_block.") - - # Decoder blocks - new_key = new_key.replace("decoder.up_blocks.", "decoder.up_blocks.") - new_key = new_key.replace("decoder.mid_block.", "decoder.mid_block.") - - # Resnet blocks within down/up/mid - new_key = new_key.replace(".resnets.", ".resnets.") - - # CausalConv3d: diffusers stores as .conv.weight inside CausalConv3d - # Our CausalConv3d also has .conv.weight, so this should match - - # Downsamplers/Upsamplers - new_key = new_key.replace(".downsamplers.0.", ".downsamplers.0.") - new_key = new_key.replace(".upsamplers.0.", ".upsamplers.0.") - - new_sd[new_key] = v - - return new_sd - - -def main(): - parser = argparse.ArgumentParser(description="Convert SparkVSR/CogVideoX to ComfyUI format") - parser.add_argument("--model_dir", type=str, required=True, - help="Path to diffusers pipeline directory (contains transformer/, vae/, etc.)") - parser.add_argument("--output_dir", type=str, default=".", - help="Output base directory (will create diffusion_models/ and vae/ subdirs)") - args = parser.parse_args() - - # Load transformer - transformer_dir = os.path.join(args.model_dir, "transformer") - print(f"Loading transformer from {transformer_dir}...") - transformer_sd = {} - for f in sorted(os.listdir(transformer_dir)): - if f.endswith(".safetensors"): - sd = load_file(os.path.join(transformer_dir, f)) - transformer_sd.update(sd) - - transformer_sd = remap_transformer_keys(transformer_sd) - - out_dir = os.path.join(args.output_dir, "diffusion_models") - os.makedirs(out_dir, exist_ok=True) - out_path = os.path.join(out_dir, "cogvideox_sparkvsr.safetensors") - print(f"Saving transformer to {out_path} ({len(transformer_sd)} keys)") - save_file(transformer_sd, out_path) - - # Load VAE - vae_dir = os.path.join(args.model_dir, "vae") - print(f"Loading VAE from {vae_dir}...") - vae_sd = {} - for f in sorted(os.listdir(vae_dir)): - if f.endswith(".safetensors"): - sd = load_file(os.path.join(vae_dir, f)) - vae_sd.update(sd) - - vae_sd = remap_vae_keys(vae_sd) - - out_dir = os.path.join(args.output_dir, "vae") - os.makedirs(out_dir, exist_ok=True) - out_path = os.path.join(out_dir, "cogvideox_vae.safetensors") - print(f"Saving VAE to {out_path} ({len(vae_sd)} keys)") - save_file(vae_sd, out_path) - - print("Done! T5-XXL text encoder does not need conversion.") - - -if __name__ == "__main__": - main() diff --git a/nodes.py b/nodes.py index f90cee732..ba2fa0246 100644 --- a/nodes.py +++ b/nodes.py @@ -2458,7 +2458,6 @@ async def init_builtin_extra_nodes(): "nodes_painter.py", "nodes_curve.py", "nodes_rtdetr.py", - "nodes_cogvideox.py", ] import_failed = [] From 220a044fabbe784cb32e3304e0d663003e70cb82 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Fri, 10 Apr 2026 15:12:32 +0200 Subject: [PATCH 04/15] Utilize use_learned_positional_embeddings in forward pass of CogVideoX. --- comfy/ldm/cogvideo/model.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/comfy/ldm/cogvideo/model.py b/comfy/ldm/cogvideo/model.py index eaf242411..af64f342a 100644 --- a/comfy/ldm/cogvideo/model.py +++ b/comfy/ldm/cogvideo/model.py @@ -182,19 +182,23 @@ class CogVideoXPatchEmbed(nn.Module): text_seq_length = text_embeds.shape[1] num_image_patches = image_embeds.shape[1] - # Compute sincos pos embedding for image patches - pos_embedding = get_3d_sincos_pos_embed( - self.dim, - (width // self.patch_size, height // self.patch_size), - num_image_patches // ((height // self.patch_size) * (width // self.patch_size)), - self.spatial_interpolation_scale, - self.temporal_interpolation_scale, - device=embeds.device, - ).reshape(-1, self.dim) + if self.use_learned_positional_embeddings: + image_pos = self.pos_embedding[ + :, self.max_text_seq_length:self.max_text_seq_length + num_image_patches + ].to(device=embeds.device, dtype=embeds.dtype) + else: + image_pos = get_3d_sincos_pos_embed( + self.dim, + (width // self.patch_size, height // self.patch_size), + num_image_patches // ((height // self.patch_size) * (width // self.patch_size)), + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + device=embeds.device, + ).reshape(1, num_image_patches, self.dim).to(dtype=embeds.dtype) # Build joint: zeros for text + sincos for image joint_pos = torch.zeros(1, text_seq_length + num_image_patches, self.dim, device=embeds.device, dtype=embeds.dtype) - joint_pos[:, text_seq_length:] = pos_embedding.to(dtype=embeds.dtype) + joint_pos[:, text_seq_length:] = image_pos embeds = embeds + joint_pos return embeds From 73bd1dd2c80769c561330ca33212963399caf5b1 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Fri, 10 Apr 2026 15:14:42 +0200 Subject: [PATCH 05/15] Fix mutable input parameter. --- comfy/ldm/cogvideo/model.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/cogvideo/model.py b/comfy/ldm/cogvideo/model.py index af64f342a..c79883fb3 100644 --- a/comfy/ldm/cogvideo/model.py +++ b/comfy/ldm/cogvideo/model.py @@ -260,7 +260,9 @@ class CogVideoXBlock(nn.Module): self.ff_proj = operations.Linear(dim, inner_dim, bias=ff_bias, device=device, dtype=dtype) self.ff_out = operations.Linear(inner_dim, dim, bias=ff_bias, device=device, dtype=dtype) - def forward(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=None, transformer_options={}): + def forward(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=None, transformer_options=None): + if transformer_options is None: + transformer_options = {} text_seq_length = encoder_hidden_states.size(1) # Norm & modulate @@ -428,14 +430,18 @@ class CogVideoXTransformer3DModel(nn.Module): self.temporal_interpolation_scale = temporal_interpolation_scale self.temporal_compression_ratio = temporal_compression_ratio - def forward(self, x, timestep, context, ofs=None, transformer_options={}, **kwargs): + def forward(self, x, timestep, context, ofs=None, transformer_options=None, **kwargs): + if transformer_options is None: + transformer_options = {} return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) ).execute(x, timestep, context, ofs, transformer_options, **kwargs) - def _forward(self, x, timestep, context, ofs=None, transformer_options={}, **kwargs): + def _forward(self, x, timestep, context, ofs=None, transformer_options=None, **kwargs): + if transformer_options is None: + transformer_options = {} # ComfyUI passes [B, C, T, H, W] batch_size, channels, t, h, w = x.shape From 9904f4d73fac256e39061fc6b20f6af5baba9420 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Fri, 10 Apr 2026 19:47:18 +0200 Subject: [PATCH 06/15] Fix CogVideoX concat_cond to handle temporal dimension and normalize channel count --- comfy/model_base.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 0f7e9f158..054853288 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -2001,9 +2001,27 @@ class CogVideoX(BaseModel): latent_dim = self.latent_format.latent_channels image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + + if noise.ndim == 5 and image.ndim == 5: + if image.shape[-3] < noise.shape[-3]: + image = torch.nn.functional.pad(image, (0, 0, 0, 0, 0, noise.shape[-3] - image.shape[-3]), "constant", 0) + elif image.shape[-3] > noise.shape[-3]: + image = image[:, :, :noise.shape[-3]] + for i in range(0, image.shape[1], latent_dim): image[:, i:i + latent_dim] = self.process_latent_in(image[:, i:i + latent_dim]) image = utils.resize_to_batch_size(image, noise.shape[0]) + + if image.shape[1] > extra_channels: + image = image[:, :extra_channels] + elif image.shape[1] < extra_channels: + repeats = extra_channels // image.shape[1] + remainder = extra_channels % image.shape[1] + parts = [image] * repeats + if remainder > 0: + parts.append(image[:, :remainder]) + image = torch.cat(parts, dim=1) + return image def extra_conds(self, **kwargs): From cee57f6827462867db37d84bce8fcec1c9a46524 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Fri, 10 Apr 2026 19:50:16 +0200 Subject: [PATCH 07/15] Add CogVideoX 1.5 geometry defaults to I2V path --- comfy/supported_models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 8fb3967ac..1c8a09fa1 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1827,6 +1827,10 @@ class CogVideoX_I2V(CogVideoX_T2V): } def get_model(self, state_dict, prefix="", device=None): + if self.unet_config.get("patch_size_t") is not None: + self.unet_config.setdefault("sample_height", 96) + self.unet_config.setdefault("sample_width", 170) + self.unet_config.setdefault("sample_frames", 81) out = model_base.CogVideoX(self, image_to_video=True, device=device) return out From 541f26ae237e076727943d4bad954fa329928f48 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Fri, 10 Apr 2026 20:24:57 +0200 Subject: [PATCH 08/15] Fixup ruff. --- comfy/model_detection.py | 1 - comfy/sd.py | 2 +- comfy/supported_models.py | 3 --- 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 62681881c..f6095581b 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -506,7 +506,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): # Detect in_channels from patch_embed patch_proj_key = '{}patch_embed.proj.weight'.format(key_prefix) - patch_proj_linear_key = '{}patch_embed.proj.weight'.format(key_prefix) if patch_proj_key in state_dict_keys: w = state_dict[patch_proj_key] if w.ndim == 4: diff --git a/comfy/sd.py b/comfy/sd.py index 7f730c2ed..42e3b1e41 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -17,6 +17,7 @@ import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae import comfy.ldm.ace.vae.music_dcae_pipeline +import comfy.ldm.cogvideo.vae import comfy.ldm.hunyuan_video.vae import comfy.ldm.mmaudio.vae.autoencoder import comfy.pixel_space_convert @@ -652,7 +653,6 @@ class VAE: self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) elif "decoder.conv_in.conv.weight" in sd and "decoder.mid_block.resnets.0.norm1.norm_layer.weight" in sd: # CogVideoX VAE - import comfy.ldm.cogvideo.vae self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) self.upscale_index_formula = (4, 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1c8a09fa1..c96168b84 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1811,9 +1811,6 @@ class CogVideoX_T2V(supported_models_base.BASE): return out def clip_target(self, state_dict={}): - pref = self.text_encoder_key_prefix[0] - t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) - class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226) From 174f68885c8b78d278c7c860a0e215d8db1cf7de Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Tue, 14 Apr 2026 14:58:06 +0200 Subject: [PATCH 09/15] Remove vae_backup.py --- comfy/ldm/cogvideo/vae_backup.py | 485 ------------------------------- 1 file changed, 485 deletions(-) delete mode 100644 comfy/ldm/cogvideo/vae_backup.py diff --git a/comfy/ldm/cogvideo/vae_backup.py b/comfy/ldm/cogvideo/vae_backup.py deleted file mode 100644 index 47254b672..000000000 --- a/comfy/ldm/cogvideo/vae_backup.py +++ /dev/null @@ -1,485 +0,0 @@ -# CogVideoX VAE - ported to ComfyUI native ops -# Architecture reference: diffusers AutoencoderKLCogVideoX -# Style reference: comfy/ldm/wan/vae.py - -import numpy as np - -import torch -import torch.nn as nn -import torch.nn.functional as F - -import comfy.ops -ops = comfy.ops.disable_weight_init - - -class SafeConv3d(nn.Conv3d): - """3D convolution that splits large inputs along temporal dim to avoid OOM.""" - def forward(self, x): - mem = x.shape[0] * x.shape[1] * x.shape[2] * x.shape[3] * x.shape[4] * 2 / 1024**3 - if mem > 2 and x.shape[2] >= self.kernel_size[0]: - kernel_t = self.kernel_size[0] - parts = int(mem / 2) + 1 - # Ensure each chunk has at least kernel_t frames - max_parts = max(1, x.shape[2] // kernel_t) - parts = min(parts, max_parts) - if parts <= 1: - return super().forward(x) - chunks = torch.chunk(x, parts, dim=2) - if kernel_t > 1: - chunks = [chunks[0]] + [ - torch.cat((chunks[i - 1][:, :, -kernel_t + 1:], chunks[i]), dim=2) - for i in range(1, len(chunks)) - ] - out = [] - for chunk in chunks: - out.append(super().forward(chunk)) - return torch.cat(out, dim=2) - return super().forward(x) - - -class CausalConv3d(nn.Module): - """Causal 3D convolution with temporal padding.""" - def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, pad_mode="constant"): - super().__init__() - if isinstance(kernel_size, int): - kernel_size = (kernel_size,) * 3 - - time_kernel, height_kernel, width_kernel = kernel_size - time_pad = time_kernel - 1 - height_pad = (height_kernel - 1) // 2 - width_pad = (width_kernel - 1) // 2 - - self.pad_mode = pad_mode - self.time_pad = time_pad - self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0) - self.const_padding = (0, width_pad, height_pad) - self.time_kernel_size = time_kernel - - stride = stride if isinstance(stride, tuple) else (stride, 1, 1) - dilation = (dilation, 1, 1) - self.conv = SafeConv3d( - in_channels, out_channels, kernel_size, - stride=stride, dilation=dilation, - padding=0 if pad_mode == "replicate" else self.const_padding, - ) - - def forward(self, x, conv_cache=None): - if self.pad_mode == "replicate": - x = F.pad(x, self.time_causal_padding, mode="replicate") - conv_cache = None - else: - kernel_t = self.time_kernel_size - if kernel_t > 1: - cached = [conv_cache] if conv_cache is not None else [x[:, :, :1]] * (kernel_t - 1) - x = torch.cat(cached + [x], dim=2) - conv_cache = x[:, :, -self.time_kernel_size + 1:].clone() if self.time_kernel_size > 1 else None - - out = self.conv(x) - return out, conv_cache - - -class SpatialNorm3D(nn.Module): - """Spatially conditioned normalization.""" - def __init__(self, f_channels, zq_channels, groups=32): - super().__init__() - self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) - self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) - self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) - - def forward(self, f, zq, conv_cache=None): - new_cache = {} - conv_cache = conv_cache or {} - - if f.shape[2] > 1 and f.shape[2] % 2 == 1: - f_first, f_rest = f[:, :, :1], f[:, :, 1:] - z_first, z_rest = zq[:, :, :1], zq[:, :, 1:] - z_first = F.interpolate(z_first, size=f_first.shape[-3:]) - z_rest = F.interpolate(z_rest, size=f_rest.shape[-3:]) - zq = torch.cat([z_first, z_rest], dim=2) - else: - zq = F.interpolate(zq, size=f.shape[-3:]) - - conv_y, new_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y")) - conv_b, new_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b")) - - return self.norm_layer(f) * conv_y + conv_b, new_cache - - -class ResnetBlock3D(nn.Module): - """3D ResNet block with optional spatial norm.""" - def __init__(self, in_channels, out_channels=None, temb_channels=512, groups=32, - eps=1e-6, act_fn="silu", spatial_norm_dim=None, pad_mode="first"): - super().__init__() - out_channels = out_channels or in_channels - self.in_channels = in_channels - self.out_channels = out_channels - self.spatial_norm_dim = spatial_norm_dim - - if act_fn == "silu": - self.nonlinearity = nn.SiLU() - elif act_fn == "swish": - self.nonlinearity = nn.SiLU() - else: - self.nonlinearity = nn.SiLU() - - if spatial_norm_dim is None: - self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) - self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) - else: - self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups) - self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups) - - self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) - - if temb_channels > 0: - self.temb_proj = nn.Linear(temb_channels, out_channels) - - self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode) - - if in_channels != out_channels: - self.conv_shortcut = SafeConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - else: - self.conv_shortcut = None - - def forward(self, x, temb=None, zq=None, conv_cache=None): - new_cache = {} - conv_cache = conv_cache or {} - residual = x - - if zq is not None: - x, new_cache["norm1"] = self.norm1(x, zq, conv_cache=conv_cache.get("norm1")) - else: - x = self.norm1(x) - - x = self.nonlinearity(x) - x, new_cache["conv1"] = self.conv1(x, conv_cache=conv_cache.get("conv1")) - - if temb is not None and hasattr(self, "temb_proj"): - x = x + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] - - if zq is not None: - x, new_cache["norm2"] = self.norm2(x, zq, conv_cache=conv_cache.get("norm2")) - else: - x = self.norm2(x) - - x = self.nonlinearity(x) - x, new_cache["conv2"] = self.conv2(x, conv_cache=conv_cache.get("conv2")) - - if self.conv_shortcut is not None: - residual = self.conv_shortcut(residual) - - return x + residual, new_cache - - -class Downsample3D(nn.Module): - """3D downsampling with optional temporal compression.""" - def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False): - super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) - self.compress_time = compress_time - - def forward(self, x): - if self.compress_time: - b, c, t, h, w = x.shape - x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t) - if t % 2 == 1: - x_first, x_rest = x[..., 0], x[..., 1:] - if x_rest.shape[-1] > 0: - x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) - x = torch.cat([x_first[..., None], x_rest], dim=-1) - x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) - else: - x = F.avg_pool1d(x, kernel_size=2, stride=2) - x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) - - pad = (0, 1, 0, 1) - x = F.pad(x, pad, mode="constant", value=0) - b, c, t, h, w = x.shape - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = self.conv(x) - x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) - return x - - -class Upsample3D(nn.Module): - """3D upsampling with optional temporal decompression.""" - def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False): - super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) - self.compress_time = compress_time - - def forward(self, x): - if self.compress_time: - if x.shape[2] > 1 and x.shape[2] % 2 == 1: - x_first, x_rest = x[:, :, 0], x[:, :, 1:] - x_first = F.interpolate(x_first, scale_factor=2.0) - x_rest = F.interpolate(x_rest, scale_factor=2.0) - x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) - elif x.shape[2] > 1: - x = F.interpolate(x, scale_factor=2.0) - else: - x = x.squeeze(2) - x = F.interpolate(x, scale_factor=2.0) - x = x[:, :, None, :, :] - else: - b, c, t, h, w = x.shape - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = F.interpolate(x, scale_factor=2.0) - x = x.reshape(b, t, c, *x.shape[2:]).permute(0, 2, 1, 3, 4) - - b, c, t, h, w = x.shape - x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) - x = self.conv(x) - x = x.reshape(b, t, *x.shape[1:]).permute(0, 2, 1, 3, 4) - return x - - -class DownBlock3D(nn.Module): - def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1, - eps=1e-6, act_fn="silu", groups=32, add_downsample=True, - compress_time=False, pad_mode="first"): - super().__init__() - self.resnets = nn.ModuleList([ - ResnetBlock3D( - in_channels=in_channels if i == 0 else out_channels, - out_channels=out_channels, - temb_channels=temb_channels, - groups=groups, eps=eps, act_fn=act_fn, pad_mode=pad_mode, - ) - for i in range(num_layers) - ]) - self.downsamplers = nn.ModuleList([Downsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_downsample else None - - def forward(self, x, temb=None, zq=None, conv_cache=None): - new_cache = {} - conv_cache = conv_cache or {} - for i, resnet in enumerate(self.resnets): - x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) - if self.downsamplers is not None: - for ds in self.downsamplers: - x = ds(x) - return x, new_cache - - -class MidBlock3D(nn.Module): - def __init__(self, in_channels, temb_channels=0, num_layers=1, - eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=None, pad_mode="first"): - super().__init__() - self.resnets = nn.ModuleList([ - ResnetBlock3D( - in_channels=in_channels, out_channels=in_channels, - temb_channels=temb_channels, groups=groups, eps=eps, - act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode, - ) - for _ in range(num_layers) - ]) - - def forward(self, x, temb=None, zq=None, conv_cache=None): - new_cache = {} - conv_cache = conv_cache or {} - for i, resnet in enumerate(self.resnets): - x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) - return x, new_cache - - -class UpBlock3D(nn.Module): - def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1, - eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=16, - add_upsample=True, compress_time=False, pad_mode="first"): - super().__init__() - self.resnets = nn.ModuleList([ - ResnetBlock3D( - in_channels=in_channels if i == 0 else out_channels, - out_channels=out_channels, - temb_channels=temb_channels, groups=groups, eps=eps, - act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode, - ) - for i in range(num_layers) - ]) - self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_upsample else None - - def forward(self, x, temb=None, zq=None, conv_cache=None): - new_cache = {} - conv_cache = conv_cache or {} - for i, resnet in enumerate(self.resnets): - x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) - if self.upsamplers is not None: - for us in self.upsamplers: - x = us(x) - return x, new_cache - - -class Encoder3D(nn.Module): - def __init__(self, in_channels=3, out_channels=16, - block_out_channels=(128, 256, 256, 512), - layers_per_block=3, act_fn="silu", - eps=1e-6, groups=32, pad_mode="first", - temporal_compression_ratio=4): - super().__init__() - temporal_compress_level = int(np.log2(temporal_compression_ratio)) - - self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) - - self.down_blocks = nn.ModuleList() - output_channel = block_out_channels[0] - for i in range(len(block_out_channels)): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final = i == len(block_out_channels) - 1 - compress_time = i < temporal_compress_level - - self.down_blocks.append(DownBlock3D( - in_channels=input_channel, out_channels=output_channel, - temb_channels=0, num_layers=layers_per_block, - eps=eps, act_fn=act_fn, groups=groups, - add_downsample=not is_final, compress_time=compress_time, - )) - - self.mid_block = MidBlock3D( - in_channels=block_out_channels[-1], temb_channels=0, - num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode, - ) - - self.norm_out = nn.GroupNorm(groups, block_out_channels[-1], eps=1e-6) - self.conv_act = nn.SiLU() - self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode) - - def forward(self, x, conv_cache=None): - new_cache = {} - conv_cache = conv_cache or {} - - x, new_cache["conv_in"] = self.conv_in(x, conv_cache=conv_cache.get("conv_in")) - - for i, block in enumerate(self.down_blocks): - key = f"down_block_{i}" - x, new_cache[key] = block(x, None, None, conv_cache.get(key)) - - x, new_cache["mid_block"] = self.mid_block(x, None, None, conv_cache=conv_cache.get("mid_block")) - - x = self.norm_out(x) - x = self.conv_act(x) - x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out")) - - return x, new_cache - - -class Decoder3D(nn.Module): - def __init__(self, in_channels=16, out_channels=3, - block_out_channels=(128, 256, 256, 512), - layers_per_block=3, act_fn="silu", - eps=1e-6, groups=32, pad_mode="first", - temporal_compression_ratio=4): - super().__init__() - reversed_channels = list(reversed(block_out_channels)) - temporal_compress_level = int(np.log2(temporal_compression_ratio)) - - self.conv_in = CausalConv3d(in_channels, reversed_channels[0], kernel_size=3, pad_mode=pad_mode) - - self.mid_block = MidBlock3D( - in_channels=reversed_channels[0], temb_channels=0, - num_layers=2, eps=eps, act_fn=act_fn, groups=groups, - spatial_norm_dim=in_channels, pad_mode=pad_mode, - ) - - self.up_blocks = nn.ModuleList() - output_channel = reversed_channels[0] - for i in range(len(block_out_channels)): - prev_channel = output_channel - output_channel = reversed_channels[i] - is_final = i == len(block_out_channels) - 1 - compress_time = i < temporal_compress_level - - self.up_blocks.append(UpBlock3D( - in_channels=prev_channel, out_channels=output_channel, - temb_channels=0, num_layers=layers_per_block + 1, - eps=eps, act_fn=act_fn, groups=groups, - spatial_norm_dim=in_channels, - add_upsample=not is_final, compress_time=compress_time, - )) - - self.norm_out = SpatialNorm3D(reversed_channels[-1], in_channels, groups=groups) - self.conv_act = nn.SiLU() - self.conv_out = CausalConv3d(reversed_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode) - - def forward(self, sample, conv_cache=None): - new_cache = {} - conv_cache = conv_cache or {} - - x, new_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) - - x, new_cache["mid_block"] = self.mid_block(x, None, sample, conv_cache=conv_cache.get("mid_block")) - - for i, block in enumerate(self.up_blocks): - key = f"up_block_{i}" - x, new_cache[key] = block(x, None, sample, conv_cache=conv_cache.get(key)) - - x, new_cache["norm_out"] = self.norm_out(x, sample, conv_cache=conv_cache.get("norm_out")) - x = self.conv_act(x) - x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out")) - - return x, new_cache - - - -class AutoencoderKLCogVideoX(nn.Module): - """CogVideoX VAE. Spatial tiling/slicing handled by ComfyUI's VAE wrapper. - - Temporal frame batching with conv_cache is kept here since the causal 3D - convolutions need state passed between temporal chunks. - """ - - def __init__(self, - in_channels=3, out_channels=3, - block_out_channels=(128, 256, 256, 512), - latent_channels=16, layers_per_block=3, - act_fn="silu", eps=1e-6, groups=32, - temporal_compression_ratio=4, - ): - super().__init__() - self.latent_channels = latent_channels - - self.encoder = Encoder3D( - in_channels=in_channels, out_channels=latent_channels, - block_out_channels=block_out_channels, layers_per_block=layers_per_block, - act_fn=act_fn, eps=eps, groups=groups, - temporal_compression_ratio=temporal_compression_ratio, - ) - self.decoder = Decoder3D( - in_channels=latent_channels, out_channels=out_channels, - block_out_channels=block_out_channels, layers_per_block=layers_per_block, - act_fn=act_fn, eps=eps, groups=groups, - temporal_compression_ratio=temporal_compression_ratio, - ) - - self.num_latent_frames_batch_size = 2 - self.num_sample_frames_batch_size = 8 - - def encode(self, x): - t = x.shape[2] - frame_batch = self.num_sample_frames_batch_size - num_batches = max(t // frame_batch, 1) - conv_cache = None - enc = [] - for i in range(num_batches): - remaining = t % frame_batch - start = frame_batch * i + (0 if i == 0 else remaining) - end = frame_batch * (i + 1) + remaining - chunk, conv_cache = self.encoder(x[:, :, start:end], conv_cache=conv_cache) - enc.append(chunk.to(x.device)) - enc = torch.cat(enc, dim=2) - mean, _ = enc.chunk(2, dim=1) - return mean - - def decode(self, z): - t = z.shape[2] - frame_batch = self.num_latent_frames_batch_size - num_batches = max(t // frame_batch, 1) - conv_cache = None - dec = [] - for i in range(num_batches): - remaining = t % frame_batch - start = frame_batch * i + (0 if i == 0 else remaining) - end = frame_batch * (i + 1) + remaining - chunk, conv_cache = self.decoder(z[:, :, start:end], conv_cache=conv_cache) - dec.append(chunk.cpu()) - return torch.cat(dec, dim=2).to(z.device) From 3e961f9960b9a919de5769123872da42c0d8b25f Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Tue, 14 Apr 2026 14:58:37 +0200 Subject: [PATCH 10/15] Move cogvideo text encoder into a dedicated module. --- comfy/supported_models.py | 7 ++----- comfy/text_encoders/cogvideo.py | 6 ++++++ 2 files changed, 8 insertions(+), 5 deletions(-) create mode 100644 comfy/text_encoders/cogvideo.py diff --git a/comfy/supported_models.py b/comfy/supported_models.py index c96168b84..33b268ac1 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -27,6 +27,7 @@ import comfy.text_encoders.anima import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image import comfy.text_encoders.ernie +import comfy.text_encoders.cogvideo from . import supported_models_base from . import latent_formats @@ -1811,11 +1812,7 @@ class CogVideoX_T2V(supported_models_base.BASE): return out def clip_target(self, state_dict={}): - class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): - super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226) - - return supported_models_base.ClipTarget(CogVideoXT5Tokenizer, comfy.text_encoders.sd3_clip.T5XXLModel) + return supported_models_base.ClipTarget(comfy.text_encoders.cogvideo.CogVideoXT5Tokenizer, comfy.text_encoders.sd3_clip.T5XXLModel) class CogVideoX_I2V(CogVideoX_T2V): unet_config = { diff --git a/comfy/text_encoders/cogvideo.py b/comfy/text_encoders/cogvideo.py new file mode 100644 index 000000000..f1e8e3f5d --- /dev/null +++ b/comfy/text_encoders/cogvideo.py @@ -0,0 +1,6 @@ +import comfy.text_encoders.sd3_clip + + +class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226) From e962c3f846feb3717d18760ca4f39baa60d30722 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Tue, 14 Apr 2026 15:05:27 +0200 Subject: [PATCH 11/15] Cap encode chunks at the configured frame batch size. --- comfy/ldm/cogvideo/vae.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/cogvideo/vae.py b/comfy/ldm/cogvideo/vae.py index fe97f362c..dd63522b1 100644 --- a/comfy/ldm/cogvideo/vae.py +++ b/comfy/ldm/cogvideo/vae.py @@ -453,13 +453,13 @@ class AutoencoderKLCogVideoX(nn.Module): def encode(self, x): t = x.shape[2] frame_batch = self.num_sample_frames_batch_size - num_batches = max(t // frame_batch, 1) + # ceil so remainder frames get their own chunk instead of inflating the first one + num_batches = max(-(-t // frame_batch), 1) conv_cache = None enc = [] for i in range(num_batches): - remaining = t % frame_batch - start = frame_batch * i + (0 if i == 0 else remaining) - end = frame_batch * (i + 1) + remaining + start = i * frame_batch + end = min((i + 1) * frame_batch, t) chunk, conv_cache = self.encoder(x[:, :, start:end], conv_cache=conv_cache) enc.append(chunk.to(x.device)) enc = torch.cat(enc, dim=2) From 9ca7cdb17ea690ebdd96c559c1559734f3c27584 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Tue, 14 Apr 2026 17:00:28 +0200 Subject: [PATCH 12/15] Cap encode chunks fix. --- comfy/ldm/cogvideo/vae.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/comfy/ldm/cogvideo/vae.py b/comfy/ldm/cogvideo/vae.py index dd63522b1..d9672f1da 100644 --- a/comfy/ldm/cogvideo/vae.py +++ b/comfy/ldm/cogvideo/vae.py @@ -453,15 +453,21 @@ class AutoencoderKLCogVideoX(nn.Module): def encode(self, x): t = x.shape[2] frame_batch = self.num_sample_frames_batch_size - # ceil so remainder frames get their own chunk instead of inflating the first one - num_batches = max(-(-t // frame_batch), 1) + remainder = t % frame_batch conv_cache = None enc = [] - for i in range(num_batches): - start = i * frame_batch - end = min((i + 1) * frame_batch, t) - chunk, conv_cache = self.encoder(x[:, :, start:end], conv_cache=conv_cache) + + # Process remainder frames first so only the first chunk can have an + # odd temporal dimension — where Downsample3D's first-frame-special + # handling in temporal compression is actually correct. + if remainder > 0: + chunk, conv_cache = self.encoder(x[:, :, :remainder], conv_cache=conv_cache) enc.append(chunk.to(x.device)) + + for start in range(remainder, t, frame_batch): + chunk, conv_cache = self.encoder(x[:, :, start:start + frame_batch], conv_cache=conv_cache) + enc.append(chunk.to(x.device)) + enc = torch.cat(enc, dim=2) mean, _ = enc.chunk(2, dim=1) return mean From c8a843e240feca1af5436c082fc1aa53bc680e4a Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Tue, 14 Apr 2026 15:09:44 +0200 Subject: [PATCH 13/15] Avoid pre-interpolating z for the full clip at every high-res stage. --- comfy/ldm/cogvideo/vae.py | 52 ++++++++++++++++----------------------- 1 file changed, 21 insertions(+), 31 deletions(-) diff --git a/comfy/ldm/cogvideo/vae.py b/comfy/ldm/cogvideo/vae.py index d9672f1da..4f1f92d9f 100644 --- a/comfy/ldm/cogvideo/vae.py +++ b/comfy/ldm/cogvideo/vae.py @@ -522,55 +522,45 @@ class AutoencoderKLCogVideoX(nn.Module): x, _ = decoder.conv_out(x) return x - # Pre-interpolate z to each spatial resolution used by Phase 2 blocks. - # Uses the exact same interpolation logic as SpatialNorm3D so chunked - # output is identical to non-chunked. - # Determine spatial sizes: run a dummy pass to find feature map sizes, - # or compute from block structure. Simpler: compute from x's current size - # and the known upsample factor (2x per block with upsample). - z_at_res = {} # keyed by (h, w) → pre-interpolated z [B, C, t_expanded, h, w] - h, w = x.shape[3], x.shape[4] - for i in remaining_blocks: - block = decoder.up_blocks[i] - # Resnets operate at current h, w - target = (t_expanded, h, w) - if target not in z_at_res: - z_at_res[target] = _interpolate_zq(z, target) - # If block has upsample, next block's input is 2x spatial - if block.upsamplers is not None: - h, w = h * 2, w * 2 - # norm_out operates at final resolution - target = (t_expanded, h, w) - if target not in z_at_res: - z_at_res[target] = _interpolate_zq(z, target) + # Expand z temporally once to match Phase 2's time dimension. + # z stays at latent spatial resolution so this is small (~16 MB vs ~1.3 GB + # for the old approach of pre-interpolating to every pixel resolution). + z_time_expanded = _interpolate_zq(z, (t_expanded, z.shape[3], z.shape[4])) - # Process in temporal chunks + # Process in temporal chunks, interpolating spatially per-chunk to avoid + # allocating full [B, C, t_expanded, H, W] tensors at each resolution. dec_out = [] conv_caches = {} for chunk_start in range(0, t_expanded, chunk_size): chunk_end = min(chunk_start + chunk_size, t_expanded) x_chunk = x[:, :, chunk_start:chunk_end] + z_t_chunk = z_time_expanded[:, :, chunk_start:chunk_end] + z_spatial_cache = {} for i in remaining_blocks: block = decoder.up_blocks[i] cache_key = f"up_block_{i}" - # Get pre-interpolated z at the block's input spatial resolution - res_key = (t_expanded, x_chunk.shape[3], x_chunk.shape[4]) - z_chunk = z_at_res[res_key][:, :, chunk_start:chunk_end] - x_chunk, new_cache = block(x_chunk, None, z_chunk, conv_cache=conv_caches.get(cache_key)) + hw_key = (x_chunk.shape[3], x_chunk.shape[4]) + if hw_key not in z_spatial_cache: + if z_t_chunk.shape[3] == hw_key[0] and z_t_chunk.shape[4] == hw_key[1]: + z_spatial_cache[hw_key] = z_t_chunk + else: + z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1])) + x_chunk, new_cache = block(x_chunk, None, z_spatial_cache[hw_key], conv_cache=conv_caches.get(cache_key)) conv_caches[cache_key] = new_cache - # norm_out at final resolution - res_key = (t_expanded, x_chunk.shape[3], x_chunk.shape[4]) - z_chunk = z_at_res[res_key][:, :, chunk_start:chunk_end] - x_chunk, new_cache = decoder.norm_out(x_chunk, z_chunk, conv_cache=conv_caches.get("norm_out")) + hw_key = (x_chunk.shape[3], x_chunk.shape[4]) + if hw_key not in z_spatial_cache: + z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1])) + x_chunk, new_cache = decoder.norm_out(x_chunk, z_spatial_cache[hw_key], conv_cache=conv_caches.get("norm_out")) conv_caches["norm_out"] = new_cache x_chunk = decoder.conv_act(x_chunk) x_chunk, new_cache = decoder.conv_out(x_chunk, conv_cache=conv_caches.get("conv_out")) conv_caches["conv_out"] = new_cache dec_out.append(x_chunk.cpu()) + del z_spatial_cache - del x + del x, z_time_expanded return torch.cat(dec_out, dim=2).to(device) From dff15d7e5f180eadb486b3a97f68996861e6b328 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Tue, 14 Apr 2026 15:58:34 +0200 Subject: [PATCH 14/15] Fix cogvideox dtypes and ops. --- comfy/ldm/cogvideo/model.py | 2 +- comfy/ldm/cogvideo/vae.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/cogvideo/model.py b/comfy/ldm/cogvideo/model.py index c79883fb3..797eb9449 100644 --- a/comfy/ldm/cogvideo/model.py +++ b/comfy/ldm/cogvideo/model.py @@ -378,7 +378,7 @@ class CogVideoXTransformer3DModel(nn.Module): temporal_interpolation_scale=temporal_interpolation_scale, use_positional_embeddings=not use_rotary_positional_embeddings, use_learned_positional_embeddings=use_learned_positional_embeddings, - device=device, dtype=torch.float32, operations=operations, + device=device, dtype=dtype, operations=operations, ) # 2. Time embedding diff --git a/comfy/ldm/cogvideo/vae.py b/comfy/ldm/cogvideo/vae.py index 4f1f92d9f..d4e6f321e 100644 --- a/comfy/ldm/cogvideo/vae.py +++ b/comfy/ldm/cogvideo/vae.py @@ -80,7 +80,7 @@ class SpatialNorm3D(nn.Module): """Spatially conditioned normalization.""" def __init__(self, f_channels, zq_channels, groups=32): super().__init__() - self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) + self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) @@ -115,8 +115,8 @@ class ResnetBlock3D(nn.Module): self.nonlinearity = nn.SiLU() if spatial_norm_dim is None: - self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) - self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) + self.norm1 = ops.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) + self.norm2 = ops.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) else: self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups) self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups) @@ -124,7 +124,7 @@ class ResnetBlock3D(nn.Module): self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) if temb_channels > 0: - self.temb_proj = nn.Linear(temb_channels, out_channels) + self.temb_proj = ops.Linear(temb_channels, out_channels) self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode) @@ -167,7 +167,7 @@ class Downsample3D(nn.Module): """3D downsampling with optional temporal compression.""" def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False): super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) self.compress_time = compress_time def forward(self, x): @@ -197,7 +197,7 @@ class Upsample3D(nn.Module): """3D upsampling with optional temporal decompression.""" def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False): super().__init__() - self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) self.compress_time = compress_time def forward(self, x): @@ -332,7 +332,7 @@ class Encoder3D(nn.Module): num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode, ) - self.norm_out = nn.GroupNorm(groups, block_out_channels[-1], eps=1e-6) + self.norm_out = ops.GroupNorm(groups, block_out_channels[-1], eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode) From 52156edbeea7024badee84ca3bca6363e5bb5e96 Mon Sep 17 00:00:00 2001 From: Talmaj Marinc Date: Tue, 14 Apr 2026 16:51:00 +0200 Subject: [PATCH 15/15] Revert dtype to float32 to increase quality of video output. --- comfy/ldm/cogvideo/model.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/comfy/ldm/cogvideo/model.py b/comfy/ldm/cogvideo/model.py index 797eb9449..fb475ed53 100644 --- a/comfy/ldm/cogvideo/model.py +++ b/comfy/ldm/cogvideo/model.py @@ -157,12 +157,14 @@ class CogVideoXPatchEmbed(nn.Module): return joint_pos_embedding def forward(self, text_embeds, image_embeds): - text_embeds = self.text_proj(text_embeds) + input_dtype = text_embeds.dtype + text_embeds = self.text_proj(text_embeds.to(self.text_proj.weight.dtype)).to(input_dtype) batch_size, num_frames, channels, height, width = image_embeds.shape + proj_dtype = self.proj.weight.dtype if self.patch_size_t is None: image_embeds = image_embeds.reshape(-1, channels, height, width) - image_embeds = self.proj(image_embeds) + image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype) image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:]) image_embeds = image_embeds.flatten(3).transpose(2, 3) image_embeds = image_embeds.flatten(1, 2) @@ -174,7 +176,7 @@ class CogVideoXPatchEmbed(nn.Module): batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels ) image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) - image_embeds = self.proj(image_embeds) + image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype) embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous() @@ -378,7 +380,7 @@ class CogVideoXTransformer3DModel(nn.Module): temporal_interpolation_scale=temporal_interpolation_scale, use_positional_embeddings=not use_rotary_positional_embeddings, use_learned_positional_embeddings=use_learned_positional_embeddings, - device=device, dtype=dtype, operations=operations, + device=device, dtype=torch.float32, operations=operations, ) # 2. Time embedding