diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 6a57bca1c..0f4059ebe 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -783,3 +783,10 @@ class ZImagePixelSpace(ChromaRadiance): No VAE encoding/decoding — the model operates directly on RGB pixels. """ pass + +class CogVideoX(LatentFormat): + latent_channels = 16 + latent_dimensions = 3 + + def __init__(self): + self.scale_factor = 1.15258426 diff --git a/comfy/ldm/cogvideo/__init__.py b/comfy/ldm/cogvideo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/comfy/ldm/cogvideo/model.py b/comfy/ldm/cogvideo/model.py new file mode 100644 index 000000000..fb475ed53 --- /dev/null +++ b/comfy/ldm/cogvideo/model.py @@ -0,0 +1,573 @@ +# CogVideoX 3D Transformer - ported to ComfyUI native ops +# Architecture reference: diffusers CogVideoXTransformer3DModel +# Style reference: comfy/ldm/wan/model.py + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from comfy.ldm.modules.attention import optimized_attention +import comfy.patcher_extension +import comfy.ldm.common_dit + + +def _get_1d_rotary_pos_embed(dim, pos, theta=10000.0): + """Returns (cos, sin) each with shape [seq_len, dim]. + + Frequencies are computed at dim//2 resolution then repeat_interleaved + to full dim, matching CogVideoX's interleaved (real, imag) pair format. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device) / dim)) + angles = torch.outer(pos.float(), freqs.float()) + cos = angles.cos().repeat_interleave(2, dim=-1).float() + sin = angles.sin().repeat_interleave(2, dim=-1).float() + return (cos, sin) + + +def apply_rotary_emb(x, freqs_cos_sin): + """Apply CogVideoX rotary embedding to query or key tensor. + + x: [B, heads, seq_len, head_dim] + freqs_cos_sin: (cos, sin) each [seq_len, head_dim//2] + + Uses interleaved pair rotation (same as diffusers CogVideoX/Flux). + head_dim is reshaped to (-1, 2) pairs, rotated, then flattened back. + """ + cos, sin = freqs_cos_sin + cos = cos[None, None, :, :].to(x.device) + sin = sin[None, None, :, :].to(x.device) + + # Interleaved pairs: [B, H, S, D] -> [B, H, S, D//2, 2] -> (real, imag) + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + return (x.float() * cos + x_rotated.float() * sin).to(x.dtype) + + +def get_timestep_embedding(timesteps, dim, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1, max_period=10000): + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=timesteps.device) / half) + args = timesteps[:, None].float() * freqs[None] * scale + embedding = torch.cat([torch.sin(args), torch.cos(args)], dim=-1) + if flip_sin_to_cos: + embedding = torch.cat([embedding[:, half:], embedding[:, :half]], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +def get_3d_sincos_pos_embed(embed_dim, spatial_size, temporal_size, spatial_interpolation_scale=1.0, temporal_interpolation_scale=1.0, device=None): + if isinstance(spatial_size, int): + spatial_size = (spatial_size, spatial_size) + + grid_w = torch.arange(spatial_size[0], dtype=torch.float32, device=device) / spatial_interpolation_scale + grid_h = torch.arange(spatial_size[1], dtype=torch.float32, device=device) / spatial_interpolation_scale + grid_t = torch.arange(temporal_size, dtype=torch.float32, device=device) / temporal_interpolation_scale + + grid_t, grid_h, grid_w = torch.meshgrid(grid_t, grid_h, grid_w, indexing="ij") + + embed_dim_spatial = 2 * (embed_dim // 3) + embed_dim_temporal = embed_dim // 3 + + pos_embed_spatial = _get_2d_sincos_pos_embed(embed_dim_spatial, grid_h, grid_w, device=device) + pos_embed_temporal = _get_1d_sincos_pos_embed(embed_dim_temporal, grid_t[:, 0, 0], device=device) + + T, H, W = grid_t.shape + pos_embed_temporal = pos_embed_temporal.unsqueeze(1).unsqueeze(1).expand(-1, H, W, -1) + pos_embed = torch.cat([pos_embed_temporal, pos_embed_spatial], dim=-1) + + return pos_embed + + +def _get_2d_sincos_pos_embed(embed_dim, grid_h, grid_w, device=None): + T, H, W = grid_h.shape + half_dim = embed_dim // 2 + pos_h = _get_1d_sincos_pos_embed(half_dim, grid_h.reshape(-1), device=device).reshape(T, H, W, half_dim) + pos_w = _get_1d_sincos_pos_embed(half_dim, grid_w.reshape(-1), device=device).reshape(T, H, W, half_dim) + return torch.cat([pos_h, pos_w], dim=-1) + + +def _get_1d_sincos_pos_embed(embed_dim, pos, device=None): + half = embed_dim // 2 + freqs = torch.exp(-math.log(10000.0) * torch.arange(start=0, end=half, dtype=torch.float32, device=device) / half) + args = pos.float().reshape(-1)[:, None] * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if embed_dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + + +class CogVideoXPatchEmbed(nn.Module): + def __init__(self, patch_size=2, patch_size_t=None, in_channels=16, dim=1920, + text_dim=4096, bias=True, sample_width=90, sample_height=60, + sample_frames=49, temporal_compression_ratio=4, + max_text_seq_length=226, spatial_interpolation_scale=1.875, + temporal_interpolation_scale=1.0, use_positional_embeddings=True, + use_learned_positional_embeddings=True, + device=None, dtype=None, operations=None): + super().__init__() + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.dim = dim + self.sample_height = sample_height + self.sample_width = sample_width + self.sample_frames = sample_frames + self.temporal_compression_ratio = temporal_compression_ratio + self.max_text_seq_length = max_text_seq_length + self.spatial_interpolation_scale = spatial_interpolation_scale + self.temporal_interpolation_scale = temporal_interpolation_scale + self.use_positional_embeddings = use_positional_embeddings + self.use_learned_positional_embeddings = use_learned_positional_embeddings + + if patch_size_t is None: + self.proj = operations.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size, bias=bias, device=device, dtype=dtype) + else: + self.proj = operations.Linear(in_channels * patch_size * patch_size * patch_size_t, dim, device=device, dtype=dtype) + + self.text_proj = operations.Linear(text_dim, dim, device=device, dtype=dtype) + + if use_positional_embeddings or use_learned_positional_embeddings: + persistent = use_learned_positional_embeddings + pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames) + self.register_buffer("pos_embedding", pos_embedding, persistent=persistent) + + def _get_positional_embeddings(self, sample_height, sample_width, sample_frames, device=None): + post_patch_height = sample_height // self.patch_size + post_patch_width = sample_width // self.patch_size + post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1 + if self.patch_size_t is not None: + post_time_compression_frames = post_time_compression_frames // self.patch_size_t + num_patches = post_patch_height * post_patch_width * post_time_compression_frames + + pos_embedding = get_3d_sincos_pos_embed( + self.dim, + (post_patch_width, post_patch_height), + post_time_compression_frames, + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + device=device, + ) + pos_embedding = pos_embedding.reshape(-1, self.dim) + joint_pos_embedding = pos_embedding.new_zeros( + 1, self.max_text_seq_length + num_patches, self.dim, requires_grad=False + ) + joint_pos_embedding.data[:, self.max_text_seq_length:].copy_(pos_embedding) + return joint_pos_embedding + + def forward(self, text_embeds, image_embeds): + input_dtype = text_embeds.dtype + text_embeds = self.text_proj(text_embeds.to(self.text_proj.weight.dtype)).to(input_dtype) + batch_size, num_frames, channels, height, width = image_embeds.shape + + proj_dtype = self.proj.weight.dtype + if self.patch_size_t is None: + image_embeds = image_embeds.reshape(-1, channels, height, width) + image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype) + image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:]) + image_embeds = image_embeds.flatten(3).transpose(2, 3) + image_embeds = image_embeds.flatten(1, 2) + else: + p = self.patch_size + p_t = self.patch_size_t + image_embeds = image_embeds.permute(0, 1, 3, 4, 2) + image_embeds = image_embeds.reshape( + batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels + ) + image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) + image_embeds = self.proj(image_embeds.to(proj_dtype)).to(input_dtype) + + embeds = torch.cat([text_embeds, image_embeds], dim=1).contiguous() + + if self.use_positional_embeddings or self.use_learned_positional_embeddings: + text_seq_length = text_embeds.shape[1] + num_image_patches = image_embeds.shape[1] + + if self.use_learned_positional_embeddings: + image_pos = self.pos_embedding[ + :, self.max_text_seq_length:self.max_text_seq_length + num_image_patches + ].to(device=embeds.device, dtype=embeds.dtype) + else: + image_pos = get_3d_sincos_pos_embed( + self.dim, + (width // self.patch_size, height // self.patch_size), + num_image_patches // ((height // self.patch_size) * (width // self.patch_size)), + self.spatial_interpolation_scale, + self.temporal_interpolation_scale, + device=embeds.device, + ).reshape(1, num_image_patches, self.dim).to(dtype=embeds.dtype) + + # Build joint: zeros for text + sincos for image + joint_pos = torch.zeros(1, text_seq_length + num_image_patches, self.dim, device=embeds.device, dtype=embeds.dtype) + joint_pos[:, text_seq_length:] = image_pos + embeds = embeds + joint_pos + + return embeds + + +class CogVideoXLayerNormZero(nn.Module): + def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5, bias=True, + device=None, dtype=None, operations=None): + super().__init__() + self.silu = nn.SiLU() + self.linear = operations.Linear(time_dim, 6 * dim, bias=bias, device=device, dtype=dtype) + self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) + + def forward(self, hidden_states, encoder_hidden_states, temb): + shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) + hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] + encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] + return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] + + +class CogVideoXAdaLayerNorm(nn.Module): + def __init__(self, time_dim, dim, elementwise_affine=True, eps=1e-5, + device=None, dtype=None, operations=None): + super().__init__() + self.silu = nn.SiLU() + self.linear = operations.Linear(time_dim, 2 * dim, device=device, dtype=dtype) + self.norm = operations.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) + + def forward(self, x, temb): + temb = self.linear(self.silu(temb)) + shift, scale = temb.chunk(2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +class CogVideoXBlock(nn.Module): + def __init__(self, dim, num_heads, head_dim, time_dim, + eps=1e-5, ff_inner_dim=None, ff_bias=True, + device=None, dtype=None, operations=None): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.norm1 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations) + + # Self-attention (joint text + latent) + self.q = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) + self.k = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) + self.v = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) + self.norm_q = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype) + self.norm_k = operations.LayerNorm(head_dim, eps=1e-6, elementwise_affine=True, device=device, dtype=dtype) + self.attn_out = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype) + + self.norm2 = CogVideoXLayerNormZero(time_dim, dim, eps=eps, device=device, dtype=dtype, operations=operations) + + # Feed-forward (GELU approximate) + inner_dim = ff_inner_dim or dim * 4 + self.ff_proj = operations.Linear(dim, inner_dim, bias=ff_bias, device=device, dtype=dtype) + self.ff_out = operations.Linear(inner_dim, dim, bias=ff_bias, device=device, dtype=dtype) + + def forward(self, hidden_states, encoder_hidden_states, temb, image_rotary_emb=None, transformer_options=None): + if transformer_options is None: + transformer_options = {} + text_seq_length = encoder_hidden_states.size(1) + + # Norm & modulate + norm_hidden, norm_encoder, gate_msa, enc_gate_msa = self.norm1(hidden_states, encoder_hidden_states, temb) + + # Joint self-attention + qkv_input = torch.cat([norm_encoder, norm_hidden], dim=1) + b, s, _ = qkv_input.shape + n, d = self.num_heads, self.head_dim + + q = self.q(qkv_input).view(b, s, n, d) + k = self.k(qkv_input).view(b, s, n, d) + v = self.v(qkv_input) + + q = self.norm_q(q).view(b, s, n, d) + k = self.norm_k(k).view(b, s, n, d) + + # Apply rotary embeddings to image tokens only (diffusers format: [B, heads, seq, head_dim]) + if image_rotary_emb is not None: + q_img = q[:, text_seq_length:].transpose(1, 2) # [B, heads, img_seq, head_dim] + k_img = k[:, text_seq_length:].transpose(1, 2) + q_img = apply_rotary_emb(q_img, image_rotary_emb) + k_img = apply_rotary_emb(k_img, image_rotary_emb) + q = torch.cat([q[:, :text_seq_length], q_img.transpose(1, 2)], dim=1) + k = torch.cat([k[:, :text_seq_length], k_img.transpose(1, 2)], dim=1) + + attn_out = optimized_attention( + q.reshape(b, s, n * d), + k.reshape(b, s, n * d), + v, + heads=self.num_heads, + transformer_options=transformer_options, + ) + + attn_out = self.attn_out(attn_out) + + attn_encoder, attn_hidden = attn_out.split([text_seq_length, s - text_seq_length], dim=1) + + hidden_states = hidden_states + gate_msa * attn_hidden + encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder + + # Norm & modulate for FF + norm_hidden, norm_encoder, gate_ff, enc_gate_ff = self.norm2(hidden_states, encoder_hidden_states, temb) + + # Feed-forward (GELU on concatenated text + latent) + ff_input = torch.cat([norm_encoder, norm_hidden], dim=1) + ff_output = self.ff_out(F.gelu(self.ff_proj(ff_input), approximate="tanh")) + + hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:] + encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length] + + return hidden_states, encoder_hidden_states + + +class CogVideoXTransformer3DModel(nn.Module): + def __init__(self, + num_attention_heads=30, + attention_head_dim=64, + in_channels=16, + out_channels=16, + flip_sin_to_cos=True, + freq_shift=0, + time_embed_dim=512, + ofs_embed_dim=None, + text_embed_dim=4096, + num_layers=30, + dropout=0.0, + attention_bias=True, + sample_width=90, + sample_height=60, + sample_frames=49, + patch_size=2, + patch_size_t=None, + temporal_compression_ratio=4, + max_text_seq_length=226, + spatial_interpolation_scale=1.875, + temporal_interpolation_scale=1.0, + use_rotary_positional_embeddings=False, + use_learned_positional_embeddings=False, + patch_bias=True, + image_model=None, + device=None, + dtype=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + dim = num_attention_heads * attention_head_dim + self.dim = dim + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + self.in_channels = in_channels + self.out_channels = out_channels + self.patch_size = patch_size + self.patch_size_t = patch_size_t + self.max_text_seq_length = max_text_seq_length + self.use_rotary_positional_embeddings = use_rotary_positional_embeddings + + # 1. Patch embedding + self.patch_embed = CogVideoXPatchEmbed( + patch_size=patch_size, + patch_size_t=patch_size_t, + in_channels=in_channels, + dim=dim, + text_dim=text_embed_dim, + bias=patch_bias, + sample_width=sample_width, + sample_height=sample_height, + sample_frames=sample_frames, + temporal_compression_ratio=temporal_compression_ratio, + max_text_seq_length=max_text_seq_length, + spatial_interpolation_scale=spatial_interpolation_scale, + temporal_interpolation_scale=temporal_interpolation_scale, + use_positional_embeddings=not use_rotary_positional_embeddings, + use_learned_positional_embeddings=use_learned_positional_embeddings, + device=device, dtype=torch.float32, operations=operations, + ) + + # 2. Time embedding + self.time_proj_dim = dim + self.time_proj_flip = flip_sin_to_cos + self.time_proj_shift = freq_shift + self.time_embedding_linear_1 = operations.Linear(dim, time_embed_dim, device=device, dtype=dtype) + self.time_embedding_act = nn.SiLU() + self.time_embedding_linear_2 = operations.Linear(time_embed_dim, time_embed_dim, device=device, dtype=dtype) + + # Optional OFS embedding (CogVideoX 1.5 I2V) + self.ofs_proj_dim = ofs_embed_dim + if ofs_embed_dim: + self.ofs_embedding_linear_1 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype) + self.ofs_embedding_act = nn.SiLU() + self.ofs_embedding_linear_2 = operations.Linear(ofs_embed_dim, ofs_embed_dim, device=device, dtype=dtype) + else: + self.ofs_embedding_linear_1 = None + + # 3. Transformer blocks + self.blocks = nn.ModuleList([ + CogVideoXBlock( + dim=dim, + num_heads=num_attention_heads, + head_dim=attention_head_dim, + time_dim=time_embed_dim, + eps=1e-5, + device=device, dtype=dtype, operations=operations, + ) + for _ in range(num_layers) + ]) + + self.norm_final = operations.LayerNorm(dim, eps=1e-5, elementwise_affine=True, device=device, dtype=dtype) + + # 4. Output + self.norm_out = CogVideoXAdaLayerNorm( + time_dim=time_embed_dim, dim=dim, eps=1e-5, + device=device, dtype=dtype, operations=operations, + ) + + if patch_size_t is None: + output_dim = patch_size * patch_size * out_channels + else: + output_dim = patch_size * patch_size * patch_size_t * out_channels + + self.proj_out = operations.Linear(dim, output_dim, device=device, dtype=dtype) + + self.spatial_interpolation_scale = spatial_interpolation_scale + self.temporal_interpolation_scale = temporal_interpolation_scale + self.temporal_compression_ratio = temporal_compression_ratio + + def forward(self, x, timestep, context, ofs=None, transformer_options=None, **kwargs): + if transformer_options is None: + transformer_options = {} + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, ofs, transformer_options, **kwargs) + + def _forward(self, x, timestep, context, ofs=None, transformer_options=None, **kwargs): + if transformer_options is None: + transformer_options = {} + # ComfyUI passes [B, C, T, H, W] + batch_size, channels, t, h, w = x.shape + + # Pad to patch size (temporal + spatial), same pattern as WAN + p_t = self.patch_size_t if self.patch_size_t is not None else 1 + x = comfy.ldm.common_dit.pad_to_patch_size(x, (p_t, self.patch_size, self.patch_size)) + + # CogVideoX expects [B, T, C, H, W] + x = x.permute(0, 2, 1, 3, 4) + batch_size, num_frames, channels, height, width = x.shape + + # Time embedding + t_emb = get_timestep_embedding(timestep, self.time_proj_dim, self.time_proj_flip, self.time_proj_shift) + t_emb = t_emb.to(dtype=x.dtype) + emb = self.time_embedding_linear_2(self.time_embedding_act(self.time_embedding_linear_1(t_emb))) + + if self.ofs_embedding_linear_1 is not None and ofs is not None: + ofs_emb = get_timestep_embedding(ofs, self.ofs_proj_dim, self.time_proj_flip, self.time_proj_shift) + ofs_emb = ofs_emb.to(dtype=x.dtype) + ofs_emb = self.ofs_embedding_linear_2(self.ofs_embedding_act(self.ofs_embedding_linear_1(ofs_emb))) + emb = emb + ofs_emb + + # Patch embedding + hidden_states = self.patch_embed(context, x) + + text_seq_length = context.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] + hidden_states = hidden_states[:, text_seq_length:] + + # Rotary embeddings (if used) + image_rotary_emb = None + if self.use_rotary_positional_embeddings: + post_patch_height = height // self.patch_size + post_patch_width = width // self.patch_size + if self.patch_size_t is None: + post_time = num_frames + else: + post_time = num_frames // self.patch_size_t + image_rotary_emb = self._get_rotary_emb(post_patch_height, post_patch_width, post_time, device=x.device) + + # Transformer blocks + for i, block in enumerate(self.blocks): + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, + ) + + hidden_states = self.norm_final(hidden_states) + + # Output projection + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # Unpatchify + p = self.patch_size + p_t = self.patch_size_t + + if p_t is None: + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + else: + output = hidden_states.reshape( + batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p + ) + output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2) + + # Back to ComfyUI format [B, C, T, H, W] and crop padding + output = output.permute(0, 2, 1, 3, 4)[:, :, :t, :h, :w] + return output + + def _get_rotary_emb(self, h, w, t, device): + """Compute CogVideoX 3D rotary positional embeddings. + + For CogVideoX 1.5 (patch_size_t != None): uses "slice" mode — grid positions + are integer arange computed at max_size, then sliced to actual size. + For CogVideoX 1.0 (patch_size_t == None): uses "linspace" mode with crop coords + scaled by spatial_interpolation_scale. + """ + d = self.attention_head_dim + dim_t = d // 4 + dim_h = d // 8 * 3 + dim_w = d // 8 * 3 + + if self.patch_size_t is not None: + # CogVideoX 1.5: "slice" mode — positions are simple integer indices + # Compute at max(sample_size, actual_size) then slice to actual + base_h = self.patch_embed.sample_height // self.patch_size + base_w = self.patch_embed.sample_width // self.patch_size + max_h = max(base_h, h) + max_w = max(base_w, w) + + grid_h = torch.arange(max_h, device=device, dtype=torch.float32) + grid_w = torch.arange(max_w, device=device, dtype=torch.float32) + grid_t = torch.arange(t, device=device, dtype=torch.float32) + else: + # CogVideoX 1.0: "linspace" mode with interpolation scale + grid_h = torch.linspace(0, h - 1, h, device=device, dtype=torch.float32) * self.spatial_interpolation_scale + grid_w = torch.linspace(0, w - 1, w, device=device, dtype=torch.float32) * self.spatial_interpolation_scale + grid_t = torch.arange(t, device=device, dtype=torch.float32) + + freqs_t = _get_1d_rotary_pos_embed(dim_t, grid_t) + freqs_h = _get_1d_rotary_pos_embed(dim_h, grid_h) + freqs_w = _get_1d_rotary_pos_embed(dim_w, grid_w) + + t_cos, t_sin = freqs_t + h_cos, h_sin = freqs_h + w_cos, w_sin = freqs_w + + # Slice to actual size (for "slice" mode where grids may be larger) + t_cos, t_sin = t_cos[:t], t_sin[:t] + h_cos, h_sin = h_cos[:h], h_sin[:h] + w_cos, w_sin = w_cos[:w], w_sin[:w] + + # Broadcast and concatenate into [T*H*W, head_dim] + t_cos = t_cos[:, None, None, :].expand(-1, h, w, -1) + t_sin = t_sin[:, None, None, :].expand(-1, h, w, -1) + h_cos = h_cos[None, :, None, :].expand(t, -1, w, -1) + h_sin = h_sin[None, :, None, :].expand(t, -1, w, -1) + w_cos = w_cos[None, None, :, :].expand(t, h, -1, -1) + w_sin = w_sin[None, None, :, :].expand(t, h, -1, -1) + + cos = torch.cat([t_cos, h_cos, w_cos], dim=-1).reshape(t * h * w, -1) + sin = torch.cat([t_sin, h_sin, w_sin], dim=-1).reshape(t * h * w, -1) + return (cos, sin) diff --git a/comfy/ldm/cogvideo/vae.py b/comfy/ldm/cogvideo/vae.py new file mode 100644 index 000000000..d4e6f321e --- /dev/null +++ b/comfy/ldm/cogvideo/vae.py @@ -0,0 +1,566 @@ +# CogVideoX VAE - ported to ComfyUI native ops +# Architecture reference: diffusers AutoencoderKLCogVideoX +# Style reference: comfy/ldm/wan/vae.py + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops +ops = comfy.ops.disable_weight_init + + +class CausalConv3d(nn.Module): + """Causal 3D convolution with temporal padding. + + Uses comfy.ops.Conv3d with autopad='causal_zero' fast path: when input has + a single temporal frame and no cache, the 3D conv weight is sliced to act + as a 2D conv, avoiding computation on zero-padded temporal dimensions. + """ + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, pad_mode="constant"): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + + time_kernel, height_kernel, width_kernel = kernel_size + self.time_kernel_size = time_kernel + self.pad_mode = pad_mode + + height_pad = (height_kernel - 1) // 2 + width_pad = (width_kernel - 1) // 2 + self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_kernel - 1, 0) + + stride = stride if isinstance(stride, tuple) else (stride, 1, 1) + dilation = (dilation, 1, 1) + self.conv = ops.Conv3d( + in_channels, out_channels, kernel_size, + stride=stride, dilation=dilation, + padding=(0, height_pad, width_pad), + ) + + def forward(self, x, conv_cache=None): + if self.pad_mode == "replicate": + x = F.pad(x, self.time_causal_padding, mode="replicate") + conv_cache = None + else: + kernel_t = self.time_kernel_size + if kernel_t > 1: + if conv_cache is None and x.shape[2] == 1: + # Fast path: single frame, no cache. All temporal padding + # frames are copies of the input (replicate-style), so the + # 3D conv reduces to a 2D conv with summed temporal kernel. + w = comfy.ops.cast_to_input(self.conv.weight, x) + b = comfy.ops.cast_to_input(self.conv.bias, x) if self.conv.bias is not None else None + w2d = w.sum(dim=2, keepdim=True) + out = F.conv3d(x, w2d, b, + self.conv.stride, self.conv.padding, + self.conv.dilation, self.conv.groups) + return out, None + cached = [conv_cache] if conv_cache is not None else [x[:, :, :1]] * (kernel_t - 1) + x = torch.cat(cached + [x], dim=2) + conv_cache = x[:, :, -self.time_kernel_size + 1:].clone() if self.time_kernel_size > 1 else None + + out = self.conv(x) + return out, conv_cache + + +def _interpolate_zq(zq, target_size): + """Interpolate latent z to target (T, H, W), matching CogVideoX's first-frame-special handling.""" + t = target_size[0] + if t > 1 and t % 2 == 1: + z_first = F.interpolate(zq[:, :, :1], size=(1, target_size[1], target_size[2])) + z_rest = F.interpolate(zq[:, :, 1:], size=(t - 1, target_size[1], target_size[2])) + return torch.cat([z_first, z_rest], dim=2) + return F.interpolate(zq, size=target_size) + + +class SpatialNorm3D(nn.Module): + """Spatially conditioned normalization.""" + def __init__(self, f_channels, zq_channels, groups=32): + super().__init__() + self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True) + self.conv_y = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + self.conv_b = CausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1) + + def forward(self, f, zq, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + + if zq.shape[-3:] != f.shape[-3:]: + zq = _interpolate_zq(zq, f.shape[-3:]) + + conv_y, new_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y")) + conv_b, new_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b")) + + return self.norm_layer(f) * conv_y + conv_b, new_cache + + +class ResnetBlock3D(nn.Module): + """3D ResNet block with optional spatial norm.""" + def __init__(self, in_channels, out_channels=None, temb_channels=512, groups=32, + eps=1e-6, act_fn="silu", spatial_norm_dim=None, pad_mode="first"): + super().__init__() + out_channels = out_channels or in_channels + self.in_channels = in_channels + self.out_channels = out_channels + self.spatial_norm_dim = spatial_norm_dim + + if act_fn == "silu": + self.nonlinearity = nn.SiLU() + elif act_fn == "swish": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = nn.SiLU() + + if spatial_norm_dim is None: + self.norm1 = ops.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps) + self.norm2 = ops.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps) + else: + self.norm1 = SpatialNorm3D(in_channels, spatial_norm_dim, groups=groups) + self.norm2 = SpatialNorm3D(out_channels, spatial_norm_dim, groups=groups) + + self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + + if temb_channels > 0: + self.temb_proj = ops.Linear(temb_channels, out_channels) + + self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, pad_mode=pad_mode) + + if in_channels != out_channels: + self.conv_shortcut = ops.Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + else: + self.conv_shortcut = None + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + residual = x + + if zq is not None: + x, new_cache["norm1"] = self.norm1(x, zq, conv_cache=conv_cache.get("norm1")) + else: + x = self.norm1(x) + + x = self.nonlinearity(x) + x, new_cache["conv1"] = self.conv1(x, conv_cache=conv_cache.get("conv1")) + + if temb is not None and hasattr(self, "temb_proj"): + x = x + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if zq is not None: + x, new_cache["norm2"] = self.norm2(x, zq, conv_cache=conv_cache.get("norm2")) + else: + x = self.norm2(x) + + x = self.nonlinearity(x) + x, new_cache["conv2"] = self.conv2(x, conv_cache=conv_cache.get("conv2")) + + if self.conv_shortcut is not None: + residual = self.conv_shortcut(residual) + + return x + residual, new_cache + + +class Downsample3D(nn.Module): + """3D downsampling with optional temporal compression.""" + def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=0, compress_time=False): + super().__init__() + self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time: + b, c, t, h, w = x.shape + x = x.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t) + if t % 2 == 1: + x_first, x_rest = x[..., 0], x[..., 1:] + if x_rest.shape[-1] > 0: + x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) + x = torch.cat([x_first[..., None], x_rest], dim=-1) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + else: + x = F.avg_pool1d(x, kernel_size=2, stride=2) + x = x.reshape(b, h, w, c, x.shape[-1]).permute(0, 3, 4, 1, 2) + + pad = (0, 1, 0, 1) + x = F.pad(x, pad, mode="constant", value=0) + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) + return x + + +class Upsample3D(nn.Module): + """3D upsampling with optional temporal decompression.""" + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, compress_time=False): + super().__init__() + self.conv = ops.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) + self.compress_time = compress_time + + def forward(self, x): + if self.compress_time: + if x.shape[2] > 1 and x.shape[2] % 2 == 1: + x_first, x_rest = x[:, :, 0], x[:, :, 1:] + x_first = F.interpolate(x_first, scale_factor=2.0) + x_rest = F.interpolate(x_rest, scale_factor=2.0) + x = torch.cat([x_first[:, :, None, :, :], x_rest], dim=2) + elif x.shape[2] > 1: + x = F.interpolate(x, scale_factor=2.0) + else: + x = x.squeeze(2) + x = F.interpolate(x, scale_factor=2.0) + x = x[:, :, None, :, :] + else: + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = F.interpolate(x, scale_factor=2.0) + x = x.reshape(b, t, c, *x.shape[2:]).permute(0, 2, 1, 3, 4) + + b, c, t, h, w = x.shape + x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w) + x = self.conv(x) + x = x.reshape(b, t, *x.shape[1:]).permute(0, 2, 1, 3, 4) + return x + + +class DownBlock3D(nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1, + eps=1e-6, act_fn="silu", groups=32, add_downsample=True, + compress_time=False, pad_mode="first"): + super().__init__() + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + groups=groups, eps=eps, act_fn=act_fn, pad_mode=pad_mode, + ) + for i in range(num_layers) + ]) + self.downsamplers = nn.ModuleList([Downsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_downsample else None + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + for i, resnet in enumerate(self.resnets): + x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) + if self.downsamplers is not None: + for ds in self.downsamplers: + x = ds(x) + return x, new_cache + + +class MidBlock3D(nn.Module): + def __init__(self, in_channels, temb_channels=0, num_layers=1, + eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=None, pad_mode="first"): + super().__init__() + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels, out_channels=in_channels, + temb_channels=temb_channels, groups=groups, eps=eps, + act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode, + ) + for _ in range(num_layers) + ]) + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + for i, resnet in enumerate(self.resnets): + x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) + return x, new_cache + + +class UpBlock3D(nn.Module): + def __init__(self, in_channels, out_channels, temb_channels=0, num_layers=1, + eps=1e-6, act_fn="silu", groups=32, spatial_norm_dim=16, + add_upsample=True, compress_time=False, pad_mode="first"): + super().__init__() + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels if i == 0 else out_channels, + out_channels=out_channels, + temb_channels=temb_channels, groups=groups, eps=eps, + act_fn=act_fn, spatial_norm_dim=spatial_norm_dim, pad_mode=pad_mode, + ) + for i in range(num_layers) + ]) + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels, compress_time=compress_time)]) if add_upsample else None + + def forward(self, x, temb=None, zq=None, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + for i, resnet in enumerate(self.resnets): + x, new_cache[f"resnet_{i}"] = resnet(x, temb, zq, conv_cache=conv_cache.get(f"resnet_{i}")) + if self.upsamplers is not None: + for us in self.upsamplers: + x = us(x) + return x, new_cache + + +class Encoder3D(nn.Module): + def __init__(self, in_channels=3, out_channels=16, + block_out_channels=(128, 256, 256, 512), + layers_per_block=3, act_fn="silu", + eps=1e-6, groups=32, pad_mode="first", + temporal_compression_ratio=4): + super().__init__() + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode) + + self.down_blocks = nn.ModuleList() + output_channel = block_out_channels[0] + for i in range(len(block_out_channels)): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + self.down_blocks.append(DownBlock3D( + in_channels=input_channel, out_channels=output_channel, + temb_channels=0, num_layers=layers_per_block, + eps=eps, act_fn=act_fn, groups=groups, + add_downsample=not is_final, compress_time=compress_time, + )) + + self.mid_block = MidBlock3D( + in_channels=block_out_channels[-1], temb_channels=0, + num_layers=2, eps=eps, act_fn=act_fn, groups=groups, pad_mode=pad_mode, + ) + + self.norm_out = ops.GroupNorm(groups, block_out_channels[-1], eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode) + + def forward(self, x, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + + x, new_cache["conv_in"] = self.conv_in(x, conv_cache=conv_cache.get("conv_in")) + + for i, block in enumerate(self.down_blocks): + key = f"down_block_{i}" + x, new_cache[key] = block(x, None, None, conv_cache.get(key)) + + x, new_cache["mid_block"] = self.mid_block(x, None, None, conv_cache=conv_cache.get("mid_block")) + + x = self.norm_out(x) + x = self.conv_act(x) + x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out")) + + return x, new_cache + + +class Decoder3D(nn.Module): + def __init__(self, in_channels=16, out_channels=3, + block_out_channels=(128, 256, 256, 512), + layers_per_block=3, act_fn="silu", + eps=1e-6, groups=32, pad_mode="first", + temporal_compression_ratio=4): + super().__init__() + reversed_channels = list(reversed(block_out_channels)) + temporal_compress_level = int(np.log2(temporal_compression_ratio)) + + self.conv_in = CausalConv3d(in_channels, reversed_channels[0], kernel_size=3, pad_mode=pad_mode) + + self.mid_block = MidBlock3D( + in_channels=reversed_channels[0], temb_channels=0, + num_layers=2, eps=eps, act_fn=act_fn, groups=groups, + spatial_norm_dim=in_channels, pad_mode=pad_mode, + ) + + self.up_blocks = nn.ModuleList() + output_channel = reversed_channels[0] + for i in range(len(block_out_channels)): + prev_channel = output_channel + output_channel = reversed_channels[i] + is_final = i == len(block_out_channels) - 1 + compress_time = i < temporal_compress_level + + self.up_blocks.append(UpBlock3D( + in_channels=prev_channel, out_channels=output_channel, + temb_channels=0, num_layers=layers_per_block + 1, + eps=eps, act_fn=act_fn, groups=groups, + spatial_norm_dim=in_channels, + add_upsample=not is_final, compress_time=compress_time, + )) + + self.norm_out = SpatialNorm3D(reversed_channels[-1], in_channels, groups=groups) + self.conv_act = nn.SiLU() + self.conv_out = CausalConv3d(reversed_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode) + + def forward(self, sample, conv_cache=None): + new_cache = {} + conv_cache = conv_cache or {} + + x, new_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in")) + + x, new_cache["mid_block"] = self.mid_block(x, None, sample, conv_cache=conv_cache.get("mid_block")) + + for i, block in enumerate(self.up_blocks): + key = f"up_block_{i}" + x, new_cache[key] = block(x, None, sample, conv_cache=conv_cache.get(key)) + + x, new_cache["norm_out"] = self.norm_out(x, sample, conv_cache=conv_cache.get("norm_out")) + x = self.conv_act(x) + x, new_cache["conv_out"] = self.conv_out(x, conv_cache=conv_cache.get("conv_out")) + + return x, new_cache + + + +class AutoencoderKLCogVideoX(nn.Module): + """CogVideoX VAE. Spatial tiling/slicing handled by ComfyUI's VAE wrapper. + + Uses rolling temporal decode: conv_in + mid_block + temporal up_blocks run + on the full (low-res) tensor, then the expensive spatial-only up_blocks + + norm_out + conv_out are processed in small temporal chunks with conv_cache + carrying causal state between chunks. This keeps peak VRAM proportional to + chunk_size rather than total frame count. + """ + + def __init__(self, + in_channels=3, out_channels=3, + block_out_channels=(128, 256, 256, 512), + latent_channels=16, layers_per_block=3, + act_fn="silu", eps=1e-6, groups=32, + temporal_compression_ratio=4, + ): + super().__init__() + self.latent_channels = latent_channels + self.temporal_compression_ratio = temporal_compression_ratio + + self.encoder = Encoder3D( + in_channels=in_channels, out_channels=latent_channels, + block_out_channels=block_out_channels, layers_per_block=layers_per_block, + act_fn=act_fn, eps=eps, groups=groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + self.decoder = Decoder3D( + in_channels=latent_channels, out_channels=out_channels, + block_out_channels=block_out_channels, layers_per_block=layers_per_block, + act_fn=act_fn, eps=eps, groups=groups, + temporal_compression_ratio=temporal_compression_ratio, + ) + + self.num_latent_frames_batch_size = 2 + self.num_sample_frames_batch_size = 8 + + def encode(self, x): + t = x.shape[2] + frame_batch = self.num_sample_frames_batch_size + remainder = t % frame_batch + conv_cache = None + enc = [] + + # Process remainder frames first so only the first chunk can have an + # odd temporal dimension — where Downsample3D's first-frame-special + # handling in temporal compression is actually correct. + if remainder > 0: + chunk, conv_cache = self.encoder(x[:, :, :remainder], conv_cache=conv_cache) + enc.append(chunk.to(x.device)) + + for start in range(remainder, t, frame_batch): + chunk, conv_cache = self.encoder(x[:, :, start:start + frame_batch], conv_cache=conv_cache) + enc.append(chunk.to(x.device)) + + enc = torch.cat(enc, dim=2) + mean, _ = enc.chunk(2, dim=1) + return mean + + def decode(self, z): + return self._decode_rolling(z) + + def _decode_batched(self, z): + """Original batched decode - processes 2 latent frames through full decoder.""" + t = z.shape[2] + frame_batch = self.num_latent_frames_batch_size + num_batches = max(t // frame_batch, 1) + conv_cache = None + dec = [] + for i in range(num_batches): + remaining = t % frame_batch + start = frame_batch * i + (0 if i == 0 else remaining) + end = frame_batch * (i + 1) + remaining + chunk, conv_cache = self.decoder(z[:, :, start:end], conv_cache=conv_cache) + dec.append(chunk.cpu()) + return torch.cat(dec, dim=2).to(z.device) + + def _decode_rolling(self, z): + """Rolling decode - processes low-res layers on full tensor, then rolls + through expensive high-res layers in temporal chunks.""" + decoder = self.decoder + device = z.device + + # Determine which up_blocks have temporal upsample vs spatial-only. + # Temporal up_blocks are cheap (low res), spatial-only are expensive. + temporal_compress_level = int(np.log2(self.temporal_compression_ratio)) + split_at = temporal_compress_level # first N up_blocks do temporal upsample + + # Phase 1: conv_in + mid_block + temporal up_blocks on full tensor (low/medium res) + x, _ = decoder.conv_in(z) + x, _ = decoder.mid_block(x, None, z) + + for i in range(split_at): + x, _ = decoder.up_blocks[i](x, None, z) + + # Phase 2: remaining spatial-only up_blocks + norm_out + conv_out in temporal chunks + remaining_blocks = list(range(split_at, len(decoder.up_blocks))) + chunk_size = 4 # pixel frames per chunk through high-res layers + t_expanded = x.shape[2] + + if t_expanded <= chunk_size or len(remaining_blocks) == 0: + # Small enough to process in one go + for i in remaining_blocks: + x, _ = decoder.up_blocks[i](x, None, z) + x, _ = decoder.norm_out(x, z) + x = decoder.conv_act(x) + x, _ = decoder.conv_out(x) + return x + + # Expand z temporally once to match Phase 2's time dimension. + # z stays at latent spatial resolution so this is small (~16 MB vs ~1.3 GB + # for the old approach of pre-interpolating to every pixel resolution). + z_time_expanded = _interpolate_zq(z, (t_expanded, z.shape[3], z.shape[4])) + + # Process in temporal chunks, interpolating spatially per-chunk to avoid + # allocating full [B, C, t_expanded, H, W] tensors at each resolution. + dec_out = [] + conv_caches = {} + + for chunk_start in range(0, t_expanded, chunk_size): + chunk_end = min(chunk_start + chunk_size, t_expanded) + x_chunk = x[:, :, chunk_start:chunk_end] + z_t_chunk = z_time_expanded[:, :, chunk_start:chunk_end] + z_spatial_cache = {} + + for i in remaining_blocks: + block = decoder.up_blocks[i] + cache_key = f"up_block_{i}" + hw_key = (x_chunk.shape[3], x_chunk.shape[4]) + if hw_key not in z_spatial_cache: + if z_t_chunk.shape[3] == hw_key[0] and z_t_chunk.shape[4] == hw_key[1]: + z_spatial_cache[hw_key] = z_t_chunk + else: + z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1])) + x_chunk, new_cache = block(x_chunk, None, z_spatial_cache[hw_key], conv_cache=conv_caches.get(cache_key)) + conv_caches[cache_key] = new_cache + + hw_key = (x_chunk.shape[3], x_chunk.shape[4]) + if hw_key not in z_spatial_cache: + z_spatial_cache[hw_key] = F.interpolate(z_t_chunk, size=(z_t_chunk.shape[2], hw_key[0], hw_key[1])) + x_chunk, new_cache = decoder.norm_out(x_chunk, z_spatial_cache[hw_key], conv_cache=conv_caches.get("norm_out")) + conv_caches["norm_out"] = new_cache + x_chunk = decoder.conv_act(x_chunk) + x_chunk, new_cache = decoder.conv_out(x_chunk, conv_cache=conv_caches.get("conv_out")) + conv_caches["conv_out"] = new_cache + + dec_out.append(x_chunk.cpu()) + del z_spatial_cache + + del x, z_time_expanded + return torch.cat(dec_out, dim=2).to(device) diff --git a/comfy/model_base.py b/comfy/model_base.py index 5c2668ba9..054853288 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -52,6 +52,7 @@ import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model import comfy.ldm.ace.ace_step15 +import comfy.ldm.cogvideo.model import comfy.ldm.rt_detr.rtdetr_v4 import comfy.ldm.ernie.model @@ -80,6 +81,7 @@ class ModelType(Enum): IMG_TO_IMG = 9 FLOW_COSMOS = 10 IMG_TO_IMG_FLOW = 11 + V_PREDICTION_DDPM = 12 def model_sampling(model_config, model_type): @@ -114,6 +116,8 @@ def model_sampling(model_config, model_type): s = comfy.model_sampling.ModelSamplingCosmosRFlow elif model_type == ModelType.IMG_TO_IMG_FLOW: c = comfy.model_sampling.IMG_TO_IMG_FLOW + elif model_type == ModelType.V_PREDICTION_DDPM: + c = comfy.model_sampling.V_PREDICTION_DDPM class ModelSampling(s, c): pass @@ -1974,3 +1978,59 @@ class ErnieImage(BaseModel): if cross_attn is not None: out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out + +class CogVideoX(BaseModel): + def __init__(self, model_config, model_type=ModelType.V_PREDICTION_DDPM, image_to_video=False, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cogvideo.model.CogVideoXTransformer3DModel) + self.image_to_video = image_to_video + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + # Detect extra channels needed (e.g. 32 - 16 = 16 for ref latent) + extra_channels = self.diffusion_model.in_channels - noise.shape[1] + if extra_channels == 0: + return None + + image = kwargs.get("concat_latent_image", None) + device = kwargs["device"] + + if image is None: + shape = list(noise.shape) + shape[1] = extra_channels + return torch.zeros(shape, dtype=noise.dtype, layout=noise.layout, device=noise.device) + + latent_dim = self.latent_format.latent_channels + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + + if noise.ndim == 5 and image.ndim == 5: + if image.shape[-3] < noise.shape[-3]: + image = torch.nn.functional.pad(image, (0, 0, 0, 0, 0, noise.shape[-3] - image.shape[-3]), "constant", 0) + elif image.shape[-3] > noise.shape[-3]: + image = image[:, :, :noise.shape[-3]] + + for i in range(0, image.shape[1], latent_dim): + image[:, i:i + latent_dim] = self.process_latent_in(image[:, i:i + latent_dim]) + image = utils.resize_to_batch_size(image, noise.shape[0]) + + if image.shape[1] > extra_channels: + image = image[:, :extra_channels] + elif image.shape[1] < extra_channels: + repeats = extra_channels // image.shape[1] + remainder = extra_channels % image.shape[1] + parts = [image] * repeats + if remainder > 0: + parts.append(image[:, :remainder]) + image = torch.cat(parts, dim=1) + + return image + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + # OFS embedding (CogVideoX 1.5 I2V), default 2.0 as used by SparkVSR + if self.diffusion_model.ofs_proj_dim is not None: + ofs = kwargs.get("ofs", None) + if ofs is None: + noise = kwargs.get("noise", None) + ofs = torch.full((noise.shape[0],), 2.0, device=noise.device, dtype=noise.dtype) + out['ofs'] = comfy.conds.CONDRegular(ofs) + return out diff --git a/comfy/model_detection.py b/comfy/model_detection.py index ca06cdd1e..f6095581b 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -490,6 +490,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return dit_config + if '{}blocks.0.norm1.linear.weight'.format(key_prefix) in state_dict_keys: # CogVideoX + dit_config = {} + dit_config["image_model"] = "cogvideox" + + # Extract config from weight shapes + norm1_weight = state_dict['{}blocks.0.norm1.linear.weight'.format(key_prefix)] + time_embed_dim = norm1_weight.shape[1] + dim = norm1_weight.shape[0] // 6 + + dit_config["num_attention_heads"] = dim // 64 + dit_config["attention_head_dim"] = 64 + dit_config["time_embed_dim"] = time_embed_dim + dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') + + # Detect in_channels from patch_embed + patch_proj_key = '{}patch_embed.proj.weight'.format(key_prefix) + if patch_proj_key in state_dict_keys: + w = state_dict[patch_proj_key] + if w.ndim == 4: + # Conv2d: [out, in, kh, kw] — CogVideoX 1.0 + dit_config["in_channels"] = w.shape[1] + dit_config["patch_size"] = w.shape[2] + elif w.ndim == 2: + # Linear: [out, in_channels * patch_size * patch_size * patch_size_t] — CogVideoX 1.5 + dit_config["patch_size"] = 2 + dit_config["patch_size_t"] = 2 + dit_config["in_channels"] = w.shape[1] // (2 * 2 * 2) # 256 // 8 = 32 + + text_proj_key = '{}patch_embed.text_proj.weight'.format(key_prefix) + if text_proj_key in state_dict_keys: + dit_config["text_embed_dim"] = state_dict[text_proj_key].shape[1] + + # Detect OFS embedding + ofs_key = '{}ofs_embedding_linear_1.weight'.format(key_prefix) + if ofs_key in state_dict_keys: + dit_config["ofs_embed_dim"] = state_dict[ofs_key].shape[1] + + # Detect positional embedding type + pos_key = '{}patch_embed.pos_embedding'.format(key_prefix) + if pos_key in state_dict_keys: + dit_config["use_learned_positional_embeddings"] = True + dit_config["use_rotary_positional_embeddings"] = False + else: + dit_config["use_learned_positional_embeddings"] = False + dit_config["use_rotary_positional_embeddings"] = True + + return dit_config + if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 dit_config = {} dit_config["image_model"] = "wan2.1" diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 13860e6a2..cf2b5db5f 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -54,6 +54,30 @@ class V_PREDICTION(EPS): sigma = reshape_sigma(sigma, model_output.ndim) return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5 +class V_PREDICTION_DDPM: + """CogVideoX v-prediction: model receives raw x_t (unscaled), predicts velocity v. + x_0 = sqrt(alpha) * x_t - sqrt(1-alpha) * v + = x_t / sqrt(sigma^2 + 1) - v * sigma / sqrt(sigma^2 + 1) + """ + def calculate_input(self, sigma, noise): + return noise + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = reshape_sigma(sigma, model_output.ndim) + return model_input / (sigma ** 2 + 1.0) ** 0.5 - model_output * sigma / (sigma ** 2 + 1.0) ** 0.5 + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + sigma = reshape_sigma(sigma, noise.ndim) + if max_denoise: + noise = noise * torch.sqrt(1.0 + sigma ** 2.0) + else: + noise = noise * sigma + noise += latent_image + return noise + + def inverse_noise_scaling(self, sigma, latent): + return latent + class EDM(V_PREDICTION): def calculate_denoised(self, sigma, model_output, model_input): sigma = reshape_sigma(sigma, model_output.ndim) diff --git a/comfy/sd.py b/comfy/sd.py index e573804a5..42e3b1e41 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -17,6 +17,7 @@ import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae import comfy.ldm.ace.vae.music_dcae_pipeline +import comfy.ldm.cogvideo.vae import comfy.ldm.hunyuan_video.vae import comfy.ldm.mmaudio.vae.autoencoder import comfy.pixel_space_convert @@ -651,6 +652,17 @@ class VAE: self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) + elif "decoder.conv_in.conv.weight" in sd and "decoder.mid_block.resnets.0.norm1.norm_layer.weight" in sd: # CogVideoX VAE + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.latent_dim = 3 + self.latent_channels = sd["encoder.conv_out.conv.weight"].shape[0] // 2 + self.first_stage_model = comfy.ldm.cogvideo.vae.AutoencoderKLCogVideoX(latent_channels=self.latent_channels) + self.memory_used_decode = lambda shape, dtype: (2800 * max(2, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1400 * max(1, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 58d4ce731..33b268ac1 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -27,6 +27,7 @@ import comfy.text_encoders.anima import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image import comfy.text_encoders.ernie +import comfy.text_encoders.cogvideo from . import supported_models_base from . import latent_formats @@ -1781,6 +1782,52 @@ class ErnieImage(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage] +class CogVideoX_T2V(supported_models_base.BASE): + unet_config = { + "image_model": "cogvideox", + } + + sampling_settings = { + "linear_start": 0.00085, + "linear_end": 0.012, + "beta_schedule": "linear", + "zsnr": True, + } + + unet_extra_config = {} + latent_format = latent_formats.CogVideoX + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + # CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE + if self.unet_config.get("patch_size_t") is not None: + self.unet_config.setdefault("sample_height", 96) + self.unet_config.setdefault("sample_width", 170) + self.unet_config.setdefault("sample_frames", 81) + out = model_base.CogVideoX(self, device=device) + return out + + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget(comfy.text_encoders.cogvideo.CogVideoXT5Tokenizer, comfy.text_encoders.sd3_clip.T5XXLModel) + +class CogVideoX_I2V(CogVideoX_T2V): + unet_config = { + "image_model": "cogvideox", + "in_channels": 32, + } + + def get_model(self, state_dict, prefix="", device=None): + if self.unet_config.get("patch_size_t") is not None: + self.unet_config.setdefault("sample_height", 96) + self.unet_config.setdefault("sample_width", 170) + self.unet_config.setdefault("sample_frames", 81) + out = model_base.CogVideoX(self, image_to_video=True, device=device) + return out + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, CogVideoX_I2V, CogVideoX_T2V] models += [SVD_img2vid] diff --git a/comfy/text_encoders/cogvideo.py b/comfy/text_encoders/cogvideo.py new file mode 100644 index 000000000..f1e8e3f5d --- /dev/null +++ b/comfy/text_encoders/cogvideo.py @@ -0,0 +1,6 @@ +import comfy.text_encoders.sd3_clip + + +class CogVideoXT5Tokenizer(comfy.text_encoders.sd3_clip.T5XXLTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, min_length=226) diff --git a/nodes.py b/nodes.py index 299b3d758..ba2fa0246 100644 --- a/nodes.py +++ b/nodes.py @@ -2457,7 +2457,7 @@ async def init_builtin_extra_nodes(): "nodes_number_convert.py", "nodes_painter.py", "nodes_curve.py", - "nodes_rtdetr.py" + "nodes_rtdetr.py", ] import_failed = []