diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index f59999af6..6a57bca1c 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -776,3 +776,10 @@ class ChromaRadiance(LatentFormat): def process_out(self, latent): return latent + + +class ZImagePixelSpace(ChromaRadiance): + """Pixel-space latent format for ZImage DCT variant. + No VAE encoding/decoding — the model operates directly on RGB pixels. + """ + pass diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py index 553fd5b38..08d686b7b 100644 --- a/comfy/ldm/lightricks/av_model.py +++ b/comfy/ldm/lightricks/av_model.py @@ -2,11 +2,16 @@ from typing import Tuple import torch import torch.nn as nn from comfy.ldm.lightricks.model import ( + ADALN_BASE_PARAMS_COUNT, + ADALN_CROSS_ATTN_PARAMS_COUNT, CrossAttention, FeedForward, AdaLayerNormSingle, PixArtAlphaTextProjection, + NormSingleLinearTextProjection, LTXVModel, + apply_cross_attention_adaln, + compute_prompt_timestep, ) from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector @@ -87,6 +92,8 @@ class BasicAVTransformerBlock(nn.Module): v_context_dim=None, a_context_dim=None, attn_precision=None, + apply_gated_attention=False, + cross_attention_adaln=False, dtype=None, device=None, operations=None, @@ -94,6 +101,7 @@ class BasicAVTransformerBlock(nn.Module): super().__init__() self.attn_precision = attn_precision + self.cross_attention_adaln = cross_attention_adaln self.attn1 = CrossAttention( query_dim=v_dim, @@ -101,6 +109,7 @@ class BasicAVTransformerBlock(nn.Module): dim_head=vd_head, context_dim=None, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -111,6 +120,7 @@ class BasicAVTransformerBlock(nn.Module): dim_head=ad_head, context_dim=None, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -122,6 +132,7 @@ class BasicAVTransformerBlock(nn.Module): heads=v_heads, dim_head=vd_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -132,6 +143,7 @@ class BasicAVTransformerBlock(nn.Module): heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -144,6 +156,7 @@ class BasicAVTransformerBlock(nn.Module): heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -156,6 +169,7 @@ class BasicAVTransformerBlock(nn.Module): heads=a_heads, dim_head=ad_head, attn_precision=self.attn_precision, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -168,11 +182,16 @@ class BasicAVTransformerBlock(nn.Module): a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations ) - self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype)) + num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT + self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype)) self.audio_scale_shift_table = nn.Parameter( - torch.empty(6, a_dim, device=device, dtype=dtype) + torch.empty(num_ada_params, a_dim, device=device, dtype=dtype) ) + if cross_attention_adaln: + self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype)) + self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype)) + self.scale_shift_table_a2v_ca_audio = nn.Parameter( torch.empty(5, a_dim, device=device, dtype=dtype) ) @@ -215,10 +234,30 @@ class BasicAVTransformerBlock(nn.Module): return (*scale_shift_ada_values, *gate_ada_values) + def _apply_text_cross_attention( + self, x, context, attn, scale_shift_table, prompt_scale_shift_table, + timestep, prompt_timestep, attention_mask, transformer_options, + ): + """Apply text cross-attention, with optional ADaLN modulation.""" + if self.cross_attention_adaln: + shift_q, scale_q, gate = self.get_ada_values( + scale_shift_table, x.shape[0], timestep, slice(6, 9) + ) + return apply_cross_attention_adaln( + x, context, attn, shift_q, scale_q, gate, + prompt_scale_shift_table, prompt_timestep, + attention_mask, transformer_options, + ) + return attn( + comfy.ldm.common_dit.rms_norm(x), context=context, + mask=attention_mask, transformer_options=transformer_options, + ) + def forward( self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None, v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None, v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None, + v_prompt_timestep=None, a_prompt_timestep=None, ) -> Tuple[torch.Tensor, torch.Tensor]: run_vx = transformer_options.get("run_vx", True) run_ax = transformer_options.get("run_ax", True) @@ -240,7 +279,11 @@ class BasicAVTransformerBlock(nn.Module): vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0] vx.addcmul_(attn1_out, vgate_msa) del vgate_msa, attn1_out - vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options)) + vx.add_(self._apply_text_cross_attention( + vx, v_context, self.attn2, self.scale_shift_table, + getattr(self, 'prompt_scale_shift_table', None), + v_timestep, v_prompt_timestep, attention_mask, transformer_options,) + ) # audio if run_ax: @@ -254,7 +297,11 @@ class BasicAVTransformerBlock(nn.Module): agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0] ax.addcmul_(attn1_out, agate_msa) del agate_msa, attn1_out - ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options)) + ax.add_(self._apply_text_cross_attention( + ax, a_context, self.audio_attn2, self.audio_scale_shift_table, + getattr(self, 'audio_prompt_scale_shift_table', None), + a_timestep, a_prompt_timestep, attention_mask, transformer_options,) + ) # video - audio cross attention. if run_a2v or run_v2a: @@ -351,6 +398,9 @@ class LTXAVModel(LTXVModel): use_middle_indices_grid=False, timestep_scale_multiplier=1000.0, av_ca_timestep_scale_multiplier=1.0, + apply_gated_attention=False, + caption_proj_before_connector=False, + cross_attention_adaln=False, dtype=None, device=None, operations=None, @@ -362,6 +412,7 @@ class LTXAVModel(LTXVModel): self.audio_attention_head_dim = audio_attention_head_dim self.audio_num_attention_heads = audio_num_attention_heads self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos + self.apply_gated_attention = apply_gated_attention # Calculate audio dimensions self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim @@ -386,6 +437,8 @@ class LTXAVModel(LTXVModel): vae_scale_factors=vae_scale_factors, use_middle_indices_grid=use_middle_indices_grid, timestep_scale_multiplier=timestep_scale_multiplier, + caption_proj_before_connector=caption_proj_before_connector, + cross_attention_adaln=cross_attention_adaln, dtype=dtype, device=device, operations=operations, @@ -400,14 +453,28 @@ class LTXAVModel(LTXVModel): ) # Audio-specific AdaLN + audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT self.audio_adaln_single = AdaLayerNormSingle( self.audio_inner_dim, + embedding_coefficient=audio_embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations, ) + if self.cross_attention_adaln: + self.audio_prompt_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + embedding_coefficient=2, + use_additional_conditions=False, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.audio_prompt_adaln_single = None + num_scale_shift_values = 4 self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( self.inner_dim, @@ -443,35 +510,73 @@ class LTXAVModel(LTXVModel): ) # Audio caption projection - self.audio_caption_projection = PixArtAlphaTextProjection( - in_features=self.caption_channels, - hidden_size=self.audio_inner_dim, - dtype=dtype, - device=device, - operations=self.operations, - ) + if self.caption_proj_before_connector: + if self.caption_projection_first_linear: + self.audio_caption_projection = NormSingleLinearTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.audio_caption_projection = lambda a: a + else: + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + + connector_split_rope = kwargs.get("rope_type", "split") == "split" + connector_gated_attention = kwargs.get("connector_apply_gated_attention", False) + attention_head_dim = kwargs.get("connector_attention_head_dim", 128) + num_attention_heads = kwargs.get("connector_num_attention_heads", 30) + num_layers = kwargs.get("connector_num_layers", 2) self.audio_embeddings_connector = Embeddings1DConnector( - split_rope=True, + attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim), + num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads), + num_layers=num_layers, + split_rope=connector_split_rope, double_precision_rope=True, + apply_gated_attention=connector_gated_attention, dtype=dtype, device=device, operations=self.operations, ) self.video_embeddings_connector = Embeddings1DConnector( - split_rope=True, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + num_layers=num_layers, + split_rope=connector_split_rope, double_precision_rope=True, + apply_gated_attention=connector_gated_attention, dtype=dtype, device=device, operations=self.operations, ) - def preprocess_text_embeds(self, context): - if context.shape[-1] == self.caption_channels * 2: - return context - out_vid = self.video_embeddings_connector(context)[0] - out_audio = self.audio_embeddings_connector(context)[0] + def preprocess_text_embeds(self, context, unprocessed=False): + # LTXv2 fully processed context has dimension of self.caption_channels * 2 + # LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim + if not unprocessed: + if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2): + return context + if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim: + context_vid = context[:, :, :self.cross_attention_dim] + context_audio = context[:, :, self.cross_attention_dim:] + else: + context_vid = context + context_audio = context + if self.caption_proj_before_connector: + context_vid = self.caption_projection(context_vid) + context_audio = self.audio_caption_projection(context_audio) + out_vid = self.video_embeddings_connector(context_vid)[0] + out_audio = self.audio_embeddings_connector(context_audio)[0] return torch.concat((out_vid, out_audio), dim=-1) def _init_transformer_blocks(self, device, dtype, **kwargs): @@ -487,6 +592,8 @@ class LTXAVModel(LTXVModel): ad_head=self.audio_attention_head_dim, v_context_dim=self.cross_attention_dim, a_context_dim=self.audio_cross_attention_dim, + apply_gated_attention=self.apply_gated_attention, + cross_attention_adaln=self.cross_attention_adaln, dtype=dtype, device=device, operations=self.operations, @@ -608,6 +715,10 @@ class LTXAVModel(LTXVModel): v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame) v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame) + v_prompt_timestep = compute_prompt_timestep( + self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype + ) + # Prepare audio timestep a_timestep = kwargs.get("a_timestep") if a_timestep is not None: @@ -618,25 +729,25 @@ class LTXAVModel(LTXVModel): # Cross-attention timesteps - compress these too av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( - a_timestep_flat, + timestep.max().expand_as(a_timestep_flat), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( - timestep_flat, + a_timestep.max().expand_as(timestep_flat), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( - timestep_flat * av_ca_factor, + a_timestep.max().expand_as(timestep_flat) * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, ) av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( - a_timestep_flat * av_ca_factor, + timestep.max().expand_as(a_timestep_flat) * av_ca_factor, {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, @@ -660,29 +771,40 @@ class LTXAVModel(LTXVModel): # Audio timesteps a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1]) a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1]) + + a_prompt_timestep = compute_prompt_timestep( + self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype + ) else: a_timestep = timestep_scaled a_embedded_timestep = kwargs.get("embedded_timestep") cross_av_timestep_ss = [] + a_prompt_timestep = None - return [v_timestep, a_timestep, cross_av_timestep_ss], [ + return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [ v_embedded_timestep, a_embedded_timestep, - ] + ], None def _prepare_context(self, context, batch_size, x, attention_mask=None): vx = x[0] ax = x[1] + video_dim = vx.shape[-1] + audio_dim = ax.shape[-1] + + v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim + a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim + v_context, a_context = torch.split( - context, int(context.shape[-1] / 2), len(context.shape) - 1 + context, [v_context_dim, a_context_dim], len(context.shape) - 1 ) v_context, attention_mask = super()._prepare_context( v_context, batch_size, vx, attention_mask ) - if self.audio_caption_projection is not None: + if self.caption_proj_before_connector is False: a_context = self.audio_caption_projection(a_context) - a_context = a_context.view(batch_size, -1, ax.shape[-1]) + a_context = a_context.view(batch_size, -1, audio_dim) return [v_context, a_context], attention_mask @@ -744,6 +866,9 @@ class LTXAVModel(LTXVModel): av_ca_v2a_gate_noise_timestep, ) = timestep[2] + v_prompt_timestep = timestep[3] + a_prompt_timestep = timestep[4] + """Process transformer blocks for LTXAV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) @@ -771,6 +896,8 @@ class LTXAVModel(LTXVModel): a_cross_gate_timestep=args["a_cross_gate_timestep"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), + v_prompt_timestep=args.get("v_prompt_timestep"), + a_prompt_timestep=args.get("a_prompt_timestep"), ) return out @@ -792,6 +919,8 @@ class LTXAVModel(LTXVModel): "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, + "v_prompt_timestep": v_prompt_timestep, + "a_prompt_timestep": a_prompt_timestep, }, {"original_block": block_wrap}, ) @@ -814,6 +943,8 @@ class LTXAVModel(LTXVModel): a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, transformer_options=transformer_options, self_attention_mask=self_attention_mask, + v_prompt_timestep=v_prompt_timestep, + a_prompt_timestep=a_prompt_timestep, ) return [vx, ax] diff --git a/comfy/ldm/lightricks/embeddings_connector.py b/comfy/ldm/lightricks/embeddings_connector.py index 33adb9671..2811080be 100644 --- a/comfy/ldm/lightricks/embeddings_connector.py +++ b/comfy/ldm/lightricks/embeddings_connector.py @@ -50,6 +50,7 @@ class BasicTransformerBlock1D(nn.Module): d_head, context_dim=None, attn_precision=None, + apply_gated_attention=False, dtype=None, device=None, operations=None, @@ -63,6 +64,7 @@ class BasicTransformerBlock1D(nn.Module): heads=n_heads, dim_head=d_head, context_dim=None, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, @@ -121,6 +123,7 @@ class Embeddings1DConnector(nn.Module): positional_embedding_max_pos=[4096], causal_temporal_positioning=False, num_learnable_registers: Optional[int] = 128, + apply_gated_attention=False, dtype=None, device=None, operations=None, @@ -145,6 +148,7 @@ class Embeddings1DConnector(nn.Module): num_attention_heads, attention_head_dim, context_dim=cross_attention_dim, + apply_gated_attention=apply_gated_attention, dtype=dtype, device=device, operations=operations, diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 60d760d29..bfbc08357 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -275,6 +275,30 @@ class PixArtAlphaTextProjection(nn.Module): return hidden_states +class NormSingleLinearTextProjection(nn.Module): + """Text projection for 20B models - single linear with RMSNorm (no activation).""" + + def __init__( + self, in_features, hidden_size, dtype=None, device=None, operations=None + ): + super().__init__() + if operations is None: + operations = comfy.ops.disable_weight_init + self.in_norm = operations.RMSNorm( + in_features, eps=1e-6, elementwise_affine=False + ) + self.linear_1 = operations.Linear( + in_features, hidden_size, bias=True, dtype=dtype, device=device + ) + self.hidden_size = hidden_size + self.in_features = in_features + + def forward(self, caption): + caption = self.in_norm(caption) + caption = caption * (self.hidden_size / self.in_features) ** 0.5 + return self.linear_1(caption) + + class GELU_approx(nn.Module): def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None): super().__init__() @@ -343,6 +367,7 @@ class CrossAttention(nn.Module): dim_head=64, dropout=0.0, attn_precision=None, + apply_gated_attention=False, dtype=None, device=None, operations=None, @@ -362,6 +387,12 @@ class CrossAttention(nn.Module): self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) + # Optional per-head gating + if apply_gated_attention: + self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device) + else: + self.to_gate_logits = None + self.to_out = nn.Sequential( operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) ) @@ -383,16 +414,30 @@ class CrossAttention(nn.Module): out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) else: out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) + + # Apply per-head gating if enabled + if self.to_gate_logits is not None: + gate_logits = self.to_gate_logits(x) # (B, T, H) + b, t, _ = out.shape + out = out.view(b, t, self.heads, self.dim_head) + gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity + out = out * gates.unsqueeze(-1) + out = out.view(b, t, self.heads * self.dim_head) + return self.to_out(out) +# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate) +ADALN_BASE_PARAMS_COUNT = 6 +ADALN_CROSS_ATTN_PARAMS_COUNT = 9 class BasicTransformerBlock(nn.Module): def __init__( - self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None + self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None ): super().__init__() self.attn_precision = attn_precision + self.cross_attention_adaln = cross_attention_adaln self.attn1 = CrossAttention( query_dim=dim, heads=n_heads, @@ -416,18 +461,25 @@ class BasicTransformerBlock(nn.Module): operations=operations, ) - self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) + num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT + self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype)) - def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None): - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) + if cross_attention_adaln: + self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype)) - attn1_input = comfy.ldm.common_dit.rms_norm(x) - attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa) - attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) - x.addcmul_(attn1_input, gate_msa) - del attn1_input + def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2) - x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) + x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa + + if self.cross_attention_adaln: + shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2) + x += apply_cross_attention_adaln( + x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca, + self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options, + ) + else: + x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) y = comfy.ldm.common_dit.rms_norm(x) y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp) @@ -435,6 +487,47 @@ class BasicTransformerBlock(nn.Module): return x +def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype): + """Compute a single global prompt timestep for cross-attention ADaLN. + + Uses the max across tokens (matching JAX max_per_segment) and broadcasts + over text tokens. Returns None when *adaln_module* is None. + """ + if adaln_module is None: + return None + ts_input = ( + timestep_scaled.max(dim=1, keepdim=True).values.flatten() + if timestep_scaled.dim() > 1 + else timestep_scaled.flatten() + ) + prompt_ts, _ = adaln_module( + ts_input, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1]) + + +def apply_cross_attention_adaln( + x, context, attn, q_shift, q_scale, q_gate, + prompt_scale_shift_table, prompt_timestep, + attention_mask=None, transformer_options={}, +): + """Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV). + + Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so + that both regular tensors and CompressedTimestep are supported. + """ + batch_size = x.shape[0] + shift_kv, scale_kv = ( + prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + + prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1) + ).unbind(dim=2) + attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift + encoder_hidden_states = context * (1 + scale_kv) + shift_kv + return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate + def get_fractional_positions(indices_grid, max_pos): n_pos_dims = indices_grid.shape[1] assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})' @@ -556,6 +649,9 @@ class LTXBaseModel(torch.nn.Module, ABC): vae_scale_factors: tuple = (8, 32, 32), use_middle_indices_grid=False, timestep_scale_multiplier = 1000.0, + caption_proj_before_connector=False, + cross_attention_adaln=False, + caption_projection_first_linear=True, dtype=None, device=None, operations=None, @@ -582,6 +678,9 @@ class LTXBaseModel(torch.nn.Module, ABC): self.causal_temporal_positioning = causal_temporal_positioning self.operations = operations self.timestep_scale_multiplier = timestep_scale_multiplier + self.caption_proj_before_connector = caption_proj_before_connector + self.cross_attention_adaln = cross_attention_adaln + self.caption_projection_first_linear = caption_projection_first_linear # Common dimensions self.inner_dim = num_attention_heads * attention_head_dim @@ -609,17 +708,37 @@ class LTXBaseModel(torch.nn.Module, ABC): self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device ) + embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT self.adaln_single = AdaLayerNormSingle( - self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations + self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations ) - self.caption_projection = PixArtAlphaTextProjection( - in_features=self.caption_channels, - hidden_size=self.inner_dim, - dtype=dtype, - device=device, - operations=self.operations, - ) + if self.cross_attention_adaln: + self.prompt_adaln_single = AdaLayerNormSingle( + self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations + ) + else: + self.prompt_adaln_single = None + + if self.caption_proj_before_connector: + if self.caption_projection_first_linear: + self.caption_projection = NormSingleLinearTextProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + else: + self.caption_projection = lambda a: a + else: + self.caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) @abstractmethod def _init_model_components(self, device, dtype, **kwargs): @@ -665,9 +784,9 @@ class LTXBaseModel(torch.nn.Module, ABC): if grid_mask is not None: timestep = timestep[:, grid_mask] - timestep = timestep * self.timestep_scale_multiplier + timestep_scaled = timestep * self.timestep_scale_multiplier timestep, embedded_timestep = self.adaln_single( - timestep.flatten(), + timestep_scaled.flatten(), {"resolution": None, "aspect_ratio": None}, batch_size=batch_size, hidden_dtype=hidden_dtype, @@ -677,14 +796,18 @@ class LTXBaseModel(torch.nn.Module, ABC): timestep = timestep.view(batch_size, -1, timestep.shape[-1]) embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) - return timestep, embedded_timestep + prompt_timestep = compute_prompt_timestep( + self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype + ) + + return timestep, embedded_timestep, prompt_timestep def _prepare_context(self, context, batch_size, x, attention_mask=None): """Prepare context for transformer blocks.""" - if self.caption_projection is not None: + if self.caption_proj_before_connector is False: context = self.caption_projection(context) - context = context.view(batch_size, -1, x.shape[-1]) + context = context.view(batch_size, -1, x.shape[-1]) return context, attention_mask def _precompute_freqs_cis( @@ -792,7 +915,8 @@ class LTXBaseModel(torch.nn.Module, ABC): merged_args.update(additional_args) # Prepare timestep and context - timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) + timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) + merged_args["prompt_timestep"] = prompt_timestep context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask) # Prepare attention mask and positional embeddings @@ -833,7 +957,9 @@ class LTXVModel(LTXBaseModel): causal_temporal_positioning=False, vae_scale_factors=(8, 32, 32), use_middle_indices_grid=False, - timestep_scale_multiplier = 1000.0, + timestep_scale_multiplier=1000.0, + caption_proj_before_connector=False, + cross_attention_adaln=False, dtype=None, device=None, operations=None, @@ -852,6 +978,8 @@ class LTXVModel(LTXBaseModel): vae_scale_factors=vae_scale_factors, use_middle_indices_grid=use_middle_indices_grid, timestep_scale_multiplier=timestep_scale_multiplier, + caption_proj_before_connector=caption_proj_before_connector, + cross_attention_adaln=cross_attention_adaln, dtype=dtype, device=device, operations=operations, @@ -860,7 +988,6 @@ class LTXVModel(LTXBaseModel): def _init_model_components(self, device, dtype, **kwargs): """Initialize LTXV-specific components.""" - # No additional components needed for LTXV beyond base class pass def _init_transformer_blocks(self, device, dtype, **kwargs): @@ -872,6 +999,7 @@ class LTXVModel(LTXBaseModel): self.num_attention_heads, self.attention_head_dim, context_dim=self.cross_attention_dim, + cross_attention_adaln=self.cross_attention_adaln, dtype=dtype, device=device, operations=self.operations, @@ -1149,16 +1277,17 @@ class LTXVModel(LTXBaseModel): """Process transformer blocks for LTXV.""" patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + prompt_timestep = kwargs.get("prompt_timestep", None) for i, block in enumerate(self.transformer_blocks): if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} - out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask")) + out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep")) return out - out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap}) + out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap}) x = out["img"] else: x = block( @@ -1169,6 +1298,7 @@ class LTXVModel(LTXBaseModel): pe=pe, transformer_options=transformer_options, self_attention_mask=self_attention_mask, + prompt_timestep=prompt_timestep, ) return x diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py index 55a074661..fa0a00748 100644 --- a/comfy/ldm/lightricks/vae/audio_vae.py +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -13,7 +13,7 @@ from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( CausalityAxis, CausalAudioAutoencoder, ) -from comfy.ldm.lightricks.vocoders.vocoder import Vocoder +from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE LATENT_DOWNSAMPLE_FACTOR = 4 @@ -141,7 +141,10 @@ class AudioVAE(torch.nn.Module): vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True) self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) - self.vocoder = Vocoder(config=component_config.vocoder) + if "bwe" in component_config.vocoder: + self.vocoder = VocoderWithBWE(config=component_config.vocoder) + else: + self.vocoder = Vocoder(config=component_config.vocoder) self.autoencoder.load_state_dict(vae_sd, strict=False) self.vocoder.load_state_dict(vocoder_sd, strict=False) diff --git a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py index f12b9bb53..b556b128f 100644 --- a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py @@ -822,26 +822,23 @@ class CausalAudioAutoencoder(nn.Module): super().__init__() if config is None: - config = self._guess_config() + config = self.get_default_config() - # Extract encoder and decoder configs from the new format model_config = config.get("model", {}).get("params", {}) - variables_config = config.get("variables", {}) - self.sampling_rate = variables_config.get( - "sampling_rate", - model_config.get("sampling_rate", config.get("sampling_rate", 16000)), + self.sampling_rate = model_config.get( + "sampling_rate", config.get("sampling_rate", 16000) ) encoder_config = model_config.get("encoder", model_config.get("ddconfig", {})) decoder_config = model_config.get("decoder", encoder_config) # Load mel spectrogram parameters self.mel_bins = encoder_config.get("mel_bins", 64) - self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160) - self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024) + self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160) + self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024) # Store causality configuration at VAE level (not just in encoder internals) - causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value) + causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value) self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value) self.is_causal = self.causality_axis == CausalityAxis.HEIGHT @@ -850,44 +847,38 @@ class CausalAudioAutoencoder(nn.Module): self.per_channel_statistics = processor() - def _guess_config(self): - encoder_config = { - # Required parameters - based on ltx-video-av-1679000 model metadata - "ch": 128, - "out_ch": 8, - "ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8] - "num_res_blocks": 2, - "attn_resolutions": [], # Based on metadata: empty list, no attention - "dropout": 0.0, - "resamp_with_conv": True, - "in_channels": 2, # stereo - "resolution": 256, - "z_channels": 8, + def get_default_config(self): + ddconfig = { "double_z": True, - "attn_type": "vanilla", - "mid_block_add_attention": False, # Based on metadata: false + "mel_bins": 64, + "z_channels": 8, + "resolution": 256, + "downsample_time": False, + "in_channels": 2, + "out_ch": 2, + "ch": 128, + "ch_mult": [1, 2, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, + "mid_block_add_attention": False, "norm_type": "pixel", - "causality_axis": "height", # Based on metadata - "mel_bins": 64, # Based on metadata: mel_bins = 64 - } - - decoder_config = { - # Inherits encoder config, can override specific params - **encoder_config, - "out_ch": 2, # Stereo audio output (2 channels) - "give_pre_end": False, - "tanh_out": False, + "causality_axis": "height", } config = { - "_class_name": "CausalAudioAutoencoder", - "sampling_rate": 16000, "model": { "params": { - "encoder": encoder_config, - "decoder": decoder_config, + "ddconfig": ddconfig, + "sampling_rate": 16000, } }, + "preprocessing": { + "stft": { + "filter_length": 1024, + "hop_length": 160, + }, + }, } return config diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index cbfdf412d..5b57dfc5e 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -15,6 +15,9 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init +def in_meta_context(): + return torch.device("meta") == torch.empty(0).device + def mark_conv3d_ended(module): tid = threading.get_ident() for _, m in module.named_modules(): @@ -350,6 +353,10 @@ class Decoder(nn.Module): output_channel = output_channel * block_params.get("multiplier", 2) if block_name == "compress_all": output_channel = output_channel * block_params.get("multiplier", 1) + if block_name == "compress_space": + output_channel = output_channel * block_params.get("multiplier", 1) + if block_name == "compress_time": + output_channel = output_channel * block_params.get("multiplier", 1) self.conv_in = make_conv_nd( dims, @@ -395,17 +402,21 @@ class Decoder(nn.Module): spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_time": + output_channel = output_channel // block_params.get("multiplier", 1) block = DepthToSpaceUpsample( dims=dims, in_channels=input_channel, stride=(2, 1, 1), + out_channels_reduction_factor=block_params.get("multiplier", 1), spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_space": + output_channel = output_channel // block_params.get("multiplier", 1) block = DepthToSpaceUpsample( dims=dims, in_channels=input_channel, stride=(1, 2, 2), + out_channels_reduction_factor=block_params.get("multiplier", 1), spatial_padding_mode=spatial_padding_mode, ) elif block_name == "compress_all": @@ -455,6 +466,15 @@ class Decoder(nn.Module): output_channel * 2, 0, operations=ops, ) self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel)) + else: + self.register_buffer( + "last_scale_shift_table", + torch.tensor( + [0.0, 0.0], + device="cpu" if in_meta_context() else None + ).unsqueeze(1).expand(2, output_channel), + persistent=False, + ) # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: @@ -883,6 +903,15 @@ class ResnetBlock3D(nn.Module): self.scale_shift_table = nn.Parameter( torch.randn(4, in_channels) / in_channels**0.5 ) + else: + self.register_buffer( + "scale_shift_table", + torch.tensor( + [0.0, 0.0, 0.0, 0.0], + device="cpu" if in_meta_context() else None + ).unsqueeze(1).expand(4, in_channels), + persistent=False, + ) self.temporal_cache_state={} @@ -1012,9 +1041,6 @@ class processor(nn.Module): super().__init__() self.register_buffer("std-of-means", torch.empty(128)) self.register_buffer("mean-of-means", torch.empty(128)) - self.register_buffer("mean-of-stds", torch.empty(128)) - self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128)) - self.register_buffer("channel", torch.empty(128)) def un_normalize(self, x): return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x) @@ -1027,9 +1053,12 @@ class VideoVAE(nn.Module): super().__init__() if config is None: - config = self.guess_config(version) + config = self.get_default_config(version) + self.config = config self.timestep_conditioning = config.get("timestep_conditioning", False) + self.decode_noise_scale = config.get("decode_noise_scale", 0.025) + self.decode_timestep = config.get("decode_timestep", 0.05) double_z = config.get("double_z", True) latent_log_var = config.get( "latent_log_var", "per_channel" if double_z else "none" @@ -1044,6 +1073,7 @@ class VideoVAE(nn.Module): latent_log_var=latent_log_var, norm_layer=config.get("norm_layer", "group_norm"), spatial_padding_mode=config.get("spatial_padding_mode", "zeros"), + base_channels=config.get("encoder_base_channels", 128), ) self.decoder = Decoder( @@ -1051,6 +1081,7 @@ class VideoVAE(nn.Module): in_channels=config["latent_channels"], out_channels=config.get("out_channels", 3), blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))), + base_channels=config.get("decoder_base_channels", 128), patch_size=config.get("patch_size", 1), norm_layer=config.get("norm_layer", "group_norm"), causal=config.get("causal_decoder", False), @@ -1060,7 +1091,7 @@ class VideoVAE(nn.Module): self.per_channel_statistics = processor() - def guess_config(self, version): + def get_default_config(self, version): if version == 0: config = { "_class_name": "CausalVideoAutoencoder", @@ -1167,8 +1198,7 @@ class VideoVAE(nn.Module): means, logvar = torch.chunk(self.encoder(x), 2, dim=1) return self.per_channel_statistics.normalize(means) - def decode(self, x, timestep=0.05, noise_scale=0.025): + def decode(self, x): if self.timestep_conditioning: #TODO: seed - x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x - return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep) - + x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x + return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py index b1f15f2c5..6c4028aa8 100644 --- a/comfy/ldm/lightricks/vocoders/vocoder.py +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -3,6 +3,7 @@ import torch.nn.functional as F import torch.nn as nn import comfy.ops import numpy as np +import math ops = comfy.ops.disable_weight_init @@ -12,6 +13,307 @@ def get_padding(kernel_size, dilation=1): return int((kernel_size * dilation - dilation) / 2) +# --------------------------------------------------------------------------- +# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2 +# Adopted from https://github.com/NVIDIA/BigVGAN +# --------------------------------------------------------------------------- + + +def _sinc(x: torch.Tensor): + return torch.where( + x == 0, + torch.tensor(1.0, device=x.device, dtype=x.dtype), + torch.sin(math.pi * x) / math.pi / x, + ) + + +def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): + even = kernel_size % 2 == 0 + half_size = kernel_size // 2 + delta_f = 4 * half_width + A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95 + if A > 50.0: + beta = 0.1102 * (A - 8.7) + elif A >= 21.0: + beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0) + else: + beta = 0.0 + window = torch.kaiser_window(kernel_size, beta=beta, periodic=False) + if even: + time = torch.arange(-half_size, half_size) + 0.5 + else: + time = torch.arange(kernel_size) - half_size + if cutoff == 0: + filter_ = torch.zeros_like(time) + else: + filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time) + filter_ /= filter_.sum() + filter = filter_.view(1, 1, kernel_size) + return filter + + +class LowPassFilter1d(nn.Module): + def __init__( + self, + cutoff=0.5, + half_width=0.6, + stride=1, + padding=True, + padding_mode="replicate", + kernel_size=12, + ): + super().__init__() + if cutoff < -0.0: + raise ValueError("Minimum cutoff must be larger than zero.") + if cutoff > 0.5: + raise ValueError("A cutoff above 0.5 does not make sense.") + self.kernel_size = kernel_size + self.even = kernel_size % 2 == 0 + self.pad_left = kernel_size // 2 - int(self.even) + self.pad_right = kernel_size // 2 + self.stride = stride + self.padding = padding + self.padding_mode = padding_mode + filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size) + self.register_buffer("filter", filter) + + def forward(self, x): + _, C, _ = x.shape + if self.padding: + x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode) + return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C) + + +class UpSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"): + super().__init__() + self.ratio = ratio + self.stride = ratio + + if window_type == "hann": + # Hann-windowed sinc filter — identical to torchaudio.functional.resample + # with its default parameters (rolloff=0.99, lowpass_filter_width=6). + # Uses replicate boundary padding, matching the reference resampler exactly. + rolloff = 0.99 + lowpass_filter_width = 6 + width = math.ceil(lowpass_filter_width / rolloff) + self.kernel_size = 2 * width * ratio + 1 + self.pad = width + self.pad_left = 2 * width * ratio + self.pad_right = self.kernel_size - ratio + t = (torch.arange(self.kernel_size) / ratio - width) * rolloff + t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width) + window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2 + filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1) + else: + # Kaiser-windowed sinc filter (BigVGAN default). + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.pad = self.kernel_size // ratio - 1 + self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2 + self.pad_right = ( + self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2 + ) + filter = kaiser_sinc_filter1d( + cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size + ) + + self.register_buffer("filter", filter, persistent=persistent) + + def forward(self, x): + _, C, _ = x.shape + x = F.pad(x, (self.pad, self.pad), mode="replicate") + x = self.ratio * F.conv_transpose1d( + x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C + ) + x = x[..., self.pad_left : -self.pad_right] + return x + + +class DownSample1d(nn.Module): + def __init__(self, ratio=2, kernel_size=None): + super().__init__() + self.ratio = ratio + self.kernel_size = ( + int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size + ) + self.lowpass = LowPassFilter1d( + cutoff=0.5 / ratio, + half_width=0.6 / ratio, + stride=ratio, + kernel_size=self.kernel_size, + ) + + def forward(self, x): + return self.lowpass(x) + + +class Activation1d(nn.Module): + def __init__( + self, + activation, + up_ratio=2, + down_ratio=2, + up_kernel_size=12, + down_kernel_size=12, + ): + super().__init__() + self.act = activation + self.upsample = UpSample1d(up_ratio, up_kernel_size) + self.downsample = DownSample1d(down_ratio, down_kernel_size) + + def forward(self, x): + x = self.upsample(x) + x = self.act(x) + x = self.downsample(x) + return x + + +# --------------------------------------------------------------------------- +# BigVGAN v2 activations (Snake / SnakeBeta) +# --------------------------------------------------------------------------- + + +class Snake(nn.Module): + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True + ): + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.alpha.requires_grad = alpha_trainable + self.eps = 1e-9 + + def forward(self, x): + a = self.alpha.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + a = torch.exp(a) + return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2) + + +class SnakeBeta(nn.Module): + def __init__( + self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True + ): + super().__init__() + self.alpha_logscale = alpha_logscale + self.alpha = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.alpha.requires_grad = alpha_trainable + self.beta = nn.Parameter( + torch.zeros(in_features) + if alpha_logscale + else torch.ones(in_features) * alpha + ) + self.beta.requires_grad = alpha_trainable + self.eps = 1e-9 + + def forward(self, x): + a = self.alpha.unsqueeze(0).unsqueeze(-1) + b = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + a = torch.exp(a) + b = torch.exp(b) + return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2) + + +# --------------------------------------------------------------------------- +# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity) +# --------------------------------------------------------------------------- + + +class AMPBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"): + super().__init__() + act_cls = SnakeBeta if activation == "snakebeta" else Snake + self.convs1 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ] + ) + + self.acts1 = nn.ModuleList( + [Activation1d(act_cls(channels)) for _ in range(len(self.convs1))] + ) + self.acts2 = nn.ModuleList( + [Activation1d(act_cls(channels)) for _ in range(len(self.convs2))] + ) + + def forward(self, x): + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = x + xt + return x + + +# --------------------------------------------------------------------------- +# HiFi-GAN residual blocks +# --------------------------------------------------------------------------- + + class ResBlock1(torch.nn.Module): def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): super(ResBlock1, self).__init__() @@ -119,6 +421,7 @@ class Vocoder(torch.nn.Module): """ Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan. + Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1"). """ def __init__(self, config=None): @@ -128,19 +431,39 @@ class Vocoder(torch.nn.Module): config = self.get_default_config() resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11]) - upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2]) - upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]) + upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2]) + upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4]) resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) upsample_initial_channel = config.get("upsample_initial_channel", 1024) stereo = config.get("stereo", True) - resblock = config.get("resblock", "1") + activation = config.get("activation", "snake") + use_bias_at_final = config.get("use_bias_at_final", True) + + # "output_sample_rate" is not present in recent checkpoint configs. + # When absent (None), AudioVAE.output_sample_rate computes it as: + # sample_rate * vocoder.upsample_factor / mel_hop_length + # where upsample_factor = product of all upsample stride lengths, + # and mel_hop_length is loaded from the autoencoder config at + # preprocessing.stft.hop_length (see CausalAudioAutoencoder). self.output_sample_rate = config.get("output_sample_rate") + self.resblock = config.get("resblock", "1") + self.use_tanh_at_final = config.get("use_tanh_at_final", True) + self.apply_final_activation = config.get("apply_final_activation", True) self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_rates) + in_channels = 128 if stereo else 64 self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) - resblock_class = ResBlock1 if resblock == "1" else ResBlock2 + + if self.resblock == "1": + resblock_cls = ResBlock1 + elif self.resblock == "2": + resblock_cls = ResBlock2 + elif self.resblock == "AMP1": + resblock_cls = AMPBlock1 + else: + raise ValueError(f"Unknown resblock type: {self.resblock}") self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): @@ -157,25 +480,40 @@ class Vocoder(torch.nn.Module): self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = upsample_initial_channel // (2 ** (i + 1)) - for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): - self.resblocks.append(resblock_class(ch, k, d)) + for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes): + if self.resblock == "AMP1": + self.resblocks.append(resblock_cls(ch, k, d, activation=activation)) + else: + self.resblocks.append(resblock_cls(ch, k, d)) out_channels = 2 if stereo else 1 - self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3) + if self.resblock == "AMP1": + act_cls = SnakeBeta if activation == "snakebeta" else Snake + self.act_post = Activation1d(act_cls(ch)) + else: + self.act_post = nn.LeakyReLU() + + self.conv_post = ops.Conv1d( + ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final + ) self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))]) + def get_default_config(self): """Generate default configuration for the vocoder.""" config = { "resblock_kernel_sizes": [3, 7, 11], - "upsample_rates": [6, 5, 2, 2, 2], - "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], "upsample_initial_channel": 1024, "stereo": True, "resblock": "1", + "activation": "snake", + "use_bias_at_final": True, + "use_tanh_at_final": True, } return config @@ -196,8 +534,10 @@ class Vocoder(torch.nn.Module): assert x.shape[1] == 2, "Input must have 2 channels for stereo" x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1) x = self.conv_pre(x) + for i in range(self.num_upsamples): - x = F.leaky_relu(x, LRELU_SLOPE) + if self.resblock != "AMP1": + x = F.leaky_relu(x, LRELU_SLOPE) x = self.ups[i](x) xs = None for j in range(self.num_kernels): @@ -206,8 +546,167 @@ class Vocoder(torch.nn.Module): else: xs += self.resblocks[i * self.num_kernels + j](x) x = xs / self.num_kernels - x = F.leaky_relu(x) + + x = self.act_post(x) x = self.conv_post(x) - x = torch.tanh(x) + + if self.apply_final_activation: + if self.use_tanh_at_final: + x = torch.tanh(x) + else: + x = torch.clamp(x, -1, 1) return x + + +class _STFTFn(nn.Module): + """Implements STFT as a convolution with precomputed DFT × Hann-window bases. + + The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal + Hann window are stored as buffers and loaded from the checkpoint. Using the exact + bfloat16 bases from training ensures the mel values fed to the BWE generator are + bit-identical to what it was trained on. + """ + + def __init__(self, filter_length: int, hop_length: int, win_length: int): + super().__init__() + self.hop_length = hop_length + self.win_length = win_length + n_freqs = filter_length // 2 + 1 + self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length)) + self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length)) + + def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Compute magnitude and phase spectrogram from a batch of waveforms. + + Applies causal (left-only) padding of win_length - hop_length samples so that + each output frame depends only on past and present input — no lookahead. + The STFT is computed by convolving the padded signal with forward_basis. + + Args: + y: Waveform tensor of shape (B, T). + + Returns: + magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). + phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). + Computed in float32 for numerical stability, then cast back to + the input dtype. + """ + if y.dim() == 2: + y = y.unsqueeze(1) # (B, 1, T) + left_pad = max(0, self.win_length - self.hop_length) # causal: left-only + y = F.pad(y, (left_pad, 0)) + spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0) + n_freqs = spec.shape[1] // 2 + real, imag = spec[:, :n_freqs], spec[:, n_freqs:] + magnitude = torch.sqrt(real ** 2 + imag ** 2) + phase = torch.atan2(imag.float(), real.float()).to(real.dtype) + return magnitude, phase + + +class MelSTFT(nn.Module): + """Causal log-mel spectrogram module whose buffers are loaded from the checkpoint. + + Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input + waveform and projecting the linear magnitude spectrum onto the mel filterbank. + + The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint + (mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis). + """ + + def __init__( + self, + filter_length: int, + hop_length: int, + win_length: int, + n_mel_channels: int, + sampling_rate: int, + mel_fmin: float, + mel_fmax: float, + ): + super().__init__() + self.stft_fn = _STFTFn(filter_length, hop_length, win_length) + + n_freqs = filter_length // 2 + 1 + self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs)) + + def mel_spectrogram( + self, y: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute log-mel spectrogram and auxiliary spectral quantities. + + Args: + y: Waveform tensor of shape (B, T). + + Returns: + log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames). + Computed as log(clamp(mel_basis @ magnitude, min=1e-5)). + magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames). + phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames). + energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames). + """ + magnitude, phase = self.stft_fn(y) + energy = torch.norm(magnitude, dim=1) + mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude) + log_mel = torch.log(torch.clamp(mel, min=1e-5)) + return log_mel, magnitude, phase, energy + + +class VocoderWithBWE(torch.nn.Module): + """Vocoder with bandwidth extension (BWE) for higher sample rate output. + + Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples + to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform. + """ + + def __init__(self, config): + super().__init__() + vocoder_config = config["vocoder"] + bwe_config = config["bwe"] + + self.vocoder = Vocoder(config=vocoder_config) + self.bwe_generator = Vocoder( + config={**bwe_config, "apply_final_activation": False} + ) + + self.input_sample_rate = bwe_config["input_sampling_rate"] + self.output_sample_rate = bwe_config["output_sampling_rate"] + self.hop_length = bwe_config["hop_length"] + + self.mel_stft = MelSTFT( + filter_length=bwe_config["n_fft"], + hop_length=bwe_config["hop_length"], + win_length=bwe_config["n_fft"], + n_mel_channels=bwe_config["num_mels"], + sampling_rate=bwe_config["input_sampling_rate"], + mel_fmin=0.0, + mel_fmax=bwe_config["input_sampling_rate"] / 2.0, + ) + self.resampler = UpSample1d( + ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"], + persistent=False, + window_type="hann", + ) + + def _compute_mel(self, audio): + """Compute log-mel spectrogram from waveform using causal STFT bases.""" + B, C, T = audio.shape + flat = audio.reshape(B * C, -1) # (B*C, T) + mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames) + return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames) + + def forward(self, mel_spec): + x = self.vocoder(mel_spec) + _, _, T_low = x.shape + T_out = T_low * self.output_sample_rate // self.input_sample_rate + + remainder = T_low % self.hop_length + if remainder != 0: + x = F.pad(x, (0, self.hop_length - remainder)) + + mel = self._compute_mel(x) + residual = self.bwe_generator(mel) + skip = self.resampler(x) + assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}" + + return torch.clamp(residual + skip, -1, 1)[..., :T_out] diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 77d1abc97..9e432d5c0 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -14,6 +14,7 @@ from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.math import apply_rope import comfy.patcher_extension import comfy.utils +from comfy.ldm.chroma_radiance.layers import NerfEmbedder def invert_slices(slices, length): @@ -858,3 +859,267 @@ class NextDiT(nn.Module): img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] return -img + +############################################################################# +# Pixel Space Decoder Components # +############################################################################# + +def _modulate_shift_scale(x, shift, scale): + return x * (1 + scale) + shift + + +class PixelResBlock(nn.Module): + """ + Residual block with AdaLN modulation, zero-initialised so it starts as + an identity at the beginning of training. + """ + + def __init__(self, channels: int, dtype=None, device=None, operations=None): + super().__init__() + self.in_ln = operations.LayerNorm(channels, eps=1e-6, dtype=dtype, device=device) + self.mlp = nn.Sequential( + operations.Linear(channels, channels, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(channels, channels, bias=True, dtype=dtype, device=device), + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device), + ) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1) + h = _modulate_shift_scale(self.in_ln(x), shift, scale) + h = self.mlp(h) + return x + gate * h + + +class DCTFinalLayer(nn.Module): + """Zero-initialised output projection (adopted from DiT).""" + + def __init__(self, model_channels: int, out_channels: int, dtype=None, device=None, operations=None): + super().__init__() + self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(self.norm_final(x)) + + +class SimpleMLPAdaLN(nn.Module): + """ + Small MLP decoder head for the pixel-space variant. + + Takes per-patch pixel values and a per-patch conditioning vector from the + transformer backbone and predicts the denoised pixel values. + + x : [B*N, P^2, C] – noisy pixel values per patch position + c : [B*N, dim] – backbone hidden state per patch (conditioning) + → [B*N, P^2, C] + """ + + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + z_channels: int, + num_res_blocks: int, + max_freqs: int = 8, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + + # Project backbone hidden state → per-patch conditioning + self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device) + + # Input projection with DCT positional encoding + self.input_embedder = NerfEmbedder( + in_channels=in_channels, + hidden_size_input=model_channels, + max_freqs=max_freqs, + dtype=dtype, + device=device, + operations=operations, + ) + + # Residual blocks + self.res_blocks = nn.ModuleList([ + PixelResBlock(model_channels, dtype=dtype, device=device, operations=operations) for _ in range(num_res_blocks) + ]) + + # Output projection + self.final_layer = DCTFinalLayer(model_channels, out_channels, dtype=dtype, device=device, operations=operations) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + # x: [B*N, 1, P^2*C], c: [B*N, dim] + original_dtype = x.dtype + weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype) + x = self.input_embedder(x) # [B*N, 1, model_channels] + y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels] + x = x.to(weight_dtype) + for block in self.res_blocks: + x = block(x, y) + return self.final_layer(x).to(original_dtype) # [B*N, 1, P^2*C] + + +############################################################################# +# NextDiT – Pixel Space # +############################################################################# + +class NextDiTPixelSpace(NextDiT): + """ + Pixel-space variant of NextDiT. + + Identical transformer backbone to NextDiT, but the output head is replaced + with a small MLP decoder (SimpleMLPAdaLN) that operates on raw pixel values + per patch rather than a single affine projection. + + Key differences vs NextDiT: + • ``final_layer`` is removed; ``dec_net`` (SimpleMLPAdaLN) is used instead. + • ``_forward`` stores the raw patchified pixel values before the backbone + embedding and feeds them to ``dec_net`` together with the per-patch + backbone hidden states. + • Supports optional x0 prediction via ``use_x0``. + """ + + def __init__( + self, + # decoder-specific + decoder_hidden_size: int = 3840, + decoder_num_res_blocks: int = 4, + decoder_max_freqs: int = 8, + decoder_in_channels: int = None, # full flattened patch size (patch_size^2 * in_channels) + use_x0: bool = False, + # all NextDiT args forwarded unchanged + **kwargs, + ): + super().__init__(**kwargs) + + # Remove the latent-space final layer – not used in pixel space + del self.final_layer + + patch_size = kwargs.get("patch_size", 2) + in_channels = kwargs.get("in_channels", 4) + dim = kwargs.get("dim", 4096) + + # decoder_in_channels is the full flattened patch: patch_size^2 * in_channels + dec_in_ch = decoder_in_channels if decoder_in_channels is not None else patch_size ** 2 * in_channels + + self.dec_net = SimpleMLPAdaLN( + in_channels=dec_in_ch, + model_channels=decoder_hidden_size, + out_channels=dec_in_ch, + z_channels=dim, + num_res_blocks=decoder_num_res_blocks, + max_freqs=decoder_max_freqs, + dtype=kwargs.get("dtype"), + device=kwargs.get("device"), + operations=kwargs.get("operations"), + ) + + if use_x0: + self.register_buffer("__x0__", torch.tensor([])) + + # ------------------------------------------------------------------ + # Forward — mirrors NextDiT._forward exactly, replacing final_layer + # with the pixel-space dec_net decoder. + # ------------------------------------------------------------------ + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs): + omni = len(ref_latents) > 0 + if omni: + timesteps = torch.cat([timesteps * 0, timesteps], dim=0) + + t = 1.0 - timesteps + cap_feats = context + cap_mask = attention_mask + bs, c, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + + t = self.t_embedder(t * self.time_scale, dtype=x.dtype) + adaln_input = t + + if self.clip_text_pooled_proj is not None: + pooled = kwargs.get("clip_text_pooled", None) + if pooled is not None: + pooled = self.clip_text_pooled_proj(pooled) + else: + pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype) + adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) + + # ---- capture raw pixel patches before patchify_and_embed embeds them ---- + pH = pW = self.patch_size + B, C, H, W = x.shape + pixel_patches = ( + x.view(B, C, H // pH, pH, W // pW, pW) + .permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C] + .flatten(3) # [B, Ht, Wt, pH*pW*C] + .flatten(1, 2) # [B, N, pH*pW*C] + ) + N = pixel_patches.shape[1] + # decoder sees one token per patch: [B*N, 1, P^2*C] + pixel_values = pixel_patches.reshape(B * N, 1, pH * pW * C) + + patches = transformer_options.get("patches", {}) + x_is_tensor = isinstance(x, torch.Tensor) + img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed( + x, cap_feats, cap_mask, adaln_input, num_tokens, + ref_latents=ref_latents, ref_contexts=ref_contexts, + siglip_feats=siglip_feats, transformer_options=transformer_options + ) + freqs_cis = freqs_cis.to(img.device) + + transformer_options["total_blocks"] = len(self.layers) + transformer_options["block_type"] = "double" + img_input = img + for i, layer in enumerate(self.layers): + transformer_options["block_index"] = i + img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options) + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) + if "img" in out: + img[:, cap_size[0]:] = out["img"] + if "txt" in out: + img[:, :cap_size[0]] = out["txt"] + + # ---- pixel-space decoder (replaces final_layer + unpatchify) ---- + # img may have padding tokens beyond N; only the first N are real image patches + img_hidden = img[:, cap_size[0]:cap_size[0] + N, :] # [B, N, dim] + decoder_cond = img_hidden.reshape(B * N, self.dim) # [B*N, dim] + + output = self.dec_net(pixel_values, decoder_cond) # [B*N, 1, P^2*C] + output = output.reshape(B, N, -1) # [B, N, P^2*C] + + # prepend zero cap placeholder so unpatchify indexing works unchanged + cap_placeholder = torch.zeros( + B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype + ) + img_out = self.unpatchify( + torch.cat([cap_placeholder, output], dim=1), + img_size, cap_size, return_tensor=x_is_tensor + )[:, :, :h, :w] + + return -img_out + + def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): + # _forward returns neg_x0 = -x0 (negated decoder output). + # + # Reference inference (working_inference_reference.py): + # out = _forward(img, t) # = -x0 + # pred = (img - out) / t # = (img + x0) / t [_apply_x0_residual] + # img += (t_prev - t_curr) * pred # Euler step + # + # ComfyUI's Euler sampler does the same: + # x_next = x + (sigma_next - sigma) * model_output + # So model_output must equal pred = (x - neg_x0) / t = (x - (-x0)) / t = (x + x0) / t + neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {})) + ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) + + return (x - neg_x0) / timesteps.view(-1, 1, 1, 1) diff --git a/comfy/model_base.py b/comfy/model_base.py index a1c690b9b..d9d5a9293 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1021,7 +1021,7 @@ class LTXAV(BaseModel): cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: if hasattr(self.diffusion_model, "preprocess_text_embeds"): - cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference())) + cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False)) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) @@ -1263,6 +1263,11 @@ class Lumina2(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) return out +class ZImagePixelSpace(Lumina2): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace) + self.memory_usage_factor_conds = ("ref_latents",) + class WAN21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 3faa950ca..6eace4628 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -423,7 +423,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["extra_per_block_abs_pos_emb_type"] = "learnable" return dit_config - if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 + if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 dit_config = {} dit_config["image_model"] = "lumina2" dit_config["patch_size"] = 2 @@ -464,6 +464,29 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if sig_weight is not None: dit_config["siglip_feat_dim"] = sig_weight.shape[0] + dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix) + if dec_cond_key in state_dict_keys: # pixel-space variant + dit_config["image_model"] = "zimage_pixel" + # patch_size and in_channels are derived from x_embedder: + # x_embedder: Linear(patch_size * patch_size * in_channels, dim) + # The decoder also receives the full flat patch, so decoder_in_channels = x_embedder input dim. + x_emb_in = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1] + dec_out = state_dict['{}dec_net.final_layer.linear.weight'.format(key_prefix)].shape[0] + # patch_size: infer from decoder final layer output matching x_embedder input + # in_channels: infer from dec_net input_embedder (in_features = dec_in_ch + max_freqs^2) + embedder_w = state_dict['{}dec_net.input_embedder.embedder.0.weight'.format(key_prefix)] + dec_in_ch = dec_out # decoder in == decoder out (same pixel space) + dit_config["patch_size"] = round((x_emb_in / 3) ** 0.5) # assume RGB (in_channels=3) + dit_config["in_channels"] = 3 + dit_config["decoder_in_channels"] = dec_in_ch + dit_config["decoder_hidden_size"] = state_dict[dec_cond_key].shape[0] + dit_config["decoder_num_res_blocks"] = count_blocks( + state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.' + ) + dit_config["decoder_max_freqs"] = int((embedder_w.shape[1] - dec_in_ch) ** 0.5) + if '{}__x0__'.format(key_prefix) in state_dict_keys: + dit_config["use_x0"] = True + return dit_config if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 @@ -533,8 +556,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys return dit_config - if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys: # Hunyuan 3D 2.1 - + if f"{key_prefix}t_embedder.mlp.2.weight" in state_dict_keys and f"{key_prefix}blocks.0.attn1.k_norm.weight" in state_dict_keys: # Hunyuan 3D 2.1 dit_config = {} dit_config["image_model"] = "hunyuan3d2_1" dit_config["in_channels"] = state_dict[f"{key_prefix}x_embedder.weight"].shape[1] @@ -1055,6 +1077,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix) + elif 'noise_refiner.0.attention.norm_k.weight' in state_dict: + n_layers = count_blocks(state_dict, 'layers.{}.') + dim = state_dict['noise_refiner.0.attention.to_k.weight'].shape[0] + sd_map = comfy.utils.z_image_to_diffusers({"n_layers": n_layers, "dim": dim}, output_prefix=output_prefix) + for k in state_dict: # For zeta chroma + if k not in sd_map: + sd_map[k] = k elif 'x_embedder.weight' in state_dict: #Flux depth = count_blocks(state_dict, 'transformer_blocks.{}.') depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') diff --git a/comfy/model_management.py b/comfy/model_management.py index c817d43b5..ee28ea107 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -32,9 +32,6 @@ import comfy.memory_management import comfy.utils import comfy.quant_ops -import comfy_aimdo.torch -import comfy_aimdo.model_vbar - class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram NO_VRAM = 1 #Very low vram: enable all the options to save vram @@ -799,6 +796,8 @@ def archive_model_dtypes(model): for name, module in model.named_modules(): for param_name, param in module.named_parameters(recurse=False): setattr(module, f"{param_name}_comfy_model_dtype", param.dtype) + for buf_name, buf in module.named_buffers(recurse=False): + setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype) def cleanup_models(): @@ -831,11 +830,14 @@ def unet_offload_device(): return torch.device("cpu") def unet_inital_load_device(parameters, dtype): + cpu_dev = torch.device("cpu") + if comfy.memory_management.aimdo_enabled: + return cpu_dev + torch_dev = get_torch_device() if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED: return torch_dev - cpu_dev = torch.device("cpu") if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM: return cpu_dev @@ -843,7 +845,7 @@ def unet_inital_load_device(parameters, dtype): mem_dev = get_free_memory(torch_dev) mem_cpu = get_free_memory(cpu_dev) - if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled: + if mem_dev > mem_cpu and model_size < mem_dev: return torch_dev else: return cpu_dev @@ -946,6 +948,9 @@ def text_encoder_device(): return torch.device("cpu") def text_encoder_initial_device(load_device, offload_device, model_size=0): + if comfy.memory_management.aimdo_enabled: + return offload_device + if load_device == offload_device or model_size <= 1024 * 1024 * 1024: return offload_device @@ -1206,43 +1211,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None): def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): - if hasattr(weight, "_v"): - #Unexpected usage patterns. There is no reason these don't work but they - #have no testing and no callers do this. - assert r is None - assert stream is None - - cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ]) - - if dtype is None: - dtype = weight._model_dtype - - signature = comfy_aimdo.model_vbar.vbar_fault(weight._v) - if signature is not None: - if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature): - v_tensor = weight._v_tensor - else: - raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device) - v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0] - weight._v_tensor = v_tensor - weight._v_signature = signature - #Send it over - v_tensor.copy_(weight, non_blocking=non_blocking) - return v_tensor.to(dtype=dtype) - - r = torch.empty_like(weight, dtype=dtype, device=device) - - if weight.dtype != r.dtype and weight.dtype != weight._model_dtype: - #Offloaded casting could skip this, however it would make the quantizations - #inconsistent between loaded and offloaded weights. So force the double casting - #that would happen in regular flow to make offload deterministic. - cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device) - cast_buffer.copy_(weight, non_blocking=non_blocking) - weight = cast_buffer - r.copy_(weight, non_blocking=non_blocking) - - return r - if device is None or weight.device == device: if not copy: if dtype is None or weight.dtype == dtype: @@ -1698,12 +1666,16 @@ def lora_compute_dtype(device): return dtype def synchronize(): + if cpu_mode(): + return if is_intel_xpu(): torch.xpu.synchronize() elif torch.cuda.is_available(): torch.cuda.synchronize() def soft_empty_cache(force=False): + if cpu_mode(): + return global cpu_state if cpu_state == CPUState.MPS: torch.mps.empty_cache() diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3fc76d9db..7e5ad7aa4 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -241,6 +241,7 @@ class ModelPatcher: self.patches = {} self.backup = {} + self.backup_buffers = {} self.object_patches = {} self.object_patches_backup = {} self.weight_wrapper_patches = {} @@ -306,10 +307,16 @@ class ModelPatcher: return self.model.lowvram_patch_counter def get_free_memory(self, device): - return comfy.model_management.get_free_memory(device) + #Prioritize batching (incl. CFG/conds etc) over keeping the model resident. In + #the vast majority of setups a little bit of offloading on the giant model more + #than pays for CFG. So return everything both torch and Aimdo could give us + aimdo_mem = 0 + if comfy.memory_management.aimdo_enabled: + aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze() + return comfy.model_management.get_free_memory(device) + aimdo_mem def get_clone_model_override(self): - return self.model, (self.backup, self.object_patches_backup, self.pinned) + return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned) def clone(self, disable_dynamic=False, model_override=None): class_ = self.__class__ @@ -336,7 +343,7 @@ class ModelPatcher: n.force_cast_weights = self.force_cast_weights - n.backup, n.object_patches_backup, n.pinned = model_override[1] + n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1] # attachments n.attachments = {} @@ -698,7 +705,7 @@ class ModelPatcher: for key in list(self.pinned): self.unpin_weight(key) - def _load_list(self, prio_comfy_cast_weights=False, default_device=None): + def _load_list(self, for_dynamic=False, default_device=None): loading = [] for n, m in self.model.named_modules(): default = False @@ -726,8 +733,13 @@ class ModelPatcher: return 0 module_offload_mem += check_module_offload_mem("{}.weight".format(n)) module_offload_mem += check_module_offload_mem("{}.bias".format(n)) - prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else () - loading.append(prepend + (module_offload_mem, module_mem, n, m, params)) + # Dynamic: small weights (<64KB) first, then larger weights prioritized by size. + # Non-dynamic: prioritize by module offload cost. + if for_dynamic: + sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem) + else: + sort_criteria = (module_offload_mem,) + loading.append(sort_criteria + (module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -1435,10 +1447,6 @@ class ModelPatcherDynamic(ModelPatcher): def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): super().__init__(model, load_device, offload_device, size, weight_inplace_update) - #this is now way more dynamic and we dont support the same base model for both Dynamic - #and non-dynamic patchers. - if hasattr(self.model, "model_loaded_weight_memory"): - del self.model.model_loaded_weight_memory if not hasattr(self.model, "dynamic_vbars"): self.model.dynamic_vbars = {} self.non_dynamic_delegate_model = None @@ -1461,15 +1469,7 @@ class ModelPatcherDynamic(ModelPatcher): def loaded_size(self): vbar = self._vbar_get() - if vbar is None: - return 0 - return vbar.loaded_size() - - def get_free_memory(self, device): - #NOTE: on high condition / batch counts, estimate should have already vacated - #all non-dynamic models so this is safe even if its not 100% true that this - #would all be avaiable for inference use. - return comfy.model_management.get_total_memory(device) - self.model_size() + return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory #Pinning is deferred to ops time. Assert against this API to avoid pin leaks. @@ -1504,6 +1504,7 @@ class ModelPatcherDynamic(ModelPatcher): num_patches = 0 allocated_size = 0 + self.model.model_loaded_weight_memory = 0 with self.use_ejected(): self.unpatch_hooks() @@ -1512,15 +1513,11 @@ class ModelPatcherDynamic(ModelPatcher): if vbar is not None: vbar.prioritize() - #We force reserve VRAM for the non comfy-weight so we dont have to deal - #with pin and unpin syncrhonization which can be expensive for small weights - #with a high layer rate (e.g. autoregressive LLMs). - #prioritize the non-comfy weights (note the order reverse). - loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to) - loading.sort(reverse=True) + loading = self._load_list(for_dynamic=True, default_device=device_to) + loading.sort() for x in loading: - _, _, _, n, m, params = x + *_, module_mem, n, m, params = x def set_dirty(item, dirty): if dirty or not hasattr(item, "_v_signature"): @@ -1558,6 +1555,9 @@ class ModelPatcherDynamic(ModelPatcher): if key in self.backup: comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) self.patch_weight_to_device(key, device_to=device_to) + weight, _, _ = get_key_weight(self.model, key) + if weight is not None: + self.model.model_loaded_weight_memory += weight.numel() * weight.element_size() if hasattr(m, "comfy_cast_weights"): m.comfy_cast_weights = True @@ -1583,21 +1583,26 @@ class ModelPatcherDynamic(ModelPatcher): for param in params: key = key_param_name_to_key(n, param) weight, _, _ = get_key_weight(self.model, key) - weight.seed_key = key - set_dirty(weight, dirty) - geometry = weight - model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype - geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) - weight_size = geometry.numel() * geometry.element_size() - if vbar is not None and not hasattr(weight, "_v"): - weight._v = vbar.alloc(weight_size) - weight._model_dtype = model_dtype - allocated_size += weight_size - vbar.set_watermark_limit(allocated_size) + if key not in self.backup: + self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False) + model_dtype = getattr(m, param + "_comfy_model_dtype", None) + casted_weight = weight.to(dtype=model_dtype, device=device_to) + comfy.utils.set_attr_param(self.model, key, casted_weight) + self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size() move_weight_functions(m, device_to) - logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.") + for key, buf in self.model.named_buffers(recurse=True): + if key not in self.backup_buffers: + self.backup_buffers[key] = buf + module, buf_name = comfy.utils.resolve_attr(self.model, key) + model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None) + casted_buf = buf.to(dtype=model_dtype, device=device_to) + comfy.utils.set_attr_buffer(self.model, key, casted_buf) + self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size() + + force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else "" + logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}") self.model.device = device_to self.model.current_weight_patches_uuid = self.patches_uuid @@ -1613,12 +1618,23 @@ class ModelPatcherDynamic(ModelPatcher): assert self.load_device != torch.device("cpu") vbar = self._vbar_get() - return 0 if vbar is None else vbar.free_memory(memory_to_free) + freed = 0 if vbar is None else vbar.free_memory(memory_to_free) + + if freed < memory_to_free: + for key in list(self.backup.keys()): + bk = self.backup.pop(key) + comfy.utils.set_attr_param(self.model, key, bk.weight) + for key in list(self.backup_buffers.keys()): + comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key)) + freed += self.model.model_loaded_weight_memory + self.model.model_loaded_weight_memory = 0 + + return freed def partially_unload_ram(self, ram_to_unload): - loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device) + loading = self._load_list(for_dynamic=True, default_device=self.offload_device) for x in loading: - _, _, _, _, m, _ = x + *_, m, _ = x ram_to_unload -= comfy.pinned_memory.unpin_memory(m) if ram_to_unload <= 0: return @@ -1640,11 +1656,6 @@ class ModelPatcherDynamic(ModelPatcher): for m in self.model.modules(): move_weight_functions(m, device_to) - keys = list(self.backup.keys()) - for k in keys: - bk = self.backup[k] - comfy.utils.set_attr_param(self.model, k, bk.weight) - def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): assert not force_patch_weights #See above with self.use_ejected(skip_and_inject_on_exit_only=True): diff --git a/comfy/ops.py b/comfy/ops.py index 6ee6075fb..3e19cd1b6 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -269,8 +269,8 @@ def uncast_bias_weight(s, weight, bias, offload_stream): return os, weight_a, bias_a = offload_stream device=None - #FIXME: This is not good RTTI - if not isinstance(weight_a, torch.Tensor): + #FIXME: This is really bad RTTI + if weight_a is not None and not isinstance(weight_a, torch.Tensor): comfy_aimdo.model_vbar.vbar_unpin(s._v) device = weight_a if os is None: @@ -660,23 +660,29 @@ class fp8_ops(manual_cast): CUBLAS_IS_AVAILABLE = False try: - from cublas_ops import CublasLinear + from cublas_ops import CublasLinear, cublas_half_matmul CUBLAS_IS_AVAILABLE = True except ImportError: pass if CUBLAS_IS_AVAILABLE: - class cublas_ops(disable_weight_init): - class Linear(CublasLinear, disable_weight_init.Linear): + class cublas_ops(manual_cast): + class Linear(CublasLinear, manual_cast.Linear): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): - return super().forward(input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): - return super().forward(*args, **kwargs) - + run_every_op() + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(*args, **kwargs) + else: + return super().forward(*args, **kwargs) # ============================================================================== # Mixed Precision Operations diff --git a/comfy/sd.py b/comfy/sd.py index a9ad7c2d2..888ef1e77 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -428,7 +428,7 @@ class CLIP: def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None): self.cond_stage_model.reset_clip_options() - self.load_model() + self.load_model(tokens) self.cond_stage_model.set_clip_options({"layer": None}) self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed) @@ -1467,7 +1467,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage elif clip_type == CLIPType.LTXV: - clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data)) + clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data), **comfy.text_encoders.lt.sd_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.NEWBIE: diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 4f63e8327..07feb31b3 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1118,6 +1118,20 @@ class ZImage(Lumina2): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect)) +class ZImagePixelSpace(ZImage): + unet_config = { + "image_model": "zimage_pixel", + } + + # Pixel-space model: no spatial compression, operates on raw RGB patches. + latent_format = latent_formats.ZImagePixelSpace + + # Much lower memory than latent-space models (no VAE, small patches). + memory_usage_factor = 0.03 # TODO: figure out the optimal value for this. + + def get_model(self, state_dict, prefix="", device=None): + return model_base.ZImagePixelSpace(self, device=device) + class WAN21_T2V(supported_models_base.BASE): unet_config = { "image_model": "wan2.1", @@ -1720,6 +1734,6 @@ class LongCatImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index e86ea9f4e..5e1273c6e 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -97,18 +97,39 @@ class Gemma3_12BModel(sd1_clip.SDClipModel): comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5) return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is +class DualLinearProjection(torch.nn.Module): + def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None): + super().__init__() + self.audio_aggregate_embed = operations.Linear(in_dim, out_dim_audio, bias=True, dtype=dtype, device=device) + self.video_aggregate_embed = operations.Linear(in_dim, out_dim_video, bias=True, dtype=dtype, device=device) + + def forward(self, x): + source_dim = x.shape[-1] + x = x.movedim(1, -1) + x = (x * torch.rsqrt(torch.mean(x**2, dim=2, keepdim=True) + 1e-6)).flatten(start_dim=2) + + video = self.video_aggregate_embed(x * math.sqrt(self.video_aggregate_embed.out_features / source_dim)) + audio = self.audio_aggregate_embed(x * math.sqrt(self.audio_aggregate_embed.out_features / source_dim)) + return torch.cat((video, audio), dim=-1) + class LTXAVTEModel(torch.nn.Module): - def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): + def __init__(self, dtype_llama=None, device="cpu", dtype=None, text_projection_type="single_linear", model_options={}): super().__init__() self.dtypes = set() self.dtypes.add(dtype) self.compat_mode = False + self.text_projection_type = text_projection_type self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None) self.dtypes.add(dtype_llama) operations = self.gemma3_12b.operations # TODO - self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) + + if self.text_projection_type == "single_linear": + self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) + elif self.text_projection_type == "dual_linear": + self.text_embedding_projection = DualLinearProjection(3840 * 49, 4096, 2048, dtype=dtype, device=device, operations=operations) + def enable_compat_mode(self): # TODO: remove from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector @@ -148,18 +169,25 @@ class LTXAVTEModel(torch.nn.Module): out_device = out.device if comfy.model_management.should_use_bf16(self.execution_device): out = out.to(device=self.execution_device, dtype=torch.bfloat16) - out = out.movedim(1, -1).to(self.execution_device) - out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) - out = out.reshape((out.shape[0], out.shape[1], -1)) - out = self.text_embedding_projection(out) - out = out.float() - if self.compat_mode: - out_vid = self.video_embeddings_connector(out)[0] - out_audio = self.audio_embeddings_connector(out)[0] - out = torch.concat((out_vid, out_audio), dim=-1) + if self.text_projection_type == "single_linear": + out = out.movedim(1, -1).to(self.execution_device) + out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) + out = out.reshape((out.shape[0], out.shape[1], -1)) + out = self.text_embedding_projection(out) - return out.to(out_device), pooled + if self.compat_mode: + out_vid = self.video_embeddings_connector(out)[0] + out_audio = self.audio_embeddings_connector(out)[0] + out = torch.concat((out_vid, out_audio), dim=-1) + extra = {} + else: + extra = {"unprocessed_ltxav_embeds": True} + elif self.text_projection_type == "dual_linear": + out = self.text_embedding_projection(out) + extra = {"unprocessed_ltxav_embeds": True} + + return out.to(device=out_device, dtype=torch.float), pooled, extra def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed): return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed) @@ -168,7 +196,7 @@ class LTXAVTEModel(torch.nn.Module): if "model.layers.47.self_attn.q_norm.weight" in sd: return self.gemma3_12b.load_sd(sd) else: - sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True) + sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "text_embedding_projection.": "text_embedding_projection."}, filter_keys=True) if len(sdo) == 0: sdo = sd @@ -206,7 +234,7 @@ class LTXAVTEModel(torch.nn.Module): num_tokens = max(num_tokens, 642) return num_tokens * constant * 1024 * 1024 -def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): +def ltxav_te(dtype_llama=None, llama_quantization_metadata=None, text_projection_type="single_linear"): class LTXAVTEModel_(LTXAVTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): if llama_quantization_metadata is not None: @@ -214,9 +242,19 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): model_options["llama_quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama - super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) + super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, text_projection_type=text_projection_type, model_options=model_options) return LTXAVTEModel_ + +def sd_detect(state_dict_list, prefix=""): + for sd in state_dict_list: + if "{}text_embedding_projection.audio_aggregate_embed.bias".format(prefix) in sd: + return {"text_projection_type": "dual_linear"} + if "{}text_embedding_projection.weight".format(prefix) in sd or "{}text_embedding_projection.aggregate_embed.weight".format(prefix) in sd: + return {"text_projection_type": "single_linear"} + return {} + + def gemma3_te(dtype_llama=None, llama_quantization_metadata=None): class Gemma3_12BModel_(Gemma3_12BModel): def __init__(self, device="cpu", dtype=None, model_options={}): diff --git a/comfy/utils.py b/comfy/utils.py index 0769cef44..6e1d14419 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -869,20 +869,31 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024): ATTR_UNSET={} -def set_attr(obj, attr, value): +def resolve_attr(obj, attr): attrs = attr.split(".") for name in attrs[:-1]: obj = getattr(obj, name) - prev = getattr(obj, attrs[-1], ATTR_UNSET) + return obj, attrs[-1] + +def set_attr(obj, attr, value): + obj, name = resolve_attr(obj, attr) + prev = getattr(obj, name, ATTR_UNSET) if value is ATTR_UNSET: - delattr(obj, attrs[-1]) + delattr(obj, name) else: - setattr(obj, attrs[-1], value) + setattr(obj, name, value) return prev def set_attr_param(obj, attr, value): return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) +def set_attr_buffer(obj, attr, value): + obj, name = resolve_attr(obj, attr) + prev = getattr(obj, name, ATTR_UNSET) + persistent = name not in getattr(obj, "_non_persistent_buffers_set", set()) + obj.register_buffer(name, value, persistent=persistent) + return prev + def copy_to_param(obj, attr, value): # inplace update tensor instead of replacing it attrs = attr.split(".") diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index a3d48c87f..58a37c9e8 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -401,6 +401,7 @@ class VideoFromComponents(VideoInput): codec: VideoCodec = VideoCodec.AUTO, metadata: Optional[dict] = None, ): + """Save the video to a file path or BytesIO buffer.""" if format != VideoContainer.AUTO and format != VideoContainer.MP4: raise ValueError("Only MP4 format is supported for now") if codec != VideoCodec.AUTO and codec != VideoCodec.H264: @@ -408,6 +409,10 @@ class VideoFromComponents(VideoInput): extra_kwargs = {} if isinstance(format, VideoContainer) and format != VideoContainer.AUTO: extra_kwargs["format"] = format.value + elif isinstance(path, io.BytesIO): + # BytesIO has no file extension, so av.open can't infer the format. + # Default to mp4 since that's the only supported format anyway. + extra_kwargs["format"] = "mp4" with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output: # Add metadata before writing any streams if metadata is not None: diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 189d7d9bc..050031dc0 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -1240,6 +1240,19 @@ class BoundingBox(ComfyTypeIO): return d +@comfytype(io_type="CURVE") +class Curve(ComfyTypeIO): + CurvePoint = tuple[float, float] + Type = list[CurvePoint] + + class Input(WidgetInput): + def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, + socketless: bool=True, default: list[tuple[float, float]]=None, advanced: bool=None): + super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced) + if default is None: + self.default = [(0.0, 0.0), (1.0, 1.0)] + + DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): DYNAMIC_INPUT_LOOKUP[io_type] = func @@ -2226,5 +2239,6 @@ __all__ = [ "PriceBadgeDepends", "PriceBadge", "BoundingBox", + "Curve", "NodeReplace", ] diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 43df0512f..5d8d9bf6f 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -96,7 +96,7 @@ class VAEEncodeAudio(IO.ComfyNode): def vae_decode_audio(vae, samples, tile=None, overlap=None): if tile is not None: - audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1) + audio = vae.decode_tiled(samples["samples"], tile_x=tile, tile_y=tile, overlap=overlap).movedim(-1, 1) else: audio = vae.decode(samples["samples"]).movedim(-1, 1) diff --git a/comfyui_version.py b/comfyui_version.py index 6a35c6de3..0aea18d3a 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.15.1" +__version__ = "0.16.0" diff --git a/execution.py b/execution.py index 75b021892..7ccdbf93e 100644 --- a/execution.py +++ b/execution.py @@ -876,12 +876,14 @@ async def validate_inputs(prompt_id, prompt, item, validated): continue else: try: - # Unwraps values wrapped in __value__ key. This is used to pass - # list widget value to execution, as by default list value is - # reserved to represent the connection between nodes. - if isinstance(val, dict) and "__value__" in val: - val = val["__value__"] - inputs[x] = val + # Unwraps values wrapped in __value__ key or typed wrapper. + # This is used to pass list widget values to execution, + # as by default list value is reserved to represent the + # connection between nodes. + if isinstance(val, dict): + if "__value__" in val: + val = val["__value__"] + inputs[x] = val if input_type == "INT": val = int(val) diff --git a/main.py b/main.py index af701f8df..0f58d57b8 100644 --- a/main.py +++ b/main.py @@ -16,11 +16,6 @@ from comfy_execution.progress import get_progress_state from comfy_execution.utils import get_executing_context from comfy_api import feature_flags -import comfy_aimdo.control - -if enables_dynamic_vram(): - comfy_aimdo.control.init() - if __name__ == "__main__": #NOTE: These do not do anything on core ComfyUI, they are for custom nodes. os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' @@ -28,6 +23,11 @@ if __name__ == "__main__": setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) +import comfy_aimdo.control + +if enables_dynamic_vram(): + comfy_aimdo.control.init() + if os.name == "nt": os.environ['MIMALLOC_PURGE_DELAY'] = '0' diff --git a/pyproject.toml b/pyproject.toml index 1b2318273..f2133d99c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.15.1" +version = "0.16.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" diff --git a/requirements.txt b/requirements.txt index 71019c16f..866818e08 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.39.19 -comfyui-workflow-templates==0.9.4 +comfyui-workflow-templates==0.9.7 comfyui-embedded-docs==0.4.3 torch torchsde @@ -22,7 +22,7 @@ alembic SQLAlchemy av>=14.2.0 comfy-kitchen>=0.2.7 -comfy-aimdo>=0.2.4 +comfy-aimdo>=0.2.6 requests #non essential dependencies: