mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-06 01:37:45 +08:00
Support the LTXAV 2.3 model. (#12773)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
This commit is contained in:
parent
ac4a943ff3
commit
43c64b6308
@ -2,11 +2,16 @@ from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.lightricks.model import (
|
||||
ADALN_BASE_PARAMS_COUNT,
|
||||
ADALN_CROSS_ATTN_PARAMS_COUNT,
|
||||
CrossAttention,
|
||||
FeedForward,
|
||||
AdaLayerNormSingle,
|
||||
PixArtAlphaTextProjection,
|
||||
NormSingleLinearTextProjection,
|
||||
LTXVModel,
|
||||
apply_cross_attention_adaln,
|
||||
compute_prompt_timestep,
|
||||
)
|
||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
@ -87,6 +92,8 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
v_context_dim=None,
|
||||
a_context_dim=None,
|
||||
attn_precision=None,
|
||||
apply_gated_attention=False,
|
||||
cross_attention_adaln=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@ -94,6 +101,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.cross_attention_adaln = cross_attention_adaln
|
||||
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=v_dim,
|
||||
@ -101,6 +109,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
dim_head=vd_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@ -111,6 +120,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
dim_head=ad_head,
|
||||
context_dim=None,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@ -122,6 +132,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=v_heads,
|
||||
dim_head=vd_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@ -132,6 +143,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@ -144,6 +156,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@ -156,6 +169,7 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
heads=a_heads,
|
||||
dim_head=ad_head,
|
||||
attn_precision=self.attn_precision,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@ -168,11 +182,16 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
|
||||
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype))
|
||||
self.audio_scale_shift_table = nn.Parameter(
|
||||
torch.empty(6, a_dim, device=device, dtype=dtype)
|
||||
torch.empty(num_ada_params, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
if cross_attention_adaln:
|
||||
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype))
|
||||
self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype))
|
||||
|
||||
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
||||
torch.empty(5, a_dim, device=device, dtype=dtype)
|
||||
)
|
||||
@ -215,10 +234,30 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
|
||||
return (*scale_shift_ada_values, *gate_ada_values)
|
||||
|
||||
def _apply_text_cross_attention(
|
||||
self, x, context, attn, scale_shift_table, prompt_scale_shift_table,
|
||||
timestep, prompt_timestep, attention_mask, transformer_options,
|
||||
):
|
||||
"""Apply text cross-attention, with optional ADaLN modulation."""
|
||||
if self.cross_attention_adaln:
|
||||
shift_q, scale_q, gate = self.get_ada_values(
|
||||
scale_shift_table, x.shape[0], timestep, slice(6, 9)
|
||||
)
|
||||
return apply_cross_attention_adaln(
|
||||
x, context, attn, shift_q, scale_q, gate,
|
||||
prompt_scale_shift_table, prompt_timestep,
|
||||
attention_mask, transformer_options,
|
||||
)
|
||||
return attn(
|
||||
comfy.ldm.common_dit.rms_norm(x), context=context,
|
||||
mask=attention_mask, transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
|
||||
v_prompt_timestep=None, a_prompt_timestep=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
run_vx = transformer_options.get("run_vx", True)
|
||||
run_ax = transformer_options.get("run_ax", True)
|
||||
@ -240,7 +279,11 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||
vx.addcmul_(attn1_out, vgate_msa)
|
||||
del vgate_msa, attn1_out
|
||||
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
vx.add_(self._apply_text_cross_attention(
|
||||
vx, v_context, self.attn2, self.scale_shift_table,
|
||||
getattr(self, 'prompt_scale_shift_table', None),
|
||||
v_timestep, v_prompt_timestep, attention_mask, transformer_options,)
|
||||
)
|
||||
|
||||
# audio
|
||||
if run_ax:
|
||||
@ -254,7 +297,11 @@ class BasicAVTransformerBlock(nn.Module):
|
||||
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||
ax.addcmul_(attn1_out, agate_msa)
|
||||
del agate_msa, attn1_out
|
||||
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
||||
ax.add_(self._apply_text_cross_attention(
|
||||
ax, a_context, self.audio_attn2, self.audio_scale_shift_table,
|
||||
getattr(self, 'audio_prompt_scale_shift_table', None),
|
||||
a_timestep, a_prompt_timestep, attention_mask, transformer_options,)
|
||||
)
|
||||
|
||||
# video - audio cross attention.
|
||||
if run_a2v or run_v2a:
|
||||
@ -351,6 +398,9 @@ class LTXAVModel(LTXVModel):
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier=1000.0,
|
||||
av_ca_timestep_scale_multiplier=1.0,
|
||||
apply_gated_attention=False,
|
||||
caption_proj_before_connector=False,
|
||||
cross_attention_adaln=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@ -362,6 +412,7 @@ class LTXAVModel(LTXVModel):
|
||||
self.audio_attention_head_dim = audio_attention_head_dim
|
||||
self.audio_num_attention_heads = audio_num_attention_heads
|
||||
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
||||
self.apply_gated_attention = apply_gated_attention
|
||||
|
||||
# Calculate audio dimensions
|
||||
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
||||
@ -386,6 +437,8 @@ class LTXAVModel(LTXVModel):
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
caption_proj_before_connector=caption_proj_before_connector,
|
||||
cross_attention_adaln=cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@ -400,14 +453,28 @@ class LTXAVModel(LTXVModel):
|
||||
)
|
||||
|
||||
# Audio-specific AdaLN
|
||||
audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.audio_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
embedding_coefficient=audio_embedding_coefficient,
|
||||
use_additional_conditions=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
if self.cross_attention_adaln:
|
||||
self.audio_prompt_adaln_single = AdaLayerNormSingle(
|
||||
self.audio_inner_dim,
|
||||
embedding_coefficient=2,
|
||||
use_additional_conditions=False,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
else:
|
||||
self.audio_prompt_adaln_single = None
|
||||
|
||||
num_scale_shift_values = 4
|
||||
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim,
|
||||
@ -443,35 +510,73 @@ class LTXAVModel(LTXVModel):
|
||||
)
|
||||
|
||||
# Audio caption projection
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
if self.caption_proj_before_connector:
|
||||
if self.caption_projection_first_linear:
|
||||
self.audio_caption_projection = NormSingleLinearTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
else:
|
||||
self.audio_caption_projection = lambda a: a
|
||||
else:
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.audio_inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
connector_split_rope = kwargs.get("rope_type", "split") == "split"
|
||||
connector_gated_attention = kwargs.get("connector_apply_gated_attention", False)
|
||||
attention_head_dim = kwargs.get("connector_attention_head_dim", 128)
|
||||
num_attention_heads = kwargs.get("connector_num_attention_heads", 30)
|
||||
num_layers = kwargs.get("connector_num_layers", 2)
|
||||
|
||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim),
|
||||
num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads),
|
||||
num_layers=num_layers,
|
||||
split_rope=connector_split_rope,
|
||||
double_precision_rope=True,
|
||||
apply_gated_attention=connector_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
self.video_embeddings_connector = Embeddings1DConnector(
|
||||
split_rope=True,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_layers=num_layers,
|
||||
split_rope=connector_split_rope,
|
||||
double_precision_rope=True,
|
||||
apply_gated_attention=connector_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
def preprocess_text_embeds(self, context):
|
||||
if context.shape[-1] == self.caption_channels * 2:
|
||||
return context
|
||||
out_vid = self.video_embeddings_connector(context)[0]
|
||||
out_audio = self.audio_embeddings_connector(context)[0]
|
||||
def preprocess_text_embeds(self, context, unprocessed=False):
|
||||
# LTXv2 fully processed context has dimension of self.caption_channels * 2
|
||||
# LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim
|
||||
if not unprocessed:
|
||||
if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2):
|
||||
return context
|
||||
if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim:
|
||||
context_vid = context[:, :, :self.cross_attention_dim]
|
||||
context_audio = context[:, :, self.cross_attention_dim:]
|
||||
else:
|
||||
context_vid = context
|
||||
context_audio = context
|
||||
if self.caption_proj_before_connector:
|
||||
context_vid = self.caption_projection(context_vid)
|
||||
context_audio = self.audio_caption_projection(context_audio)
|
||||
out_vid = self.video_embeddings_connector(context_vid)[0]
|
||||
out_audio = self.audio_embeddings_connector(context_audio)[0]
|
||||
return torch.concat((out_vid, out_audio), dim=-1)
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
@ -487,6 +592,8 @@ class LTXAVModel(LTXVModel):
|
||||
ad_head=self.audio_attention_head_dim,
|
||||
v_context_dim=self.cross_attention_dim,
|
||||
a_context_dim=self.audio_cross_attention_dim,
|
||||
apply_gated_attention=self.apply_gated_attention,
|
||||
cross_attention_adaln=self.cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
@ -608,6 +715,10 @@ class LTXAVModel(LTXVModel):
|
||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||
|
||||
v_prompt_timestep = compute_prompt_timestep(
|
||||
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||
)
|
||||
|
||||
# Prepare audio timestep
|
||||
a_timestep = kwargs.get("a_timestep")
|
||||
if a_timestep is not None:
|
||||
@ -618,25 +729,25 @@ class LTXAVModel(LTXVModel):
|
||||
|
||||
# Cross-attention timesteps - compress these too
|
||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||
a_timestep_flat,
|
||||
timestep.max().expand_as(a_timestep_flat),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||
timestep_flat,
|
||||
a_timestep.max().expand_as(timestep_flat),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||
timestep_flat * av_ca_factor,
|
||||
a_timestep.max().expand_as(timestep_flat) * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||
a_timestep_flat * av_ca_factor,
|
||||
timestep.max().expand_as(a_timestep_flat) * av_ca_factor,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
@ -660,29 +771,40 @@ class LTXAVModel(LTXVModel):
|
||||
# Audio timesteps
|
||||
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
||||
|
||||
a_prompt_timestep = compute_prompt_timestep(
|
||||
self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype
|
||||
)
|
||||
else:
|
||||
a_timestep = timestep_scaled
|
||||
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||
cross_av_timestep_ss = []
|
||||
a_prompt_timestep = None
|
||||
|
||||
return [v_timestep, a_timestep, cross_av_timestep_ss], [
|
||||
return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [
|
||||
v_embedded_timestep,
|
||||
a_embedded_timestep,
|
||||
]
|
||||
], None
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
vx = x[0]
|
||||
ax = x[1]
|
||||
video_dim = vx.shape[-1]
|
||||
audio_dim = ax.shape[-1]
|
||||
|
||||
v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim
|
||||
a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim
|
||||
|
||||
v_context, a_context = torch.split(
|
||||
context, int(context.shape[-1] / 2), len(context.shape) - 1
|
||||
context, [v_context_dim, a_context_dim], len(context.shape) - 1
|
||||
)
|
||||
|
||||
v_context, attention_mask = super()._prepare_context(
|
||||
v_context, batch_size, vx, attention_mask
|
||||
)
|
||||
if self.audio_caption_projection is not None:
|
||||
if self.caption_proj_before_connector is False:
|
||||
a_context = self.audio_caption_projection(a_context)
|
||||
a_context = a_context.view(batch_size, -1, ax.shape[-1])
|
||||
a_context = a_context.view(batch_size, -1, audio_dim)
|
||||
|
||||
return [v_context, a_context], attention_mask
|
||||
|
||||
@ -744,6 +866,9 @@ class LTXAVModel(LTXVModel):
|
||||
av_ca_v2a_gate_noise_timestep,
|
||||
) = timestep[2]
|
||||
|
||||
v_prompt_timestep = timestep[3]
|
||||
a_prompt_timestep = timestep[4]
|
||||
|
||||
"""Process transformer blocks for LTXAV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
@ -771,6 +896,8 @@ class LTXAVModel(LTXVModel):
|
||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||
transformer_options=args["transformer_options"],
|
||||
self_attention_mask=args.get("self_attention_mask"),
|
||||
v_prompt_timestep=args.get("v_prompt_timestep"),
|
||||
a_prompt_timestep=args.get("a_prompt_timestep"),
|
||||
)
|
||||
return out
|
||||
|
||||
@ -792,6 +919,8 @@ class LTXAVModel(LTXVModel):
|
||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||
"transformer_options": transformer_options,
|
||||
"self_attention_mask": self_attention_mask,
|
||||
"v_prompt_timestep": v_prompt_timestep,
|
||||
"a_prompt_timestep": a_prompt_timestep,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
@ -814,6 +943,8 @@ class LTXAVModel(LTXVModel):
|
||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
v_prompt_timestep=v_prompt_timestep,
|
||||
a_prompt_timestep=a_prompt_timestep,
|
||||
)
|
||||
|
||||
return [vx, ax]
|
||||
|
||||
@ -50,6 +50,7 @@ class BasicTransformerBlock1D(nn.Module):
|
||||
d_head,
|
||||
context_dim=None,
|
||||
attn_precision=None,
|
||||
apply_gated_attention=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@ -63,6 +64,7 @@ class BasicTransformerBlock1D(nn.Module):
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
context_dim=None,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@ -121,6 +123,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
positional_embedding_max_pos=[4096],
|
||||
causal_temporal_positioning=False,
|
||||
num_learnable_registers: Optional[int] = 128,
|
||||
apply_gated_attention=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@ -145,6 +148,7 @@ class Embeddings1DConnector(nn.Module):
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
context_dim=cross_attention_dim,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
|
||||
@ -275,6 +275,30 @@ class PixArtAlphaTextProjection(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NormSingleLinearTextProjection(nn.Module):
|
||||
"""Text projection for 20B models - single linear with RMSNorm (no activation)."""
|
||||
|
||||
def __init__(
|
||||
self, in_features, hidden_size, dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
if operations is None:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
self.in_norm = operations.RMSNorm(
|
||||
in_features, eps=1e-6, elementwise_affine=False
|
||||
)
|
||||
self.linear_1 = operations.Linear(
|
||||
in_features, hidden_size, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
self.hidden_size = hidden_size
|
||||
self.in_features = in_features
|
||||
|
||||
def forward(self, caption):
|
||||
caption = self.in_norm(caption)
|
||||
caption = caption * (self.hidden_size / self.in_features) ** 0.5
|
||||
return self.linear_1(caption)
|
||||
|
||||
|
||||
class GELU_approx(nn.Module):
|
||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
@ -343,6 +367,7 @@ class CrossAttention(nn.Module):
|
||||
dim_head=64,
|
||||
dropout=0.0,
|
||||
attn_precision=None,
|
||||
apply_gated_attention=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@ -362,6 +387,12 @@ class CrossAttention(nn.Module):
|
||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
# Optional per-head gating
|
||||
if apply_gated_attention:
|
||||
self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_gate_logits = None
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
||||
)
|
||||
@ -383,16 +414,30 @@ class CrossAttention(nn.Module):
|
||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
else:
|
||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||
|
||||
# Apply per-head gating if enabled
|
||||
if self.to_gate_logits is not None:
|
||||
gate_logits = self.to_gate_logits(x) # (B, T, H)
|
||||
b, t, _ = out.shape
|
||||
out = out.view(b, t, self.heads, self.dim_head)
|
||||
gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity
|
||||
out = out * gates.unsqueeze(-1)
|
||||
out = out.view(b, t, self.heads * self.dim_head)
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate)
|
||||
ADALN_BASE_PARAMS_COUNT = 6
|
||||
ADALN_CROSS_ATTN_PARAMS_COUNT = 9
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
|
||||
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn_precision = attn_precision
|
||||
self.cross_attention_adaln = cross_attention_adaln
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
@ -416,18 +461,25 @@ class BasicTransformerBlock(nn.Module):
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||
if cross_attention_adaln:
|
||||
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype))
|
||||
|
||||
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
||||
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
||||
attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options)
|
||||
x.addcmul_(attn1_input, gate_msa)
|
||||
del attn1_input
|
||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None):
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2)
|
||||
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa
|
||||
|
||||
if self.cross_attention_adaln:
|
||||
shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2)
|
||||
x += apply_cross_attention_adaln(
|
||||
x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca,
|
||||
self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options,
|
||||
)
|
||||
else:
|
||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||
|
||||
y = comfy.ldm.common_dit.rms_norm(x)
|
||||
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
||||
@ -435,6 +487,47 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
return x
|
||||
|
||||
def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype):
|
||||
"""Compute a single global prompt timestep for cross-attention ADaLN.
|
||||
|
||||
Uses the max across tokens (matching JAX max_per_segment) and broadcasts
|
||||
over text tokens. Returns None when *adaln_module* is None.
|
||||
"""
|
||||
if adaln_module is None:
|
||||
return None
|
||||
ts_input = (
|
||||
timestep_scaled.max(dim=1, keepdim=True).values.flatten()
|
||||
if timestep_scaled.dim() > 1
|
||||
else timestep_scaled.flatten()
|
||||
)
|
||||
prompt_ts, _ = adaln_module(
|
||||
ts_input,
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
)
|
||||
return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1])
|
||||
|
||||
|
||||
def apply_cross_attention_adaln(
|
||||
x, context, attn, q_shift, q_scale, q_gate,
|
||||
prompt_scale_shift_table, prompt_timestep,
|
||||
attention_mask=None, transformer_options={},
|
||||
):
|
||||
"""Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV).
|
||||
|
||||
Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so
|
||||
that both regular tensors and CompressedTimestep are supported.
|
||||
"""
|
||||
batch_size = x.shape[0]
|
||||
shift_kv, scale_kv = (
|
||||
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
|
||||
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
|
||||
).unbind(dim=2)
|
||||
attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift
|
||||
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
|
||||
return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate
|
||||
|
||||
def get_fractional_positions(indices_grid, max_pos):
|
||||
n_pos_dims = indices_grid.shape[1]
|
||||
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
||||
@ -556,6 +649,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
vae_scale_factors: tuple = (8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier = 1000.0,
|
||||
caption_proj_before_connector=False,
|
||||
cross_attention_adaln=False,
|
||||
caption_projection_first_linear=True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@ -582,6 +678,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
self.operations = operations
|
||||
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||
self.caption_proj_before_connector = caption_proj_before_connector
|
||||
self.cross_attention_adaln = cross_attention_adaln
|
||||
self.caption_projection_first_linear = caption_projection_first_linear
|
||||
|
||||
# Common dimensions
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
@ -609,17 +708,37 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
|
||||
embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
if self.cross_attention_adaln:
|
||||
self.prompt_adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||
)
|
||||
else:
|
||||
self.prompt_adaln_single = None
|
||||
|
||||
if self.caption_proj_before_connector:
|
||||
if self.caption_projection_first_linear:
|
||||
self.caption_projection = NormSingleLinearTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
else:
|
||||
self.caption_projection = lambda a: a
|
||||
else:
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=self.caption_channels,
|
||||
hidden_size=self.inner_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
@ -665,9 +784,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
if grid_mask is not None:
|
||||
timestep = timestep[:, grid_mask]
|
||||
|
||||
timestep = timestep * self.timestep_scale_multiplier
|
||||
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep.flatten(),
|
||||
timestep_scaled.flatten(),
|
||||
{"resolution": None, "aspect_ratio": None},
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_dtype,
|
||||
@ -677,14 +796,18 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
||||
|
||||
return timestep, embedded_timestep
|
||||
prompt_timestep = compute_prompt_timestep(
|
||||
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||
)
|
||||
|
||||
return timestep, embedded_timestep, prompt_timestep
|
||||
|
||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||
"""Prepare context for transformer blocks."""
|
||||
if self.caption_projection is not None:
|
||||
if self.caption_proj_before_connector is False:
|
||||
context = self.caption_projection(context)
|
||||
context = context.view(batch_size, -1, x.shape[-1])
|
||||
|
||||
context = context.view(batch_size, -1, x.shape[-1])
|
||||
return context, attention_mask
|
||||
|
||||
def _precompute_freqs_cis(
|
||||
@ -792,7 +915,8 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
||||
merged_args.update(additional_args)
|
||||
|
||||
# Prepare timestep and context
|
||||
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||
timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||
merged_args["prompt_timestep"] = prompt_timestep
|
||||
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
||||
|
||||
# Prepare attention mask and positional embeddings
|
||||
@ -833,7 +957,9 @@ class LTXVModel(LTXBaseModel):
|
||||
causal_temporal_positioning=False,
|
||||
vae_scale_factors=(8, 32, 32),
|
||||
use_middle_indices_grid=False,
|
||||
timestep_scale_multiplier = 1000.0,
|
||||
timestep_scale_multiplier=1000.0,
|
||||
caption_proj_before_connector=False,
|
||||
cross_attention_adaln=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
@ -852,6 +978,8 @@ class LTXVModel(LTXBaseModel):
|
||||
vae_scale_factors=vae_scale_factors,
|
||||
use_middle_indices_grid=use_middle_indices_grid,
|
||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||
caption_proj_before_connector=caption_proj_before_connector,
|
||||
cross_attention_adaln=cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
@ -860,7 +988,6 @@ class LTXVModel(LTXBaseModel):
|
||||
|
||||
def _init_model_components(self, device, dtype, **kwargs):
|
||||
"""Initialize LTXV-specific components."""
|
||||
# No additional components needed for LTXV beyond base class
|
||||
pass
|
||||
|
||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||
@ -872,6 +999,7 @@ class LTXVModel(LTXBaseModel):
|
||||
self.num_attention_heads,
|
||||
self.attention_head_dim,
|
||||
context_dim=self.cross_attention_dim,
|
||||
cross_attention_adaln=self.cross_attention_adaln,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=self.operations,
|
||||
@ -1149,16 +1277,17 @@ class LTXVModel(LTXBaseModel):
|
||||
"""Process transformer blocks for LTXV."""
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
prompt_timestep = kwargs.get("prompt_timestep", None)
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"))
|
||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap})
|
||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap})
|
||||
x = out["img"]
|
||||
else:
|
||||
x = block(
|
||||
@ -1169,6 +1298,7 @@ class LTXVModel(LTXBaseModel):
|
||||
pe=pe,
|
||||
transformer_options=transformer_options,
|
||||
self_attention_mask=self_attention_mask,
|
||||
prompt_timestep=prompt_timestep,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
@ -13,7 +13,7 @@ from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
||||
CausalityAxis,
|
||||
CausalAudioAutoencoder,
|
||||
)
|
||||
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
|
||||
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE
|
||||
|
||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
|
||||
@ -141,7 +141,10 @@ class AudioVAE(torch.nn.Module):
|
||||
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
||||
|
||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||
if "bwe" in component_config.vocoder:
|
||||
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
|
||||
else:
|
||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||
|
||||
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||
|
||||
@ -822,26 +822,23 @@ class CausalAudioAutoencoder(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = self._guess_config()
|
||||
config = self.get_default_config()
|
||||
|
||||
# Extract encoder and decoder configs from the new format
|
||||
model_config = config.get("model", {}).get("params", {})
|
||||
variables_config = config.get("variables", {})
|
||||
|
||||
self.sampling_rate = variables_config.get(
|
||||
"sampling_rate",
|
||||
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
|
||||
self.sampling_rate = model_config.get(
|
||||
"sampling_rate", config.get("sampling_rate", 16000)
|
||||
)
|
||||
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
|
||||
decoder_config = model_config.get("decoder", encoder_config)
|
||||
|
||||
# Load mel spectrogram parameters
|
||||
self.mel_bins = encoder_config.get("mel_bins", 64)
|
||||
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||
self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||
self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||
|
||||
# Store causality configuration at VAE level (not just in encoder internals)
|
||||
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
|
||||
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value)
|
||||
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
|
||||
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
|
||||
|
||||
@ -850,44 +847,38 @@ class CausalAudioAutoencoder(nn.Module):
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def _guess_config(self):
|
||||
encoder_config = {
|
||||
# Required parameters - based on ltx-video-av-1679000 model metadata
|
||||
"ch": 128,
|
||||
"out_ch": 8,
|
||||
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": [], # Based on metadata: empty list, no attention
|
||||
"dropout": 0.0,
|
||||
"resamp_with_conv": True,
|
||||
"in_channels": 2, # stereo
|
||||
"resolution": 256,
|
||||
"z_channels": 8,
|
||||
def get_default_config(self):
|
||||
ddconfig = {
|
||||
"double_z": True,
|
||||
"attn_type": "vanilla",
|
||||
"mid_block_add_attention": False, # Based on metadata: false
|
||||
"mel_bins": 64,
|
||||
"z_channels": 8,
|
||||
"resolution": 256,
|
||||
"downsample_time": False,
|
||||
"in_channels": 2,
|
||||
"out_ch": 2,
|
||||
"ch": 128,
|
||||
"ch_mult": [1, 2, 4],
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": [],
|
||||
"dropout": 0.0,
|
||||
"mid_block_add_attention": False,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height", # Based on metadata
|
||||
"mel_bins": 64, # Based on metadata: mel_bins = 64
|
||||
}
|
||||
|
||||
decoder_config = {
|
||||
# Inherits encoder config, can override specific params
|
||||
**encoder_config,
|
||||
"out_ch": 2, # Stereo audio output (2 channels)
|
||||
"give_pre_end": False,
|
||||
"tanh_out": False,
|
||||
"causality_axis": "height",
|
||||
}
|
||||
|
||||
config = {
|
||||
"_class_name": "CausalAudioAutoencoder",
|
||||
"sampling_rate": 16000,
|
||||
"model": {
|
||||
"params": {
|
||||
"encoder": encoder_config,
|
||||
"decoder": decoder_config,
|
||||
"ddconfig": ddconfig,
|
||||
"sampling_rate": 16000,
|
||||
}
|
||||
},
|
||||
"preprocessing": {
|
||||
"stft": {
|
||||
"filter_length": 1024,
|
||||
"hop_length": 160,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return config
|
||||
|
||||
@ -15,6 +15,9 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
def in_meta_context():
|
||||
return torch.device("meta") == torch.empty(0).device
|
||||
|
||||
def mark_conv3d_ended(module):
|
||||
tid = threading.get_ident()
|
||||
for _, m in module.named_modules():
|
||||
@ -350,6 +353,10 @@ class Decoder(nn.Module):
|
||||
output_channel = output_channel * block_params.get("multiplier", 2)
|
||||
if block_name == "compress_all":
|
||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||
if block_name == "compress_space":
|
||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||
if block_name == "compress_time":
|
||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims,
|
||||
@ -395,17 +402,21 @@ class Decoder(nn.Module):
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(2, 1, 1),
|
||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_space":
|
||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(1, 2, 2),
|
||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
@ -455,6 +466,15 @@ class Decoder(nn.Module):
|
||||
output_channel * 2, 0, operations=ops,
|
||||
)
|
||||
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
||||
else:
|
||||
self.register_buffer(
|
||||
"last_scale_shift_table",
|
||||
torch.tensor(
|
||||
[0.0, 0.0],
|
||||
device="cpu" if in_meta_context() else None
|
||||
).unsqueeze(1).expand(2, output_channel),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
@ -883,6 +903,15 @@ class ResnetBlock3D(nn.Module):
|
||||
self.scale_shift_table = nn.Parameter(
|
||||
torch.randn(4, in_channels) / in_channels**0.5
|
||||
)
|
||||
else:
|
||||
self.register_buffer(
|
||||
"scale_shift_table",
|
||||
torch.tensor(
|
||||
[0.0, 0.0, 0.0, 0.0],
|
||||
device="cpu" if in_meta_context() else None
|
||||
).unsqueeze(1).expand(4, in_channels),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
self.temporal_cache_state={}
|
||||
|
||||
@ -1012,9 +1041,6 @@ class processor(nn.Module):
|
||||
super().__init__()
|
||||
self.register_buffer("std-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-means", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds", torch.empty(128))
|
||||
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
|
||||
self.register_buffer("channel", torch.empty(128))
|
||||
|
||||
def un_normalize(self, x):
|
||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
@ -1027,9 +1053,12 @@ class VideoVAE(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
if config is None:
|
||||
config = self.guess_config(version)
|
||||
config = self.get_default_config(version)
|
||||
|
||||
self.config = config
|
||||
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||
self.decode_noise_scale = config.get("decode_noise_scale", 0.025)
|
||||
self.decode_timestep = config.get("decode_timestep", 0.05)
|
||||
double_z = config.get("double_z", True)
|
||||
latent_log_var = config.get(
|
||||
"latent_log_var", "per_channel" if double_z else "none"
|
||||
@ -1044,6 +1073,7 @@ class VideoVAE(nn.Module):
|
||||
latent_log_var=latent_log_var,
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||
base_channels=config.get("encoder_base_channels", 128),
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
@ -1051,6 +1081,7 @@ class VideoVAE(nn.Module):
|
||||
in_channels=config["latent_channels"],
|
||||
out_channels=config.get("out_channels", 3),
|
||||
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||
base_channels=config.get("decoder_base_channels", 128),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
@ -1060,7 +1091,7 @@ class VideoVAE(nn.Module):
|
||||
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def guess_config(self, version):
|
||||
def get_default_config(self, version):
|
||||
if version == 0:
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
@ -1167,8 +1198,7 @@ class VideoVAE(nn.Module):
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode(self, x, timestep=0.05, noise_scale=0.025):
|
||||
def decode(self, x):
|
||||
if self.timestep_conditioning: #TODO: seed
|
||||
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
|
||||
|
||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
||||
|
||||
@ -3,6 +3,7 @@ import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import comfy.ops
|
||||
import numpy as np
|
||||
import math
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
@ -12,6 +13,307 @@ def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
|
||||
# Adopted from https://github.com/NVIDIA/BigVGAN
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sinc(x: torch.Tensor):
|
||||
return torch.where(
|
||||
x == 0,
|
||||
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
||||
torch.sin(math.pi * x) / math.pi / x,
|
||||
)
|
||||
|
||||
|
||||
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
|
||||
even = kernel_size % 2 == 0
|
||||
half_size = kernel_size // 2
|
||||
delta_f = 4 * half_width
|
||||
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if A > 50.0:
|
||||
beta = 0.1102 * (A - 8.7)
|
||||
elif A >= 21.0:
|
||||
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
||||
else:
|
||||
beta = 0.0
|
||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||
if even:
|
||||
time = torch.arange(-half_size, half_size) + 0.5
|
||||
else:
|
||||
time = torch.arange(kernel_size) - half_size
|
||||
if cutoff == 0:
|
||||
filter_ = torch.zeros_like(time)
|
||||
else:
|
||||
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
|
||||
filter_ /= filter_.sum()
|
||||
filter = filter_.view(1, 1, kernel_size)
|
||||
return filter
|
||||
|
||||
|
||||
class LowPassFilter1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cutoff=0.5,
|
||||
half_width=0.6,
|
||||
stride=1,
|
||||
padding=True,
|
||||
padding_mode="replicate",
|
||||
kernel_size=12,
|
||||
):
|
||||
super().__init__()
|
||||
if cutoff < -0.0:
|
||||
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||
if cutoff > 0.5:
|
||||
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||
self.kernel_size = kernel_size
|
||||
self.even = kernel_size % 2 == 0
|
||||
self.pad_left = kernel_size // 2 - int(self.even)
|
||||
self.pad_right = kernel_size // 2
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.padding_mode = padding_mode
|
||||
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
if self.padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||
return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.stride = ratio
|
||||
|
||||
if window_type == "hann":
|
||||
# Hann-windowed sinc filter — identical to torchaudio.functional.resample
|
||||
# with its default parameters (rolloff=0.99, lowpass_filter_width=6).
|
||||
# Uses replicate boundary padding, matching the reference resampler exactly.
|
||||
rolloff = 0.99
|
||||
lowpass_filter_width = 6
|
||||
width = math.ceil(lowpass_filter_width / rolloff)
|
||||
self.kernel_size = 2 * width * ratio + 1
|
||||
self.pad = width
|
||||
self.pad_left = 2 * width * ratio
|
||||
self.pad_right = self.kernel_size - ratio
|
||||
t = (torch.arange(self.kernel_size) / ratio - width) * rolloff
|
||||
t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width)
|
||||
window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
||||
filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1)
|
||||
else:
|
||||
# Kaiser-windowed sinc filter (BigVGAN default).
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = (
|
||||
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
)
|
||||
filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
||||
)
|
||||
|
||||
self.register_buffer("filter", filter, persistent=persistent)
|
||||
|
||||
def forward(self, x):
|
||||
_, C, _ = x.shape
|
||||
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||
x = self.ratio * F.conv_transpose1d(
|
||||
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
||||
)
|
||||
x = x[..., self.pad_left : -self.pad_right]
|
||||
return x
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.lowpass = LowPassFilter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
stride=ratio,
|
||||
kernel_size=self.kernel_size,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.lowpass(x)
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
activation,
|
||||
up_ratio=2,
|
||||
down_ratio=2,
|
||||
up_kernel_size=12,
|
||||
down_kernel_size=12,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = activation
|
||||
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BigVGAN v2 activations (Snake / SnakeBeta)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Snake(nn.Module):
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
|
||||
):
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = nn.Parameter(
|
||||
torch.zeros(in_features)
|
||||
if alpha_logscale
|
||||
else torch.ones(in_features) * alpha
|
||||
)
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x):
|
||||
a = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
a = torch.exp(a)
|
||||
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
|
||||
):
|
||||
super().__init__()
|
||||
self.alpha_logscale = alpha_logscale
|
||||
self.alpha = nn.Parameter(
|
||||
torch.zeros(in_features)
|
||||
if alpha_logscale
|
||||
else torch.ones(in_features) * alpha
|
||||
)
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
self.beta = nn.Parameter(
|
||||
torch.zeros(in_features)
|
||||
if alpha_logscale
|
||||
else torch.ones(in_features) * alpha
|
||||
)
|
||||
self.beta.requires_grad = alpha_trainable
|
||||
self.eps = 1e-9
|
||||
|
||||
def forward(self, x):
|
||||
a = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||
b = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||
if self.alpha_logscale:
|
||||
a = torch.exp(a)
|
||||
b = torch.exp(b)
|
||||
return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class AMPBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"):
|
||||
super().__init__()
|
||||
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[0],
|
||||
padding=get_padding(kernel_size, dilation[0]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[1],
|
||||
padding=get_padding(kernel_size, dilation[1]),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[2],
|
||||
padding=get_padding(kernel_size, dilation[2]),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
ops.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
self.acts1 = nn.ModuleList(
|
||||
[Activation1d(act_cls(channels)) for _ in range(len(self.convs1))]
|
||||
)
|
||||
self.acts2 = nn.ModuleList(
|
||||
[Activation1d(act_cls(channels)) for _ in range(len(self.convs2))]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2):
|
||||
xt = a1(x)
|
||||
xt = c1(xt)
|
||||
xt = a2(xt)
|
||||
xt = c2(xt)
|
||||
x = x + xt
|
||||
return x
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HiFi-GAN residual blocks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super(ResBlock1, self).__init__()
|
||||
@ -119,6 +421,7 @@ class Vocoder(torch.nn.Module):
|
||||
"""
|
||||
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
|
||||
|
||||
Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1").
|
||||
"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
@ -128,19 +431,39 @@ class Vocoder(torch.nn.Module):
|
||||
config = self.get_default_config()
|
||||
|
||||
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
|
||||
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
|
||||
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
|
||||
upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2])
|
||||
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4])
|
||||
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
|
||||
stereo = config.get("stereo", True)
|
||||
resblock = config.get("resblock", "1")
|
||||
activation = config.get("activation", "snake")
|
||||
use_bias_at_final = config.get("use_bias_at_final", True)
|
||||
|
||||
|
||||
# "output_sample_rate" is not present in recent checkpoint configs.
|
||||
# When absent (None), AudioVAE.output_sample_rate computes it as:
|
||||
# sample_rate * vocoder.upsample_factor / mel_hop_length
|
||||
# where upsample_factor = product of all upsample stride lengths,
|
||||
# and mel_hop_length is loaded from the autoencoder config at
|
||||
# preprocessing.stft.hop_length (see CausalAudioAutoencoder).
|
||||
self.output_sample_rate = config.get("output_sample_rate")
|
||||
self.resblock = config.get("resblock", "1")
|
||||
self.use_tanh_at_final = config.get("use_tanh_at_final", True)
|
||||
self.apply_final_activation = config.get("apply_final_activation", True)
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
|
||||
in_channels = 128 if stereo else 64
|
||||
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
||||
|
||||
if self.resblock == "1":
|
||||
resblock_cls = ResBlock1
|
||||
elif self.resblock == "2":
|
||||
resblock_cls = ResBlock2
|
||||
elif self.resblock == "AMP1":
|
||||
resblock_cls = AMPBlock1
|
||||
else:
|
||||
raise ValueError(f"Unknown resblock type: {self.resblock}")
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
@ -157,25 +480,40 @@ class Vocoder(torch.nn.Module):
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock_class(ch, k, d))
|
||||
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
||||
if self.resblock == "AMP1":
|
||||
self.resblocks.append(resblock_cls(ch, k, d, activation=activation))
|
||||
else:
|
||||
self.resblocks.append(resblock_cls(ch, k, d))
|
||||
|
||||
out_channels = 2 if stereo else 1
|
||||
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
|
||||
if self.resblock == "AMP1":
|
||||
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||
self.act_post = Activation1d(act_cls(ch))
|
||||
else:
|
||||
self.act_post = nn.LeakyReLU()
|
||||
|
||||
self.conv_post = ops.Conv1d(
|
||||
ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final
|
||||
)
|
||||
|
||||
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
|
||||
|
||||
|
||||
def get_default_config(self):
|
||||
"""Generate default configuration for the vocoder."""
|
||||
|
||||
config = {
|
||||
"resblock_kernel_sizes": [3, 7, 11],
|
||||
"upsample_rates": [6, 5, 2, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"upsample_rates": [5, 4, 2, 2, 2],
|
||||
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
|
||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"upsample_initial_channel": 1024,
|
||||
"stereo": True,
|
||||
"resblock": "1",
|
||||
"activation": "snake",
|
||||
"use_bias_at_final": True,
|
||||
"use_tanh_at_final": True,
|
||||
}
|
||||
|
||||
return config
|
||||
@ -196,8 +534,10 @@ class Vocoder(torch.nn.Module):
|
||||
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
||||
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
|
||||
x = self.conv_pre(x)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
if self.resblock != "AMP1":
|
||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x = self.ups[i](x)
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
@ -206,8 +546,167 @@ class Vocoder(torch.nn.Module):
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
|
||||
x = self.act_post(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
if self.apply_final_activation:
|
||||
if self.use_tanh_at_final:
|
||||
x = torch.tanh(x)
|
||||
else:
|
||||
x = torch.clamp(x, -1, 1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class _STFTFn(nn.Module):
|
||||
"""Implements STFT as a convolution with precomputed DFT × Hann-window bases.
|
||||
|
||||
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
|
||||
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
|
||||
bfloat16 bases from training ensures the mel values fed to the BWE generator are
|
||||
bit-identical to what it was trained on.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_length: int, hop_length: int, win_length: int):
|
||||
super().__init__()
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
n_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||
|
||||
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute magnitude and phase spectrogram from a batch of waveforms.
|
||||
|
||||
Applies causal (left-only) padding of win_length - hop_length samples so that
|
||||
each output frame depends only on past and present input — no lookahead.
|
||||
The STFT is computed by convolving the padded signal with forward_basis.
|
||||
|
||||
Args:
|
||||
y: Waveform tensor of shape (B, T).
|
||||
|
||||
Returns:
|
||||
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||
Computed in float32 for numerical stability, then cast back to
|
||||
the input dtype.
|
||||
"""
|
||||
if y.dim() == 2:
|
||||
y = y.unsqueeze(1) # (B, 1, T)
|
||||
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
||||
y = F.pad(y, (left_pad, 0))
|
||||
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
|
||||
n_freqs = spec.shape[1] // 2
|
||||
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
||||
magnitude = torch.sqrt(real ** 2 + imag ** 2)
|
||||
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
|
||||
return magnitude, phase
|
||||
|
||||
|
||||
class MelSTFT(nn.Module):
|
||||
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
|
||||
|
||||
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
|
||||
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
|
||||
|
||||
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
|
||||
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter_length: int,
|
||||
hop_length: int,
|
||||
win_length: int,
|
||||
n_mel_channels: int,
|
||||
sampling_rate: int,
|
||||
mel_fmin: float,
|
||||
mel_fmax: float,
|
||||
):
|
||||
super().__init__()
|
||||
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
|
||||
|
||||
n_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
|
||||
|
||||
def mel_spectrogram(
|
||||
self, y: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute log-mel spectrogram and auxiliary spectral quantities.
|
||||
|
||||
Args:
|
||||
y: Waveform tensor of shape (B, T).
|
||||
|
||||
Returns:
|
||||
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
|
||||
Computed as log(clamp(mel_basis @ magnitude, min=1e-5)).
|
||||
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
|
||||
"""
|
||||
magnitude, phase = self.stft_fn(y)
|
||||
energy = torch.norm(magnitude, dim=1)
|
||||
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
||||
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
return log_mel, magnitude, phase, energy
|
||||
|
||||
|
||||
class VocoderWithBWE(torch.nn.Module):
|
||||
"""Vocoder with bandwidth extension (BWE) for higher sample rate output.
|
||||
|
||||
Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples
|
||||
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
vocoder_config = config["vocoder"]
|
||||
bwe_config = config["bwe"]
|
||||
|
||||
self.vocoder = Vocoder(config=vocoder_config)
|
||||
self.bwe_generator = Vocoder(
|
||||
config={**bwe_config, "apply_final_activation": False}
|
||||
)
|
||||
|
||||
self.input_sample_rate = bwe_config["input_sampling_rate"]
|
||||
self.output_sample_rate = bwe_config["output_sampling_rate"]
|
||||
self.hop_length = bwe_config["hop_length"]
|
||||
|
||||
self.mel_stft = MelSTFT(
|
||||
filter_length=bwe_config["n_fft"],
|
||||
hop_length=bwe_config["hop_length"],
|
||||
win_length=bwe_config["n_fft"],
|
||||
n_mel_channels=bwe_config["num_mels"],
|
||||
sampling_rate=bwe_config["input_sampling_rate"],
|
||||
mel_fmin=0.0,
|
||||
mel_fmax=bwe_config["input_sampling_rate"] / 2.0,
|
||||
)
|
||||
self.resampler = UpSample1d(
|
||||
ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"],
|
||||
persistent=False,
|
||||
window_type="hann",
|
||||
)
|
||||
|
||||
def _compute_mel(self, audio):
|
||||
"""Compute log-mel spectrogram from waveform using causal STFT bases."""
|
||||
B, C, T = audio.shape
|
||||
flat = audio.reshape(B * C, -1) # (B*C, T)
|
||||
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
|
||||
return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
|
||||
|
||||
def forward(self, mel_spec):
|
||||
x = self.vocoder(mel_spec)
|
||||
_, _, T_low = x.shape
|
||||
T_out = T_low * self.output_sample_rate // self.input_sample_rate
|
||||
|
||||
remainder = T_low % self.hop_length
|
||||
if remainder != 0:
|
||||
x = F.pad(x, (0, self.hop_length - remainder))
|
||||
|
||||
mel = self._compute_mel(x)
|
||||
residual = self.bwe_generator(mel)
|
||||
skip = self.resampler(x)
|
||||
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
|
||||
|
||||
return torch.clamp(residual + skip, -1, 1)[..., :T_out]
|
||||
|
||||
@ -1021,7 +1021,7 @@ class LTXAV(BaseModel):
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()))
|
||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||
|
||||
@ -1467,7 +1467,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
||||
elif clip_type == CLIPType.LTXV:
|
||||
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data))
|
||||
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data), **comfy.text_encoders.lt.sd_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif clip_type == CLIPType.NEWBIE:
|
||||
|
||||
@ -97,18 +97,39 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
|
||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
|
||||
|
||||
class DualLinearProjection(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.audio_aggregate_embed = operations.Linear(in_dim, out_dim_audio, bias=True, dtype=dtype, device=device)
|
||||
self.video_aggregate_embed = operations.Linear(in_dim, out_dim_video, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
source_dim = x.shape[-1]
|
||||
x = x.movedim(1, -1)
|
||||
x = (x * torch.rsqrt(torch.mean(x**2, dim=2, keepdim=True) + 1e-6)).flatten(start_dim=2)
|
||||
|
||||
video = self.video_aggregate_embed(x * math.sqrt(self.video_aggregate_embed.out_features / source_dim))
|
||||
audio = self.audio_aggregate_embed(x * math.sqrt(self.audio_aggregate_embed.out_features / source_dim))
|
||||
return torch.cat((video, audio), dim=-1)
|
||||
|
||||
class LTXAVTEModel(torch.nn.Module):
|
||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, text_projection_type="single_linear", model_options={}):
|
||||
super().__init__()
|
||||
self.dtypes = set()
|
||||
self.dtypes.add(dtype)
|
||||
self.compat_mode = False
|
||||
self.text_projection_type = text_projection_type
|
||||
|
||||
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
||||
self.dtypes.add(dtype_llama)
|
||||
|
||||
operations = self.gemma3_12b.operations # TODO
|
||||
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
||||
|
||||
if self.text_projection_type == "single_linear":
|
||||
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
||||
elif self.text_projection_type == "dual_linear":
|
||||
self.text_embedding_projection = DualLinearProjection(3840 * 49, 4096, 2048, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
|
||||
def enable_compat_mode(self): # TODO: remove
|
||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||
@ -148,18 +169,25 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
out_device = out.device
|
||||
if comfy.model_management.should_use_bf16(self.execution_device):
|
||||
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
||||
out = out.movedim(1, -1).to(self.execution_device)
|
||||
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
||||
out = out.reshape((out.shape[0], out.shape[1], -1))
|
||||
out = self.text_embedding_projection(out)
|
||||
out = out.float()
|
||||
|
||||
if self.compat_mode:
|
||||
out_vid = self.video_embeddings_connector(out)[0]
|
||||
out_audio = self.audio_embeddings_connector(out)[0]
|
||||
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||
if self.text_projection_type == "single_linear":
|
||||
out = out.movedim(1, -1).to(self.execution_device)
|
||||
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
||||
out = out.reshape((out.shape[0], out.shape[1], -1))
|
||||
out = self.text_embedding_projection(out)
|
||||
|
||||
return out.to(out_device), pooled
|
||||
if self.compat_mode:
|
||||
out_vid = self.video_embeddings_connector(out)[0]
|
||||
out_audio = self.audio_embeddings_connector(out)[0]
|
||||
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||
extra = {}
|
||||
else:
|
||||
extra = {"unprocessed_ltxav_embeds": True}
|
||||
elif self.text_projection_type == "dual_linear":
|
||||
out = self.text_embedding_projection(out)
|
||||
extra = {"unprocessed_ltxav_embeds": True}
|
||||
|
||||
return out.to(device=out_device, dtype=torch.float), pooled, extra
|
||||
|
||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
||||
@ -168,7 +196,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||
return self.gemma3_12b.load_sd(sd)
|
||||
else:
|
||||
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True)
|
||||
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "text_embedding_projection.": "text_embedding_projection."}, filter_keys=True)
|
||||
if len(sdo) == 0:
|
||||
sdo = sd
|
||||
|
||||
@ -206,7 +234,7 @@ class LTXAVTEModel(torch.nn.Module):
|
||||
num_tokens = max(num_tokens, 642)
|
||||
return num_tokens * constant * 1024 * 1024
|
||||
|
||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None, text_projection_type="single_linear"):
|
||||
class LTXAVTEModel_(LTXAVTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
@ -214,9 +242,19 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, text_projection_type=text_projection_type, model_options=model_options)
|
||||
return LTXAVTEModel_
|
||||
|
||||
|
||||
def sd_detect(state_dict_list, prefix=""):
|
||||
for sd in state_dict_list:
|
||||
if "{}text_embedding_projection.audio_aggregate_embed.bias".format(prefix) in sd:
|
||||
return {"text_projection_type": "dual_linear"}
|
||||
if "{}text_embedding_projection.weight".format(prefix) in sd or "{}text_embedding_projection.aggregate_embed.weight".format(prefix) in sd:
|
||||
return {"text_projection_type": "single_linear"}
|
||||
return {}
|
||||
|
||||
|
||||
def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class Gemma3_12BModel_(Gemma3_12BModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user