mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-06 09:47:35 +08:00
Support the LTXAV 2.3 model. (#12773)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
This commit is contained in:
parent
ac4a943ff3
commit
43c64b6308
@ -2,11 +2,16 @@ from typing import Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from comfy.ldm.lightricks.model import (
|
from comfy.ldm.lightricks.model import (
|
||||||
|
ADALN_BASE_PARAMS_COUNT,
|
||||||
|
ADALN_CROSS_ATTN_PARAMS_COUNT,
|
||||||
CrossAttention,
|
CrossAttention,
|
||||||
FeedForward,
|
FeedForward,
|
||||||
AdaLayerNormSingle,
|
AdaLayerNormSingle,
|
||||||
PixArtAlphaTextProjection,
|
PixArtAlphaTextProjection,
|
||||||
|
NormSingleLinearTextProjection,
|
||||||
LTXVModel,
|
LTXVModel,
|
||||||
|
apply_cross_attention_adaln,
|
||||||
|
compute_prompt_timestep,
|
||||||
)
|
)
|
||||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||||
@ -87,6 +92,8 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
v_context_dim=None,
|
v_context_dim=None,
|
||||||
a_context_dim=None,
|
a_context_dim=None,
|
||||||
attn_precision=None,
|
attn_precision=None,
|
||||||
|
apply_gated_attention=False,
|
||||||
|
cross_attention_adaln=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -94,6 +101,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.attn_precision = attn_precision
|
self.attn_precision = attn_precision
|
||||||
|
self.cross_attention_adaln = cross_attention_adaln
|
||||||
|
|
||||||
self.attn1 = CrossAttention(
|
self.attn1 = CrossAttention(
|
||||||
query_dim=v_dim,
|
query_dim=v_dim,
|
||||||
@ -101,6 +109,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
dim_head=vd_head,
|
dim_head=vd_head,
|
||||||
context_dim=None,
|
context_dim=None,
|
||||||
attn_precision=self.attn_precision,
|
attn_precision=self.attn_precision,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@ -111,6 +120,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
dim_head=ad_head,
|
dim_head=ad_head,
|
||||||
context_dim=None,
|
context_dim=None,
|
||||||
attn_precision=self.attn_precision,
|
attn_precision=self.attn_precision,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@ -122,6 +132,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
heads=v_heads,
|
heads=v_heads,
|
||||||
dim_head=vd_head,
|
dim_head=vd_head,
|
||||||
attn_precision=self.attn_precision,
|
attn_precision=self.attn_precision,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@ -132,6 +143,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
heads=a_heads,
|
heads=a_heads,
|
||||||
dim_head=ad_head,
|
dim_head=ad_head,
|
||||||
attn_precision=self.attn_precision,
|
attn_precision=self.attn_precision,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@ -144,6 +156,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
heads=a_heads,
|
heads=a_heads,
|
||||||
dim_head=ad_head,
|
dim_head=ad_head,
|
||||||
attn_precision=self.attn_precision,
|
attn_precision=self.attn_precision,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@ -156,6 +169,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
heads=a_heads,
|
heads=a_heads,
|
||||||
dim_head=ad_head,
|
dim_head=ad_head,
|
||||||
attn_precision=self.attn_precision,
|
attn_precision=self.attn_precision,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@ -168,11 +182,16 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
|
|
||||||
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
|
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||||
|
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype))
|
||||||
self.audio_scale_shift_table = nn.Parameter(
|
self.audio_scale_shift_table = nn.Parameter(
|
||||||
torch.empty(6, a_dim, device=device, dtype=dtype)
|
torch.empty(num_ada_params, a_dim, device=device, dtype=dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cross_attention_adaln:
|
||||||
|
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype))
|
||||||
|
self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
||||||
torch.empty(5, a_dim, device=device, dtype=dtype)
|
torch.empty(5, a_dim, device=device, dtype=dtype)
|
||||||
)
|
)
|
||||||
@ -215,10 +234,30 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
return (*scale_shift_ada_values, *gate_ada_values)
|
return (*scale_shift_ada_values, *gate_ada_values)
|
||||||
|
|
||||||
|
def _apply_text_cross_attention(
|
||||||
|
self, x, context, attn, scale_shift_table, prompt_scale_shift_table,
|
||||||
|
timestep, prompt_timestep, attention_mask, transformer_options,
|
||||||
|
):
|
||||||
|
"""Apply text cross-attention, with optional ADaLN modulation."""
|
||||||
|
if self.cross_attention_adaln:
|
||||||
|
shift_q, scale_q, gate = self.get_ada_values(
|
||||||
|
scale_shift_table, x.shape[0], timestep, slice(6, 9)
|
||||||
|
)
|
||||||
|
return apply_cross_attention_adaln(
|
||||||
|
x, context, attn, shift_q, scale_q, gate,
|
||||||
|
prompt_scale_shift_table, prompt_timestep,
|
||||||
|
attention_mask, transformer_options,
|
||||||
|
)
|
||||||
|
return attn(
|
||||||
|
comfy.ldm.common_dit.rms_norm(x), context=context,
|
||||||
|
mask=attention_mask, transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
|
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
|
||||||
|
v_prompt_timestep=None, a_prompt_timestep=None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
run_vx = transformer_options.get("run_vx", True)
|
run_vx = transformer_options.get("run_vx", True)
|
||||||
run_ax = transformer_options.get("run_ax", True)
|
run_ax = transformer_options.get("run_ax", True)
|
||||||
@ -240,7 +279,11 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||||
vx.addcmul_(attn1_out, vgate_msa)
|
vx.addcmul_(attn1_out, vgate_msa)
|
||||||
del vgate_msa, attn1_out
|
del vgate_msa, attn1_out
|
||||||
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
vx.add_(self._apply_text_cross_attention(
|
||||||
|
vx, v_context, self.attn2, self.scale_shift_table,
|
||||||
|
getattr(self, 'prompt_scale_shift_table', None),
|
||||||
|
v_timestep, v_prompt_timestep, attention_mask, transformer_options,)
|
||||||
|
)
|
||||||
|
|
||||||
# audio
|
# audio
|
||||||
if run_ax:
|
if run_ax:
|
||||||
@ -254,7 +297,11 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||||
ax.addcmul_(attn1_out, agate_msa)
|
ax.addcmul_(attn1_out, agate_msa)
|
||||||
del agate_msa, attn1_out
|
del agate_msa, attn1_out
|
||||||
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
ax.add_(self._apply_text_cross_attention(
|
||||||
|
ax, a_context, self.audio_attn2, self.audio_scale_shift_table,
|
||||||
|
getattr(self, 'audio_prompt_scale_shift_table', None),
|
||||||
|
a_timestep, a_prompt_timestep, attention_mask, transformer_options,)
|
||||||
|
)
|
||||||
|
|
||||||
# video - audio cross attention.
|
# video - audio cross attention.
|
||||||
if run_a2v or run_v2a:
|
if run_a2v or run_v2a:
|
||||||
@ -351,6 +398,9 @@ class LTXAVModel(LTXVModel):
|
|||||||
use_middle_indices_grid=False,
|
use_middle_indices_grid=False,
|
||||||
timestep_scale_multiplier=1000.0,
|
timestep_scale_multiplier=1000.0,
|
||||||
av_ca_timestep_scale_multiplier=1.0,
|
av_ca_timestep_scale_multiplier=1.0,
|
||||||
|
apply_gated_attention=False,
|
||||||
|
caption_proj_before_connector=False,
|
||||||
|
cross_attention_adaln=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -362,6 +412,7 @@ class LTXAVModel(LTXVModel):
|
|||||||
self.audio_attention_head_dim = audio_attention_head_dim
|
self.audio_attention_head_dim = audio_attention_head_dim
|
||||||
self.audio_num_attention_heads = audio_num_attention_heads
|
self.audio_num_attention_heads = audio_num_attention_heads
|
||||||
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
||||||
|
self.apply_gated_attention = apply_gated_attention
|
||||||
|
|
||||||
# Calculate audio dimensions
|
# Calculate audio dimensions
|
||||||
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
||||||
@ -386,6 +437,8 @@ class LTXAVModel(LTXVModel):
|
|||||||
vae_scale_factors=vae_scale_factors,
|
vae_scale_factors=vae_scale_factors,
|
||||||
use_middle_indices_grid=use_middle_indices_grid,
|
use_middle_indices_grid=use_middle_indices_grid,
|
||||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||||
|
caption_proj_before_connector=caption_proj_before_connector,
|
||||||
|
cross_attention_adaln=cross_attention_adaln,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@ -400,14 +453,28 @@ class LTXAVModel(LTXVModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Audio-specific AdaLN
|
# Audio-specific AdaLN
|
||||||
|
audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||||
self.audio_adaln_single = AdaLayerNormSingle(
|
self.audio_adaln_single = AdaLayerNormSingle(
|
||||||
self.audio_inner_dim,
|
self.audio_inner_dim,
|
||||||
|
embedding_coefficient=audio_embedding_coefficient,
|
||||||
use_additional_conditions=False,
|
use_additional_conditions=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=self.operations,
|
operations=self.operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.cross_attention_adaln:
|
||||||
|
self.audio_prompt_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.audio_inner_dim,
|
||||||
|
embedding_coefficient=2,
|
||||||
|
use_additional_conditions=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.audio_prompt_adaln_single = None
|
||||||
|
|
||||||
num_scale_shift_values = 4
|
num_scale_shift_values = 4
|
||||||
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||||
self.inner_dim,
|
self.inner_dim,
|
||||||
@ -443,35 +510,73 @@ class LTXAVModel(LTXVModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Audio caption projection
|
# Audio caption projection
|
||||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
if self.caption_proj_before_connector:
|
||||||
in_features=self.caption_channels,
|
if self.caption_projection_first_linear:
|
||||||
hidden_size=self.audio_inner_dim,
|
self.audio_caption_projection = NormSingleLinearTextProjection(
|
||||||
dtype=dtype,
|
in_features=self.caption_channels,
|
||||||
device=device,
|
hidden_size=self.audio_inner_dim,
|
||||||
operations=self.operations,
|
dtype=dtype,
|
||||||
)
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.audio_caption_projection = lambda a: a
|
||||||
|
else:
|
||||||
|
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||||
|
in_features=self.caption_channels,
|
||||||
|
hidden_size=self.audio_inner_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
connector_split_rope = kwargs.get("rope_type", "split") == "split"
|
||||||
|
connector_gated_attention = kwargs.get("connector_apply_gated_attention", False)
|
||||||
|
attention_head_dim = kwargs.get("connector_attention_head_dim", 128)
|
||||||
|
num_attention_heads = kwargs.get("connector_num_attention_heads", 30)
|
||||||
|
num_layers = kwargs.get("connector_num_layers", 2)
|
||||||
|
|
||||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||||
split_rope=True,
|
attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim),
|
||||||
|
num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads),
|
||||||
|
num_layers=num_layers,
|
||||||
|
split_rope=connector_split_rope,
|
||||||
double_precision_rope=True,
|
double_precision_rope=True,
|
||||||
|
apply_gated_attention=connector_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=self.operations,
|
operations=self.operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.video_embeddings_connector = Embeddings1DConnector(
|
self.video_embeddings_connector = Embeddings1DConnector(
|
||||||
split_rope=True,
|
attention_head_dim=attention_head_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_layers=num_layers,
|
||||||
|
split_rope=connector_split_rope,
|
||||||
double_precision_rope=True,
|
double_precision_rope=True,
|
||||||
|
apply_gated_attention=connector_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=self.operations,
|
operations=self.operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
def preprocess_text_embeds(self, context):
|
def preprocess_text_embeds(self, context, unprocessed=False):
|
||||||
if context.shape[-1] == self.caption_channels * 2:
|
# LTXv2 fully processed context has dimension of self.caption_channels * 2
|
||||||
return context
|
# LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim
|
||||||
out_vid = self.video_embeddings_connector(context)[0]
|
if not unprocessed:
|
||||||
out_audio = self.audio_embeddings_connector(context)[0]
|
if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2):
|
||||||
|
return context
|
||||||
|
if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim:
|
||||||
|
context_vid = context[:, :, :self.cross_attention_dim]
|
||||||
|
context_audio = context[:, :, self.cross_attention_dim:]
|
||||||
|
else:
|
||||||
|
context_vid = context
|
||||||
|
context_audio = context
|
||||||
|
if self.caption_proj_before_connector:
|
||||||
|
context_vid = self.caption_projection(context_vid)
|
||||||
|
context_audio = self.audio_caption_projection(context_audio)
|
||||||
|
out_vid = self.video_embeddings_connector(context_vid)[0]
|
||||||
|
out_audio = self.audio_embeddings_connector(context_audio)[0]
|
||||||
return torch.concat((out_vid, out_audio), dim=-1)
|
return torch.concat((out_vid, out_audio), dim=-1)
|
||||||
|
|
||||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||||
@ -487,6 +592,8 @@ class LTXAVModel(LTXVModel):
|
|||||||
ad_head=self.audio_attention_head_dim,
|
ad_head=self.audio_attention_head_dim,
|
||||||
v_context_dim=self.cross_attention_dim,
|
v_context_dim=self.cross_attention_dim,
|
||||||
a_context_dim=self.audio_cross_attention_dim,
|
a_context_dim=self.audio_cross_attention_dim,
|
||||||
|
apply_gated_attention=self.apply_gated_attention,
|
||||||
|
cross_attention_adaln=self.cross_attention_adaln,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=self.operations,
|
operations=self.operations,
|
||||||
@ -608,6 +715,10 @@ class LTXAVModel(LTXVModel):
|
|||||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||||
|
|
||||||
|
v_prompt_timestep = compute_prompt_timestep(
|
||||||
|
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare audio timestep
|
# Prepare audio timestep
|
||||||
a_timestep = kwargs.get("a_timestep")
|
a_timestep = kwargs.get("a_timestep")
|
||||||
if a_timestep is not None:
|
if a_timestep is not None:
|
||||||
@ -618,25 +729,25 @@ class LTXAVModel(LTXVModel):
|
|||||||
|
|
||||||
# Cross-attention timesteps - compress these too
|
# Cross-attention timesteps - compress these too
|
||||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||||
a_timestep_flat,
|
timestep.max().expand_as(a_timestep_flat),
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||||
timestep_flat,
|
a_timestep.max().expand_as(timestep_flat),
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||||
timestep_flat * av_ca_factor,
|
a_timestep.max().expand_as(timestep_flat) * av_ca_factor,
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||||
a_timestep_flat * av_ca_factor,
|
timestep.max().expand_as(a_timestep_flat) * av_ca_factor,
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
@ -660,29 +771,40 @@ class LTXAVModel(LTXVModel):
|
|||||||
# Audio timesteps
|
# Audio timesteps
|
||||||
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||||
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
||||||
|
|
||||||
|
a_prompt_timestep = compute_prompt_timestep(
|
||||||
|
self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
a_timestep = timestep_scaled
|
a_timestep = timestep_scaled
|
||||||
a_embedded_timestep = kwargs.get("embedded_timestep")
|
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||||
cross_av_timestep_ss = []
|
cross_av_timestep_ss = []
|
||||||
|
a_prompt_timestep = None
|
||||||
|
|
||||||
return [v_timestep, a_timestep, cross_av_timestep_ss], [
|
return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [
|
||||||
v_embedded_timestep,
|
v_embedded_timestep,
|
||||||
a_embedded_timestep,
|
a_embedded_timestep,
|
||||||
]
|
], None
|
||||||
|
|
||||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||||
vx = x[0]
|
vx = x[0]
|
||||||
ax = x[1]
|
ax = x[1]
|
||||||
|
video_dim = vx.shape[-1]
|
||||||
|
audio_dim = ax.shape[-1]
|
||||||
|
|
||||||
|
v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim
|
||||||
|
a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim
|
||||||
|
|
||||||
v_context, a_context = torch.split(
|
v_context, a_context = torch.split(
|
||||||
context, int(context.shape[-1] / 2), len(context.shape) - 1
|
context, [v_context_dim, a_context_dim], len(context.shape) - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
v_context, attention_mask = super()._prepare_context(
|
v_context, attention_mask = super()._prepare_context(
|
||||||
v_context, batch_size, vx, attention_mask
|
v_context, batch_size, vx, attention_mask
|
||||||
)
|
)
|
||||||
if self.audio_caption_projection is not None:
|
if self.caption_proj_before_connector is False:
|
||||||
a_context = self.audio_caption_projection(a_context)
|
a_context = self.audio_caption_projection(a_context)
|
||||||
a_context = a_context.view(batch_size, -1, ax.shape[-1])
|
a_context = a_context.view(batch_size, -1, audio_dim)
|
||||||
|
|
||||||
return [v_context, a_context], attention_mask
|
return [v_context, a_context], attention_mask
|
||||||
|
|
||||||
@ -744,6 +866,9 @@ class LTXAVModel(LTXVModel):
|
|||||||
av_ca_v2a_gate_noise_timestep,
|
av_ca_v2a_gate_noise_timestep,
|
||||||
) = timestep[2]
|
) = timestep[2]
|
||||||
|
|
||||||
|
v_prompt_timestep = timestep[3]
|
||||||
|
a_prompt_timestep = timestep[4]
|
||||||
|
|
||||||
"""Process transformer blocks for LTXAV."""
|
"""Process transformer blocks for LTXAV."""
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
@ -771,6 +896,8 @@ class LTXAVModel(LTXVModel):
|
|||||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||||
transformer_options=args["transformer_options"],
|
transformer_options=args["transformer_options"],
|
||||||
self_attention_mask=args.get("self_attention_mask"),
|
self_attention_mask=args.get("self_attention_mask"),
|
||||||
|
v_prompt_timestep=args.get("v_prompt_timestep"),
|
||||||
|
a_prompt_timestep=args.get("a_prompt_timestep"),
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -792,6 +919,8 @@ class LTXAVModel(LTXVModel):
|
|||||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||||
"transformer_options": transformer_options,
|
"transformer_options": transformer_options,
|
||||||
"self_attention_mask": self_attention_mask,
|
"self_attention_mask": self_attention_mask,
|
||||||
|
"v_prompt_timestep": v_prompt_timestep,
|
||||||
|
"a_prompt_timestep": a_prompt_timestep,
|
||||||
},
|
},
|
||||||
{"original_block": block_wrap},
|
{"original_block": block_wrap},
|
||||||
)
|
)
|
||||||
@ -814,6 +943,8 @@ class LTXAVModel(LTXVModel):
|
|||||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
self_attention_mask=self_attention_mask,
|
self_attention_mask=self_attention_mask,
|
||||||
|
v_prompt_timestep=v_prompt_timestep,
|
||||||
|
a_prompt_timestep=a_prompt_timestep,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [vx, ax]
|
return [vx, ax]
|
||||||
|
|||||||
@ -50,6 +50,7 @@ class BasicTransformerBlock1D(nn.Module):
|
|||||||
d_head,
|
d_head,
|
||||||
context_dim=None,
|
context_dim=None,
|
||||||
attn_precision=None,
|
attn_precision=None,
|
||||||
|
apply_gated_attention=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -63,6 +64,7 @@ class BasicTransformerBlock1D(nn.Module):
|
|||||||
heads=n_heads,
|
heads=n_heads,
|
||||||
dim_head=d_head,
|
dim_head=d_head,
|
||||||
context_dim=None,
|
context_dim=None,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@ -121,6 +123,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
positional_embedding_max_pos=[4096],
|
positional_embedding_max_pos=[4096],
|
||||||
causal_temporal_positioning=False,
|
causal_temporal_positioning=False,
|
||||||
num_learnable_registers: Optional[int] = 128,
|
num_learnable_registers: Optional[int] = 128,
|
||||||
|
apply_gated_attention=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -145,6 +148,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
num_attention_heads,
|
num_attention_heads,
|
||||||
attention_head_dim,
|
attention_head_dim,
|
||||||
context_dim=cross_attention_dim,
|
context_dim=cross_attention_dim,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
|
|||||||
@ -275,6 +275,30 @@ class PixArtAlphaTextProjection(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class NormSingleLinearTextProjection(nn.Module):
|
||||||
|
"""Text projection for 20B models - single linear with RMSNorm (no activation)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, in_features, hidden_size, dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if operations is None:
|
||||||
|
operations = comfy.ops.disable_weight_init
|
||||||
|
self.in_norm = operations.RMSNorm(
|
||||||
|
in_features, eps=1e-6, elementwise_affine=False
|
||||||
|
)
|
||||||
|
self.linear_1 = operations.Linear(
|
||||||
|
in_features, hidden_size, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.in_features = in_features
|
||||||
|
|
||||||
|
def forward(self, caption):
|
||||||
|
caption = self.in_norm(caption)
|
||||||
|
caption = caption * (self.hidden_size / self.in_features) ** 0.5
|
||||||
|
return self.linear_1(caption)
|
||||||
|
|
||||||
|
|
||||||
class GELU_approx(nn.Module):
|
class GELU_approx(nn.Module):
|
||||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -343,6 +367,7 @@ class CrossAttention(nn.Module):
|
|||||||
dim_head=64,
|
dim_head=64,
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
attn_precision=None,
|
attn_precision=None,
|
||||||
|
apply_gated_attention=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -362,6 +387,12 @@ class CrossAttention(nn.Module):
|
|||||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# Optional per-head gating
|
||||||
|
if apply_gated_attention:
|
||||||
|
self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
self.to_gate_logits = None
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
@ -383,16 +414,30 @@ class CrossAttention(nn.Module):
|
|||||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
else:
|
else:
|
||||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
# Apply per-head gating if enabled
|
||||||
|
if self.to_gate_logits is not None:
|
||||||
|
gate_logits = self.to_gate_logits(x) # (B, T, H)
|
||||||
|
b, t, _ = out.shape
|
||||||
|
out = out.view(b, t, self.heads, self.dim_head)
|
||||||
|
gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity
|
||||||
|
out = out * gates.unsqueeze(-1)
|
||||||
|
out = out.view(b, t, self.heads * self.dim_head)
|
||||||
|
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate)
|
||||||
|
ADALN_BASE_PARAMS_COUNT = 6
|
||||||
|
ADALN_CROSS_ATTN_PARAMS_COUNT = 9
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
|
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.attn_precision = attn_precision
|
self.attn_precision = attn_precision
|
||||||
|
self.cross_attention_adaln = cross_attention_adaln
|
||||||
self.attn1 = CrossAttention(
|
self.attn1 = CrossAttention(
|
||||||
query_dim=dim,
|
query_dim=dim,
|
||||||
heads=n_heads,
|
heads=n_heads,
|
||||||
@ -416,18 +461,25 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||||
|
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None):
|
if cross_attention_adaln:
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None):
|
||||||
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2)
|
||||||
attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options)
|
|
||||||
x.addcmul_(attn1_input, gate_msa)
|
|
||||||
del attn1_input
|
|
||||||
|
|
||||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa
|
||||||
|
|
||||||
|
if self.cross_attention_adaln:
|
||||||
|
shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2)
|
||||||
|
x += apply_cross_attention_adaln(
|
||||||
|
x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca,
|
||||||
|
self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
y = comfy.ldm.common_dit.rms_norm(x)
|
y = comfy.ldm.common_dit.rms_norm(x)
|
||||||
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
||||||
@ -435,6 +487,47 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype):
|
||||||
|
"""Compute a single global prompt timestep for cross-attention ADaLN.
|
||||||
|
|
||||||
|
Uses the max across tokens (matching JAX max_per_segment) and broadcasts
|
||||||
|
over text tokens. Returns None when *adaln_module* is None.
|
||||||
|
"""
|
||||||
|
if adaln_module is None:
|
||||||
|
return None
|
||||||
|
ts_input = (
|
||||||
|
timestep_scaled.max(dim=1, keepdim=True).values.flatten()
|
||||||
|
if timestep_scaled.dim() > 1
|
||||||
|
else timestep_scaled.flatten()
|
||||||
|
)
|
||||||
|
prompt_ts, _ = adaln_module(
|
||||||
|
ts_input,
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1])
|
||||||
|
|
||||||
|
|
||||||
|
def apply_cross_attention_adaln(
|
||||||
|
x, context, attn, q_shift, q_scale, q_gate,
|
||||||
|
prompt_scale_shift_table, prompt_timestep,
|
||||||
|
attention_mask=None, transformer_options={},
|
||||||
|
):
|
||||||
|
"""Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV).
|
||||||
|
|
||||||
|
Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so
|
||||||
|
that both regular tensors and CompressedTimestep are supported.
|
||||||
|
"""
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
shift_kv, scale_kv = (
|
||||||
|
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
|
||||||
|
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
|
||||||
|
).unbind(dim=2)
|
||||||
|
attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift
|
||||||
|
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
|
||||||
|
return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate
|
||||||
|
|
||||||
def get_fractional_positions(indices_grid, max_pos):
|
def get_fractional_positions(indices_grid, max_pos):
|
||||||
n_pos_dims = indices_grid.shape[1]
|
n_pos_dims = indices_grid.shape[1]
|
||||||
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
||||||
@ -556,6 +649,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
vae_scale_factors: tuple = (8, 32, 32),
|
vae_scale_factors: tuple = (8, 32, 32),
|
||||||
use_middle_indices_grid=False,
|
use_middle_indices_grid=False,
|
||||||
timestep_scale_multiplier = 1000.0,
|
timestep_scale_multiplier = 1000.0,
|
||||||
|
caption_proj_before_connector=False,
|
||||||
|
cross_attention_adaln=False,
|
||||||
|
caption_projection_first_linear=True,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -582,6 +678,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
self.causal_temporal_positioning = causal_temporal_positioning
|
self.causal_temporal_positioning = causal_temporal_positioning
|
||||||
self.operations = operations
|
self.operations = operations
|
||||||
self.timestep_scale_multiplier = timestep_scale_multiplier
|
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||||
|
self.caption_proj_before_connector = caption_proj_before_connector
|
||||||
|
self.cross_attention_adaln = cross_attention_adaln
|
||||||
|
self.caption_projection_first_linear = caption_projection_first_linear
|
||||||
|
|
||||||
# Common dimensions
|
# Common dimensions
|
||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
@ -609,17 +708,37 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||||
self.adaln_single = AdaLayerNormSingle(
|
self.adaln_single = AdaLayerNormSingle(
|
||||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||||
)
|
)
|
||||||
|
|
||||||
self.caption_projection = PixArtAlphaTextProjection(
|
if self.cross_attention_adaln:
|
||||||
in_features=self.caption_channels,
|
self.prompt_adaln_single = AdaLayerNormSingle(
|
||||||
hidden_size=self.inner_dim,
|
self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||||
dtype=dtype,
|
)
|
||||||
device=device,
|
else:
|
||||||
operations=self.operations,
|
self.prompt_adaln_single = None
|
||||||
)
|
|
||||||
|
if self.caption_proj_before_connector:
|
||||||
|
if self.caption_projection_first_linear:
|
||||||
|
self.caption_projection = NormSingleLinearTextProjection(
|
||||||
|
in_features=self.caption_channels,
|
||||||
|
hidden_size=self.inner_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.caption_projection = lambda a: a
|
||||||
|
else:
|
||||||
|
self.caption_projection = PixArtAlphaTextProjection(
|
||||||
|
in_features=self.caption_channels,
|
||||||
|
hidden_size=self.inner_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _init_model_components(self, device, dtype, **kwargs):
|
def _init_model_components(self, device, dtype, **kwargs):
|
||||||
@ -665,9 +784,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
if grid_mask is not None:
|
if grid_mask is not None:
|
||||||
timestep = timestep[:, grid_mask]
|
timestep = timestep[:, grid_mask]
|
||||||
|
|
||||||
timestep = timestep * self.timestep_scale_multiplier
|
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||||
timestep, embedded_timestep = self.adaln_single(
|
timestep, embedded_timestep = self.adaln_single(
|
||||||
timestep.flatten(),
|
timestep_scaled.flatten(),
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
@ -677,14 +796,18 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
||||||
|
|
||||||
return timestep, embedded_timestep
|
prompt_timestep = compute_prompt_timestep(
|
||||||
|
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return timestep, embedded_timestep, prompt_timestep
|
||||||
|
|
||||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||||
"""Prepare context for transformer blocks."""
|
"""Prepare context for transformer blocks."""
|
||||||
if self.caption_projection is not None:
|
if self.caption_proj_before_connector is False:
|
||||||
context = self.caption_projection(context)
|
context = self.caption_projection(context)
|
||||||
context = context.view(batch_size, -1, x.shape[-1])
|
|
||||||
|
|
||||||
|
context = context.view(batch_size, -1, x.shape[-1])
|
||||||
return context, attention_mask
|
return context, attention_mask
|
||||||
|
|
||||||
def _precompute_freqs_cis(
|
def _precompute_freqs_cis(
|
||||||
@ -792,7 +915,8 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
merged_args.update(additional_args)
|
merged_args.update(additional_args)
|
||||||
|
|
||||||
# Prepare timestep and context
|
# Prepare timestep and context
|
||||||
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||||
|
merged_args["prompt_timestep"] = prompt_timestep
|
||||||
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
||||||
|
|
||||||
# Prepare attention mask and positional embeddings
|
# Prepare attention mask and positional embeddings
|
||||||
@ -833,7 +957,9 @@ class LTXVModel(LTXBaseModel):
|
|||||||
causal_temporal_positioning=False,
|
causal_temporal_positioning=False,
|
||||||
vae_scale_factors=(8, 32, 32),
|
vae_scale_factors=(8, 32, 32),
|
||||||
use_middle_indices_grid=False,
|
use_middle_indices_grid=False,
|
||||||
timestep_scale_multiplier = 1000.0,
|
timestep_scale_multiplier=1000.0,
|
||||||
|
caption_proj_before_connector=False,
|
||||||
|
cross_attention_adaln=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -852,6 +978,8 @@ class LTXVModel(LTXBaseModel):
|
|||||||
vae_scale_factors=vae_scale_factors,
|
vae_scale_factors=vae_scale_factors,
|
||||||
use_middle_indices_grid=use_middle_indices_grid,
|
use_middle_indices_grid=use_middle_indices_grid,
|
||||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||||
|
caption_proj_before_connector=caption_proj_before_connector,
|
||||||
|
cross_attention_adaln=cross_attention_adaln,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@ -860,7 +988,6 @@ class LTXVModel(LTXBaseModel):
|
|||||||
|
|
||||||
def _init_model_components(self, device, dtype, **kwargs):
|
def _init_model_components(self, device, dtype, **kwargs):
|
||||||
"""Initialize LTXV-specific components."""
|
"""Initialize LTXV-specific components."""
|
||||||
# No additional components needed for LTXV beyond base class
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||||
@ -872,6 +999,7 @@ class LTXVModel(LTXBaseModel):
|
|||||||
self.num_attention_heads,
|
self.num_attention_heads,
|
||||||
self.attention_head_dim,
|
self.attention_head_dim,
|
||||||
context_dim=self.cross_attention_dim,
|
context_dim=self.cross_attention_dim,
|
||||||
|
cross_attention_adaln=self.cross_attention_adaln,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=self.operations,
|
operations=self.operations,
|
||||||
@ -1149,16 +1277,17 @@ class LTXVModel(LTXBaseModel):
|
|||||||
"""Process transformer blocks for LTXV."""
|
"""Process transformer blocks for LTXV."""
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
prompt_timestep = kwargs.get("prompt_timestep", None)
|
||||||
|
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
|
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"))
|
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(
|
x = block(
|
||||||
@ -1169,6 +1298,7 @@ class LTXVModel(LTXBaseModel):
|
|||||||
pe=pe,
|
pe=pe,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
self_attention_mask=self_attention_mask,
|
self_attention_mask=self_attention_mask,
|
||||||
|
prompt_timestep=prompt_timestep,
|
||||||
)
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
|||||||
CausalityAxis,
|
CausalityAxis,
|
||||||
CausalAudioAutoencoder,
|
CausalAudioAutoencoder,
|
||||||
)
|
)
|
||||||
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
|
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE
|
||||||
|
|
||||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||||
|
|
||||||
@ -141,7 +141,10 @@ class AudioVAE(torch.nn.Module):
|
|||||||
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
||||||
|
|
||||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
if "bwe" in component_config.vocoder:
|
||||||
|
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
|
||||||
|
else:
|
||||||
|
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||||
|
|
||||||
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||||
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||||
|
|||||||
@ -822,26 +822,23 @@ class CausalAudioAutoencoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = self._guess_config()
|
config = self.get_default_config()
|
||||||
|
|
||||||
# Extract encoder and decoder configs from the new format
|
|
||||||
model_config = config.get("model", {}).get("params", {})
|
model_config = config.get("model", {}).get("params", {})
|
||||||
variables_config = config.get("variables", {})
|
|
||||||
|
|
||||||
self.sampling_rate = variables_config.get(
|
self.sampling_rate = model_config.get(
|
||||||
"sampling_rate",
|
"sampling_rate", config.get("sampling_rate", 16000)
|
||||||
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
|
|
||||||
)
|
)
|
||||||
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
|
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
|
||||||
decoder_config = model_config.get("decoder", encoder_config)
|
decoder_config = model_config.get("decoder", encoder_config)
|
||||||
|
|
||||||
# Load mel spectrogram parameters
|
# Load mel spectrogram parameters
|
||||||
self.mel_bins = encoder_config.get("mel_bins", 64)
|
self.mel_bins = encoder_config.get("mel_bins", 64)
|
||||||
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||||
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||||
|
|
||||||
# Store causality configuration at VAE level (not just in encoder internals)
|
# Store causality configuration at VAE level (not just in encoder internals)
|
||||||
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
|
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value)
|
||||||
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
|
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
|
||||||
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
|
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
|
||||||
|
|
||||||
@ -850,44 +847,38 @@ class CausalAudioAutoencoder(nn.Module):
|
|||||||
|
|
||||||
self.per_channel_statistics = processor()
|
self.per_channel_statistics = processor()
|
||||||
|
|
||||||
def _guess_config(self):
|
def get_default_config(self):
|
||||||
encoder_config = {
|
ddconfig = {
|
||||||
# Required parameters - based on ltx-video-av-1679000 model metadata
|
|
||||||
"ch": 128,
|
|
||||||
"out_ch": 8,
|
|
||||||
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
|
|
||||||
"num_res_blocks": 2,
|
|
||||||
"attn_resolutions": [], # Based on metadata: empty list, no attention
|
|
||||||
"dropout": 0.0,
|
|
||||||
"resamp_with_conv": True,
|
|
||||||
"in_channels": 2, # stereo
|
|
||||||
"resolution": 256,
|
|
||||||
"z_channels": 8,
|
|
||||||
"double_z": True,
|
"double_z": True,
|
||||||
"attn_type": "vanilla",
|
"mel_bins": 64,
|
||||||
"mid_block_add_attention": False, # Based on metadata: false
|
"z_channels": 8,
|
||||||
|
"resolution": 256,
|
||||||
|
"downsample_time": False,
|
||||||
|
"in_channels": 2,
|
||||||
|
"out_ch": 2,
|
||||||
|
"ch": 128,
|
||||||
|
"ch_mult": [1, 2, 4],
|
||||||
|
"num_res_blocks": 2,
|
||||||
|
"attn_resolutions": [],
|
||||||
|
"dropout": 0.0,
|
||||||
|
"mid_block_add_attention": False,
|
||||||
"norm_type": "pixel",
|
"norm_type": "pixel",
|
||||||
"causality_axis": "height", # Based on metadata
|
"causality_axis": "height",
|
||||||
"mel_bins": 64, # Based on metadata: mel_bins = 64
|
|
||||||
}
|
|
||||||
|
|
||||||
decoder_config = {
|
|
||||||
# Inherits encoder config, can override specific params
|
|
||||||
**encoder_config,
|
|
||||||
"out_ch": 2, # Stereo audio output (2 channels)
|
|
||||||
"give_pre_end": False,
|
|
||||||
"tanh_out": False,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"_class_name": "CausalAudioAutoencoder",
|
|
||||||
"sampling_rate": 16000,
|
|
||||||
"model": {
|
"model": {
|
||||||
"params": {
|
"params": {
|
||||||
"encoder": encoder_config,
|
"ddconfig": ddconfig,
|
||||||
"decoder": decoder_config,
|
"sampling_rate": 16000,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"preprocessing": {
|
||||||
|
"stft": {
|
||||||
|
"filter_length": 1024,
|
||||||
|
"hop_length": 160,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|||||||
@ -15,6 +15,9 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
|||||||
|
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
def in_meta_context():
|
||||||
|
return torch.device("meta") == torch.empty(0).device
|
||||||
|
|
||||||
def mark_conv3d_ended(module):
|
def mark_conv3d_ended(module):
|
||||||
tid = threading.get_ident()
|
tid = threading.get_ident()
|
||||||
for _, m in module.named_modules():
|
for _, m in module.named_modules():
|
||||||
@ -350,6 +353,10 @@ class Decoder(nn.Module):
|
|||||||
output_channel = output_channel * block_params.get("multiplier", 2)
|
output_channel = output_channel * block_params.get("multiplier", 2)
|
||||||
if block_name == "compress_all":
|
if block_name == "compress_all":
|
||||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||||
|
if block_name == "compress_space":
|
||||||
|
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||||
|
if block_name == "compress_time":
|
||||||
|
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||||
|
|
||||||
self.conv_in = make_conv_nd(
|
self.conv_in = make_conv_nd(
|
||||||
dims,
|
dims,
|
||||||
@ -395,17 +402,21 @@ class Decoder(nn.Module):
|
|||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_time":
|
elif block_name == "compress_time":
|
||||||
|
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||||
block = DepthToSpaceUpsample(
|
block = DepthToSpaceUpsample(
|
||||||
dims=dims,
|
dims=dims,
|
||||||
in_channels=input_channel,
|
in_channels=input_channel,
|
||||||
stride=(2, 1, 1),
|
stride=(2, 1, 1),
|
||||||
|
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_space":
|
elif block_name == "compress_space":
|
||||||
|
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||||
block = DepthToSpaceUpsample(
|
block = DepthToSpaceUpsample(
|
||||||
dims=dims,
|
dims=dims,
|
||||||
in_channels=input_channel,
|
in_channels=input_channel,
|
||||||
stride=(1, 2, 2),
|
stride=(1, 2, 2),
|
||||||
|
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_all":
|
elif block_name == "compress_all":
|
||||||
@ -455,6 +466,15 @@ class Decoder(nn.Module):
|
|||||||
output_channel * 2, 0, operations=ops,
|
output_channel * 2, 0, operations=ops,
|
||||||
)
|
)
|
||||||
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
||||||
|
else:
|
||||||
|
self.register_buffer(
|
||||||
|
"last_scale_shift_table",
|
||||||
|
torch.tensor(
|
||||||
|
[0.0, 0.0],
|
||||||
|
device="cpu" if in_meta_context() else None
|
||||||
|
).unsqueeze(1).expand(2, output_channel),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||||
@ -883,6 +903,15 @@ class ResnetBlock3D(nn.Module):
|
|||||||
self.scale_shift_table = nn.Parameter(
|
self.scale_shift_table = nn.Parameter(
|
||||||
torch.randn(4, in_channels) / in_channels**0.5
|
torch.randn(4, in_channels) / in_channels**0.5
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.register_buffer(
|
||||||
|
"scale_shift_table",
|
||||||
|
torch.tensor(
|
||||||
|
[0.0, 0.0, 0.0, 0.0],
|
||||||
|
device="cpu" if in_meta_context() else None
|
||||||
|
).unsqueeze(1).expand(4, in_channels),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
self.temporal_cache_state={}
|
self.temporal_cache_state={}
|
||||||
|
|
||||||
@ -1012,9 +1041,6 @@ class processor(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.register_buffer("std-of-means", torch.empty(128))
|
self.register_buffer("std-of-means", torch.empty(128))
|
||||||
self.register_buffer("mean-of-means", torch.empty(128))
|
self.register_buffer("mean-of-means", torch.empty(128))
|
||||||
self.register_buffer("mean-of-stds", torch.empty(128))
|
|
||||||
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
|
|
||||||
self.register_buffer("channel", torch.empty(128))
|
|
||||||
|
|
||||||
def un_normalize(self, x):
|
def un_normalize(self, x):
|
||||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||||
@ -1027,9 +1053,12 @@ class VideoVAE(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = self.guess_config(version)
|
config = self.get_default_config(version)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||||
|
self.decode_noise_scale = config.get("decode_noise_scale", 0.025)
|
||||||
|
self.decode_timestep = config.get("decode_timestep", 0.05)
|
||||||
double_z = config.get("double_z", True)
|
double_z = config.get("double_z", True)
|
||||||
latent_log_var = config.get(
|
latent_log_var = config.get(
|
||||||
"latent_log_var", "per_channel" if double_z else "none"
|
"latent_log_var", "per_channel" if double_z else "none"
|
||||||
@ -1044,6 +1073,7 @@ class VideoVAE(nn.Module):
|
|||||||
latent_log_var=latent_log_var,
|
latent_log_var=latent_log_var,
|
||||||
norm_layer=config.get("norm_layer", "group_norm"),
|
norm_layer=config.get("norm_layer", "group_norm"),
|
||||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||||
|
base_channels=config.get("encoder_base_channels", 128),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.decoder = Decoder(
|
self.decoder = Decoder(
|
||||||
@ -1051,6 +1081,7 @@ class VideoVAE(nn.Module):
|
|||||||
in_channels=config["latent_channels"],
|
in_channels=config["latent_channels"],
|
||||||
out_channels=config.get("out_channels", 3),
|
out_channels=config.get("out_channels", 3),
|
||||||
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||||
|
base_channels=config.get("decoder_base_channels", 128),
|
||||||
patch_size=config.get("patch_size", 1),
|
patch_size=config.get("patch_size", 1),
|
||||||
norm_layer=config.get("norm_layer", "group_norm"),
|
norm_layer=config.get("norm_layer", "group_norm"),
|
||||||
causal=config.get("causal_decoder", False),
|
causal=config.get("causal_decoder", False),
|
||||||
@ -1060,7 +1091,7 @@ class VideoVAE(nn.Module):
|
|||||||
|
|
||||||
self.per_channel_statistics = processor()
|
self.per_channel_statistics = processor()
|
||||||
|
|
||||||
def guess_config(self, version):
|
def get_default_config(self, version):
|
||||||
if version == 0:
|
if version == 0:
|
||||||
config = {
|
config = {
|
||||||
"_class_name": "CausalVideoAutoencoder",
|
"_class_name": "CausalVideoAutoencoder",
|
||||||
@ -1167,8 +1198,7 @@ class VideoVAE(nn.Module):
|
|||||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||||
return self.per_channel_statistics.normalize(means)
|
return self.per_channel_statistics.normalize(means)
|
||||||
|
|
||||||
def decode(self, x, timestep=0.05, noise_scale=0.025):
|
def decode(self, x):
|
||||||
if self.timestep_conditioning: #TODO: seed
|
if self.timestep_conditioning: #TODO: seed
|
||||||
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
|
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
|
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import torch.nn.functional as F
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import math
|
||||||
|
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
@ -12,6 +13,307 @@ def get_padding(kernel_size, dilation=1):
|
|||||||
return int((kernel_size * dilation - dilation) / 2)
|
return int((kernel_size * dilation - dilation) / 2)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
|
||||||
|
# Adopted from https://github.com/NVIDIA/BigVGAN
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _sinc(x: torch.Tensor):
|
||||||
|
return torch.where(
|
||||||
|
x == 0,
|
||||||
|
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
||||||
|
torch.sin(math.pi * x) / math.pi / x,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
|
||||||
|
even = kernel_size % 2 == 0
|
||||||
|
half_size = kernel_size // 2
|
||||||
|
delta_f = 4 * half_width
|
||||||
|
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||||
|
if A > 50.0:
|
||||||
|
beta = 0.1102 * (A - 8.7)
|
||||||
|
elif A >= 21.0:
|
||||||
|
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
||||||
|
else:
|
||||||
|
beta = 0.0
|
||||||
|
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||||
|
if even:
|
||||||
|
time = torch.arange(-half_size, half_size) + 0.5
|
||||||
|
else:
|
||||||
|
time = torch.arange(kernel_size) - half_size
|
||||||
|
if cutoff == 0:
|
||||||
|
filter_ = torch.zeros_like(time)
|
||||||
|
else:
|
||||||
|
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
|
||||||
|
filter_ /= filter_.sum()
|
||||||
|
filter = filter_.view(1, 1, kernel_size)
|
||||||
|
return filter
|
||||||
|
|
||||||
|
|
||||||
|
class LowPassFilter1d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cutoff=0.5,
|
||||||
|
half_width=0.6,
|
||||||
|
stride=1,
|
||||||
|
padding=True,
|
||||||
|
padding_mode="replicate",
|
||||||
|
kernel_size=12,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if cutoff < -0.0:
|
||||||
|
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||||
|
if cutoff > 0.5:
|
||||||
|
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.even = kernel_size % 2 == 0
|
||||||
|
self.pad_left = kernel_size // 2 - int(self.even)
|
||||||
|
self.pad_right = kernel_size // 2
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||||
|
self.register_buffer("filter", filter)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
_, C, _ = x.shape
|
||||||
|
if self.padding:
|
||||||
|
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||||
|
return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||||
|
|
||||||
|
|
||||||
|
class UpSample1d(nn.Module):
|
||||||
|
def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"):
|
||||||
|
super().__init__()
|
||||||
|
self.ratio = ratio
|
||||||
|
self.stride = ratio
|
||||||
|
|
||||||
|
if window_type == "hann":
|
||||||
|
# Hann-windowed sinc filter — identical to torchaudio.functional.resample
|
||||||
|
# with its default parameters (rolloff=0.99, lowpass_filter_width=6).
|
||||||
|
# Uses replicate boundary padding, matching the reference resampler exactly.
|
||||||
|
rolloff = 0.99
|
||||||
|
lowpass_filter_width = 6
|
||||||
|
width = math.ceil(lowpass_filter_width / rolloff)
|
||||||
|
self.kernel_size = 2 * width * ratio + 1
|
||||||
|
self.pad = width
|
||||||
|
self.pad_left = 2 * width * ratio
|
||||||
|
self.pad_right = self.kernel_size - ratio
|
||||||
|
t = (torch.arange(self.kernel_size) / ratio - width) * rolloff
|
||||||
|
t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width)
|
||||||
|
window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
||||||
|
filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1)
|
||||||
|
else:
|
||||||
|
# Kaiser-windowed sinc filter (BigVGAN default).
|
||||||
|
self.kernel_size = (
|
||||||
|
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||||
|
)
|
||||||
|
self.pad = self.kernel_size // ratio - 1
|
||||||
|
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||||
|
self.pad_right = (
|
||||||
|
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||||
|
)
|
||||||
|
filter = kaiser_sinc_filter1d(
|
||||||
|
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_buffer("filter", filter, persistent=persistent)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
_, C, _ = x.shape
|
||||||
|
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||||
|
x = self.ratio * F.conv_transpose1d(
|
||||||
|
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
||||||
|
)
|
||||||
|
x = x[..., self.pad_left : -self.pad_right]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DownSample1d(nn.Module):
|
||||||
|
def __init__(self, ratio=2, kernel_size=None):
|
||||||
|
super().__init__()
|
||||||
|
self.ratio = ratio
|
||||||
|
self.kernel_size = (
|
||||||
|
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||||
|
)
|
||||||
|
self.lowpass = LowPassFilter1d(
|
||||||
|
cutoff=0.5 / ratio,
|
||||||
|
half_width=0.6 / ratio,
|
||||||
|
stride=ratio,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.lowpass(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Activation1d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation,
|
||||||
|
up_ratio=2,
|
||||||
|
down_ratio=2,
|
||||||
|
up_kernel_size=12,
|
||||||
|
down_kernel_size=12,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.act = activation
|
||||||
|
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||||
|
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.upsample(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.downsample(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# BigVGAN v2 activations (Snake / SnakeBeta)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class Snake(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.alpha_logscale = alpha_logscale
|
||||||
|
self.alpha = nn.Parameter(
|
||||||
|
torch.zeros(in_features)
|
||||||
|
if alpha_logscale
|
||||||
|
else torch.ones(in_features) * alpha
|
||||||
|
)
|
||||||
|
self.alpha.requires_grad = alpha_trainable
|
||||||
|
self.eps = 1e-9
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
a = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||||
|
if self.alpha_logscale:
|
||||||
|
a = torch.exp(a)
|
||||||
|
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
|
||||||
|
|
||||||
|
|
||||||
|
class SnakeBeta(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.alpha_logscale = alpha_logscale
|
||||||
|
self.alpha = nn.Parameter(
|
||||||
|
torch.zeros(in_features)
|
||||||
|
if alpha_logscale
|
||||||
|
else torch.ones(in_features) * alpha
|
||||||
|
)
|
||||||
|
self.alpha.requires_grad = alpha_trainable
|
||||||
|
self.beta = nn.Parameter(
|
||||||
|
torch.zeros(in_features)
|
||||||
|
if alpha_logscale
|
||||||
|
else torch.ones(in_features) * alpha
|
||||||
|
)
|
||||||
|
self.beta.requires_grad = alpha_trainable
|
||||||
|
self.eps = 1e-9
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
a = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||||
|
b = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||||
|
if self.alpha_logscale:
|
||||||
|
a = torch.exp(a)
|
||||||
|
b = torch.exp(b)
|
||||||
|
return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class AMPBlock1(torch.nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"):
|
||||||
|
super().__init__()
|
||||||
|
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||||
|
self.convs1 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[2],
|
||||||
|
padding=get_padding(kernel_size, dilation[2]),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.convs2 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.acts1 = nn.ModuleList(
|
||||||
|
[Activation1d(act_cls(channels)) for _ in range(len(self.convs1))]
|
||||||
|
)
|
||||||
|
self.acts2 = nn.ModuleList(
|
||||||
|
[Activation1d(act_cls(channels)) for _ in range(len(self.convs2))]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2):
|
||||||
|
xt = a1(x)
|
||||||
|
xt = c1(xt)
|
||||||
|
xt = a2(xt)
|
||||||
|
xt = c2(xt)
|
||||||
|
x = x + xt
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# HiFi-GAN residual blocks
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class ResBlock1(torch.nn.Module):
|
class ResBlock1(torch.nn.Module):
|
||||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||||
super(ResBlock1, self).__init__()
|
super(ResBlock1, self).__init__()
|
||||||
@ -119,6 +421,7 @@ class Vocoder(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
|
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
|
||||||
|
|
||||||
|
Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1").
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config=None):
|
def __init__(self, config=None):
|
||||||
@ -128,19 +431,39 @@ class Vocoder(torch.nn.Module):
|
|||||||
config = self.get_default_config()
|
config = self.get_default_config()
|
||||||
|
|
||||||
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
|
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
|
||||||
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
|
upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2])
|
||||||
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
|
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4])
|
||||||
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||||
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
|
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
|
||||||
stereo = config.get("stereo", True)
|
stereo = config.get("stereo", True)
|
||||||
resblock = config.get("resblock", "1")
|
activation = config.get("activation", "snake")
|
||||||
|
use_bias_at_final = config.get("use_bias_at_final", True)
|
||||||
|
|
||||||
|
|
||||||
|
# "output_sample_rate" is not present in recent checkpoint configs.
|
||||||
|
# When absent (None), AudioVAE.output_sample_rate computes it as:
|
||||||
|
# sample_rate * vocoder.upsample_factor / mel_hop_length
|
||||||
|
# where upsample_factor = product of all upsample stride lengths,
|
||||||
|
# and mel_hop_length is loaded from the autoencoder config at
|
||||||
|
# preprocessing.stft.hop_length (see CausalAudioAutoencoder).
|
||||||
self.output_sample_rate = config.get("output_sample_rate")
|
self.output_sample_rate = config.get("output_sample_rate")
|
||||||
|
self.resblock = config.get("resblock", "1")
|
||||||
|
self.use_tanh_at_final = config.get("use_tanh_at_final", True)
|
||||||
|
self.apply_final_activation = config.get("apply_final_activation", True)
|
||||||
self.num_kernels = len(resblock_kernel_sizes)
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
self.num_upsamples = len(upsample_rates)
|
self.num_upsamples = len(upsample_rates)
|
||||||
|
|
||||||
in_channels = 128 if stereo else 64
|
in_channels = 128 if stereo else 64
|
||||||
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||||
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
|
||||||
|
if self.resblock == "1":
|
||||||
|
resblock_cls = ResBlock1
|
||||||
|
elif self.resblock == "2":
|
||||||
|
resblock_cls = ResBlock2
|
||||||
|
elif self.resblock == "AMP1":
|
||||||
|
resblock_cls = AMPBlock1
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown resblock type: {self.resblock}")
|
||||||
|
|
||||||
self.ups = nn.ModuleList()
|
self.ups = nn.ModuleList()
|
||||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
@ -157,25 +480,40 @@ class Vocoder(torch.nn.Module):
|
|||||||
self.resblocks = nn.ModuleList()
|
self.resblocks = nn.ModuleList()
|
||||||
for i in range(len(self.ups)):
|
for i in range(len(self.ups)):
|
||||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
||||||
self.resblocks.append(resblock_class(ch, k, d))
|
if self.resblock == "AMP1":
|
||||||
|
self.resblocks.append(resblock_cls(ch, k, d, activation=activation))
|
||||||
|
else:
|
||||||
|
self.resblocks.append(resblock_cls(ch, k, d))
|
||||||
|
|
||||||
out_channels = 2 if stereo else 1
|
out_channels = 2 if stereo else 1
|
||||||
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
|
if self.resblock == "AMP1":
|
||||||
|
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||||
|
self.act_post = Activation1d(act_cls(ch))
|
||||||
|
else:
|
||||||
|
self.act_post = nn.LeakyReLU()
|
||||||
|
|
||||||
|
self.conv_post = ops.Conv1d(
|
||||||
|
ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final
|
||||||
|
)
|
||||||
|
|
||||||
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
|
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
|
||||||
|
|
||||||
|
|
||||||
def get_default_config(self):
|
def get_default_config(self):
|
||||||
"""Generate default configuration for the vocoder."""
|
"""Generate default configuration for the vocoder."""
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"resblock_kernel_sizes": [3, 7, 11],
|
"resblock_kernel_sizes": [3, 7, 11],
|
||||||
"upsample_rates": [6, 5, 2, 2, 2],
|
"upsample_rates": [5, 4, 2, 2, 2],
|
||||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
|
||||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
"upsample_initial_channel": 1024,
|
"upsample_initial_channel": 1024,
|
||||||
"stereo": True,
|
"stereo": True,
|
||||||
"resblock": "1",
|
"resblock": "1",
|
||||||
|
"activation": "snake",
|
||||||
|
"use_bias_at_final": True,
|
||||||
|
"use_tanh_at_final": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
return config
|
return config
|
||||||
@ -196,8 +534,10 @@ class Vocoder(torch.nn.Module):
|
|||||||
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
||||||
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
|
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
|
||||||
x = self.conv_pre(x)
|
x = self.conv_pre(x)
|
||||||
|
|
||||||
for i in range(self.num_upsamples):
|
for i in range(self.num_upsamples):
|
||||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
if self.resblock != "AMP1":
|
||||||
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
x = self.ups[i](x)
|
x = self.ups[i](x)
|
||||||
xs = None
|
xs = None
|
||||||
for j in range(self.num_kernels):
|
for j in range(self.num_kernels):
|
||||||
@ -206,8 +546,167 @@ class Vocoder(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||||
x = xs / self.num_kernels
|
x = xs / self.num_kernels
|
||||||
x = F.leaky_relu(x)
|
|
||||||
|
x = self.act_post(x)
|
||||||
x = self.conv_post(x)
|
x = self.conv_post(x)
|
||||||
x = torch.tanh(x)
|
|
||||||
|
if self.apply_final_activation:
|
||||||
|
if self.use_tanh_at_final:
|
||||||
|
x = torch.tanh(x)
|
||||||
|
else:
|
||||||
|
x = torch.clamp(x, -1, 1)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class _STFTFn(nn.Module):
|
||||||
|
"""Implements STFT as a convolution with precomputed DFT × Hann-window bases.
|
||||||
|
|
||||||
|
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
|
||||||
|
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
|
||||||
|
bfloat16 bases from training ensures the mel values fed to the BWE generator are
|
||||||
|
bit-identical to what it was trained on.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, filter_length: int, hop_length: int, win_length: int):
|
||||||
|
super().__init__()
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.win_length = win_length
|
||||||
|
n_freqs = filter_length // 2 + 1
|
||||||
|
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||||
|
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute magnitude and phase spectrogram from a batch of waveforms.
|
||||||
|
|
||||||
|
Applies causal (left-only) padding of win_length - hop_length samples so that
|
||||||
|
each output frame depends only on past and present input — no lookahead.
|
||||||
|
The STFT is computed by convolving the padded signal with forward_basis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: Waveform tensor of shape (B, T).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||||
|
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||||
|
Computed in float32 for numerical stability, then cast back to
|
||||||
|
the input dtype.
|
||||||
|
"""
|
||||||
|
if y.dim() == 2:
|
||||||
|
y = y.unsqueeze(1) # (B, 1, T)
|
||||||
|
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
||||||
|
y = F.pad(y, (left_pad, 0))
|
||||||
|
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
|
||||||
|
n_freqs = spec.shape[1] // 2
|
||||||
|
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
||||||
|
magnitude = torch.sqrt(real ** 2 + imag ** 2)
|
||||||
|
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
|
||||||
|
return magnitude, phase
|
||||||
|
|
||||||
|
|
||||||
|
class MelSTFT(nn.Module):
|
||||||
|
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
|
||||||
|
|
||||||
|
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
|
||||||
|
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
|
||||||
|
|
||||||
|
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
|
||||||
|
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filter_length: int,
|
||||||
|
hop_length: int,
|
||||||
|
win_length: int,
|
||||||
|
n_mel_channels: int,
|
||||||
|
sampling_rate: int,
|
||||||
|
mel_fmin: float,
|
||||||
|
mel_fmax: float,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
|
||||||
|
|
||||||
|
n_freqs = filter_length // 2 + 1
|
||||||
|
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
|
||||||
|
|
||||||
|
def mel_spectrogram(
|
||||||
|
self, y: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute log-mel spectrogram and auxiliary spectral quantities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: Waveform tensor of shape (B, T).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
|
||||||
|
Computed as log(clamp(mel_basis @ magnitude, min=1e-5)).
|
||||||
|
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||||
|
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||||
|
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
|
||||||
|
"""
|
||||||
|
magnitude, phase = self.stft_fn(y)
|
||||||
|
energy = torch.norm(magnitude, dim=1)
|
||||||
|
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
||||||
|
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||||
|
return log_mel, magnitude, phase, energy
|
||||||
|
|
||||||
|
|
||||||
|
class VocoderWithBWE(torch.nn.Module):
|
||||||
|
"""Vocoder with bandwidth extension (BWE) for higher sample rate output.
|
||||||
|
|
||||||
|
Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples
|
||||||
|
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
vocoder_config = config["vocoder"]
|
||||||
|
bwe_config = config["bwe"]
|
||||||
|
|
||||||
|
self.vocoder = Vocoder(config=vocoder_config)
|
||||||
|
self.bwe_generator = Vocoder(
|
||||||
|
config={**bwe_config, "apply_final_activation": False}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_sample_rate = bwe_config["input_sampling_rate"]
|
||||||
|
self.output_sample_rate = bwe_config["output_sampling_rate"]
|
||||||
|
self.hop_length = bwe_config["hop_length"]
|
||||||
|
|
||||||
|
self.mel_stft = MelSTFT(
|
||||||
|
filter_length=bwe_config["n_fft"],
|
||||||
|
hop_length=bwe_config["hop_length"],
|
||||||
|
win_length=bwe_config["n_fft"],
|
||||||
|
n_mel_channels=bwe_config["num_mels"],
|
||||||
|
sampling_rate=bwe_config["input_sampling_rate"],
|
||||||
|
mel_fmin=0.0,
|
||||||
|
mel_fmax=bwe_config["input_sampling_rate"] / 2.0,
|
||||||
|
)
|
||||||
|
self.resampler = UpSample1d(
|
||||||
|
ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"],
|
||||||
|
persistent=False,
|
||||||
|
window_type="hann",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _compute_mel(self, audio):
|
||||||
|
"""Compute log-mel spectrogram from waveform using causal STFT bases."""
|
||||||
|
B, C, T = audio.shape
|
||||||
|
flat = audio.reshape(B * C, -1) # (B*C, T)
|
||||||
|
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
|
||||||
|
return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
|
||||||
|
|
||||||
|
def forward(self, mel_spec):
|
||||||
|
x = self.vocoder(mel_spec)
|
||||||
|
_, _, T_low = x.shape
|
||||||
|
T_out = T_low * self.output_sample_rate // self.input_sample_rate
|
||||||
|
|
||||||
|
remainder = T_low % self.hop_length
|
||||||
|
if remainder != 0:
|
||||||
|
x = F.pad(x, (0, self.hop_length - remainder))
|
||||||
|
|
||||||
|
mel = self._compute_mel(x)
|
||||||
|
residual = self.bwe_generator(mel)
|
||||||
|
skip = self.resampler(x)
|
||||||
|
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
|
||||||
|
|
||||||
|
return torch.clamp(residual + skip, -1, 1)[..., :T_out]
|
||||||
|
|||||||
@ -1021,7 +1021,7 @@ class LTXAV(BaseModel):
|
|||||||
cross_attn = kwargs.get("cross_attn", None)
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
|
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
|
||||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()))
|
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False))
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||||
|
|||||||
@ -1467,7 +1467,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
||||||
elif clip_type == CLIPType.LTXV:
|
elif clip_type == CLIPType.LTXV:
|
||||||
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data), **comfy.text_encoders.lt.sd_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
|
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif clip_type == CLIPType.NEWBIE:
|
elif clip_type == CLIPType.NEWBIE:
|
||||||
|
|||||||
@ -97,18 +97,39 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
|
|||||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
|
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
|
||||||
|
|
||||||
|
class DualLinearProjection(torch.nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.audio_aggregate_embed = operations.Linear(in_dim, out_dim_audio, bias=True, dtype=dtype, device=device)
|
||||||
|
self.video_aggregate_embed = operations.Linear(in_dim, out_dim_video, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
source_dim = x.shape[-1]
|
||||||
|
x = x.movedim(1, -1)
|
||||||
|
x = (x * torch.rsqrt(torch.mean(x**2, dim=2, keepdim=True) + 1e-6)).flatten(start_dim=2)
|
||||||
|
|
||||||
|
video = self.video_aggregate_embed(x * math.sqrt(self.video_aggregate_embed.out_features / source_dim))
|
||||||
|
audio = self.audio_aggregate_embed(x * math.sqrt(self.audio_aggregate_embed.out_features / source_dim))
|
||||||
|
return torch.cat((video, audio), dim=-1)
|
||||||
|
|
||||||
class LTXAVTEModel(torch.nn.Module):
|
class LTXAVTEModel(torch.nn.Module):
|
||||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
def __init__(self, dtype_llama=None, device="cpu", dtype=None, text_projection_type="single_linear", model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
self.dtypes.add(dtype)
|
self.dtypes.add(dtype)
|
||||||
self.compat_mode = False
|
self.compat_mode = False
|
||||||
|
self.text_projection_type = text_projection_type
|
||||||
|
|
||||||
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
||||||
self.dtypes.add(dtype_llama)
|
self.dtypes.add(dtype_llama)
|
||||||
|
|
||||||
operations = self.gemma3_12b.operations # TODO
|
operations = self.gemma3_12b.operations # TODO
|
||||||
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
|
||||||
|
if self.text_projection_type == "single_linear":
|
||||||
|
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
||||||
|
elif self.text_projection_type == "dual_linear":
|
||||||
|
self.text_embedding_projection = DualLinearProjection(3840 * 49, 4096, 2048, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
|
||||||
def enable_compat_mode(self): # TODO: remove
|
def enable_compat_mode(self): # TODO: remove
|
||||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||||
@ -148,18 +169,25 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
out_device = out.device
|
out_device = out.device
|
||||||
if comfy.model_management.should_use_bf16(self.execution_device):
|
if comfy.model_management.should_use_bf16(self.execution_device):
|
||||||
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
||||||
out = out.movedim(1, -1).to(self.execution_device)
|
|
||||||
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
|
||||||
out = out.reshape((out.shape[0], out.shape[1], -1))
|
|
||||||
out = self.text_embedding_projection(out)
|
|
||||||
out = out.float()
|
|
||||||
|
|
||||||
if self.compat_mode:
|
if self.text_projection_type == "single_linear":
|
||||||
out_vid = self.video_embeddings_connector(out)[0]
|
out = out.movedim(1, -1).to(self.execution_device)
|
||||||
out_audio = self.audio_embeddings_connector(out)[0]
|
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
||||||
out = torch.concat((out_vid, out_audio), dim=-1)
|
out = out.reshape((out.shape[0], out.shape[1], -1))
|
||||||
|
out = self.text_embedding_projection(out)
|
||||||
|
|
||||||
return out.to(out_device), pooled
|
if self.compat_mode:
|
||||||
|
out_vid = self.video_embeddings_connector(out)[0]
|
||||||
|
out_audio = self.audio_embeddings_connector(out)[0]
|
||||||
|
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||||
|
extra = {}
|
||||||
|
else:
|
||||||
|
extra = {"unprocessed_ltxav_embeds": True}
|
||||||
|
elif self.text_projection_type == "dual_linear":
|
||||||
|
out = self.text_embedding_projection(out)
|
||||||
|
extra = {"unprocessed_ltxav_embeds": True}
|
||||||
|
|
||||||
|
return out.to(device=out_device, dtype=torch.float), pooled, extra
|
||||||
|
|
||||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||||
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
||||||
@ -168,7 +196,7 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||||
return self.gemma3_12b.load_sd(sd)
|
return self.gemma3_12b.load_sd(sd)
|
||||||
else:
|
else:
|
||||||
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True)
|
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "text_embedding_projection.": "text_embedding_projection."}, filter_keys=True)
|
||||||
if len(sdo) == 0:
|
if len(sdo) == 0:
|
||||||
sdo = sd
|
sdo = sd
|
||||||
|
|
||||||
@ -206,7 +234,7 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
num_tokens = max(num_tokens, 642)
|
num_tokens = max(num_tokens, 642)
|
||||||
return num_tokens * constant * 1024 * 1024
|
return num_tokens * constant * 1024 * 1024
|
||||||
|
|
||||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None, text_projection_type="single_linear"):
|
||||||
class LTXAVTEModel_(LTXAVTEModel):
|
class LTXAVTEModel_(LTXAVTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_quantization_metadata is not None:
|
if llama_quantization_metadata is not None:
|
||||||
@ -214,9 +242,19 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
|||||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, text_projection_type=text_projection_type, model_options=model_options)
|
||||||
return LTXAVTEModel_
|
return LTXAVTEModel_
|
||||||
|
|
||||||
|
|
||||||
|
def sd_detect(state_dict_list, prefix=""):
|
||||||
|
for sd in state_dict_list:
|
||||||
|
if "{}text_embedding_projection.audio_aggregate_embed.bias".format(prefix) in sd:
|
||||||
|
return {"text_projection_type": "dual_linear"}
|
||||||
|
if "{}text_embedding_projection.weight".format(prefix) in sd or "{}text_embedding_projection.aggregate_embed.weight".format(prefix) in sd:
|
||||||
|
return {"text_projection_type": "single_linear"}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
|
def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class Gemma3_12BModel_(Gemma3_12BModel):
|
class Gemma3_12BModel_(Gemma3_12BModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user