From f2b002372b71cf0671a4cf1fa539e1c386d727e4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 4 Jan 2026 22:58:59 -0800 Subject: [PATCH] Support the LTXV 2 model. (#11632) --- comfy/latent_formats.py | 3 + comfy/ldm/lightricks/av_model.py | 837 ++++++++++++++++ comfy/ldm/lightricks/embeddings_connector.py | 305 ++++++ comfy/ldm/lightricks/latent_upsampler.py | 292 ++++++ comfy/ldm/lightricks/model.py | 715 +++++++++++--- comfy/ldm/lightricks/symmetric_patchifier.py | 87 +- comfy/ldm/lightricks/vae/audio_vae.py | 286 ++++++ .../vae/causal_audio_autoencoder.py | 909 ++++++++++++++++++ comfy/ldm/lightricks/vocoders/vocoder.py | 213 ++++ comfy/model_base.py | 57 +- comfy/model_detection.py | 2 +- comfy/sd.py | 9 +- comfy/supported_models.py | 17 +- comfy/text_encoders/llama.py | 79 ++ comfy/text_encoders/lt.py | 111 +++ comfy/utils.py | 2 +- comfy_extras/nodes_audio.py | 2 +- comfy_extras/nodes_hunyuan.py | 15 +- comfy_extras/nodes_lt.py | 188 +++- comfy_extras/nodes_lt_audio.py | 183 ++++ comfy_extras/nodes_lt_upsampler.py | 75 ++ nodes.py | 10 +- pyproject.toml | 2 +- 23 files changed, 4214 insertions(+), 185 deletions(-) create mode 100644 comfy/ldm/lightricks/av_model.py create mode 100644 comfy/ldm/lightricks/embeddings_connector.py create mode 100644 comfy/ldm/lightricks/latent_upsampler.py create mode 100644 comfy/ldm/lightricks/vae/audio_vae.py create mode 100644 comfy/ldm/lightricks/vae/causal_audio_autoencoder.py create mode 100644 comfy/ldm/lightricks/vocoders/vocoder.py create mode 100644 comfy_extras/nodes_lt_audio.py create mode 100644 comfy_extras/nodes_lt_upsampler.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index f1ca0151e..9bbe30b53 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -407,6 +407,9 @@ class LTXV(LatentFormat): self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512] +class LTXAV(LTXV): + pass + class HunyuanVideo(LatentFormat): latent_channels = 16 latent_dimensions = 3 diff --git a/comfy/ldm/lightricks/av_model.py b/comfy/ldm/lightricks/av_model.py new file mode 100644 index 000000000..759535501 --- /dev/null +++ b/comfy/ldm/lightricks/av_model.py @@ -0,0 +1,837 @@ +from typing import Tuple +import torch +import torch.nn as nn +from comfy.ldm.lightricks.model import ( + CrossAttention, + FeedForward, + AdaLayerNormSingle, + PixArtAlphaTextProjection, + LTXVModel, +) +from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier +import comfy.ldm.common_dit + +class BasicAVTransformerBlock(nn.Module): + def __init__( + self, + v_dim, + a_dim, + v_heads, + a_heads, + vd_head, + ad_head, + v_context_dim=None, + a_context_dim=None, + attn_precision=None, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + + self.attn_precision = attn_precision + + self.attn1 = CrossAttention( + query_dim=v_dim, + heads=v_heads, + dim_head=vd_head, + context_dim=None, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + self.audio_attn1 = CrossAttention( + query_dim=a_dim, + heads=a_heads, + dim_head=ad_head, + context_dim=None, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + self.attn2 = CrossAttention( + query_dim=v_dim, + context_dim=v_context_dim, + heads=v_heads, + dim_head=vd_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + self.audio_attn2 = CrossAttention( + query_dim=a_dim, + context_dim=a_context_dim, + heads=a_heads, + dim_head=ad_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + # Q: Video, K,V: Audio + self.audio_to_video_attn = CrossAttention( + query_dim=v_dim, + context_dim=a_dim, + heads=a_heads, + dim_head=ad_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + # Q: Audio, K,V: Video + self.video_to_audio_attn = CrossAttention( + query_dim=a_dim, + context_dim=v_dim, + heads=a_heads, + dim_head=ad_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) + + self.ff = FeedForward( + v_dim, dim_out=v_dim, glu=True, dtype=dtype, device=device, operations=operations + ) + self.audio_ff = FeedForward( + a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations + ) + + self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype)) + self.audio_scale_shift_table = nn.Parameter( + torch.empty(6, a_dim, device=device, dtype=dtype) + ) + + self.scale_shift_table_a2v_ca_audio = nn.Parameter( + torch.empty(5, a_dim, device=device, dtype=dtype) + ) + self.scale_shift_table_a2v_ca_video = nn.Parameter( + torch.empty(5, v_dim, device=device, dtype=dtype) + ) + + def get_ada_values( + self, scale_shift_table: torch.Tensor, batch_size: int, timestep: torch.Tensor, indices: slice = slice(None, None) + ): + num_ada_params = scale_shift_table.shape[0] + + ada_values = ( + scale_shift_table[indices].unsqueeze(0).unsqueeze(0).to(device=timestep.device, dtype=timestep.dtype) + + timestep.reshape(batch_size, timestep.shape[1], num_ada_params, -1)[:, :, indices, :] + ).unbind(dim=2) + return ada_values + + def get_av_ca_ada_values( + self, + scale_shift_table: torch.Tensor, + batch_size: int, + scale_shift_timestep: torch.Tensor, + gate_timestep: torch.Tensor, + num_scale_shift_values: int = 4, + ): + scale_shift_ada_values = self.get_ada_values( + scale_shift_table[:num_scale_shift_values, :], + batch_size, + scale_shift_timestep, + ) + gate_ada_values = self.get_ada_values( + scale_shift_table[num_scale_shift_values:, :], + batch_size, + gate_timestep, + ) + + scale_shift_chunks = [t.squeeze(2) for t in scale_shift_ada_values] + gate_ada_values = [t.squeeze(2) for t in gate_ada_values] + + return (*scale_shift_chunks, *gate_ada_values) + + def forward( + self, + x: Tuple[torch.Tensor, torch.Tensor], + v_context=None, + a_context=None, + attention_mask=None, + v_timestep=None, + a_timestep=None, + v_pe=None, + a_pe=None, + v_cross_pe=None, + a_cross_pe=None, + v_cross_scale_shift_timestep=None, + a_cross_scale_shift_timestep=None, + v_cross_gate_timestep=None, + a_cross_gate_timestep=None, + transformer_options=None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + run_vx = transformer_options.get("run_vx", True) + run_ax = transformer_options.get("run_ax", True) + + vx, ax = x + run_ax = run_ax and ax.numel() > 0 + run_a2v = run_vx and transformer_options.get("a2v_cross_attn", True) and ax.numel() > 0 + run_v2a = run_ax and transformer_options.get("v2a_cross_attn", True) + + if run_vx: + vshift_msa, vscale_msa, vgate_msa = ( + self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(0, 3)) + ) + + norm_vx = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_msa) + vshift_msa + vx += self.attn1(norm_vx, pe=v_pe, transformer_options=transformer_options) * vgate_msa + vx += self.attn2( + comfy.ldm.common_dit.rms_norm(vx), + context=v_context, + mask=attention_mask, + transformer_options=transformer_options, + ) + + del vshift_msa, vscale_msa, vgate_msa + + if run_ax: + ashift_msa, ascale_msa, agate_msa = ( + self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(0, 3)) + ) + + norm_ax = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_msa) + ashift_msa + ax += ( + self.audio_attn1(norm_ax, pe=a_pe, transformer_options=transformer_options) + * agate_msa + ) + ax += self.audio_attn2( + comfy.ldm.common_dit.rms_norm(ax), + context=a_context, + mask=attention_mask, + transformer_options=transformer_options, + ) + + del ashift_msa, ascale_msa, agate_msa + + # Audio - Video cross attention. + if run_a2v or run_v2a: + # norm3 + vx_norm3 = comfy.ldm.common_dit.rms_norm(vx) + ax_norm3 = comfy.ldm.common_dit.rms_norm(ax) + + ( + scale_ca_audio_hidden_states_a2v, + shift_ca_audio_hidden_states_a2v, + scale_ca_audio_hidden_states_v2a, + shift_ca_audio_hidden_states_v2a, + gate_out_v2a, + ) = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_audio, + ax.shape[0], + a_cross_scale_shift_timestep, + a_cross_gate_timestep, + ) + + ( + scale_ca_video_hidden_states_a2v, + shift_ca_video_hidden_states_a2v, + scale_ca_video_hidden_states_v2a, + shift_ca_video_hidden_states_v2a, + gate_out_a2v, + ) = self.get_av_ca_ada_values( + self.scale_shift_table_a2v_ca_video, + vx.shape[0], + v_cross_scale_shift_timestep, + v_cross_gate_timestep, + ) + + if run_a2v: + vx_scaled = ( + vx_norm3 * (1 + scale_ca_video_hidden_states_a2v) + + shift_ca_video_hidden_states_a2v + ) + ax_scaled = ( + ax_norm3 * (1 + scale_ca_audio_hidden_states_a2v) + + shift_ca_audio_hidden_states_a2v + ) + vx += ( + self.audio_to_video_attn( + vx_scaled, + context=ax_scaled, + pe=v_cross_pe, + k_pe=a_cross_pe, + transformer_options=transformer_options, + ) + * gate_out_a2v + ) + + del gate_out_a2v + del scale_ca_video_hidden_states_a2v,\ + shift_ca_video_hidden_states_a2v,\ + scale_ca_audio_hidden_states_a2v,\ + shift_ca_audio_hidden_states_a2v,\ + + if run_v2a: + ax_scaled = ( + ax_norm3 * (1 + scale_ca_audio_hidden_states_v2a) + + shift_ca_audio_hidden_states_v2a + ) + vx_scaled = ( + vx_norm3 * (1 + scale_ca_video_hidden_states_v2a) + + shift_ca_video_hidden_states_v2a + ) + ax += ( + self.video_to_audio_attn( + ax_scaled, + context=vx_scaled, + pe=a_cross_pe, + k_pe=v_cross_pe, + transformer_options=transformer_options, + ) + * gate_out_v2a + ) + + del gate_out_v2a + del scale_ca_video_hidden_states_v2a,\ + shift_ca_video_hidden_states_v2a,\ + scale_ca_audio_hidden_states_v2a,\ + shift_ca_audio_hidden_states_v2a + + if run_vx: + vshift_mlp, vscale_mlp, vgate_mlp = ( + self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(3, None)) + ) + + vx_scaled = comfy.ldm.common_dit.rms_norm(vx) * (1 + vscale_mlp) + vshift_mlp + vx += self.ff(vx_scaled) * vgate_mlp + del vshift_mlp, vscale_mlp, vgate_mlp + + if run_ax: + ashift_mlp, ascale_mlp, agate_mlp = ( + self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(3, None)) + ) + + ax_scaled = comfy.ldm.common_dit.rms_norm(ax) * (1 + ascale_mlp) + ashift_mlp + ax += self.audio_ff(ax_scaled) * agate_mlp + + del ashift_mlp, ascale_mlp, agate_mlp + + + return vx, ax + + +class LTXAVModel(LTXVModel): + """LTXAV model for audio-video generation.""" + + def __init__( + self, + in_channels=128, + audio_in_channels=128, + cross_attention_dim=4096, + audio_cross_attention_dim=2048, + attention_head_dim=128, + audio_attention_head_dim=64, + num_attention_heads=32, + audio_num_attention_heads=32, + caption_channels=3840, + num_layers=48, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + audio_positional_embedding_max_pos=[20], + causal_temporal_positioning=False, + vae_scale_factors=(8, 32, 32), + use_middle_indices_grid=False, + timestep_scale_multiplier=1000.0, + av_ca_timestep_scale_multiplier=1.0, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + # Store audio-specific parameters + self.audio_in_channels = audio_in_channels + self.audio_cross_attention_dim = audio_cross_attention_dim + self.audio_attention_head_dim = audio_attention_head_dim + self.audio_num_attention_heads = audio_num_attention_heads + self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos + + # Calculate audio dimensions + self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim + self.audio_out_channels = audio_in_channels + + # Audio-specific constants + self.num_audio_channels = 8 + self.audio_frequency_bins = 16 + + self.av_ca_timestep_scale_multiplier = av_ca_timestep_scale_multiplier + + super().__init__( + in_channels=in_channels, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + caption_channels=caption_channels, + num_layers=num_layers, + positional_embedding_theta=positional_embedding_theta, + positional_embedding_max_pos=positional_embedding_max_pos, + causal_temporal_positioning=causal_temporal_positioning, + vae_scale_factors=vae_scale_factors, + use_middle_indices_grid=use_middle_indices_grid, + timestep_scale_multiplier=timestep_scale_multiplier, + dtype=dtype, + device=device, + operations=operations, + **kwargs, + ) + + def _init_model_components(self, device, dtype, **kwargs): + """Initialize LTXAV-specific components.""" + # Audio-specific projections + self.audio_patchify_proj = self.operations.Linear( + self.audio_in_channels, self.audio_inner_dim, bias=True, dtype=dtype, device=device + ) + + # Audio-specific AdaLN + self.audio_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + use_additional_conditions=False, + dtype=dtype, + device=device, + operations=self.operations, + ) + + num_scale_shift_values = 4 + self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle( + self.inner_dim, + use_additional_conditions=False, + embedding_coefficient=num_scale_shift_values, + dtype=dtype, + device=device, + operations=self.operations, + ) + self.av_ca_a2v_gate_adaln_single = AdaLayerNormSingle( + self.inner_dim, + use_additional_conditions=False, + embedding_coefficient=1, + dtype=dtype, + device=device, + operations=self.operations, + ) + self.av_ca_audio_scale_shift_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + use_additional_conditions=False, + embedding_coefficient=num_scale_shift_values, + dtype=dtype, + device=device, + operations=self.operations, + ) + self.av_ca_v2a_gate_adaln_single = AdaLayerNormSingle( + self.audio_inner_dim, + use_additional_conditions=False, + embedding_coefficient=1, + dtype=dtype, + device=device, + operations=self.operations, + ) + + # Audio caption projection + self.audio_caption_projection = PixArtAlphaTextProjection( + in_features=self.caption_channels, + hidden_size=self.audio_inner_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + + def _init_transformer_blocks(self, device, dtype, **kwargs): + """Initialize transformer blocks for LTXAV.""" + self.transformer_blocks = nn.ModuleList( + [ + BasicAVTransformerBlock( + v_dim=self.inner_dim, + a_dim=self.audio_inner_dim, + v_heads=self.num_attention_heads, + a_heads=self.audio_num_attention_heads, + vd_head=self.attention_head_dim, + ad_head=self.audio_attention_head_dim, + v_context_dim=self.cross_attention_dim, + a_context_dim=self.audio_cross_attention_dim, + dtype=dtype, + device=device, + operations=self.operations, + ) + for _ in range(self.num_layers) + ] + ) + + def _init_output_components(self, device, dtype): + """Initialize output components for LTXAV.""" + # Video output components + super()._init_output_components(device, dtype) + # Audio output components + self.audio_scale_shift_table = nn.Parameter( + torch.empty(2, self.audio_inner_dim, dtype=dtype, device=device) + ) + self.audio_norm_out = self.operations.LayerNorm( + self.audio_inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device + ) + self.audio_proj_out = self.operations.Linear( + self.audio_inner_dim, self.audio_out_channels, dtype=dtype, device=device + ) + self.a_patchifier = AudioPatchifier(1, start_end=True) + + def separate_audio_and_video_latents(self, x, audio_length): + """Separate audio and video latents from combined input.""" + # vx = x[:, : self.in_channels] + # ax = x[:, self.in_channels :] + # + # ax = ax.reshape(ax.shape[0], -1) + # ax = ax[:, : audio_length * self.num_audio_channels * self.audio_frequency_bins] + # + # ax = ax.reshape( + # ax.shape[0], self.num_audio_channels, audio_length, self.audio_frequency_bins + # ) + + vx = x[0] + ax = x[1] if len(x) > 1 else torch.zeros( + (vx.shape[0], self.num_audio_channels, 0, self.audio_frequency_bins), + device=vx.device, dtype=vx.dtype + ) + return vx, ax + + def recombine_audio_and_video_latents(self, vx, ax, target_shape=None): + if ax.numel() == 0: + return vx + else: + return [vx, ax] + """Recombine audio and video latents for output.""" + # if ax.device != vx.device or ax.dtype != vx.dtype: + # logging.warning("Audio and video latents are on different devices or dtypes.") + # ax = ax.to(device=vx.device, dtype=vx.dtype) + # logging.warning(f"Audio audio latent moved to device: {ax.device}, dtype: {ax.dtype}") + # + # ax = ax.reshape(ax.shape[0], -1) + # # pad to f x h x w of the video latents + # divisor = vx.shape[-1] * vx.shape[-2] * vx.shape[-3] + # if target_shape is None: + # repetitions = math.ceil(ax.shape[-1] / divisor) + # else: + # repetitions = target_shape[1] - vx.shape[1] + # padded_len = repetitions * divisor + # ax = F.pad(ax, (0, padded_len - ax.shape[-1])) + # ax = ax.reshape(ax.shape[0], -1, vx.shape[-3], vx.shape[-2], vx.shape[-1]) + # return torch.cat([vx, ax], dim=1) + + def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): + """Process input for LTXAV - separate audio and video, then patchify.""" + audio_length = kwargs.get("audio_length", 0) + # Separate audio and video latents + vx, ax = self.separate_audio_and_video_latents(x, audio_length) + [vx, v_pixel_coords, additional_args] = super()._process_input( + vx, keyframe_idxs, denoise_mask, **kwargs + ) + + ax, a_latent_coords = self.a_patchifier.patchify(ax) + ax = self.audio_patchify_proj(ax) + + # additional_args.update({"av_orig_shape": list(x.shape)}) + return [vx, ax], [v_pixel_coords, a_latent_coords], additional_args + + def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): + """Prepare timestep embeddings.""" + # TODO: some code reuse is needed here. + grid_mask = kwargs.get("grid_mask", None) + if grid_mask is not None: + timestep = timestep[:, grid_mask] + + timestep = timestep * self.timestep_scale_multiplier + v_timestep, v_embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + + # Second dimension is 1 or number of tokens (if timestep_per_token) + v_timestep = v_timestep.view(batch_size, -1, v_timestep.shape[-1]) + v_embedded_timestep = v_embedded_timestep.view( + batch_size, -1, v_embedded_timestep.shape[-1] + ) + + # Prepare audio timestep + a_timestep = kwargs.get("a_timestep") + if a_timestep is not None: + a_timestep = a_timestep * self.timestep_scale_multiplier + av_ca_factor = self.av_ca_timestep_scale_multiplier / self.timestep_scale_multiplier + + av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single( + a_timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single( + timestep.flatten() * av_ca_factor, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single( + a_timestep.flatten() * av_ca_factor, + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + + a_timestep, a_embedded_timestep = self.audio_adaln_single( + a_timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1]) + a_embedded_timestep = a_embedded_timestep.view( + batch_size, -1, a_embedded_timestep.shape[-1] + ) + cross_av_timestep_ss = [ + av_ca_audio_scale_shift_timestep, + av_ca_video_scale_shift_timestep, + av_ca_a2v_gate_noise_timestep, + av_ca_v2a_gate_noise_timestep, + ] + cross_av_timestep_ss = list( + [t.view(batch_size, -1, t.shape[-1]) for t in cross_av_timestep_ss] + ) + else: + a_timestep = timestep + a_embedded_timestep = kwargs.get("embedded_timestep") + cross_av_timestep_ss = [] + + return [v_timestep, a_timestep, cross_av_timestep_ss], [ + v_embedded_timestep, + a_embedded_timestep, + ] + + def _prepare_context(self, context, batch_size, x, attention_mask=None): + vx = x[0] + ax = x[1] + v_context, a_context = torch.split( + context, int(context.shape[-1] / 2), len(context.shape) - 1 + ) + + v_context, attention_mask = super()._prepare_context( + v_context, batch_size, vx, attention_mask + ) + if self.audio_caption_projection is not None: + a_context = self.audio_caption_projection(a_context) + a_context = a_context.view(batch_size, -1, ax.shape[-1]) + + return [v_context, a_context], attention_mask + + def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype): + v_pixel_coords = pixel_coords[0] + v_pe = super()._prepare_positional_embeddings(v_pixel_coords, frame_rate, x_dtype) + + a_latent_coords = pixel_coords[1] + a_pe = self._precompute_freqs_cis( + a_latent_coords, + dim=self.audio_inner_dim, + out_dtype=x_dtype, + max_pos=self.audio_positional_embedding_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.audio_num_attention_heads, + ) + + # calculate positional embeddings for the middle of the token duration, to use in av cross attention layers. + max_pos = max( + self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0] + ) + v_pixel_coords = v_pixel_coords.to(torch.float32) + v_pixel_coords[:, 0] = v_pixel_coords[:, 0] * (1.0 / frame_rate) + av_cross_video_freq_cis = self._precompute_freqs_cis( + v_pixel_coords[:, 0:1, :], + dim=self.audio_cross_attention_dim, + out_dtype=x_dtype, + max_pos=[max_pos], + use_middle_indices_grid=True, + num_attention_heads=self.audio_num_attention_heads, + ) + av_cross_audio_freq_cis = self._precompute_freqs_cis( + a_latent_coords[:, 0:1, :], + dim=self.audio_cross_attention_dim, + out_dtype=x_dtype, + max_pos=[max_pos], + use_middle_indices_grid=True, + num_attention_heads=self.audio_num_attention_heads, + ) + + return [(v_pe, av_cross_video_freq_cis), (a_pe, av_cross_audio_freq_cis)] + + def _process_transformer_blocks( + self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs + ): + vx = x[0] + ax = x[1] + v_context = context[0] + a_context = context[1] + v_timestep = timestep[0] + a_timestep = timestep[1] + v_pe, av_cross_video_freq_cis = pe[0] + a_pe, av_cross_audio_freq_cis = pe[1] + + ( + av_ca_audio_scale_shift_timestep, + av_ca_video_scale_shift_timestep, + av_ca_a2v_gate_noise_timestep, + av_ca_v2a_gate_noise_timestep, + ) = timestep[2] + + """Process transformer blocks for LTXAV.""" + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + + # Process transformer blocks + for i, block in enumerate(self.transformer_blocks): + if ("double_block", i) in blocks_replace: + + def block_wrap(args): + out = {} + out["img"] = block( + args["img"], + v_context=args["v_context"], + a_context=args["a_context"], + attention_mask=args["attention_mask"], + v_timestep=args["v_timestep"], + a_timestep=args["a_timestep"], + v_pe=args["v_pe"], + a_pe=args["a_pe"], + v_cross_pe=args["v_cross_pe"], + a_cross_pe=args["a_cross_pe"], + v_cross_scale_shift_timestep=args["v_cross_scale_shift_timestep"], + a_cross_scale_shift_timestep=args["a_cross_scale_shift_timestep"], + v_cross_gate_timestep=args["v_cross_gate_timestep"], + a_cross_gate_timestep=args["a_cross_gate_timestep"], + transformer_options=args["transformer_options"], + ) + return out + + out = blocks_replace[("double_block", i)]( + { + "img": (vx, ax), + "v_context": v_context, + "a_context": a_context, + "attention_mask": attention_mask, + "v_timestep": v_timestep, + "a_timestep": a_timestep, + "v_pe": v_pe, + "a_pe": a_pe, + "v_cross_pe": av_cross_video_freq_cis, + "a_cross_pe": av_cross_audio_freq_cis, + "v_cross_scale_shift_timestep": av_ca_video_scale_shift_timestep, + "a_cross_scale_shift_timestep": av_ca_audio_scale_shift_timestep, + "v_cross_gate_timestep": av_ca_a2v_gate_noise_timestep, + "a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep, + "transformer_options": transformer_options, + }, + {"original_block": block_wrap}, + ) + vx, ax = out["img"] + else: + vx, ax = block( + (vx, ax), + v_context=v_context, + a_context=a_context, + attention_mask=attention_mask, + v_timestep=v_timestep, + a_timestep=a_timestep, + v_pe=v_pe, + a_pe=a_pe, + v_cross_pe=av_cross_video_freq_cis, + a_cross_pe=av_cross_audio_freq_cis, + v_cross_scale_shift_timestep=av_ca_video_scale_shift_timestep, + a_cross_scale_shift_timestep=av_ca_audio_scale_shift_timestep, + v_cross_gate_timestep=av_ca_a2v_gate_noise_timestep, + a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep, + transformer_options=transformer_options, + ) + + return [vx, ax] + + def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): + vx = x[0] + ax = x[1] + v_embedded_timestep = embedded_timestep[0] + a_embedded_timestep = embedded_timestep[1] + vx = super()._process_output(vx, v_embedded_timestep, keyframe_idxs, **kwargs) + + # Process audio output + a_scale_shift_values = ( + self.audio_scale_shift_table[None, None].to(device=a_embedded_timestep.device, dtype=a_embedded_timestep.dtype) + + a_embedded_timestep[:, :, None] + ) + a_shift, a_scale = a_scale_shift_values[:, :, 0], a_scale_shift_values[:, :, 1] + + ax = self.audio_norm_out(ax) + ax = ax * (1 + a_scale) + a_shift + ax = self.audio_proj_out(ax) + + # Unpatchify audio + ax = self.a_patchifier.unpatchify( + ax, channels=self.num_audio_channels, freq=self.audio_frequency_bins + ) + + # Recombine audio and video + original_shape = kwargs.get("av_orig_shape") + return self.recombine_audio_and_video_latents(vx, ax, original_shape) + + def forward( + self, + x, + timestep, + context, + attention_mask=None, + frame_rate=25, + transformer_options={}, + keyframe_idxs=None, + **kwargs, + ): + """ + Forward pass for LTXAV model. + + Args: + x: Combined audio-video input tensor + timestep: Tuple of (video_timestep, audio_timestep) or single timestep + context: Context tensor (e.g., text embeddings) + attention_mask: Attention mask tensor + frame_rate: Frame rate for temporal processing + transformer_options: Additional options for transformer blocks + keyframe_idxs: Keyframe indices for temporal processing + **kwargs: Additional keyword arguments including audio_length + + Returns: + Combined audio-video output tensor + """ + # Handle timestep format + if isinstance(timestep, (tuple, list)) and len(timestep) == 2: + v_timestep, a_timestep = timestep + kwargs["a_timestep"] = a_timestep + timestep = v_timestep + else: + kwargs["a_timestep"] = timestep + + # Call parent forward method + return super().forward( + x, + timestep, + context, + attention_mask, + frame_rate, + transformer_options, + keyframe_idxs, + **kwargs, + ) diff --git a/comfy/ldm/lightricks/embeddings_connector.py b/comfy/ldm/lightricks/embeddings_connector.py new file mode 100644 index 000000000..f7a43f3c3 --- /dev/null +++ b/comfy/ldm/lightricks/embeddings_connector.py @@ -0,0 +1,305 @@ +import math +from typing import Optional + +import comfy.ldm.common_dit +import torch +from comfy.ldm.lightricks.model import ( + CrossAttention, + FeedForward, + generate_freq_grid_np, + interleaved_freqs_cis, + split_freqs_cis, +) +from torch import nn + + +class BasicTransformerBlock1D(nn.Module): + r""" + A basic Transformer block. + + Parameters: + + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`. + norm_eps (`float`, *optional*, defaults to 1e-5): Epsilon value for normalization layers. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + ff_inner_dim (`int`, *optional*): Dimension of the inner feed-forward layer. If not provided, defaults to `dim * 4`. + ff_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the feed-forward layer. + attention_out_bias (`bool`, *optional*, defaults to `True`): Whether to use bias in the attention output layer. + use_rope (`bool`, *optional*, defaults to `False`): Whether to use Rotary Position Embeddings (RoPE). + ffn_dim_mult (`int`, *optional*, defaults to 4): Multiplier for the inner dimension of the feed-forward layer. + """ + + def __init__( + self, + dim, + n_heads, + d_head, + context_dim=None, + attn_precision=None, + dtype=None, + device=None, + operations=None, + ): + super().__init__() + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + context_dim=None, + dtype=dtype, + device=device, + operations=operations, + ) + + # 3. Feed-forward + self.ff = FeedForward( + dim, + dim_out=dim, + glu=True, + dtype=dtype, + device=device, + operations=operations, + ) + + def forward(self, hidden_states, attention_mask=None, pe=None) -> torch.FloatTensor: + + # Notice that normalization is always applied before the real computation in the following blocks. + + # 1. Normalization Before Self-Attention + norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) + + norm_hidden_states = norm_hidden_states.squeeze(1) + + # 2. Self-Attention + attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe) + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 3. Normalization before Feed-Forward + norm_hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) + + # 4. Feed-forward + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +class Embeddings1DConnector(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + in_channels=128, + cross_attention_dim=2048, + attention_head_dim=128, + num_attention_heads=30, + num_layers=2, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[4096], + causal_temporal_positioning=False, + num_learnable_registers: Optional[int] = 128, + dtype=None, + device=None, + operations=None, + split_rope=False, + double_precision_rope=False, + **kwargs, + ): + super().__init__() + self.dtype = dtype + self.out_channels = in_channels + self.num_attention_heads = num_attention_heads + self.inner_dim = num_attention_heads * attention_head_dim + self.causal_temporal_positioning = causal_temporal_positioning + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.split_rope = split_rope + self.double_precision_rope = double_precision_rope + self.transformer_1d_blocks = nn.ModuleList( + [ + BasicTransformerBlock1D( + self.inner_dim, + num_attention_heads, + attention_head_dim, + context_dim=cross_attention_dim, + dtype=dtype, + device=device, + operations=operations, + ) + for _ in range(num_layers) + ] + ) + + inner_dim = num_attention_heads * attention_head_dim + self.num_learnable_registers = num_learnable_registers + if self.num_learnable_registers: + self.learnable_registers = nn.Parameter( + torch.rand( + self.num_learnable_registers, inner_dim, dtype=dtype, device=device + ) + * 2.0 + - 1.0 + ) + + def get_fractional_positions(self, indices_grid): + fractional_positions = torch.stack( + [ + indices_grid[:, i] / self.positional_embedding_max_pos[i] + for i in range(1) + ], + dim=-1, + ) + return fractional_positions + + def precompute_freqs(self, indices_grid, spacing): + source_dtype = indices_grid.dtype + dtype = ( + torch.float32 + if source_dtype in (torch.bfloat16, torch.float16) + else source_dtype + ) + + fractional_positions = self.get_fractional_positions(indices_grid) + indices = ( + generate_freq_grid_np( + self.positional_embedding_theta, + indices_grid.shape[1], + self.inner_dim, + ) + if self.double_precision_rope + else self.generate_freq_grid(spacing, dtype, fractional_positions.device) + ).to(device=fractional_positions.device) + + if spacing == "exp_2": + freqs = ( + (indices * fractional_positions.unsqueeze(-1)) + .transpose(-1, -2) + .flatten(2) + ) + else: + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + return freqs + + def generate_freq_grid(self, spacing, dtype, device): + dim = self.inner_dim + theta = self.positional_embedding_theta + n_pos_dims = 1 + n_elem = 2 * n_pos_dims # 2 for cos and sin e.g. x 3 = 6 + start = 1 + end = theta + + if spacing == "exp": + indices = theta ** (torch.arange(0, dim, n_elem, device="cpu", dtype=torch.float32) / (dim - n_elem)) + indices = indices.to(dtype=dtype, device=device) + elif spacing == "exp_2": + indices = 1.0 / theta ** (torch.arange(0, dim, n_elem, device=device) / dim) + indices = indices.to(dtype=dtype) + elif spacing == "linear": + indices = torch.linspace( + start, end, dim // n_elem, device=device, dtype=dtype + ) + elif spacing == "sqrt": + indices = torch.linspace( + start**2, end**2, dim // n_elem, device=device, dtype=dtype + ).sqrt() + + indices = indices * math.pi / 2 + + return indices + + def precompute_freqs_cis(self, indices_grid, spacing="exp"): + dim = self.inner_dim + n_elem = 2 # 2 because of cos and sin + freqs = self.precompute_freqs(indices_grid, spacing) + if self.split_rope: + expected_freqs = dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq, sin_freq = split_freqs_cis( + freqs, pad_size, self.num_attention_heads + ) + else: + cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) + return cos_freq.to(self.dtype), sin_freq.to(self.dtype), self.split_rope + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`): + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # 1. Input + + if self.num_learnable_registers: + num_registers_duplications = math.ceil( + max(1024, hidden_states.shape[1]) / self.num_learnable_registers + ) + learnable_registers = torch.tile( + self.learnable_registers, (num_registers_duplications, 1) + ) + + hidden_states = torch.cat((hidden_states, learnable_registers[hidden_states.shape[1]:].unsqueeze(0).repeat(hidden_states.shape[0], 1, 1)), dim=1) + + if attention_mask is not None: + attention_mask = torch.zeros([1, 1, 1, hidden_states.shape[1]], dtype=attention_mask.dtype, device=attention_mask.device) + + indices_grid = torch.arange( + hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device + ) + indices_grid = indices_grid[None, None, :] + freqs_cis = self.precompute_freqs_cis(indices_grid) + + # 2. Blocks + for block_idx, block in enumerate(self.transformer_1d_blocks): + hidden_states = block( + hidden_states, attention_mask=attention_mask, pe=freqs_cis + ) + + # 3. Output + # if self.output_scale is not None: + # hidden_states = hidden_states / self.output_scale + + hidden_states = comfy.ldm.common_dit.rms_norm(hidden_states) + + return hidden_states, attention_mask diff --git a/comfy/ldm/lightricks/latent_upsampler.py b/comfy/ldm/lightricks/latent_upsampler.py new file mode 100644 index 000000000..78ed7653f --- /dev/null +++ b/comfy/ldm/lightricks/latent_upsampler.py @@ -0,0 +1,292 @@ +from typing import Optional, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +def _rational_for_scale(scale: float) -> Tuple[int, int]: + mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)} + if float(scale) not in mapping: + raise ValueError( + f"Unsupported spatial_scale {scale}. Choose from {list(mapping.keys())}" + ) + return mapping[float(scale)] + + +class PixelShuffleND(nn.Module): + def __init__(self, dims, upscale_factors=(2, 2, 2)): + super().__init__() + assert dims in [1, 2, 3], "dims must be 1, 2, or 3" + self.dims = dims + self.upscale_factors = upscale_factors + + def forward(self, x): + if self.dims == 3: + return rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + p3=self.upscale_factors[2], + ) + elif self.dims == 2: + return rearrange( + x, + "b (c p1 p2) h w -> b c (h p1) (w p2)", + p1=self.upscale_factors[0], + p2=self.upscale_factors[1], + ) + elif self.dims == 1: + return rearrange( + x, + "b (c p1) f h w -> b c (f p1) h w", + p1=self.upscale_factors[0], + ) + + +class BlurDownsample(nn.Module): + """ + Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. + Applies only on H,W. Works for dims=2 or dims=3 (per-frame). + """ + + def __init__(self, dims: int, stride: int): + super().__init__() + assert dims in (2, 3) + assert stride >= 1 and isinstance(stride, int) + self.dims = dims + self.stride = stride + + # 5x5 separable binomial kernel [1,4,6,4,1] (outer product), normalized + k = torch.tensor([1.0, 4.0, 6.0, 4.0, 1.0]) + k2d = k[:, None] @ k[None, :] + k2d = (k2d / k2d.sum()).float() # shape (5,5) + self.register_buffer("kernel", k2d[None, None, :, :]) # (1,1,5,5) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.stride == 1: + return x + + def _apply_2d(x2d: torch.Tensor) -> torch.Tensor: + # x2d: (B, C, H, W) + B, C, H, W = x2d.shape + weight = self.kernel.expand(C, 1, 5, 5) # depthwise + x2d = F.conv2d( + x2d, weight=weight, bias=None, stride=self.stride, padding=2, groups=C + ) + return x2d + + if self.dims == 2: + return _apply_2d(x) + else: + # dims == 3: apply per-frame on H,W + b, c, f, h, w = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = _apply_2d(x) + h2, w2 = x.shape[-2:] + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2) + return x + + +class SpatialRationalResampler(nn.Module): + """ + Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased + downsample by 'den' using fixed blur + stride. Operates on H,W only. + + For dims==3, work per-frame for spatial scaling (temporal axis untouched). + """ + + def __init__(self, mid_channels: int, scale: float): + super().__init__() + self.scale = float(scale) + self.num, self.den = _rational_for_scale(self.scale) + self.conv = nn.Conv2d( + mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1 + ) + self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num)) + self.blur_down = BlurDownsample(dims=2, stride=self.den) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, c, f, h, w = x.shape + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.conv(x) + x = self.pixel_shuffle(x) + x = self.blur_down(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + return x + + +class ResBlock(nn.Module): + def __init__( + self, channels: int, mid_channels: Optional[int] = None, dims: int = 3 + ): + super().__init__() + if mid_channels is None: + mid_channels = channels + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1) + self.norm1 = nn.GroupNorm(32, mid_channels) + self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1) + self.norm2 = nn.GroupNorm(32, channels) + self.activation = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + x = self.conv2(x) + x = self.norm2(x) + x = self.activation(x + residual) + return x + + +class LatentUpsampler(nn.Module): + """ + Model to spatially upsample VAE latents. + + Args: + in_channels (`int`): Number of channels in the input latent + mid_channels (`int`): Number of channels in the middle layers + num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling) + dims (`int`): Number of dimensions for convolutions (2 or 3) + spatial_upsample (`bool`): Whether to spatially upsample the latent + temporal_upsample (`bool`): Whether to temporally upsample the latent + """ + + def __init__( + self, + in_channels: int = 128, + mid_channels: int = 512, + num_blocks_per_stage: int = 4, + dims: int = 3, + spatial_upsample: bool = True, + temporal_upsample: bool = False, + spatial_scale: float = 2.0, + rational_resampler: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.mid_channels = mid_channels + self.num_blocks_per_stage = num_blocks_per_stage + self.dims = dims + self.spatial_upsample = spatial_upsample + self.temporal_upsample = temporal_upsample + self.spatial_scale = float(spatial_scale) + self.rational_resampler = rational_resampler + + Conv = nn.Conv2d if dims == 2 else nn.Conv3d + + self.initial_conv = Conv(in_channels, mid_channels, kernel_size=3, padding=1) + self.initial_norm = nn.GroupNorm(32, mid_channels) + self.initial_activation = nn.SiLU() + + self.res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + if spatial_upsample and temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(3), + ) + elif spatial_upsample: + if rational_resampler: + self.upsampler = SpatialRationalResampler( + mid_channels=mid_channels, scale=self.spatial_scale + ) + else: + self.upsampler = nn.Sequential( + nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(2), + ) + elif temporal_upsample: + self.upsampler = nn.Sequential( + nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1), + PixelShuffleND(1), + ) + else: + raise ValueError( + "Either spatial_upsample or temporal_upsample must be True" + ) + + self.post_upsample_res_blocks = nn.ModuleList( + [ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)] + ) + + self.final_conv = Conv(mid_channels, in_channels, kernel_size=3, padding=1) + + def forward(self, latent: torch.Tensor) -> torch.Tensor: + b, c, f, h, w = latent.shape + + if self.dims == 2: + x = rearrange(latent, "b c f h w -> (b f) c h w") + x = self.initial_conv(x) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + x = self.upsampler(x) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + else: + x = self.initial_conv(latent) + x = self.initial_norm(x) + x = self.initial_activation(x) + + for block in self.res_blocks: + x = block(x) + + if self.temporal_upsample: + x = self.upsampler(x) + x = x[:, :, 1:, :, :] + else: + if isinstance(self.upsampler, SpatialRationalResampler): + x = self.upsampler(x) + else: + x = rearrange(x, "b c f h w -> (b f) c h w") + x = self.upsampler(x) + x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f) + + for block in self.post_upsample_res_blocks: + x = block(x) + + x = self.final_conv(x) + + return x + + @classmethod + def from_config(cls, config): + return cls( + in_channels=config.get("in_channels", 4), + mid_channels=config.get("mid_channels", 128), + num_blocks_per_stage=config.get("num_blocks_per_stage", 4), + dims=config.get("dims", 2), + spatial_upsample=config.get("spatial_upsample", True), + temporal_upsample=config.get("temporal_upsample", False), + spatial_scale=config.get("spatial_scale", 2.0), + rational_resampler=config.get("rational_resampler", False), + ) + + def config(self): + return { + "_class_name": "LatentUpsampler", + "in_channels": self.in_channels, + "mid_channels": self.mid_channels, + "num_blocks_per_stage": self.num_blocks_per_stage, + "dims": self.dims, + "spatial_upsample": self.spatial_upsample, + "temporal_upsample": self.temporal_upsample, + "spatial_scale": self.spatial_scale, + "rational_resampler": self.rational_resampler, + } diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 593f7940f..d61e19d6e 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1,13 +1,47 @@ +from abc import ABC, abstractmethod +from enum import Enum +import functools +import math +from typing import Dict, Optional, Tuple + +from einops import rearrange +import numpy as np import torch from torch import nn import comfy.patcher_extension import comfy.ldm.modules.attention import comfy.ldm.common_dit -import math -from typing import Dict, Optional, Tuple from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords -from comfy.ldm.flux.math import apply_rope1 + +def _log_base(x, base): + return np.log(x) / np.log(base) + +class LTXRopeType(str, Enum): + INTERLEAVED = "interleaved" + SPLIT = "split" + + KEY = "rope_type" + + @classmethod + def from_dict(cls, kwargs, default=None): + if default is None: + default = cls.INTERLEAVED + return cls(kwargs.get(cls.KEY, default)) + + +class LTXFrequenciesPrecision(str, Enum): + FLOAT32 = "float32" + FLOAT64 = "float64" + + KEY = "frequencies_precision" + + @classmethod + def from_dict(cls, kwargs, default=None): + if default is None: + default = cls.FLOAT32 + return cls(kwargs.get(cls.KEY, default)) + def get_timestep_embedding( timesteps: torch.Tensor, @@ -39,9 +73,7 @@ def get_timestep_embedding( assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 - exponent = -math.log(max_period) * torch.arange( - start=0, end=half_dim, dtype=torch.float32, device=timesteps.device - ) + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) @@ -73,7 +105,9 @@ class TimestepEmbedding(nn.Module): post_act_fn: Optional[str] = None, cond_proj_dim=None, sample_proj_bias=True, - dtype=None, device=None, operations=None, + dtype=None, + device=None, + operations=None, ): super().__init__() @@ -90,7 +124,9 @@ class TimestepEmbedding(nn.Module): time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim - self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device) + self.linear_2 = operations.Linear( + time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device + ) if post_act_fn is None: self.post_act = None @@ -139,12 +175,22 @@ class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 """ - def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None): + def __init__( + self, + embedding_dim, + size_emb_dim, + use_additional_conditions: bool = False, + dtype=None, + device=None, + operations=None, + ): super().__init__() self.outdim = size_emb_dim self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations + ) def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): timesteps_proj = self.time_proj(timestep) @@ -163,15 +209,22 @@ class AdaLayerNormSingle(nn.Module): use_additional_conditions (`bool`): To use additional conditions for normalization or not. """ - def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None): + def __init__( + self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None + ): super().__init__() self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings( - embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations + embedding_dim, + size_emb_dim=embedding_dim // 3, + use_additional_conditions=use_additional_conditions, + dtype=dtype, + device=device, + operations=operations, ) self.silu = nn.SiLU() - self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device) + self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device) def forward( self, @@ -185,6 +238,7 @@ class AdaLayerNormSingle(nn.Module): embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) return self.linear(self.silu(embedded_timestep)), embedded_timestep + class PixArtAlphaTextProjection(nn.Module): """ Projects caption embeddings. Also handles dropout for classifier-free guidance. @@ -192,18 +246,24 @@ class PixArtAlphaTextProjection(nn.Module): Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py """ - def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None): + def __init__( + self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None + ): super().__init__() if out_features is None: out_features = hidden_size - self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device) + self.linear_1 = operations.Linear( + in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device + ) if act_fn == "gelu_tanh": self.act_1 = nn.GELU(approximate="tanh") elif act_fn == "silu": self.act_1 = nn.SiLU() else: raise ValueError(f"Unknown activation function: {act_fn}") - self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device) + self.linear_2 = operations.Linear( + in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device + ) def forward(self, caption): hidden_states = self.linear_1(caption) @@ -222,23 +282,68 @@ class GELU_approx(nn.Module): class FeedForward(nn.Module): - def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None): + def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None): super().__init__() inner_dim = int(dim * mult) project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations) self.net = nn.Sequential( - project_in, - nn.Dropout(dropout), - operations.Linear(inner_dim, dim_out, dtype=dtype, device=device) + project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device) ) def forward(self, x): return self.net(x) +def apply_rotary_emb(input_tensor, freqs_cis): + cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1] + split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False + return ( + apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs) + if split_pe else + apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs) + ) + +def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + +def apply_split_rotary_emb(input_tensor, cos, sin): + needs_reshape = False + if input_tensor.ndim != 4 and cos.ndim == 4: + B, H, T, _ = cos.shape + input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2) + needs_reshape = True + split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2) + first_half_input = split_input[..., :1, :] + second_half_input = split_input[..., 1:, :] + output = split_input * cos.unsqueeze(-2) + first_half_output = output[..., :1, :] + second_half_output = output[..., 1:, :] + first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input) + second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input) + output = rearrange(output, "... d r -> ... (d r)") + return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output + class CrossAttention(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None): + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + attn_precision=None, + dtype=None, + device=None, + operations=None, + ): super().__init__() inner_dim = dim_head * heads context_dim = query_dim if context_dim is None else context_dim @@ -254,9 +359,11 @@ class CrossAttention(nn.Module): self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device) - self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) + self.to_out = nn.Sequential( + operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout) + ) - def forward(self, x, context=None, mask=None, pe=None, transformer_options={}): + def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}): q = self.to_q(x) context = x if context is None else context k = self.to_k(context) @@ -266,8 +373,8 @@ class CrossAttention(nn.Module): k = self.k_norm(k) if pe is not None: - q = apply_rope1(q.unsqueeze(1), pe).squeeze(1) - k = apply_rope1(k.unsqueeze(1), pe).squeeze(1) + q = apply_rotary_emb(q, pe) + k = apply_rotary_emb(k, pe if k_pe is None else k_pe) if mask is None: out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) @@ -277,14 +384,34 @@ class CrossAttention(nn.Module): class BasicTransformerBlock(nn.Module): - def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None): + def __init__( + self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None + ): super().__init__() self.attn_precision = attn_precision - self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) + self.attn1 = CrossAttention( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + context_dim=None, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations) - self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + attn_precision=self.attn_precision, + dtype=dtype, + device=device, + operations=operations, + ) self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) @@ -306,116 +433,446 @@ class BasicTransformerBlock(nn.Module): return x def get_fractional_positions(indices_grid, max_pos): + n_pos_dims = indices_grid.shape[1] + assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})' fractional_positions = torch.stack( - [ - indices_grid[:, i] / max_pos[i] - for i in range(3) - ], - dim=-1, + [indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)], + axis=-1, ) return fractional_positions -def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]): - dtype = torch.float32 - device = indices_grid.device +@functools.lru_cache(maxsize=5) +def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None): + theta = positional_embedding_theta + start = 1 + end = theta + + n_elem = 2 * positional_embedding_max_pos_count + pow_indices = np.power( + theta, + np.linspace( + _log_base(start, theta), + _log_base(end, theta), + inner_dim // n_elem, + dtype=np.float64, + ), + ) + return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32) + +def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device): + theta = positional_embedding_theta + start = 1 + end = theta + n_elem = 2 * positional_embedding_max_pos_count + + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + inner_dim // n_elem, + device=device, + dtype=torch.float32, + ) + ) + indices = indices.to(dtype=torch.float32) + + indices = indices * math.pi / 2 + + return indices + +def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid): + if use_middle_indices_grid: + assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2) + indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1] + indices_grid = (indices_grid_start + indices_grid_end) / 2.0 + elif len(indices_grid.shape) == 4: + indices_grid = indices_grid[..., 0] # Get fractional positions and compute frequency indices fractional_positions = get_fractional_positions(indices_grid, max_pos) - indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2 + indices = indices.to(device=fractional_positions.device) - # Compute frequencies and apply cos/sin - freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) - cos_vals = freqs.cos().repeat_interleave(2, dim=-1) - sin_vals = freqs.sin().repeat_interleave(2, dim=-1) + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + return freqs - # Pad if dim is not divisible by 6 - if dim % 6 != 0: - padding_size = dim % 6 - cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1) - sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1) +def interleaved_freqs_cis(freqs, pad_size): + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, : pad_size]) + sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq, sin_freq - # Reshape and extract one value per pair (since repeat_interleave duplicates each value) - cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] - sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] +def split_freqs_cis(freqs, pad_size, num_attention_heads): + cos_freq = freqs.cos() + sin_freq = freqs.sin() - # Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension - freqs_cis = torch.stack([ - torch.stack([cos_vals, -sin_vals], dim=-1), - torch.stack([sin_vals, cos_vals], dim=-1) - ], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2] + if pad_size != 0: + cos_padding = torch.ones_like(cos_freq[:, :, :pad_size]) + sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size]) - return freqs_cis + cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1) + sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1) + # Reshape freqs to be compatible with multi-head attention + B , T, half_HD = cos_freq.shape -class LTXVModel(torch.nn.Module): - def __init__(self, - in_channels=128, - cross_attention_dim=2048, - attention_head_dim=64, - num_attention_heads=32, + cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads) + sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads) - caption_channels=4096, - num_layers=28, + cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2) + sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2) + return cos_freq, sin_freq +class LTXBaseModel(torch.nn.Module, ABC): + """ + Abstract base class for LTX models (Lightricks Transformer models). - positional_embedding_theta=10000.0, - positional_embedding_max_pos=[20, 2048, 2048], - causal_temporal_positioning=False, - vae_scale_factors=(8, 32, 32), - dtype=None, device=None, operations=None, **kwargs): + This class defines the common interface and shared functionality for all LTX models, + including LTXV (video) and LTXAV (audio-video) variants. + """ + + def __init__( + self, + in_channels: int, + cross_attention_dim: int, + attention_head_dim: int, + num_attention_heads: int, + caption_channels: int, + num_layers: int, + positional_embedding_theta: float = 10000.0, + positional_embedding_max_pos: list = [20, 2048, 2048], + causal_temporal_positioning: bool = False, + vae_scale_factors: tuple = (8, 32, 32), + use_middle_indices_grid=False, + timestep_scale_multiplier = 1000.0, + dtype=None, + device=None, + operations=None, + **kwargs, + ): super().__init__() self.generator = None self.vae_scale_factors = vae_scale_factors + self.use_middle_indices_grid = use_middle_indices_grid self.dtype = dtype - self.out_channels = in_channels - self.inner_dim = num_attention_heads * attention_head_dim + self.in_channels = in_channels + self.cross_attention_dim = cross_attention_dim + self.attention_head_dim = attention_head_dim + self.num_attention_heads = num_attention_heads + self.caption_channels = caption_channels + self.num_layers = num_layers + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.split_positional_embedding = LTXRopeType.from_dict(kwargs) + self.freq_grid_generator = ( + generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64 + else generate_freq_grid_pytorch + ) self.causal_temporal_positioning = causal_temporal_positioning + self.operations = operations + self.timestep_scale_multiplier = timestep_scale_multiplier - self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device) + # Common dimensions + self.inner_dim = num_attention_heads * attention_head_dim + self.out_channels = in_channels + + # Initialize common components + self._init_common_components(device, dtype) + + # Initialize model-specific components + self._init_model_components(device, dtype, **kwargs) + + # Initialize transformer blocks + self._init_transformer_blocks(device, dtype, **kwargs) + + # Initialize output components + self._init_output_components(device, dtype) + + def _init_common_components(self, device, dtype): + """Initialize components common to all LTX models + - patchify_proj: Linear projection for patchifying input + - adaln_single: AdaLN layer for timestep embedding + - caption_projection: Linear projection for caption embedding + """ + self.patchify_proj = self.operations.Linear( + self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device + ) self.adaln_single = AdaLayerNormSingle( - self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations + self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations ) - # self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device) - self.caption_projection = PixArtAlphaTextProjection( - in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations + in_features=self.caption_channels, + hidden_size=self.inner_dim, + dtype=dtype, + device=device, + operations=self.operations, ) + @abstractmethod + def _init_model_components(self, device, dtype, **kwargs): + """Initialize model-specific components. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _init_transformer_blocks(self, device, dtype, **kwargs): + """Initialize transformer blocks. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _init_output_components(self, device, dtype): + """Initialize output components. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): + """Process input data. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs): + """Process transformer blocks. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): + """Process output data. Must be implemented by subclasses.""" + pass + + def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): + """Prepare timestep embeddings.""" + grid_mask = kwargs.get("grid_mask", None) + if grid_mask is not None: + timestep = timestep[:, grid_mask] + + timestep = timestep * self.timestep_scale_multiplier + timestep, embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_dtype, + ) + + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1]) + + return timestep, embedded_timestep + + def _prepare_context(self, context, batch_size, x, attention_mask=None): + """Prepare context for transformer blocks.""" + if self.caption_projection is not None: + context = self.caption_projection(context) + context = context.view(batch_size, -1, x.shape[-1]) + + return context, attention_mask + + def _precompute_freqs_cis( + self, + indices_grid, + dim, + out_dtype, + theta=10000.0, + max_pos=[20, 2048, 2048], + use_middle_indices_grid=False, + num_attention_heads=32, + ): + split_mode = self.split_positional_embedding == LTXRopeType.SPLIT + indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device) + freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid) + + if split_mode: + expected_freqs = dim // 2 + current_freqs = freqs.shape[-1] + pad_size = expected_freqs - current_freqs + cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads) + else: + # 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only + n_elem = 2 * indices_grid.shape[1] + cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem) + return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode + + def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype): + """Prepare positional embeddings.""" + fractional_coords = pixel_coords.to(torch.float32) + fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + pe = self._precompute_freqs_cis( + fractional_coords, + dim=self.inner_dim, + out_dtype=x_dtype, + max_pos=self.positional_embedding_max_pos, + use_middle_indices_grid=self.use_middle_indices_grid, + num_attention_heads=self.num_attention_heads, + ) + return pe + + def _prepare_attention_mask(self, attention_mask, x_dtype): + """Prepare attention mask.""" + if attention_mask is not None and not torch.is_floating_point(attention_mask): + attention_mask = (attention_mask - 1).to(x_dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + ) * torch.finfo(x_dtype).max + return attention_mask + + def forward( + self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs + ): + """ + Forward pass for LTX models. + + Args: + x: Input tensor + timestep: Timestep tensor + context: Context tensor (e.g., text embeddings) + attention_mask: Attention mask tensor + frame_rate: Frame rate for temporal processing + transformer_options: Additional options for transformer blocks + keyframe_idxs: Keyframe indices for temporal processing + **kwargs: Additional keyword arguments + + Returns: + Processed output tensor + """ + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers( + comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options + ), + ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs) + + def _forward( + self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs + ): + """ + Internal forward pass for LTX models. + + Args: + x: Input tensor + timestep: Timestep tensor + context: Context tensor (e.g., text embeddings) + attention_mask: Attention mask tensor + frame_rate: Frame rate for temporal processing + transformer_options: Additional options for transformer blocks + keyframe_idxs: Keyframe indices for temporal processing + **kwargs: Additional keyword arguments + + Returns: + Processed output tensor + """ + if isinstance(x, list): + input_dtype = x[0].dtype + batch_size = x[0].shape[0] + else: + input_dtype = x.dtype + batch_size = x.shape[0] + # Process input + merged_args = {**transformer_options, **kwargs} + x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args) + merged_args.update(additional_args) + + # Prepare timestep and context + timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args) + context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask) + + # Prepare attention mask and positional embeddings + attention_mask = self._prepare_attention_mask(attention_mask, input_dtype) + pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype) + + # Process transformer blocks + x = self._process_transformer_blocks( + x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args + ) + + # Process output + x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args) + return x + + +class LTXVModel(LTXBaseModel): + """LTXV model for video generation.""" + + def __init__( + self, + in_channels=128, + cross_attention_dim=2048, + attention_head_dim=64, + num_attention_heads=32, + caption_channels=4096, + num_layers=28, + positional_embedding_theta=10000.0, + positional_embedding_max_pos=[20, 2048, 2048], + causal_temporal_positioning=False, + vae_scale_factors=(8, 32, 32), + use_middle_indices_grid=False, + timestep_scale_multiplier = 1000.0, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + super().__init__( + in_channels=in_channels, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + caption_channels=caption_channels, + num_layers=num_layers, + positional_embedding_theta=positional_embedding_theta, + positional_embedding_max_pos=positional_embedding_max_pos, + causal_temporal_positioning=causal_temporal_positioning, + vae_scale_factors=vae_scale_factors, + use_middle_indices_grid=use_middle_indices_grid, + timestep_scale_multiplier=timestep_scale_multiplier, + dtype=dtype, + device=device, + operations=operations, + **kwargs, + ) + + def _init_model_components(self, device, dtype, **kwargs): + """Initialize LTXV-specific components.""" + # No additional components needed for LTXV beyond base class + pass + + def _init_transformer_blocks(self, device, dtype, **kwargs): + """Initialize transformer blocks for LTXV.""" self.transformer_blocks = nn.ModuleList( [ BasicTransformerBlock( self.inner_dim, - num_attention_heads, - attention_head_dim, - context_dim=cross_attention_dim, - # attn_precision=attn_precision, - dtype=dtype, device=device, operations=operations + self.num_attention_heads, + self.attention_head_dim, + context_dim=self.cross_attention_dim, + dtype=dtype, + device=device, + operations=self.operations, ) - for d in range(num_layers) + for _ in range(self.num_layers) ] ) + def _init_output_components(self, device, dtype): + """Initialize output components for LTXV.""" self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device)) - self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device) - - self.patchifier = SymmetricPatchifier(1) - - def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): - return comfy.patcher_extension.WrapperExecutor.new_class_executor( - self._forward, - self, - comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, **kwargs) - - def _forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs): - patches_replace = transformer_options.get("patches_replace", {}) - - orig_shape = list(x.shape) + self.norm_out = self.operations.LayerNorm( + self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device + ) + self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device) + self.patchifier = SymmetricPatchifier(1, start_end=True) + def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs): + """Process input for LTXV.""" + additional_args = {"orig_shape": list(x.shape)} x, latent_coords = self.patchifier.patchify(x) pixel_coords = latent_to_pixel_coords( latent_coords=latent_coords, @@ -423,44 +880,30 @@ class LTXVModel(torch.nn.Module): causal_fix=self.causal_temporal_positioning, ) + grid_mask = None if keyframe_idxs is not None: - pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs + additional_args.update({ "orig_patchified_shape": list(x.shape)}) + denoise_mask = self.patchifier.patchify(denoise_mask)[0] + grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0] + additional_args.update({"grid_mask": grid_mask}) + x = x[:, grid_mask, :] + pixel_coords = pixel_coords[:, :, grid_mask, ...] - fractional_coords = pixel_coords.to(torch.float32) - fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate) + kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:] + keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] + pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs x = self.patchify_proj(x) - timestep = timestep * 1000.0 - - if attention_mask is not None and not torch.is_floating_point(attention_mask): - attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max - - pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype) - - batch_size = x.shape[0] - timestep, embedded_timestep = self.adaln_single( - timestep.flatten(), - {"resolution": None, "aspect_ratio": None}, - batch_size=batch_size, - hidden_dtype=x.dtype, - ) - # Second dimension is 1 or number of tokens (if timestep_per_token) - timestep = timestep.view(batch_size, -1, timestep.shape[-1]) - embedded_timestep = embedded_timestep.view( - batch_size, -1, embedded_timestep.shape[-1] - ) - - # 2. Blocks - if self.caption_projection is not None: - batch_size = x.shape[0] - context = self.caption_projection(context) - context = context.view( - batch_size, -1, x.shape[-1] - ) + return x, pixel_coords, additional_args + def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs): + """Process transformer blocks for LTXV.""" + patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) + for i, block in enumerate(self.transformer_blocks): if ("double_block", i) in blocks_replace: + def block_wrap(args): out = {} out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"]) @@ -478,16 +921,28 @@ class LTXVModel(torch.nn.Module): transformer_options=transformer_options, ) - # 3. Output + return x + + def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs): + """Process output for LTXV.""" + # Apply scale-shift modulation scale_shift_values = ( self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None] ) shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + x = self.norm_out(x) - # Modulation - x = torch.addcmul(x, x, scale).add_(shift) + x = x * (1 + scale) + shift x = self.proj_out(x) + if keyframe_idxs is not None: + grid_mask = kwargs["grid_mask"] + orig_patchified_shape = kwargs["orig_patchified_shape"] + full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device) + full_x[:, grid_mask, :] = x + x = full_x + # Unpatchify to restore original dimensions + orig_shape = kwargs["orig_shape"] x = self.patchifier.unpatchify( latents=x, output_height=orig_shape[3], diff --git a/comfy/ldm/lightricks/symmetric_patchifier.py b/comfy/ldm/lightricks/symmetric_patchifier.py index 4b9972b9f..8f9a41186 100644 --- a/comfy/ldm/lightricks/symmetric_patchifier.py +++ b/comfy/ldm/lightricks/symmetric_patchifier.py @@ -21,20 +21,23 @@ def latent_to_pixel_coords( Returns: Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates. """ + shape = [1] * latent_coords.ndim + shape[1] = -1 pixel_coords = ( latent_coords - * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None] + * torch.tensor(scale_factors, device=latent_coords.device).view(*shape) ) if causal_fix: # Fix temporal scale for first frame to 1 due to causality - pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0) + pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0) return pixel_coords class Patchifier(ABC): - def __init__(self, patch_size: int): + def __init__(self, patch_size: int, start_end: bool=False): super().__init__() self._patch_size = (1, patch_size, patch_size) + self.start_end = start_end @abstractmethod def patchify( @@ -71,11 +74,23 @@ class Patchifier(ABC): torch.arange(0, latent_width, self._patch_size[2], device=device), indexing="ij", ) - latent_sample_coords = torch.stack(latent_sample_coords, dim=0) - latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) - latent_coords = rearrange( - latent_coords, "b c f h w -> b c (f h w)", b=batch_size + latent_sample_coords_start = torch.stack(latent_sample_coords, dim=0) + delta = torch.tensor(self._patch_size, device=latent_sample_coords_start.device, dtype=latent_sample_coords_start.dtype)[:, None, None, None] + latent_sample_coords_end = latent_sample_coords_start + delta + + latent_sample_coords_start = latent_sample_coords_start.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_sample_coords_start = rearrange( + latent_sample_coords_start, "b c f h w -> b c (f h w)", b=batch_size ) + if self.start_end: + latent_sample_coords_end = latent_sample_coords_end.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + latent_sample_coords_end = rearrange( + latent_sample_coords_end, "b c f h w -> b c (f h w)", b=batch_size + ) + + latent_coords = torch.stack((latent_sample_coords_start, latent_sample_coords_end), dim=-1) + else: + latent_coords = latent_sample_coords_start return latent_coords @@ -115,3 +130,61 @@ class SymmetricPatchifier(Patchifier): q=self._patch_size[2], ) return latents + + +class AudioPatchifier(Patchifier): + def __init__(self, patch_size: int, + sample_rate=16000, + hop_length=160, + audio_latent_downsample_factor=4, + is_causal=True, + start_end=False, + shift = 0 + ): + super().__init__(patch_size, start_end=start_end) + self.hop_length = hop_length + self.sample_rate = sample_rate + self.audio_latent_downsample_factor = audio_latent_downsample_factor + self.is_causal = is_causal + self.shift = shift + + def copy_with_shift(self, shift): + return AudioPatchifier( + self.patch_size, self.sample_rate, self.hop_length, self.audio_latent_downsample_factor, + self.is_causal, self.start_end, shift + ) + + def _get_audio_latent_time_in_sec(self, start_latent, end_latent: int, dtype: torch.dtype, device=torch.device): + audio_latent_frame = torch.arange(start_latent, end_latent, dtype=dtype, device=device) + audio_mel_frame = audio_latent_frame * self.audio_latent_downsample_factor + if self.is_causal: + audio_mel_frame = (audio_mel_frame + 1 - self.audio_latent_downsample_factor).clip(min=0) + return audio_mel_frame * self.hop_length / self.sample_rate + + + def patchify(self, audio_latents: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + # audio_latents: (batch, channels, time, freq) + b, _, t, _ = audio_latents.shape + audio_latents = rearrange( + audio_latents, + "b c t f -> b t (c f)", + ) + + audio_latents_start_timings = self._get_audio_latent_time_in_sec(self.shift, t + self.shift, torch.float32, audio_latents.device) + audio_latents_start_timings = audio_latents_start_timings.unsqueeze(0).expand(b, -1).unsqueeze(1) + + if self.start_end: + audio_latents_end_timings = self._get_audio_latent_time_in_sec(self.shift + 1, t + self.shift + 1, torch.float32, audio_latents.device) + audio_latents_end_timings = audio_latents_end_timings.unsqueeze(0).expand(b, -1).unsqueeze(1) + + audio_latents_timings = torch.stack([audio_latents_start_timings, audio_latents_end_timings], dim=-1) + else: + audio_latents_timings = audio_latents_start_timings + return audio_latents, audio_latents_timings + + def unpatchify(self, audio_latents: torch.Tensor, channels: int, freq: int) -> torch.Tensor: + # audio_latents: (batch, time, freq * channels) + audio_latents = rearrange( + audio_latents, "b t (c f) -> b c t f", c=channels, f=freq + ) + return audio_latents diff --git a/comfy/ldm/lightricks/vae/audio_vae.py b/comfy/ldm/lightricks/vae/audio_vae.py new file mode 100644 index 000000000..a9111d3bd --- /dev/null +++ b/comfy/ldm/lightricks/vae/audio_vae.py @@ -0,0 +1,286 @@ +import json +from dataclasses import dataclass +import math +import torch +import torchaudio + +import comfy.model_management +import comfy.model_patcher +import comfy.utils as utils +from comfy.ldm.mmaudio.vae.distributions import DiagonalGaussianDistribution +from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier +from comfy.ldm.lightricks.vae.causal_audio_autoencoder import ( + CausalityAxis, + CausalAudioAutoencoder, +) +from comfy.ldm.lightricks.vocoders.vocoder import Vocoder + +LATENT_DOWNSAMPLE_FACTOR = 4 + + +@dataclass(frozen=True) +class AudioVAEComponentConfig: + """Container for model component configuration extracted from metadata.""" + + autoencoder: dict + vocoder: dict + + @classmethod + def from_metadata(cls, metadata: dict) -> "AudioVAEComponentConfig": + assert metadata is not None and "config" in metadata, "Metadata is required for audio VAE" + + raw_config = metadata["config"] + if isinstance(raw_config, str): + parsed_config = json.loads(raw_config) + else: + parsed_config = raw_config + + audio_config = parsed_config.get("audio_vae") + vocoder_config = parsed_config.get("vocoder") + + assert audio_config is not None, "Audio VAE config is required for audio VAE" + assert vocoder_config is not None, "Vocoder config is required for audio VAE" + + return cls(autoencoder=audio_config, vocoder=vocoder_config) + + +class ModelDeviceManager: + """Manages device placement and GPU residency for the composed model.""" + + def __init__(self, module: torch.nn.Module): + load_device = comfy.model_management.get_torch_device() + offload_device = comfy.model_management.vae_offload_device() + self.patcher = comfy.model_patcher.ModelPatcher(module, load_device, offload_device) + + def ensure_model_loaded(self) -> None: + comfy.model_management.free_memory( + self.patcher.model_size(), + self.patcher.load_device, + ) + comfy.model_management.load_model_gpu(self.patcher) + + def move_to_load_device(self, tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(self.patcher.load_device) + + @property + def load_device(self): + return self.patcher.load_device + + +class AudioLatentNormalizer: + """Applies per-channel statistics in patch space and restores original layout.""" + + def __init__(self, patchfier: AudioPatchifier, statistics_processor: torch.nn.Module): + self.patchifier = patchfier + self.statistics = statistics_processor + + def normalize(self, latents: torch.Tensor) -> torch.Tensor: + channels = latents.shape[1] + freq = latents.shape[3] + patched, _ = self.patchifier.patchify(latents) + normalized = self.statistics.normalize(patched) + return self.patchifier.unpatchify(normalized, channels=channels, freq=freq) + + def denormalize(self, latents: torch.Tensor) -> torch.Tensor: + channels = latents.shape[1] + freq = latents.shape[3] + patched, _ = self.patchifier.patchify(latents) + denormalized = self.statistics.un_normalize(patched) + return self.patchifier.unpatchify(denormalized, channels=channels, freq=freq) + + +class AudioPreprocessor: + """Prepares raw waveforms for the autoencoder by matching training conditions.""" + + def __init__(self, target_sample_rate: int, mel_bins: int, mel_hop_length: int, n_fft: int): + self.target_sample_rate = target_sample_rate + self.mel_bins = mel_bins + self.mel_hop_length = mel_hop_length + self.n_fft = n_fft + + def resample(self, waveform: torch.Tensor, source_rate: int) -> torch.Tensor: + if source_rate == self.target_sample_rate: + return waveform + return torchaudio.functional.resample(waveform, source_rate, self.target_sample_rate) + + @staticmethod + def normalize_amplitude( + waveform: torch.Tensor, max_amplitude: float = 0.5, eps: float = 1e-5 + ) -> torch.Tensor: + waveform = waveform - waveform.mean(dim=2, keepdim=True) + peak = torch.max(torch.abs(waveform)) + eps + scale = peak.clamp(max=max_amplitude) / peak + return waveform * scale + + def waveform_to_mel( + self, waveform: torch.Tensor, waveform_sample_rate: int, device + ) -> torch.Tensor: + waveform = self.resample(waveform, waveform_sample_rate) + waveform = self.normalize_amplitude(waveform) + + mel_transform = torchaudio.transforms.MelSpectrogram( + sample_rate=self.target_sample_rate, + n_fft=self.n_fft, + win_length=self.n_fft, + hop_length=self.mel_hop_length, + f_min=0.0, + f_max=self.target_sample_rate / 2.0, + n_mels=self.mel_bins, + window_fn=torch.hann_window, + center=True, + pad_mode="reflect", + power=1.0, + mel_scale="slaney", + norm="slaney", + ).to(device) + + mel = mel_transform(waveform) + mel = torch.log(torch.clamp(mel, min=1e-5)) + return mel.permute(0, 1, 3, 2).contiguous() + + +class AudioVAE(torch.nn.Module): + """High-level Audio VAE wrapper exposing encode and decode entry points.""" + + def __init__(self, state_dict: dict, metadata: dict): + super().__init__() + + component_config = AudioVAEComponentConfig.from_metadata(metadata) + + vae_sd = utils.state_dict_prefix_replace(state_dict, {"audio_vae.": ""}, filter_keys=True) + vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True) + + self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder) + self.vocoder = Vocoder(config=component_config.vocoder) + + self.autoencoder.load_state_dict(vae_sd, strict=False) + self.vocoder.load_state_dict(vocoder_sd, strict=False) + + autoencoder_config = self.autoencoder.get_config() + self.normalizer = AudioLatentNormalizer( + AudioPatchifier( + patch_size=1, + audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR, + sample_rate=autoencoder_config["sampling_rate"], + hop_length=autoencoder_config["mel_hop_length"], + is_causal=autoencoder_config["is_causal"], + ), + self.autoencoder.per_channel_statistics, + ) + + self.preprocessor = AudioPreprocessor( + target_sample_rate=autoencoder_config["sampling_rate"], + mel_bins=autoencoder_config["mel_bins"], + mel_hop_length=autoencoder_config["mel_hop_length"], + n_fft=autoencoder_config["n_fft"], + ) + + self.device_manager = ModelDeviceManager(self) + + def encode(self, audio: dict) -> torch.Tensor: + """Encode a waveform dictionary into normalized latent tensors.""" + + waveform = audio["waveform"] + waveform_sample_rate = audio["sample_rate"] + input_device = waveform.device + # Ensure that Audio VAE is loaded on the correct device. + self.device_manager.ensure_model_loaded() + + waveform = self.device_manager.move_to_load_device(waveform) + expected_channels = self.autoencoder.encoder.in_channels + if waveform.shape[1] != expected_channels: + raise ValueError( + f"Input audio must have {expected_channels} channels, got {waveform.shape[1]}" + ) + + mel_spec = self.preprocessor.waveform_to_mel( + waveform, waveform_sample_rate, device=self.device_manager.load_device + ) + + latents = self.autoencoder.encode(mel_spec) + posterior = DiagonalGaussianDistribution(latents) + latent_mode = posterior.mode() + + normalized = self.normalizer.normalize(latent_mode) + return normalized.to(input_device) + + def decode(self, latents: torch.Tensor) -> torch.Tensor: + """Decode normalized latent tensors into an audio waveform.""" + original_shape = latents.shape + + # Ensure that Audio VAE is loaded on the correct device. + self.device_manager.ensure_model_loaded() + + latents = self.device_manager.move_to_load_device(latents) + latents = self.normalizer.denormalize(latents) + + target_shape = self.target_shape_from_latents(original_shape) + mel_spec = self.autoencoder.decode(latents, target_shape=target_shape) + + waveform = self.run_vocoder(mel_spec) + return self.device_manager.move_to_load_device(waveform) + + def target_shape_from_latents(self, latents_shape): + batch, _, time, _ = latents_shape + target_length = time * LATENT_DOWNSAMPLE_FACTOR + if self.autoencoder.causality_axis != CausalityAxis.NONE: + target_length -= LATENT_DOWNSAMPLE_FACTOR - 1 + return ( + batch, + self.autoencoder.decoder.out_ch, + target_length, + self.autoencoder.mel_bins, + ) + + def num_of_latents_from_frames(self, frames_number: int, frame_rate: int) -> int: + return math.ceil((float(frames_number) / frame_rate) * self.latents_per_second) + + def run_vocoder(self, mel_spec: torch.Tensor) -> torch.Tensor: + audio_channels = self.autoencoder.decoder.out_ch + vocoder_input = mel_spec.transpose(2, 3) + + if audio_channels == 1: + vocoder_input = vocoder_input.squeeze(1) + elif audio_channels != 2: + raise ValueError(f"Unsupported audio_channels: {audio_channels}") + + return self.vocoder(vocoder_input) + + @property + def sample_rate(self) -> int: + return int(self.autoencoder.sampling_rate) + + @property + def mel_hop_length(self) -> int: + return int(self.autoencoder.mel_hop_length) + + @property + def mel_bins(self) -> int: + return int(self.autoencoder.mel_bins) + + @property + def latent_channels(self) -> int: + return int(self.autoencoder.decoder.z_channels) + + @property + def latent_frequency_bins(self) -> int: + return int(self.mel_bins // LATENT_DOWNSAMPLE_FACTOR) + + @property + def latents_per_second(self) -> float: + return self.sample_rate / self.mel_hop_length / LATENT_DOWNSAMPLE_FACTOR + + @property + def output_sample_rate(self) -> int: + output_rate = getattr(self.vocoder, "output_sample_rate", None) + if output_rate is not None: + return int(output_rate) + upsample_factor = getattr(self.vocoder, "upsample_factor", None) + if upsample_factor is None: + raise AttributeError( + "Vocoder is missing upsample_factor; cannot infer output sample rate" + ) + return int(self.sample_rate * upsample_factor / self.mel_hop_length) + + def memory_required(self, input_shape): + return self.device_manager.patcher.model_size() diff --git a/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py new file mode 100644 index 000000000..f12b9bb53 --- /dev/null +++ b/comfy/ldm/lightricks/vae/causal_audio_autoencoder.py @@ -0,0 +1,909 @@ +from __future__ import annotations +import torch +from torch import nn +from torch.nn import functional as F +from typing import Optional +from enum import Enum +from .pixel_norm import PixelNorm +import comfy.ops +import logging + +ops = comfy.ops.disable_weight_init + + +class StringConvertibleEnum(Enum): + """ + Base enum class that provides string-to-enum conversion functionality. + + This mixin adds a str_to_enum() class method that handles conversion from + strings, None, or existing enum instances with case-insensitive matching. + """ + + @classmethod + def str_to_enum(cls, value): + """ + Convert a string, enum instance, or None to the appropriate enum member. + + Args: + value: Can be an enum instance of this class, a string, or None + + Returns: + Enum member of this class + + Raises: + ValueError: If the value cannot be converted to a valid enum member + """ + # Already an enum instance of this class + if isinstance(value, cls): + return value + + # None maps to NONE member if it exists + if value is None: + if hasattr(cls, "NONE"): + return cls.NONE + raise ValueError(f"{cls.__name__} does not have a NONE member to map None to") + + # String conversion (case-insensitive) + if isinstance(value, str): + value_lower = value.lower() + + # Try to match against enum values + for member in cls: + # Handle members with None values + if member.value is None: + if value_lower == "none": + return member + # Handle members with string values + elif isinstance(member.value, str) and member.value.lower() == value_lower: + return member + + # Build helpful error message with valid values + valid_values = [] + for member in cls: + if member.value is None: + valid_values.append("none") + elif isinstance(member.value, str): + valid_values.append(member.value) + + raise ValueError(f"Invalid {cls.__name__} string: '{value}'. " f"Valid values are: {valid_values}") + + raise ValueError( + f"Cannot convert type {type(value).__name__} to {cls.__name__} enum. " + f"Expected string, None, or {cls.__name__} instance." + ) + + +class AttentionType(StringConvertibleEnum): + """Enum for specifying the attention mechanism type.""" + + VANILLA = "vanilla" + LINEAR = "linear" + NONE = "none" + + +class CausalityAxis(StringConvertibleEnum): + """Enum for specifying the causality axis in causal convolutions.""" + + NONE = None + WIDTH = "width" + HEIGHT = "height" + WIDTH_COMPATIBILITY = "width-compatibility" + + +def Normalize(in_channels, *, num_groups=32, normtype="group"): + if normtype == "group": + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + elif normtype == "pixel": + return PixelNorm(dim=1, eps=1e-6) + else: + raise ValueError(f"Invalid normalization type: {normtype}") + + +class CausalConv2d(nn.Module): + """ + A causal 2D convolution. + + This layer ensures that the output at time `t` only depends on inputs + at time `t` and earlier. It achieves this by applying asymmetric padding + to the time dimension (width) before the convolution. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + groups=1, + bias=True, + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ): + super().__init__() + + self.causality_axis = causality_axis + + # Ensure kernel_size and dilation are tuples + kernel_size = nn.modules.utils._pair(kernel_size) + dilation = nn.modules.utils._pair(dilation) + + # Calculate padding dimensions + pad_h = (kernel_size[0] - 1) * dilation[0] + pad_w = (kernel_size[1] - 1) * dilation[1] + + # The padding tuple for F.pad is (pad_left, pad_right, pad_top, pad_bottom) + match self.causality_axis: + case CausalityAxis.NONE: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.WIDTH | CausalityAxis.WIDTH_COMPATIBILITY: + self.padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2) + case CausalityAxis.HEIGHT: + self.padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0) + case _: + raise ValueError(f"Invalid causality_axis: {causality_axis}") + + # The internal convolution layer uses no padding, as we handle it manually + self.conv = ops.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + ) + + def forward(self, x): + # Apply causal padding before convolution + x = F.pad(x, self.padding) + return self.conv(x) + + +def make_conv2d( + in_channels, + out_channels, + kernel_size, + stride=1, + padding=None, + dilation=1, + groups=1, + bias=True, + causality_axis: Optional[CausalityAxis] = None, +): + """ + Create a 2D convolution layer that can be either causal or non-causal. + + Args: + in_channels: Number of input channels + out_channels: Number of output channels + kernel_size: Size of the convolution kernel + stride: Convolution stride + padding: Padding (if None, will be calculated based on causal flag) + dilation: Dilation rate + groups: Number of groups for grouped convolution + bias: Whether to use bias + causality_axis: Dimension along which to apply causality. + + Returns: + Either a regular Conv2d or CausalConv2d layer + """ + if causality_axis is not None: + # For causal convolution, padding is handled internally by CausalConv2d + return CausalConv2d(in_channels, out_channels, kernel_size, stride, dilation, groups, bias, causality_axis) + else: + # For non-causal convolution, use symmetric padding if not specified + if padding is None: + if isinstance(kernel_size, int): + padding = kernel_size // 2 + else: + padding = tuple(k // 2 for k in kernel_size) + return ops.Conv2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.HEIGHT): + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + if self.with_conv: + self.conv = make_conv2d(in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + # Drop FIRST element in the causal axis to undo encoder's padding, while keeping the length 1 + 2 * n. + # For example, if the input is [0, 1, 2], after interpolation, the output is [0, 0, 1, 1, 2, 2]. + # The causal convolution will pad the first element as [-, -, 0, 0, 1, 1, 2, 2], + # So the output elements rely on the following windows: + # 0: [-,-,0] + # 1: [-,0,0] + # 2: [0,0,1] + # 3: [0,1,1] + # 4: [1,1,2] + # 5: [1,2,2] + # Notice that the first and second elements in the output rely only on the first element in the input, + # while all other elements rely on two elements in the input. + # So we can drop the first element to undo the padding (rather than the last element). + # This is a no-op for non-causal convolutions. + match self.causality_axis: + case CausalityAxis.NONE: + pass # x remains unchanged + case CausalityAxis.HEIGHT: + x = x[:, :, 1:, :] + case CausalityAxis.WIDTH: + x = x[:, :, :, 1:] + case CausalityAxis.WIDTH_COMPATIBILITY: + pass # x remains unchanged + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + return x + + +class Downsample(nn.Module): + """ + A downsampling layer that can use either a strided convolution + or average pooling. Supports standard and causal padding for the + convolutional mode. + """ + + def __init__(self, in_channels, with_conv, causality_axis: CausalityAxis = CausalityAxis.WIDTH): + super().__init__() + self.with_conv = with_conv + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and not self.with_conv: + raise ValueError("causality is only supported when `with_conv=True`.") + + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = ops.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + if self.with_conv: + # (pad_left, pad_right, pad_top, pad_bottom) + match self.causality_axis: + case CausalityAxis.NONE: + pad = (0, 1, 0, 1) + case CausalityAxis.WIDTH: + pad = (2, 0, 0, 1) + case CausalityAxis.HEIGHT: + pad = (0, 1, 2, 0) + case CausalityAxis.WIDTH_COMPATIBILITY: + pad = (1, 0, 0, 1) + case _: + raise ValueError(f"Invalid causality_axis: {self.causality_axis}") + + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + # This branch is only taken if with_conv=False, which implies causality_axis is NONE. + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + norm_type="group", + causality_axis: CausalityAxis = CausalityAxis.HEIGHT, + ): + super().__init__() + self.causality_axis = causality_axis + + if self.causality_axis != CausalityAxis.NONE and norm_type == "group": + raise ValueError("Causal ResnetBlock with GroupNorm is not supported.") + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, normtype=norm_type) + self.non_linearity = nn.SiLU() + self.conv1 = make_conv2d(in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if temb_channels > 0: + self.temb_proj = ops.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels, normtype=norm_type) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = make_conv2d(out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis + ) + else: + self.nin_shortcut = make_conv2d( + in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = self.non_linearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = self.non_linearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels, norm_type="group"): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels, normtype=norm_type) + self.q = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.k = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.v = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + self.proj_out = ops.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v, w_).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla", norm_type="group"): + # Convert string to enum if needed + attn_type = AttentionType.str_to_enum(attn_type) + + if attn_type != AttentionType.NONE: + logging.info(f"making attention of type '{attn_type.value}' with {in_channels} in_channels") + else: + logging.info(f"making identity attention with {in_channels} in_channels") + + match attn_type: + case AttentionType.VANILLA: + return AttnBlock(in_channels, norm_type=norm_type) + case AttentionType.NONE: + return nn.Identity(in_channels) + case AttentionType.LINEAR: + raise NotImplementedError(f"Attention type {attn_type.value} is not supported yet.") + case _: + raise ValueError(f"Unknown attention type: {attn_type}") + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + attn_type="vanilla", + mid_block_add_attention=True, + norm_type="group", + causality_axis=CausalityAxis.WIDTH.value, + **ignore_kwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.z_channels = z_channels + self.double_z = double_z + self.norm_type = norm_type + # Convert string to enum if needed (for config loading) + causality_axis = CausalityAxis.str_to_enum(causality_axis) + self.attn_type = AttentionType.str_to_enum(attn_type) + + # downsampling + self.conv_in = make_conv2d( + in_channels, + self.ch, + kernel_size=3, + stride=1, + causality_axis=causality_axis, + ) + + self.non_linearity = nn.SiLU() + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + + for _ in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)) + + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + + # end + self.norm_out = Normalize(block_in, normtype=self.norm_type) + self.conv_out = make_conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + causality_axis=causality_axis, + ) + + def forward(self, x): + """ + Forward pass through the encoder. + + Args: + x: Input tensor of shape [batch, channels, time, n_mels] + + Returns: + Encoded latent representation + """ + feature_maps = [self.conv_in(x)] + + # Process each resolution level (from high to low resolution) + for resolution_level in range(self.num_resolutions): + # Apply residual blocks at current resolution level + for block_idx in range(self.num_res_blocks): + # Apply ResNet block with optional timestep embedding + current_features = self.down[resolution_level].block[block_idx](feature_maps[-1], temb=None) + + # Apply attention if configured for this resolution level + if len(self.down[resolution_level].attn) > 0: + current_features = self.down[resolution_level].attn[block_idx](current_features) + + # Store processed features + feature_maps.append(current_features) + + # Downsample spatial dimensions (except at the final resolution level) + if resolution_level != self.num_resolutions - 1: + downsampled_features = self.down[resolution_level].downsample(feature_maps[-1]) + feature_maps.append(downsampled_features) + + # === MIDDLE PROCESSING PHASE === + # Take the lowest resolution features for middle processing + bottleneck_features = feature_maps[-1] + + # Apply first middle ResNet block + bottleneck_features = self.mid.block_1(bottleneck_features, temb=None) + + # Apply middle attention block + bottleneck_features = self.mid.attn_1(bottleneck_features) + + # Apply second middle ResNet block + bottleneck_features = self.mid.block_2(bottleneck_features, temb=None) + + # === OUTPUT PHASE === + # Normalize the bottleneck features + output_features = self.norm_out(bottleneck_features) + + # Apply non-linearity (SiLU activation) + output_features = self.non_linearity(output_features) + + # Final convolution to produce latent representation + # [batch, channels, time, n_mels] -> [batch, 2 * z_channels if double_z else z_channels, time, n_mels] + return self.conv_out(output_features) + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + attn_type="vanilla", + mid_block_add_attention=True, + norm_type="group", + causality_axis=CausalityAxis.WIDTH.value, + **ignorekwargs, + ): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.out_ch = out_ch + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + self.norm_type = norm_type + self.z_channels = z_channels + # Convert string to enum if needed (for config loading) + causality_axis = CausalityAxis.str_to_enum(causality_axis) + self.attn_type = AttentionType.str_to_enum(attn_type) + + # compute block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = make_conv2d(z_channels, block_in, kernel_size=3, stride=1, causality_axis=causality_axis) + + self.non_linearity = nn.SiLU() + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + if mid_block_add_attention: + self.mid.attn_1 = make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type) + else: + self.mid.attn_1 = nn.Identity() + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + norm_type=self.norm_type, + causality_axis=causality_axis, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=self.attn_type, norm_type=self.norm_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv, causality_axis=causality_axis) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in, normtype=self.norm_type) + self.conv_out = make_conv2d(block_in, out_ch, kernel_size=3, stride=1, causality_axis=causality_axis) + + def _adjust_output_shape(self, decoded_output, target_shape): + """ + Adjust output shape to match target dimensions for variable-length audio. + + This function handles the common case where decoded audio spectrograms need to be + resized to match a specific target shape. + + Args: + decoded_output: Tensor of shape (batch, channels, time, frequency) + target_shape: Target shape tuple (batch, channels, time, frequency) + + Returns: + Tensor adjusted to match target_shape exactly + """ + # Current output shape: (batch, channels, time, frequency) + _, _, current_time, current_freq = decoded_output.shape + _, target_channels, target_time, target_freq = target_shape + + # Step 1: Crop first to avoid exceeding target dimensions + decoded_output = decoded_output[ + :, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq) + ] + + # Step 2: Calculate padding needed for time and frequency dimensions + time_padding_needed = target_time - decoded_output.shape[2] + freq_padding_needed = target_freq - decoded_output.shape[3] + + # Step 3: Apply padding if needed + if time_padding_needed > 0 or freq_padding_needed > 0: + # PyTorch padding format: (pad_left, pad_right, pad_top, pad_bottom) + # For audio: pad_left/right = frequency, pad_top/bottom = time + padding = ( + 0, + max(freq_padding_needed, 0), # frequency padding (left, right) + 0, + max(time_padding_needed, 0), # time padding (top, bottom) + ) + decoded_output = F.pad(decoded_output, padding) + + # Step 4: Final safety crop to ensure exact target shape + decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq] + + return decoded_output + + def get_config(self): + return { + "ch": self.ch, + "out_ch": self.out_ch, + "ch_mult": self.ch_mult, + "num_res_blocks": self.num_res_blocks, + "in_channels": self.in_channels, + "resolution": self.resolution, + "z_channels": self.z_channels, + } + + def forward(self, latent_features, target_shape=None): + """ + Decode latent features back to audio spectrograms. + + Args: + latent_features: Encoded latent representation of shape (batch, channels, height, width) + target_shape: Optional target output shape (batch, channels, time, frequency) + If provided, output will be cropped/padded to match this shape + + Returns: + Reconstructed audio spectrogram of shape (batch, channels, time, frequency) + """ + assert target_shape is not None, "Target shape is required for CausalAudioAutoencoder Decoder" + + # Transform latent features to decoder's internal feature dimension + hidden_features = self.conv_in(latent_features) + + # Middle processing + hidden_features = self.mid.block_1(hidden_features, temb=None) + hidden_features = self.mid.attn_1(hidden_features) + hidden_features = self.mid.block_2(hidden_features, temb=None) + + # Upsampling + # Progressively increase spatial resolution from lowest to highest + for resolution_level in reversed(range(self.num_resolutions)): + # Apply residual blocks at current resolution level + for block_index in range(self.num_res_blocks + 1): + hidden_features = self.up[resolution_level].block[block_index](hidden_features, temb=None) + + if len(self.up[resolution_level].attn) > 0: + hidden_features = self.up[resolution_level].attn[block_index](hidden_features) + + if resolution_level != 0: + hidden_features = self.up[resolution_level].upsample(hidden_features) + + # Output + if self.give_pre_end: + # Return intermediate features before final processing (for debugging/analysis) + decoded_output = hidden_features + else: + # Standard output path: normalize, activate, and convert to output channels + # Final normalization layer + hidden_features = self.norm_out(hidden_features) + + # Apply SiLU (Swish) activation function + hidden_features = self.non_linearity(hidden_features) + + # Final convolution to map to output channels (typically 2 for stereo audio) + decoded_output = self.conv_out(hidden_features) + + # Optional tanh activation to bound output values to [-1, 1] range + if self.tanh_out: + decoded_output = torch.tanh(decoded_output) + + # Adjust shape for audio data + if target_shape is not None: + decoded_output = self._adjust_output_shape(decoded_output, target_shape) + + return decoded_output + + +class processor(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("std-of-means", torch.empty(128)) + self.register_buffer("mean-of-means", torch.empty(128)) + + def un_normalize(self, x): + return (x * self.get_buffer("std-of-means").to(x)) + self.get_buffer("mean-of-means").to(x) + + def normalize(self, x): + return (x - self.get_buffer("mean-of-means").to(x)) / self.get_buffer("std-of-means").to(x) + + +class CausalAudioAutoencoder(nn.Module): + def __init__(self, config=None): + super().__init__() + + if config is None: + config = self._guess_config() + + # Extract encoder and decoder configs from the new format + model_config = config.get("model", {}).get("params", {}) + variables_config = config.get("variables", {}) + + self.sampling_rate = variables_config.get( + "sampling_rate", + model_config.get("sampling_rate", config.get("sampling_rate", 16000)), + ) + encoder_config = model_config.get("encoder", model_config.get("ddconfig", {})) + decoder_config = model_config.get("decoder", encoder_config) + + # Load mel spectrogram parameters + self.mel_bins = encoder_config.get("mel_bins", 64) + self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160) + self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024) + + # Store causality configuration at VAE level (not just in encoder internals) + causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value) + self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value) + self.is_causal = self.causality_axis == CausalityAxis.HEIGHT + + self.encoder = Encoder(**encoder_config) + self.decoder = Decoder(**decoder_config) + + self.per_channel_statistics = processor() + + def _guess_config(self): + encoder_config = { + # Required parameters - based on ltx-video-av-1679000 model metadata + "ch": 128, + "out_ch": 8, + "ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8] + "num_res_blocks": 2, + "attn_resolutions": [], # Based on metadata: empty list, no attention + "dropout": 0.0, + "resamp_with_conv": True, + "in_channels": 2, # stereo + "resolution": 256, + "z_channels": 8, + "double_z": True, + "attn_type": "vanilla", + "mid_block_add_attention": False, # Based on metadata: false + "norm_type": "pixel", + "causality_axis": "height", # Based on metadata + "mel_bins": 64, # Based on metadata: mel_bins = 64 + } + + decoder_config = { + # Inherits encoder config, can override specific params + **encoder_config, + "out_ch": 2, # Stereo audio output (2 channels) + "give_pre_end": False, + "tanh_out": False, + } + + config = { + "_class_name": "CausalAudioAutoencoder", + "sampling_rate": 16000, + "model": { + "params": { + "encoder": encoder_config, + "decoder": decoder_config, + } + }, + } + + return config + + def get_config(self): + return { + "sampling_rate": self.sampling_rate, + "mel_bins": self.mel_bins, + "mel_hop_length": self.mel_hop_length, + "n_fft": self.n_fft, + "causality_axis": self.causality_axis.value, + "is_causal": self.is_causal, + } + + def encode(self, x): + return self.encoder(x) + + def decode(self, x, target_shape=None): + return self.decoder(x, target_shape=target_shape) diff --git a/comfy/ldm/lightricks/vocoders/vocoder.py b/comfy/ldm/lightricks/vocoders/vocoder.py new file mode 100644 index 000000000..b1f15f2c5 --- /dev/null +++ b/comfy/ldm/lightricks/vocoders/vocoder.py @@ -0,0 +1,213 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +import comfy.ops +import numpy as np + +ops = comfy.ops.disable_weight_init + +LRELU_SLOPE = 0.1 + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ), + ] + ) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList( + [ + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ), + ops.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ), + ] + ) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + +class Vocoder(torch.nn.Module): + """ + Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan. + + """ + + def __init__(self, config=None): + super(Vocoder, self).__init__() + + if config is None: + config = self.get_default_config() + + resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11]) + upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2]) + upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4]) + resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_initial_channel = config.get("upsample_initial_channel", 1024) + stereo = config.get("stereo", True) + resblock = config.get("resblock", "1") + + self.output_sample_rate = config.get("output_sample_rate") + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + in_channels = 128 if stereo else 64 + self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3) + resblock_class = ResBlock1 if resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + ops.ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock_class(ch, k, d)) + + out_channels = 2 if stereo else 1 + self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3) + + self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))]) + + def get_default_config(self): + """Generate default configuration for the vocoder.""" + + config = { + "resblock_kernel_sizes": [3, 7, 11], + "upsample_rates": [6, 5, 2, 2, 2], + "upsample_kernel_sizes": [16, 15, 8, 4, 4], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "upsample_initial_channel": 1024, + "stereo": True, + "resblock": "1", + } + + return config + + def forward(self, x): + """ + Forward pass of the vocoder. + + Args: + x: Input spectrogram tensor. Can be: + - 3D: (batch_size, channels, time_steps) for mono + - 4D: (batch_size, 2, channels, time_steps) for stereo + + Returns: + Audio tensor of shape (batch_size, out_channels, audio_length) + """ + if x.dim() == 4: # stereo + assert x.shape[1] == 2, "Input must have 2 channels for stereo" + x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1) + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index c4f3c0639..49efd700b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -20,6 +20,7 @@ import comfy.ldm.hunyuan3dv2_1 import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging +import comfy.ldm.lightricks.av_model from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC from comfy.ldm.cascade.stage_b import StageB @@ -946,7 +947,7 @@ class GenmoMochi(BaseModel): class LTXV(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): - super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) @@ -977,6 +978,60 @@ class LTXV(BaseModel): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image +class LTXAV(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.av_model.LTXAVModel) #TODO + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25)) + + denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + + audio_denoise_mask = None + if denoise_mask is not None and "latent_shapes" in kwargs: + denoise_mask = utils.unpack_latents(denoise_mask, kwargs["latent_shapes"]) + if len(denoise_mask) > 1: + audio_denoise_mask = denoise_mask[1] + denoise_mask = denoise_mask[0] + + if denoise_mask is not None: + out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask) + + if audio_denoise_mask is not None: + out["audio_denoise_mask"] = comfy.conds.CONDRegular(audio_denoise_mask) + + keyframe_idxs = kwargs.get("keyframe_idxs", None) + if keyframe_idxs is not None: + out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs) + + latent_shapes = kwargs.get("latent_shapes", None) + if latent_shapes is not None: + out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) + + return out + + def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): + v_timestep = timestep + a_timestep = timestep + + if denoise_mask is not None: + v_timestep = self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0] + if audio_denoise_mask is not None: + a_timestep = self.diffusion_model.a_patchifier.patchify(((audio_denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (audio_denoise_mask.ndim - 1)))[:, :1, :, :1])[0] + + return v_timestep, a_timestep + + def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): + return latent_image + class HunyuanVideo(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 539e296ed..0853b3aec 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -305,7 +305,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv dit_config = {} - dit_config["image_model"] = "ltxv" + dit_config["image_model"] = "ltxav" if f'{key_prefix}audio_adaln_single.linear.weight' in state_dict_keys else "ltxv" dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape dit_config["attention_head_dim"] = shape[0] // 32 diff --git a/comfy/sd.py b/comfy/sd.py index 7de7dd9c6..32157e18b 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1041,7 +1041,8 @@ class TEModel(Enum): MISTRAL3_24B_PRUNED_FLUX2 = 15 QWEN3_4B = 16 QWEN3_2B = 17 - JINA_CLIP_2 = 18 + GEMMA_3_12B = 18 + JINA_CLIP_2 = 19 def detect_te_model(sd): @@ -1067,6 +1068,8 @@ def detect_te_model(sd): return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE if 'model.layers.0.post_feedforward_layernorm.weight' in sd: + if 'model.layers.47.self_attn.q_norm.weight' in sd: + return TEModel.GEMMA_3_12B if 'model.layers.0.self_attn.q_norm.weight' in sd: return TEModel.GEMMA_3_4B return TEModel.GEMMA_2_2B @@ -1271,6 +1274,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.KANDINSKY5_IMAGE: clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage + elif clip_type == CLIPType.LTXV: + clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.NEWBIE: clip_target.clip = comfy.text_encoders.newbie.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.newbie.NewBieTokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1888f35ba..ee9a79001 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -836,6 +836,21 @@ class LTXV(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect)) +class LTXAV(LTXV): + unet_config = { + "image_model": "ltxav", + } + + latent_format = latent_formats.LTXAV + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = 0.055 # TODO + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.LTXAV(self, device=device) + return out + class HunyuanVideo(supported_models_base.BASE): unet_config = { "image_model": "hunyuan_video", @@ -1536,6 +1551,6 @@ class Kandinsky5Image(Kandinsky5): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] models += [SVD_img2vid] diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index faa4e1de8..76731576b 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -7,6 +7,7 @@ import math from comfy.ldm.modules.attention import optimized_attention_for_device import comfy.model_management import comfy.ldm.common_dit +import comfy.clip_model from . import qwen_vl @@ -188,6 +189,31 @@ class Gemma3_4B_Config: rope_scale = [8.0, 1.0] final_norm: bool = True +@dataclass +class Gemma3_12B_Config: + vocab_size: int = 262208 + hidden_size: int = 3840 + intermediate_size: int = 15360 + num_hidden_layers: int = 48 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 131072 + rms_norm_eps: float = 1e-6 + rope_theta = [1000000.0, 10000.0] + transformer_type: str = "gemma3" + head_dim = 256 + rms_norm_add = True + mlp_activation = "gelu_pytorch_tanh" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + sliding_attention = [1024, 1024, 1024, 1024, 1024, False] + rope_scale = [8.0, 1.0] + final_norm: bool = True + vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14} + mm_tokens_per_image = 256 + class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None): super().__init__() @@ -520,6 +546,41 @@ class Llama2_(nn.Module): return x, intermediate + +class Gemma3MultiModalProjector(torch.nn.Module): + def __init__(self, config, dtype, device, operations): + super().__init__() + + self.mm_input_projection_weight = nn.Parameter( + torch.empty(config.vision_config["hidden_size"], config.hidden_size, device=device, dtype=dtype) + ) + + self.mm_soft_emb_norm = RMSNorm(config.vision_config["hidden_size"], eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + + self.patches_per_image = int(config.vision_config["image_size"] // config.vision_config["patch_size"]) + self.tokens_per_side = int(config.mm_tokens_per_image**0.5) + self.kernel_size = self.patches_per_image // self.tokens_per_side + self.avg_pool = nn.AvgPool2d(kernel_size=self.kernel_size, stride=self.kernel_size) + + def forward(self, vision_outputs: torch.Tensor): + batch_size, _, seq_length = vision_outputs.shape + + reshaped_vision_outputs = vision_outputs.transpose(1, 2) + reshaped_vision_outputs = reshaped_vision_outputs.reshape( + batch_size, seq_length, self.patches_per_image, self.patches_per_image + ) + reshaped_vision_outputs = reshaped_vision_outputs.contiguous() + + pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs) + pooled_vision_outputs = pooled_vision_outputs.flatten(2) + pooled_vision_outputs = pooled_vision_outputs.transpose(1, 2) + + normed_vision_outputs = self.mm_soft_emb_norm(pooled_vision_outputs) + + projected_vision_outputs = torch.matmul(normed_vision_outputs, comfy.model_management.cast_to_device(self.mm_input_projection_weight, device=normed_vision_outputs.device, dtype=normed_vision_outputs.dtype)) + return projected_vision_outputs.type_as(vision_outputs) + + class BaseLlama: def get_input_embeddings(self): return self.model.embed_tokens @@ -636,3 +697,21 @@ class Gemma3_4B(BaseLlama, torch.nn.Module): self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.dtype = dtype + +class Gemma3_12B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Gemma3_12B_Config(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations) + self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations) + self.dtype = dtype + self.image_size = config.vision_config["image_size"] + + def preprocess_embed(self, embed, device): + if embed["type"] == "image": + image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True) + return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None + return None, None diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 48ea67e67..2c2d453e8 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -1,7 +1,11 @@ from comfy import sd1_clip import os from transformers import T5TokenizerFast +from .spiece_tokenizer import SPieceTokenizer import comfy.text_encoders.genmo +from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector +import torch +import comfy.utils class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -16,3 +20,110 @@ class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer): def ltxv_te(*args, **kwargs): return comfy.text_encoders.genmo.mochi_te(*args, **kwargs) + + +class Gemma3_12BTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer = tokenizer_data.get("spiece_model", None) + super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} + +class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer) + +class Gemma3_12BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + llama_scaled_fp8 = model_options.get("gemma_scaled_fp8", None) + if llama_scaled_fp8 is not None: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs): + text = llama_template.format(text) + text_tokens = super().tokenize_with_weights(text, return_word_ids) + embed_count = 0 + for k in text_tokens: + tt = text_tokens[k] + for r in tt: + for i in range(len(r)): + if r[i][0] == 262144: + if image_embeds is not None and embed_count < image_embeds.shape[0]: + r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:] + embed_count += 1 + return text_tokens + +class LTXAVTEModel(torch.nn.Module): + def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): + super().__init__() + self.dtypes = set() + self.dtypes.add(dtype) + + self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None) + self.dtypes.add(dtype_llama) + + operations = self.gemma3_12b.operations # TODO + self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device) + + self.audio_embeddings_connector = Embeddings1DConnector( + split_rope=True, + double_precision_rope=True, + dtype=dtype, + device=device, + operations=operations, + ) + + self.video_embeddings_connector = Embeddings1DConnector( + split_rope=True, + double_precision_rope=True, + dtype=dtype, + device=device, + operations=operations, + ) + + def set_clip_options(self, options): + self.gemma3_12b.set_clip_options(options) + + def reset_clip_options(self): + self.gemma3_12b.reset_clip_options() + + def encode_token_weights(self, token_weight_pairs): + token_weight_pairs = token_weight_pairs["gemma3_12b"] + + out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs) + out_device = out.device + out = out.movedim(1, -1).to(self.text_embedding_projection.weight.device) + out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6) + out = out.reshape((out.shape[0], out.shape[1], -1)) + out = self.text_embedding_projection(out) + out_vid = self.video_embeddings_connector(out)[0] + out_audio = self.audio_embeddings_connector(out)[0] + out = torch.concat((out_vid, out_audio), dim=-1) + + return out.to(out_device), pooled + + def load_sd(self, sd): + if "model.layers.47.self_attn.q_norm.weight" in sd: + return self.gemma3_12b.load_sd(sd) + else: + sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "model.diffusion_model.video_embeddings_connector.": "video_embeddings_connector.", "model.diffusion_model.audio_embeddings_connector.": "audio_embeddings_connector."}, filter_keys=True) + if len(sdo) == 0: + sdo = sd + + return self.load_state_dict(sdo, strict=False) + + +def ltxav_te(dtype_llama=None, llama_scaled_fp8=None): + class LTXAVTEModel_(LTXAVTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["llama_scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) + return LTXAVTEModel_ diff --git a/comfy/utils.py b/comfy/utils.py index e4162d7ac..ffa98c9b1 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1198,7 +1198,7 @@ def unpack_latents(combined_latent, latent_shapes): combined_latent = combined_latent[:, :, cut:] output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:])) else: - output_tensors = combined_latent + output_tensors = [combined_latent] return output_tensors def detect_layer_quantization(state_dict, prefix): diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index c7916443c..94ad5e8a8 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -112,7 +112,7 @@ class VAEDecodeAudio(IO.ComfyNode): std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - return IO.NodeOutput({"waveform": audio, "sample_rate": 44100}) + return IO.NodeOutput({"waveform": audio, "sample_rate": 44100 if "sample_rate" not in samples else samples["sample_rate"]}) decode = execute # TODO: remove diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 32be182f1..ceff657d3 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -5,7 +5,9 @@ import comfy.model_management from typing_extensions import override from comfy_api.latest import ComfyExtension, io from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel +from comfy.ldm.lightricks.latent_upsampler import LatentUpsampler import folder_paths +import json class CLIPTextEncodeHunyuanDiT(io.ComfyNode): @classmethod @@ -186,7 +188,7 @@ class LatentUpscaleModelLoader(io.ComfyNode): @classmethod def execute(cls, model_name) -> io.NodeOutput: model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name) - sd = comfy.utils.load_torch_file(model_path, safe_load=True) + sd, metadata = comfy.utils.load_torch_file(model_path, safe_load=True, return_metadata=True) if "blocks.0.block.0.conv.weight" in sd: config = { @@ -197,6 +199,8 @@ class LatentUpscaleModelLoader(io.ComfyNode): "global_residual": False, } model_type = "720p" + model = HunyuanVideo15SRModel(model_type, config) + model.load_sd(sd) elif "up.0.block.0.conv1.conv.weight" in sd: sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()} config = { @@ -205,9 +209,12 @@ class LatentUpscaleModelLoader(io.ComfyNode): "block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))), } model_type = "1080p" - - model = HunyuanVideo15SRModel(model_type, config) - model.load_sd(sd) + model = HunyuanVideo15SRModel(model_type, config) + model.load_sd(sd) + elif "post_upsample_res_blocks.0.conv2.bias" in sd: + config = json.loads(metadata["config"]) + model = LatentUpsampler.from_config(config).to(dtype=comfy.model_management.vae_dtype(allowed_dtypes=[torch.bfloat16, torch.float32])) + model.load_state_dict(sd) return io.NodeOutput(model) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 50da5f4eb..b91a22309 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -81,6 +81,59 @@ class LTXVImgToVideo(io.ComfyNode): generate = execute # TODO: remove +class LTXVImgToVideoInplace(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LTXVImgToVideoInplace", + category="conditioning/video_models", + inputs=[ + io.Vae.Input("vae"), + io.Image.Input("image"), + io.Latent.Input("latent"), + io.Float.Input("strength", default=1.0, min=0.0, max=1.0), + io.Boolean.Input("bypass", default=False, tooltip="Bypass the conditioning.") + ], + outputs=[ + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, vae, image, latent, strength, bypass=False) -> io.NodeOutput: + if bypass: + return (latent,) + + samples = latent["samples"] + _, height_scale_factor, width_scale_factor = ( + vae.downscale_index_formula + ) + + batch, _, latent_frames, latent_height, latent_width = samples.shape + width = latent_width * width_scale_factor + height = latent_height * height_scale_factor + + if image.shape[1] != height or image.shape[2] != width: + pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + else: + pixels = image + encode_pixels = pixels[:, :, :, :3] + t = vae.encode(encode_pixels) + + samples[:, :, :t.shape[2]] = t + + conditioning_latent_frames_mask = torch.ones( + (batch, 1, latent_frames, 1, 1), + dtype=torch.float32, + device=samples.device, + ) + conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength + + return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask}) + + generate = execute # TODO: remove + + def conditioning_get_any_value(conditioning, key, default=None): for t in conditioning: if key in t[1]: @@ -106,12 +159,12 @@ def get_keyframe_idxs(cond): keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None) if keyframe_idxs is None: return None, 0 - num_keyframes = torch.unique(keyframe_idxs[:, 0]).shape[0] + # keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start + num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0] return keyframe_idxs, num_keyframes class LTXVAddGuide(io.ComfyNode): - NUM_PREFIX_FRAMES = 2 - PATCHIFIER = SymmetricPatchifier(1) + PATCHIFIER = SymmetricPatchifier(1, start_end=True) @classmethod def define_schema(cls): @@ -182,26 +235,35 @@ class LTXVAddGuide(io.ComfyNode): return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs}) @classmethod - def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors): - _, latent_idx = cls.get_latent_index( - cond=positive, - latent_length=latent_image.shape[2], - guide_length=guiding_latent.shape[2], - frame_idx=frame_idx, - scale_factors=scale_factors, - ) - noise_mask[:, :, latent_idx:latent_idx + guiding_latent.shape[2]] = 1.0 + def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128): + if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels: + raise ValueError("Adding guide to a combined AV latent is not supported.") positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors) negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors) - mask = torch.full( - (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), - 1.0 - strength, - dtype=noise_mask.dtype, - device=noise_mask.device, - ) + if guide_mask is not None: + target_h = max(noise_mask.shape[3], guide_mask.shape[3]) + target_w = max(noise_mask.shape[4], guide_mask.shape[4]) + if noise_mask.shape[3] == 1 or noise_mask.shape[4] == 1: + noise_mask = noise_mask.expand(-1, -1, -1, target_h, target_w) + + if guide_mask.shape[3] == 1 or guide_mask.shape[4] == 1: + guide_mask = guide_mask.expand(-1, -1, -1, target_h, target_w) + mask = guide_mask - strength + else: + mask = torch.full( + (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), + 1.0 - strength, + dtype=noise_mask.dtype, + device=noise_mask.device, + ) + # This solves audio video combined latent case where latent_image has audio latent concatenated + # in channel dimension with video latent. The solution is to pad guiding latent accordingly. + if latent_image.shape[1] > guiding_latent.shape[1]: + pad_len = latent_image.shape[1] - guiding_latent.shape[1] + guiding_latent = torch.nn.functional.pad(guiding_latent, pad=(0, 0, 0, 0, 0, 0, 0, pad_len), value=0) latent_image = torch.cat([latent_image, guiding_latent], dim=2) noise_mask = torch.cat([noise_mask, mask], dim=2) return positive, negative, latent_image, noise_mask @@ -238,33 +300,17 @@ class LTXVAddGuide(io.ComfyNode): frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." - num_prefix_frames = min(cls.NUM_PREFIX_FRAMES, t.shape[2]) - positive, negative, latent_image, noise_mask = cls.append_keyframe( positive, negative, frame_idx, latent_image, noise_mask, - t[:, :, :num_prefix_frames], + t, strength, scale_factors, ) - latent_idx += num_prefix_frames - - t = t[:, :, num_prefix_frames:] - if t.shape[2] == 0: - return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) - - latent_image, noise_mask = cls.replace_latent_frames( - latent_image, - noise_mask, - t, - latent_idx, - strength, - ) - return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask}) generate = execute # TODO: remove @@ -507,18 +553,90 @@ class LTXVPreprocess(io.ComfyNode): preprocess = execute # TODO: remove + +import comfy.nested_tensor +class LTXVConcatAVLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LTXVConcatAVLatent", + category="latent/video/ltxv", + inputs=[ + io.Latent.Input("video_latent"), + io.Latent.Input("audio_latent"), + ], + outputs=[ + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, video_latent, audio_latent) -> io.NodeOutput: + output = {} + output.update(video_latent) + output.update(audio_latent) + video_noise_mask = video_latent.get("noise_mask", None) + audio_noise_mask = audio_latent.get("noise_mask", None) + + if video_noise_mask is not None or audio_noise_mask is not None: + if video_noise_mask is None: + video_noise_mask = torch.ones_like(video_latent["samples"]) + if audio_noise_mask is None: + audio_noise_mask = torch.ones_like(audio_latent["samples"]) + output["noise_mask"] = comfy.nested_tensor.NestedTensor((video_noise_mask, audio_noise_mask)) + + output["samples"] = comfy.nested_tensor.NestedTensor((video_latent["samples"], audio_latent["samples"])) + + return io.NodeOutput(output) + + +class LTXVSeparateAVLatent(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LTXVSeparateAVLatent", + category="latent/video/ltxv", + description="LTXV Separate AV Latent", + inputs=[ + io.Latent.Input("av_latent"), + ], + outputs=[ + io.Latent.Output(display_name="video_latent"), + io.Latent.Output(display_name="audio_latent"), + ], + ) + + @classmethod + def execute(cls, av_latent) -> io.NodeOutput: + latents = av_latent["samples"].unbind() + video_latent = av_latent.copy() + video_latent["samples"] = latents[0] + audio_latent = av_latent.copy() + audio_latent["samples"] = latents[1] + if "noise_mask" in av_latent: + masks = av_latent["noise_mask"] + if masks is not None: + masks = masks.unbind() + video_latent["noise_mask"] = masks[0] + audio_latent["noise_mask"] = masks[1] + return io.NodeOutput(video_latent, audio_latent) + + class LtxvExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ EmptyLTXVLatentVideo, LTXVImgToVideo, + LTXVImgToVideoInplace, ModelSamplingLTXV, LTXVConditioning, LTXVScheduler, LTXVAddGuide, LTXVPreprocess, LTXVCropGuides, + LTXVConcatAVLatent, + LTXVSeparateAVLatent, ] diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py new file mode 100644 index 000000000..b0b7000ef --- /dev/null +++ b/comfy_extras/nodes_lt_audio.py @@ -0,0 +1,183 @@ +import folder_paths +import comfy.utils +import comfy.model_management +import torch + +from comfy.ldm.lightricks.vae.audio_vae import AudioVAE +from comfy_api.latest import ComfyExtension, io + + +class LTXVAudioVAELoader(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVAudioVAELoader", + display_name="LTXV Audio VAE Loader", + category="audio", + inputs=[ + io.Combo.Input( + "ckpt_name", + options=folder_paths.get_filename_list("checkpoints"), + tooltip="Audio VAE checkpoint to load.", + ) + ], + outputs=[io.Vae.Output(display_name="Audio VAE")], + ) + + @classmethod + def execute(cls, ckpt_name: str) -> io.NodeOutput: + ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) + sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) + return io.NodeOutput(AudioVAE(sd, metadata)) + + +class LTXVAudioVAEEncode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVAudioVAEEncode", + display_name="LTXV Audio VAE Encode", + category="audio", + inputs=[ + io.Audio.Input("audio", tooltip="The audio to be encoded."), + io.Vae.Input( + id="audio_vae", + display_name="Audio VAE", + tooltip="The Audio VAE model to use for encoding.", + ), + ], + outputs=[io.Latent.Output(display_name="Audio Latent")], + ) + + @classmethod + def execute(cls, audio, audio_vae: AudioVAE) -> io.NodeOutput: + audio_latents = audio_vae.encode(audio) + return io.NodeOutput( + { + "samples": audio_latents, + "sample_rate": int(audio_vae.sample_rate), + "type": "audio", + } + ) + + +class LTXVAudioVAEDecode(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVAudioVAEDecode", + display_name="LTXV Audio VAE Decode", + category="audio", + inputs=[ + io.Latent.Input("samples", tooltip="The latent to be decoded."), + io.Vae.Input( + id="audio_vae", + display_name="Audio VAE", + tooltip="The Audio VAE model used for decoding the latent.", + ), + ], + outputs=[io.Audio.Output(display_name="Audio")], + ) + + @classmethod + def execute(cls, samples, audio_vae: AudioVAE) -> io.NodeOutput: + audio_latent = samples["samples"] + if audio_latent.is_nested: + audio_latent = audio_latent.unbind()[-1] + audio = audio_vae.decode(audio_latent).to(audio_latent.device) + output_audio_sample_rate = audio_vae.output_sample_rate + return io.NodeOutput( + { + "waveform": audio, + "sample_rate": int(output_audio_sample_rate), + } + ) + + +class LTXVEmptyLatentAudio(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="LTXVEmptyLatentAudio", + display_name="LTXV Empty Latent Audio", + category="latent/audio", + inputs=[ + io.Int.Input( + "frames_number", + default=97, + min=1, + max=1000, + step=1, + display_mode=io.NumberDisplay.number, + tooltip="Number of frames.", + ), + io.Int.Input( + "frame_rate", + default=25, + min=1, + max=1000, + step=1, + display_mode=io.NumberDisplay.number, + tooltip="Number of frames per second.", + ), + io.Int.Input( + "batch_size", + default=1, + min=1, + max=4096, + display_mode=io.NumberDisplay.number, + tooltip="The number of latent audio samples in the batch.", + ), + io.Vae.Input( + id="audio_vae", + display_name="Audio VAE", + tooltip="The Audio VAE model to get configuration from.", + ), + ], + outputs=[io.Latent.Output(display_name="Latent")], + ) + + @classmethod + def execute( + cls, + frames_number: int, + frame_rate: int, + batch_size: int, + audio_vae: AudioVAE, + ) -> io.NodeOutput: + """Generate empty audio latents matching the reference pipeline structure.""" + + assert audio_vae is not None, "Audio VAE model is required" + + z_channels = audio_vae.latent_channels + audio_freq = audio_vae.latent_frequency_bins + sampling_rate = int(audio_vae.sample_rate) + + num_audio_latents = audio_vae.num_of_latents_from_frames(frames_number, frame_rate) + + audio_latents = torch.zeros( + (batch_size, z_channels, num_audio_latents, audio_freq), + device=comfy.model_management.intermediate_device(), + ) + + return io.NodeOutput( + { + "samples": audio_latents, + "sample_rate": sampling_rate, + "type": "audio", + } + ) + + +class LTXVAudioExtension(ComfyExtension): + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + LTXVAudioVAELoader, + LTXVAudioVAEEncode, + LTXVAudioVAEDecode, + LTXVEmptyLatentAudio, + ] + + +async def comfy_entrypoint() -> ComfyExtension: + return LTXVAudioExtension() diff --git a/comfy_extras/nodes_lt_upsampler.py b/comfy_extras/nodes_lt_upsampler.py new file mode 100644 index 000000000..f99ba13fb --- /dev/null +++ b/comfy_extras/nodes_lt_upsampler.py @@ -0,0 +1,75 @@ +from comfy import model_management +import math + +class LTXVLatentUpsampler: + """ + Upsamples a video latent by a factor of 2. + """ + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "samples": ("LATENT",), + "upscale_model": ("LATENT_UPSCALE_MODEL",), + "vae": ("VAE",), + } + } + + RETURN_TYPES = ("LATENT",) + FUNCTION = "upsample_latent" + CATEGORY = "latent/video" + EXPERIMENTAL = True + + def upsample_latent( + self, + samples: dict, + upscale_model, + vae, + ) -> tuple: + """ + Upsample the input latent using the provided model. + + Args: + samples (dict): Input latent samples + upscale_model (LatentUpsampler): Loaded upscale model + vae: VAE model for normalization + auto_tiling (bool): Whether to automatically tile the input for processing + + Returns: + tuple: Tuple containing the upsampled latent + """ + device = model_management.get_torch_device() + memory_required = model_management.module_size(upscale_model) + + model_dtype = next(upscale_model.parameters()).dtype + latents = samples["samples"] + input_dtype = latents.dtype + + memory_required += math.prod(latents.shape) * 3000.0 # TODO: more accurate + model_management.free_memory(memory_required, device) + + try: + upscale_model.to(device) # TODO: use the comfy model management system. + + latents = latents.to(dtype=model_dtype, device=device) + + """Upsample latents without tiling.""" + latents = vae.first_stage_model.per_channel_statistics.un_normalize(latents) + upsampled_latents = upscale_model(latents) + finally: + upscale_model.cpu() + + upsampled_latents = vae.first_stage_model.per_channel_statistics.normalize( + upsampled_latents + ) + upsampled_latents = upsampled_latents.to(dtype=input_dtype, device=model_management.intermediate_device()) + return_dict = samples.copy() + return_dict["samples"] = upsampled_latents + return_dict.pop("noise_mask", None) + return (return_dict,) + + +NODE_CLASS_MAPPINGS = { + "LTXVLatentUpsampler": LTXVLatentUpsampler, +} diff --git a/nodes.py b/nodes.py index 662907ae6..56b74ebe3 100644 --- a/nodes.py +++ b/nodes.py @@ -295,7 +295,11 @@ class VAEDecode: DESCRIPTION = "Decodes latent images back into pixel space images." def decode(self, vae, samples): - images = vae.decode(samples["samples"]) + latent = samples["samples"] + if latent.is_nested: + latent = latent.unbind()[0] + + images = vae.decode(latent) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) return (images, ) @@ -970,7 +974,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "newbie"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -2331,6 +2335,8 @@ async def init_builtin_extra_nodes(): "nodes_mochi.py", "nodes_slg.py", "nodes_mahiro.py", + "nodes_lt_upsampler.py", + "nodes_lt_audio.py", "nodes_lt.py", "nodes_hooks.py", "nodes_load_3d.py", diff --git a/pyproject.toml b/pyproject.toml index 60378de1e..a7d159be9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ name = "ComfyUI" version = "0.7.0" readme = "README.md" license = { file = "LICENSE" } -requires-python = ">=3.9" +requires-python = ">=3.10" [project.urls] homepage = "https://www.comfy.org/"