mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-23 18:13:28 +08:00
Merge remote-tracking branch 'upstream/master' into qwen35
This commit is contained in:
commit
56708ac088
@ -2,11 +2,16 @@ from typing import Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from comfy.ldm.lightricks.model import (
|
from comfy.ldm.lightricks.model import (
|
||||||
|
ADALN_BASE_PARAMS_COUNT,
|
||||||
|
ADALN_CROSS_ATTN_PARAMS_COUNT,
|
||||||
CrossAttention,
|
CrossAttention,
|
||||||
FeedForward,
|
FeedForward,
|
||||||
AdaLayerNormSingle,
|
AdaLayerNormSingle,
|
||||||
PixArtAlphaTextProjection,
|
PixArtAlphaTextProjection,
|
||||||
|
NormSingleLinearTextProjection,
|
||||||
LTXVModel,
|
LTXVModel,
|
||||||
|
apply_cross_attention_adaln,
|
||||||
|
compute_prompt_timestep,
|
||||||
)
|
)
|
||||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||||
@ -87,6 +92,8 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
v_context_dim=None,
|
v_context_dim=None,
|
||||||
a_context_dim=None,
|
a_context_dim=None,
|
||||||
attn_precision=None,
|
attn_precision=None,
|
||||||
|
apply_gated_attention=False,
|
||||||
|
cross_attention_adaln=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -94,6 +101,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.attn_precision = attn_precision
|
self.attn_precision = attn_precision
|
||||||
|
self.cross_attention_adaln = cross_attention_adaln
|
||||||
|
|
||||||
self.attn1 = CrossAttention(
|
self.attn1 = CrossAttention(
|
||||||
query_dim=v_dim,
|
query_dim=v_dim,
|
||||||
@ -101,6 +109,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
dim_head=vd_head,
|
dim_head=vd_head,
|
||||||
context_dim=None,
|
context_dim=None,
|
||||||
attn_precision=self.attn_precision,
|
attn_precision=self.attn_precision,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@ -111,6 +120,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
dim_head=ad_head,
|
dim_head=ad_head,
|
||||||
context_dim=None,
|
context_dim=None,
|
||||||
attn_precision=self.attn_precision,
|
attn_precision=self.attn_precision,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@ -122,6 +132,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
heads=v_heads,
|
heads=v_heads,
|
||||||
dim_head=vd_head,
|
dim_head=vd_head,
|
||||||
attn_precision=self.attn_precision,
|
attn_precision=self.attn_precision,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@ -132,6 +143,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
heads=a_heads,
|
heads=a_heads,
|
||||||
dim_head=ad_head,
|
dim_head=ad_head,
|
||||||
attn_precision=self.attn_precision,
|
attn_precision=self.attn_precision,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@ -144,6 +156,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
heads=a_heads,
|
heads=a_heads,
|
||||||
dim_head=ad_head,
|
dim_head=ad_head,
|
||||||
attn_precision=self.attn_precision,
|
attn_precision=self.attn_precision,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@ -156,6 +169,7 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
heads=a_heads,
|
heads=a_heads,
|
||||||
dim_head=ad_head,
|
dim_head=ad_head,
|
||||||
attn_precision=self.attn_precision,
|
attn_precision=self.attn_precision,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@ -168,11 +182,16 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
|
||||||
)
|
)
|
||||||
|
|
||||||
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
|
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||||
|
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype))
|
||||||
self.audio_scale_shift_table = nn.Parameter(
|
self.audio_scale_shift_table = nn.Parameter(
|
||||||
torch.empty(6, a_dim, device=device, dtype=dtype)
|
torch.empty(num_ada_params, a_dim, device=device, dtype=dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cross_attention_adaln:
|
||||||
|
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype))
|
||||||
|
self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
|
||||||
torch.empty(5, a_dim, device=device, dtype=dtype)
|
torch.empty(5, a_dim, device=device, dtype=dtype)
|
||||||
)
|
)
|
||||||
@ -215,10 +234,30 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
return (*scale_shift_ada_values, *gate_ada_values)
|
return (*scale_shift_ada_values, *gate_ada_values)
|
||||||
|
|
||||||
|
def _apply_text_cross_attention(
|
||||||
|
self, x, context, attn, scale_shift_table, prompt_scale_shift_table,
|
||||||
|
timestep, prompt_timestep, attention_mask, transformer_options,
|
||||||
|
):
|
||||||
|
"""Apply text cross-attention, with optional ADaLN modulation."""
|
||||||
|
if self.cross_attention_adaln:
|
||||||
|
shift_q, scale_q, gate = self.get_ada_values(
|
||||||
|
scale_shift_table, x.shape[0], timestep, slice(6, 9)
|
||||||
|
)
|
||||||
|
return apply_cross_attention_adaln(
|
||||||
|
x, context, attn, shift_q, scale_q, gate,
|
||||||
|
prompt_scale_shift_table, prompt_timestep,
|
||||||
|
attention_mask, transformer_options,
|
||||||
|
)
|
||||||
|
return attn(
|
||||||
|
comfy.ldm.common_dit.rms_norm(x), context=context,
|
||||||
|
mask=attention_mask, transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
|
||||||
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
|
||||||
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
|
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
|
||||||
|
v_prompt_timestep=None, a_prompt_timestep=None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
run_vx = transformer_options.get("run_vx", True)
|
run_vx = transformer_options.get("run_vx", True)
|
||||||
run_ax = transformer_options.get("run_ax", True)
|
run_ax = transformer_options.get("run_ax", True)
|
||||||
@ -240,7 +279,11 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
|
||||||
vx.addcmul_(attn1_out, vgate_msa)
|
vx.addcmul_(attn1_out, vgate_msa)
|
||||||
del vgate_msa, attn1_out
|
del vgate_msa, attn1_out
|
||||||
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
|
vx.add_(self._apply_text_cross_attention(
|
||||||
|
vx, v_context, self.attn2, self.scale_shift_table,
|
||||||
|
getattr(self, 'prompt_scale_shift_table', None),
|
||||||
|
v_timestep, v_prompt_timestep, attention_mask, transformer_options,)
|
||||||
|
)
|
||||||
|
|
||||||
# audio
|
# audio
|
||||||
if run_ax:
|
if run_ax:
|
||||||
@ -254,7 +297,11 @@ class BasicAVTransformerBlock(nn.Module):
|
|||||||
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
|
||||||
ax.addcmul_(attn1_out, agate_msa)
|
ax.addcmul_(attn1_out, agate_msa)
|
||||||
del agate_msa, attn1_out
|
del agate_msa, attn1_out
|
||||||
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
|
ax.add_(self._apply_text_cross_attention(
|
||||||
|
ax, a_context, self.audio_attn2, self.audio_scale_shift_table,
|
||||||
|
getattr(self, 'audio_prompt_scale_shift_table', None),
|
||||||
|
a_timestep, a_prompt_timestep, attention_mask, transformer_options,)
|
||||||
|
)
|
||||||
|
|
||||||
# video - audio cross attention.
|
# video - audio cross attention.
|
||||||
if run_a2v or run_v2a:
|
if run_a2v or run_v2a:
|
||||||
@ -351,6 +398,9 @@ class LTXAVModel(LTXVModel):
|
|||||||
use_middle_indices_grid=False,
|
use_middle_indices_grid=False,
|
||||||
timestep_scale_multiplier=1000.0,
|
timestep_scale_multiplier=1000.0,
|
||||||
av_ca_timestep_scale_multiplier=1.0,
|
av_ca_timestep_scale_multiplier=1.0,
|
||||||
|
apply_gated_attention=False,
|
||||||
|
caption_proj_before_connector=False,
|
||||||
|
cross_attention_adaln=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -362,6 +412,7 @@ class LTXAVModel(LTXVModel):
|
|||||||
self.audio_attention_head_dim = audio_attention_head_dim
|
self.audio_attention_head_dim = audio_attention_head_dim
|
||||||
self.audio_num_attention_heads = audio_num_attention_heads
|
self.audio_num_attention_heads = audio_num_attention_heads
|
||||||
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
|
||||||
|
self.apply_gated_attention = apply_gated_attention
|
||||||
|
|
||||||
# Calculate audio dimensions
|
# Calculate audio dimensions
|
||||||
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
|
||||||
@ -386,6 +437,8 @@ class LTXAVModel(LTXVModel):
|
|||||||
vae_scale_factors=vae_scale_factors,
|
vae_scale_factors=vae_scale_factors,
|
||||||
use_middle_indices_grid=use_middle_indices_grid,
|
use_middle_indices_grid=use_middle_indices_grid,
|
||||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||||
|
caption_proj_before_connector=caption_proj_before_connector,
|
||||||
|
cross_attention_adaln=cross_attention_adaln,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@ -400,14 +453,28 @@ class LTXAVModel(LTXVModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Audio-specific AdaLN
|
# Audio-specific AdaLN
|
||||||
|
audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||||
self.audio_adaln_single = AdaLayerNormSingle(
|
self.audio_adaln_single = AdaLayerNormSingle(
|
||||||
self.audio_inner_dim,
|
self.audio_inner_dim,
|
||||||
|
embedding_coefficient=audio_embedding_coefficient,
|
||||||
use_additional_conditions=False,
|
use_additional_conditions=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=self.operations,
|
operations=self.operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.cross_attention_adaln:
|
||||||
|
self.audio_prompt_adaln_single = AdaLayerNormSingle(
|
||||||
|
self.audio_inner_dim,
|
||||||
|
embedding_coefficient=2,
|
||||||
|
use_additional_conditions=False,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.audio_prompt_adaln_single = None
|
||||||
|
|
||||||
num_scale_shift_values = 4
|
num_scale_shift_values = 4
|
||||||
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
|
||||||
self.inner_dim,
|
self.inner_dim,
|
||||||
@ -443,35 +510,73 @@ class LTXAVModel(LTXVModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Audio caption projection
|
# Audio caption projection
|
||||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
if self.caption_proj_before_connector:
|
||||||
in_features=self.caption_channels,
|
if self.caption_projection_first_linear:
|
||||||
hidden_size=self.audio_inner_dim,
|
self.audio_caption_projection = NormSingleLinearTextProjection(
|
||||||
dtype=dtype,
|
in_features=self.caption_channels,
|
||||||
device=device,
|
hidden_size=self.audio_inner_dim,
|
||||||
operations=self.operations,
|
dtype=dtype,
|
||||||
)
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.audio_caption_projection = lambda a: a
|
||||||
|
else:
|
||||||
|
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||||
|
in_features=self.caption_channels,
|
||||||
|
hidden_size=self.audio_inner_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
connector_split_rope = kwargs.get("rope_type", "split") == "split"
|
||||||
|
connector_gated_attention = kwargs.get("connector_apply_gated_attention", False)
|
||||||
|
attention_head_dim = kwargs.get("connector_attention_head_dim", 128)
|
||||||
|
num_attention_heads = kwargs.get("connector_num_attention_heads", 30)
|
||||||
|
num_layers = kwargs.get("connector_num_layers", 2)
|
||||||
|
|
||||||
self.audio_embeddings_connector = Embeddings1DConnector(
|
self.audio_embeddings_connector = Embeddings1DConnector(
|
||||||
split_rope=True,
|
attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim),
|
||||||
|
num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads),
|
||||||
|
num_layers=num_layers,
|
||||||
|
split_rope=connector_split_rope,
|
||||||
double_precision_rope=True,
|
double_precision_rope=True,
|
||||||
|
apply_gated_attention=connector_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=self.operations,
|
operations=self.operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.video_embeddings_connector = Embeddings1DConnector(
|
self.video_embeddings_connector = Embeddings1DConnector(
|
||||||
split_rope=True,
|
attention_head_dim=attention_head_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_layers=num_layers,
|
||||||
|
split_rope=connector_split_rope,
|
||||||
double_precision_rope=True,
|
double_precision_rope=True,
|
||||||
|
apply_gated_attention=connector_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=self.operations,
|
operations=self.operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
def preprocess_text_embeds(self, context):
|
def preprocess_text_embeds(self, context, unprocessed=False):
|
||||||
if context.shape[-1] == self.caption_channels * 2:
|
# LTXv2 fully processed context has dimension of self.caption_channels * 2
|
||||||
return context
|
# LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim
|
||||||
out_vid = self.video_embeddings_connector(context)[0]
|
if not unprocessed:
|
||||||
out_audio = self.audio_embeddings_connector(context)[0]
|
if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2):
|
||||||
|
return context
|
||||||
|
if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim:
|
||||||
|
context_vid = context[:, :, :self.cross_attention_dim]
|
||||||
|
context_audio = context[:, :, self.cross_attention_dim:]
|
||||||
|
else:
|
||||||
|
context_vid = context
|
||||||
|
context_audio = context
|
||||||
|
if self.caption_proj_before_connector:
|
||||||
|
context_vid = self.caption_projection(context_vid)
|
||||||
|
context_audio = self.audio_caption_projection(context_audio)
|
||||||
|
out_vid = self.video_embeddings_connector(context_vid)[0]
|
||||||
|
out_audio = self.audio_embeddings_connector(context_audio)[0]
|
||||||
return torch.concat((out_vid, out_audio), dim=-1)
|
return torch.concat((out_vid, out_audio), dim=-1)
|
||||||
|
|
||||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||||
@ -487,6 +592,8 @@ class LTXAVModel(LTXVModel):
|
|||||||
ad_head=self.audio_attention_head_dim,
|
ad_head=self.audio_attention_head_dim,
|
||||||
v_context_dim=self.cross_attention_dim,
|
v_context_dim=self.cross_attention_dim,
|
||||||
a_context_dim=self.audio_cross_attention_dim,
|
a_context_dim=self.audio_cross_attention_dim,
|
||||||
|
apply_gated_attention=self.apply_gated_attention,
|
||||||
|
cross_attention_adaln=self.cross_attention_adaln,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=self.operations,
|
operations=self.operations,
|
||||||
@ -608,6 +715,10 @@ class LTXAVModel(LTXVModel):
|
|||||||
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
|
||||||
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
|
||||||
|
|
||||||
|
v_prompt_timestep = compute_prompt_timestep(
|
||||||
|
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare audio timestep
|
# Prepare audio timestep
|
||||||
a_timestep = kwargs.get("a_timestep")
|
a_timestep = kwargs.get("a_timestep")
|
||||||
if a_timestep is not None:
|
if a_timestep is not None:
|
||||||
@ -618,25 +729,25 @@ class LTXAVModel(LTXVModel):
|
|||||||
|
|
||||||
# Cross-attention timesteps - compress these too
|
# Cross-attention timesteps - compress these too
|
||||||
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
|
||||||
a_timestep_flat,
|
timestep.max().expand_as(a_timestep_flat),
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
|
||||||
timestep_flat,
|
a_timestep.max().expand_as(timestep_flat),
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
|
||||||
timestep_flat * av_ca_factor,
|
a_timestep.max().expand_as(timestep_flat) * av_ca_factor,
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
)
|
)
|
||||||
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
|
||||||
a_timestep_flat * av_ca_factor,
|
timestep.max().expand_as(a_timestep_flat) * av_ca_factor,
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
@ -660,29 +771,40 @@ class LTXAVModel(LTXVModel):
|
|||||||
# Audio timesteps
|
# Audio timesteps
|
||||||
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
|
||||||
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
|
||||||
|
|
||||||
|
a_prompt_timestep = compute_prompt_timestep(
|
||||||
|
self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
a_timestep = timestep_scaled
|
a_timestep = timestep_scaled
|
||||||
a_embedded_timestep = kwargs.get("embedded_timestep")
|
a_embedded_timestep = kwargs.get("embedded_timestep")
|
||||||
cross_av_timestep_ss = []
|
cross_av_timestep_ss = []
|
||||||
|
a_prompt_timestep = None
|
||||||
|
|
||||||
return [v_timestep, a_timestep, cross_av_timestep_ss], [
|
return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [
|
||||||
v_embedded_timestep,
|
v_embedded_timestep,
|
||||||
a_embedded_timestep,
|
a_embedded_timestep,
|
||||||
]
|
], None
|
||||||
|
|
||||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||||
vx = x[0]
|
vx = x[0]
|
||||||
ax = x[1]
|
ax = x[1]
|
||||||
|
video_dim = vx.shape[-1]
|
||||||
|
audio_dim = ax.shape[-1]
|
||||||
|
|
||||||
|
v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim
|
||||||
|
a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim
|
||||||
|
|
||||||
v_context, a_context = torch.split(
|
v_context, a_context = torch.split(
|
||||||
context, int(context.shape[-1] / 2), len(context.shape) - 1
|
context, [v_context_dim, a_context_dim], len(context.shape) - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
v_context, attention_mask = super()._prepare_context(
|
v_context, attention_mask = super()._prepare_context(
|
||||||
v_context, batch_size, vx, attention_mask
|
v_context, batch_size, vx, attention_mask
|
||||||
)
|
)
|
||||||
if self.audio_caption_projection is not None:
|
if self.caption_proj_before_connector is False:
|
||||||
a_context = self.audio_caption_projection(a_context)
|
a_context = self.audio_caption_projection(a_context)
|
||||||
a_context = a_context.view(batch_size, -1, ax.shape[-1])
|
a_context = a_context.view(batch_size, -1, audio_dim)
|
||||||
|
|
||||||
return [v_context, a_context], attention_mask
|
return [v_context, a_context], attention_mask
|
||||||
|
|
||||||
@ -744,6 +866,9 @@ class LTXAVModel(LTXVModel):
|
|||||||
av_ca_v2a_gate_noise_timestep,
|
av_ca_v2a_gate_noise_timestep,
|
||||||
) = timestep[2]
|
) = timestep[2]
|
||||||
|
|
||||||
|
v_prompt_timestep = timestep[3]
|
||||||
|
a_prompt_timestep = timestep[4]
|
||||||
|
|
||||||
"""Process transformer blocks for LTXAV."""
|
"""Process transformer blocks for LTXAV."""
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
@ -771,6 +896,8 @@ class LTXAVModel(LTXVModel):
|
|||||||
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
a_cross_gate_timestep=args["a_cross_gate_timestep"],
|
||||||
transformer_options=args["transformer_options"],
|
transformer_options=args["transformer_options"],
|
||||||
self_attention_mask=args.get("self_attention_mask"),
|
self_attention_mask=args.get("self_attention_mask"),
|
||||||
|
v_prompt_timestep=args.get("v_prompt_timestep"),
|
||||||
|
a_prompt_timestep=args.get("a_prompt_timestep"),
|
||||||
)
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -792,6 +919,8 @@ class LTXAVModel(LTXVModel):
|
|||||||
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
|
||||||
"transformer_options": transformer_options,
|
"transformer_options": transformer_options,
|
||||||
"self_attention_mask": self_attention_mask,
|
"self_attention_mask": self_attention_mask,
|
||||||
|
"v_prompt_timestep": v_prompt_timestep,
|
||||||
|
"a_prompt_timestep": a_prompt_timestep,
|
||||||
},
|
},
|
||||||
{"original_block": block_wrap},
|
{"original_block": block_wrap},
|
||||||
)
|
)
|
||||||
@ -814,6 +943,8 @@ class LTXAVModel(LTXVModel):
|
|||||||
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
self_attention_mask=self_attention_mask,
|
self_attention_mask=self_attention_mask,
|
||||||
|
v_prompt_timestep=v_prompt_timestep,
|
||||||
|
a_prompt_timestep=a_prompt_timestep,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [vx, ax]
|
return [vx, ax]
|
||||||
|
|||||||
@ -50,6 +50,7 @@ class BasicTransformerBlock1D(nn.Module):
|
|||||||
d_head,
|
d_head,
|
||||||
context_dim=None,
|
context_dim=None,
|
||||||
attn_precision=None,
|
attn_precision=None,
|
||||||
|
apply_gated_attention=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -63,6 +64,7 @@ class BasicTransformerBlock1D(nn.Module):
|
|||||||
heads=n_heads,
|
heads=n_heads,
|
||||||
dim_head=d_head,
|
dim_head=d_head,
|
||||||
context_dim=None,
|
context_dim=None,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@ -121,6 +123,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
positional_embedding_max_pos=[4096],
|
positional_embedding_max_pos=[4096],
|
||||||
causal_temporal_positioning=False,
|
causal_temporal_positioning=False,
|
||||||
num_learnable_registers: Optional[int] = 128,
|
num_learnable_registers: Optional[int] = 128,
|
||||||
|
apply_gated_attention=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -145,6 +148,7 @@ class Embeddings1DConnector(nn.Module):
|
|||||||
num_attention_heads,
|
num_attention_heads,
|
||||||
attention_head_dim,
|
attention_head_dim,
|
||||||
context_dim=cross_attention_dim,
|
context_dim=cross_attention_dim,
|
||||||
|
apply_gated_attention=apply_gated_attention,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
|
|||||||
@ -275,6 +275,30 @@ class PixArtAlphaTextProjection(nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class NormSingleLinearTextProjection(nn.Module):
|
||||||
|
"""Text projection for 20B models - single linear with RMSNorm (no activation)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, in_features, hidden_size, dtype=None, device=None, operations=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if operations is None:
|
||||||
|
operations = comfy.ops.disable_weight_init
|
||||||
|
self.in_norm = operations.RMSNorm(
|
||||||
|
in_features, eps=1e-6, elementwise_affine=False
|
||||||
|
)
|
||||||
|
self.linear_1 = operations.Linear(
|
||||||
|
in_features, hidden_size, bias=True, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.in_features = in_features
|
||||||
|
|
||||||
|
def forward(self, caption):
|
||||||
|
caption = self.in_norm(caption)
|
||||||
|
caption = caption * (self.hidden_size / self.in_features) ** 0.5
|
||||||
|
return self.linear_1(caption)
|
||||||
|
|
||||||
|
|
||||||
class GELU_approx(nn.Module):
|
class GELU_approx(nn.Module):
|
||||||
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -343,6 +367,7 @@ class CrossAttention(nn.Module):
|
|||||||
dim_head=64,
|
dim_head=64,
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
attn_precision=None,
|
attn_precision=None,
|
||||||
|
apply_gated_attention=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -362,6 +387,12 @@ class CrossAttention(nn.Module):
|
|||||||
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||||
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# Optional per-head gating
|
||||||
|
if apply_gated_attention:
|
||||||
|
self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
self.to_gate_logits = None
|
||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
||||||
)
|
)
|
||||||
@ -383,16 +414,30 @@ class CrossAttention(nn.Module):
|
|||||||
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
else:
|
else:
|
||||||
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
||||||
|
|
||||||
|
# Apply per-head gating if enabled
|
||||||
|
if self.to_gate_logits is not None:
|
||||||
|
gate_logits = self.to_gate_logits(x) # (B, T, H)
|
||||||
|
b, t, _ = out.shape
|
||||||
|
out = out.view(b, t, self.heads, self.dim_head)
|
||||||
|
gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity
|
||||||
|
out = out * gates.unsqueeze(-1)
|
||||||
|
out = out.view(b, t, self.heads * self.dim_head)
|
||||||
|
|
||||||
return self.to_out(out)
|
return self.to_out(out)
|
||||||
|
|
||||||
|
# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate)
|
||||||
|
ADALN_BASE_PARAMS_COUNT = 6
|
||||||
|
ADALN_CROSS_ATTN_PARAMS_COUNT = 9
|
||||||
|
|
||||||
class BasicTransformerBlock(nn.Module):
|
class BasicTransformerBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
|
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.attn_precision = attn_precision
|
self.attn_precision = attn_precision
|
||||||
|
self.cross_attention_adaln = cross_attention_adaln
|
||||||
self.attn1 = CrossAttention(
|
self.attn1 = CrossAttention(
|
||||||
query_dim=dim,
|
query_dim=dim,
|
||||||
heads=n_heads,
|
heads=n_heads,
|
||||||
@ -416,18 +461,25 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
operations=operations,
|
operations=operations,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||||
|
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None):
|
if cross_attention_adaln:
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None):
|
||||||
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2)
|
||||||
attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options)
|
|
||||||
x.addcmul_(attn1_input, gate_msa)
|
|
||||||
del attn1_input
|
|
||||||
|
|
||||||
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa
|
||||||
|
|
||||||
|
if self.cross_attention_adaln:
|
||||||
|
shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2)
|
||||||
|
x += apply_cross_attention_adaln(
|
||||||
|
x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca,
|
||||||
|
self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
||||||
|
|
||||||
y = comfy.ldm.common_dit.rms_norm(x)
|
y = comfy.ldm.common_dit.rms_norm(x)
|
||||||
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
||||||
@ -435,6 +487,47 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype):
|
||||||
|
"""Compute a single global prompt timestep for cross-attention ADaLN.
|
||||||
|
|
||||||
|
Uses the max across tokens (matching JAX max_per_segment) and broadcasts
|
||||||
|
over text tokens. Returns None when *adaln_module* is None.
|
||||||
|
"""
|
||||||
|
if adaln_module is None:
|
||||||
|
return None
|
||||||
|
ts_input = (
|
||||||
|
timestep_scaled.max(dim=1, keepdim=True).values.flatten()
|
||||||
|
if timestep_scaled.dim() > 1
|
||||||
|
else timestep_scaled.flatten()
|
||||||
|
)
|
||||||
|
prompt_ts, _ = adaln_module(
|
||||||
|
ts_input,
|
||||||
|
{"resolution": None, "aspect_ratio": None},
|
||||||
|
batch_size=batch_size,
|
||||||
|
hidden_dtype=hidden_dtype,
|
||||||
|
)
|
||||||
|
return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1])
|
||||||
|
|
||||||
|
|
||||||
|
def apply_cross_attention_adaln(
|
||||||
|
x, context, attn, q_shift, q_scale, q_gate,
|
||||||
|
prompt_scale_shift_table, prompt_timestep,
|
||||||
|
attention_mask=None, transformer_options={},
|
||||||
|
):
|
||||||
|
"""Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV).
|
||||||
|
|
||||||
|
Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so
|
||||||
|
that both regular tensors and CompressedTimestep are supported.
|
||||||
|
"""
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
shift_kv, scale_kv = (
|
||||||
|
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
|
||||||
|
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
|
||||||
|
).unbind(dim=2)
|
||||||
|
attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift
|
||||||
|
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
|
||||||
|
return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate
|
||||||
|
|
||||||
def get_fractional_positions(indices_grid, max_pos):
|
def get_fractional_positions(indices_grid, max_pos):
|
||||||
n_pos_dims = indices_grid.shape[1]
|
n_pos_dims = indices_grid.shape[1]
|
||||||
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
||||||
@ -556,6 +649,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
vae_scale_factors: tuple = (8, 32, 32),
|
vae_scale_factors: tuple = (8, 32, 32),
|
||||||
use_middle_indices_grid=False,
|
use_middle_indices_grid=False,
|
||||||
timestep_scale_multiplier = 1000.0,
|
timestep_scale_multiplier = 1000.0,
|
||||||
|
caption_proj_before_connector=False,
|
||||||
|
cross_attention_adaln=False,
|
||||||
|
caption_projection_first_linear=True,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -582,6 +678,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
self.causal_temporal_positioning = causal_temporal_positioning
|
self.causal_temporal_positioning = causal_temporal_positioning
|
||||||
self.operations = operations
|
self.operations = operations
|
||||||
self.timestep_scale_multiplier = timestep_scale_multiplier
|
self.timestep_scale_multiplier = timestep_scale_multiplier
|
||||||
|
self.caption_proj_before_connector = caption_proj_before_connector
|
||||||
|
self.cross_attention_adaln = cross_attention_adaln
|
||||||
|
self.caption_projection_first_linear = caption_projection_first_linear
|
||||||
|
|
||||||
# Common dimensions
|
# Common dimensions
|
||||||
self.inner_dim = num_attention_heads * attention_head_dim
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
@ -609,17 +708,37 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
||||||
self.adaln_single = AdaLayerNormSingle(
|
self.adaln_single = AdaLayerNormSingle(
|
||||||
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||||
)
|
)
|
||||||
|
|
||||||
self.caption_projection = PixArtAlphaTextProjection(
|
if self.cross_attention_adaln:
|
||||||
in_features=self.caption_channels,
|
self.prompt_adaln_single = AdaLayerNormSingle(
|
||||||
hidden_size=self.inner_dim,
|
self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
||||||
dtype=dtype,
|
)
|
||||||
device=device,
|
else:
|
||||||
operations=self.operations,
|
self.prompt_adaln_single = None
|
||||||
)
|
|
||||||
|
if self.caption_proj_before_connector:
|
||||||
|
if self.caption_projection_first_linear:
|
||||||
|
self.caption_projection = NormSingleLinearTextProjection(
|
||||||
|
in_features=self.caption_channels,
|
||||||
|
hidden_size=self.inner_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.caption_projection = lambda a: a
|
||||||
|
else:
|
||||||
|
self.caption_projection = PixArtAlphaTextProjection(
|
||||||
|
in_features=self.caption_channels,
|
||||||
|
hidden_size=self.inner_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=self.operations,
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _init_model_components(self, device, dtype, **kwargs):
|
def _init_model_components(self, device, dtype, **kwargs):
|
||||||
@ -665,9 +784,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
if grid_mask is not None:
|
if grid_mask is not None:
|
||||||
timestep = timestep[:, grid_mask]
|
timestep = timestep[:, grid_mask]
|
||||||
|
|
||||||
timestep = timestep * self.timestep_scale_multiplier
|
timestep_scaled = timestep * self.timestep_scale_multiplier
|
||||||
timestep, embedded_timestep = self.adaln_single(
|
timestep, embedded_timestep = self.adaln_single(
|
||||||
timestep.flatten(),
|
timestep_scaled.flatten(),
|
||||||
{"resolution": None, "aspect_ratio": None},
|
{"resolution": None, "aspect_ratio": None},
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
hidden_dtype=hidden_dtype,
|
hidden_dtype=hidden_dtype,
|
||||||
@ -677,14 +796,18 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
||||||
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
||||||
|
|
||||||
return timestep, embedded_timestep
|
prompt_timestep = compute_prompt_timestep(
|
||||||
|
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return timestep, embedded_timestep, prompt_timestep
|
||||||
|
|
||||||
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
||||||
"""Prepare context for transformer blocks."""
|
"""Prepare context for transformer blocks."""
|
||||||
if self.caption_projection is not None:
|
if self.caption_proj_before_connector is False:
|
||||||
context = self.caption_projection(context)
|
context = self.caption_projection(context)
|
||||||
context = context.view(batch_size, -1, x.shape[-1])
|
|
||||||
|
|
||||||
|
context = context.view(batch_size, -1, x.shape[-1])
|
||||||
return context, attention_mask
|
return context, attention_mask
|
||||||
|
|
||||||
def _precompute_freqs_cis(
|
def _precompute_freqs_cis(
|
||||||
@ -792,7 +915,8 @@ class LTXBaseModel(torch.nn.Module, ABC):
|
|||||||
merged_args.update(additional_args)
|
merged_args.update(additional_args)
|
||||||
|
|
||||||
# Prepare timestep and context
|
# Prepare timestep and context
|
||||||
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
||||||
|
merged_args["prompt_timestep"] = prompt_timestep
|
||||||
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
||||||
|
|
||||||
# Prepare attention mask and positional embeddings
|
# Prepare attention mask and positional embeddings
|
||||||
@ -833,7 +957,9 @@ class LTXVModel(LTXBaseModel):
|
|||||||
causal_temporal_positioning=False,
|
causal_temporal_positioning=False,
|
||||||
vae_scale_factors=(8, 32, 32),
|
vae_scale_factors=(8, 32, 32),
|
||||||
use_middle_indices_grid=False,
|
use_middle_indices_grid=False,
|
||||||
timestep_scale_multiplier = 1000.0,
|
timestep_scale_multiplier=1000.0,
|
||||||
|
caption_proj_before_connector=False,
|
||||||
|
cross_attention_adaln=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -852,6 +978,8 @@ class LTXVModel(LTXBaseModel):
|
|||||||
vae_scale_factors=vae_scale_factors,
|
vae_scale_factors=vae_scale_factors,
|
||||||
use_middle_indices_grid=use_middle_indices_grid,
|
use_middle_indices_grid=use_middle_indices_grid,
|
||||||
timestep_scale_multiplier=timestep_scale_multiplier,
|
timestep_scale_multiplier=timestep_scale_multiplier,
|
||||||
|
caption_proj_before_connector=caption_proj_before_connector,
|
||||||
|
cross_attention_adaln=cross_attention_adaln,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations,
|
operations=operations,
|
||||||
@ -860,7 +988,6 @@ class LTXVModel(LTXBaseModel):
|
|||||||
|
|
||||||
def _init_model_components(self, device, dtype, **kwargs):
|
def _init_model_components(self, device, dtype, **kwargs):
|
||||||
"""Initialize LTXV-specific components."""
|
"""Initialize LTXV-specific components."""
|
||||||
# No additional components needed for LTXV beyond base class
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
||||||
@ -872,6 +999,7 @@ class LTXVModel(LTXBaseModel):
|
|||||||
self.num_attention_heads,
|
self.num_attention_heads,
|
||||||
self.attention_head_dim,
|
self.attention_head_dim,
|
||||||
context_dim=self.cross_attention_dim,
|
context_dim=self.cross_attention_dim,
|
||||||
|
cross_attention_adaln=self.cross_attention_adaln,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=self.operations,
|
operations=self.operations,
|
||||||
@ -1149,16 +1277,17 @@ class LTXVModel(LTXBaseModel):
|
|||||||
"""Process transformer blocks for LTXV."""
|
"""Process transformer blocks for LTXV."""
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
prompt_timestep = kwargs.get("prompt_timestep", None)
|
||||||
|
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
|
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"))
|
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
x = block(
|
x = block(
|
||||||
@ -1169,6 +1298,7 @@ class LTXVModel(LTXBaseModel):
|
|||||||
pe=pe,
|
pe=pe,
|
||||||
transformer_options=transformer_options,
|
transformer_options=transformer_options,
|
||||||
self_attention_mask=self_attention_mask,
|
self_attention_mask=self_attention_mask,
|
||||||
|
prompt_timestep=prompt_timestep,
|
||||||
)
|
)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
|
|||||||
CausalityAxis,
|
CausalityAxis,
|
||||||
CausalAudioAutoencoder,
|
CausalAudioAutoencoder,
|
||||||
)
|
)
|
||||||
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
|
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE
|
||||||
|
|
||||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||||
|
|
||||||
@ -141,7 +141,10 @@ class AudioVAE(torch.nn.Module):
|
|||||||
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
|
||||||
|
|
||||||
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
|
||||||
self.vocoder = Vocoder(config=component_config.vocoder)
|
if "bwe" in component_config.vocoder:
|
||||||
|
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
|
||||||
|
else:
|
||||||
|
self.vocoder = Vocoder(config=component_config.vocoder)
|
||||||
|
|
||||||
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
self.autoencoder.load_state_dict(vae_sd, strict=False)
|
||||||
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
self.vocoder.load_state_dict(vocoder_sd, strict=False)
|
||||||
|
|||||||
@ -822,26 +822,23 @@ class CausalAudioAutoencoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = self._guess_config()
|
config = self.get_default_config()
|
||||||
|
|
||||||
# Extract encoder and decoder configs from the new format
|
|
||||||
model_config = config.get("model", {}).get("params", {})
|
model_config = config.get("model", {}).get("params", {})
|
||||||
variables_config = config.get("variables", {})
|
|
||||||
|
|
||||||
self.sampling_rate = variables_config.get(
|
self.sampling_rate = model_config.get(
|
||||||
"sampling_rate",
|
"sampling_rate", config.get("sampling_rate", 16000)
|
||||||
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
|
|
||||||
)
|
)
|
||||||
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
|
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
|
||||||
decoder_config = model_config.get("decoder", encoder_config)
|
decoder_config = model_config.get("decoder", encoder_config)
|
||||||
|
|
||||||
# Load mel spectrogram parameters
|
# Load mel spectrogram parameters
|
||||||
self.mel_bins = encoder_config.get("mel_bins", 64)
|
self.mel_bins = encoder_config.get("mel_bins", 64)
|
||||||
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
|
||||||
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
|
||||||
|
|
||||||
# Store causality configuration at VAE level (not just in encoder internals)
|
# Store causality configuration at VAE level (not just in encoder internals)
|
||||||
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
|
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value)
|
||||||
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
|
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
|
||||||
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
|
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
|
||||||
|
|
||||||
@ -850,44 +847,38 @@ class CausalAudioAutoencoder(nn.Module):
|
|||||||
|
|
||||||
self.per_channel_statistics = processor()
|
self.per_channel_statistics = processor()
|
||||||
|
|
||||||
def _guess_config(self):
|
def get_default_config(self):
|
||||||
encoder_config = {
|
ddconfig = {
|
||||||
# Required parameters - based on ltx-video-av-1679000 model metadata
|
|
||||||
"ch": 128,
|
|
||||||
"out_ch": 8,
|
|
||||||
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
|
|
||||||
"num_res_blocks": 2,
|
|
||||||
"attn_resolutions": [], # Based on metadata: empty list, no attention
|
|
||||||
"dropout": 0.0,
|
|
||||||
"resamp_with_conv": True,
|
|
||||||
"in_channels": 2, # stereo
|
|
||||||
"resolution": 256,
|
|
||||||
"z_channels": 8,
|
|
||||||
"double_z": True,
|
"double_z": True,
|
||||||
"attn_type": "vanilla",
|
"mel_bins": 64,
|
||||||
"mid_block_add_attention": False, # Based on metadata: false
|
"z_channels": 8,
|
||||||
|
"resolution": 256,
|
||||||
|
"downsample_time": False,
|
||||||
|
"in_channels": 2,
|
||||||
|
"out_ch": 2,
|
||||||
|
"ch": 128,
|
||||||
|
"ch_mult": [1, 2, 4],
|
||||||
|
"num_res_blocks": 2,
|
||||||
|
"attn_resolutions": [],
|
||||||
|
"dropout": 0.0,
|
||||||
|
"mid_block_add_attention": False,
|
||||||
"norm_type": "pixel",
|
"norm_type": "pixel",
|
||||||
"causality_axis": "height", # Based on metadata
|
"causality_axis": "height",
|
||||||
"mel_bins": 64, # Based on metadata: mel_bins = 64
|
|
||||||
}
|
|
||||||
|
|
||||||
decoder_config = {
|
|
||||||
# Inherits encoder config, can override specific params
|
|
||||||
**encoder_config,
|
|
||||||
"out_ch": 2, # Stereo audio output (2 channels)
|
|
||||||
"give_pre_end": False,
|
|
||||||
"tanh_out": False,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"_class_name": "CausalAudioAutoencoder",
|
|
||||||
"sampling_rate": 16000,
|
|
||||||
"model": {
|
"model": {
|
||||||
"params": {
|
"params": {
|
||||||
"encoder": encoder_config,
|
"ddconfig": ddconfig,
|
||||||
"decoder": decoder_config,
|
"sampling_rate": 16000,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
"preprocessing": {
|
||||||
|
"stft": {
|
||||||
|
"filter_length": 1024,
|
||||||
|
"hop_length": 160,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|||||||
@ -15,6 +15,9 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
|||||||
|
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
def in_meta_context():
|
||||||
|
return torch.device("meta") == torch.empty(0).device
|
||||||
|
|
||||||
def mark_conv3d_ended(module):
|
def mark_conv3d_ended(module):
|
||||||
tid = threading.get_ident()
|
tid = threading.get_ident()
|
||||||
for _, m in module.named_modules():
|
for _, m in module.named_modules():
|
||||||
@ -350,6 +353,10 @@ class Decoder(nn.Module):
|
|||||||
output_channel = output_channel * block_params.get("multiplier", 2)
|
output_channel = output_channel * block_params.get("multiplier", 2)
|
||||||
if block_name == "compress_all":
|
if block_name == "compress_all":
|
||||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||||
|
if block_name == "compress_space":
|
||||||
|
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||||
|
if block_name == "compress_time":
|
||||||
|
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||||
|
|
||||||
self.conv_in = make_conv_nd(
|
self.conv_in = make_conv_nd(
|
||||||
dims,
|
dims,
|
||||||
@ -395,17 +402,21 @@ class Decoder(nn.Module):
|
|||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_time":
|
elif block_name == "compress_time":
|
||||||
|
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||||
block = DepthToSpaceUpsample(
|
block = DepthToSpaceUpsample(
|
||||||
dims=dims,
|
dims=dims,
|
||||||
in_channels=input_channel,
|
in_channels=input_channel,
|
||||||
stride=(2, 1, 1),
|
stride=(2, 1, 1),
|
||||||
|
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_space":
|
elif block_name == "compress_space":
|
||||||
|
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||||
block = DepthToSpaceUpsample(
|
block = DepthToSpaceUpsample(
|
||||||
dims=dims,
|
dims=dims,
|
||||||
in_channels=input_channel,
|
in_channels=input_channel,
|
||||||
stride=(1, 2, 2),
|
stride=(1, 2, 2),
|
||||||
|
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
elif block_name == "compress_all":
|
elif block_name == "compress_all":
|
||||||
@ -455,6 +466,15 @@ class Decoder(nn.Module):
|
|||||||
output_channel * 2, 0, operations=ops,
|
output_channel * 2, 0, operations=ops,
|
||||||
)
|
)
|
||||||
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
||||||
|
else:
|
||||||
|
self.register_buffer(
|
||||||
|
"last_scale_shift_table",
|
||||||
|
torch.tensor(
|
||||||
|
[0.0, 0.0],
|
||||||
|
device="cpu" if in_meta_context() else None
|
||||||
|
).unsqueeze(1).expand(2, output_channel),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||||
@ -883,6 +903,15 @@ class ResnetBlock3D(nn.Module):
|
|||||||
self.scale_shift_table = nn.Parameter(
|
self.scale_shift_table = nn.Parameter(
|
||||||
torch.randn(4, in_channels) / in_channels**0.5
|
torch.randn(4, in_channels) / in_channels**0.5
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
self.register_buffer(
|
||||||
|
"scale_shift_table",
|
||||||
|
torch.tensor(
|
||||||
|
[0.0, 0.0, 0.0, 0.0],
|
||||||
|
device="cpu" if in_meta_context() else None
|
||||||
|
).unsqueeze(1).expand(4, in_channels),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
self.temporal_cache_state={}
|
self.temporal_cache_state={}
|
||||||
|
|
||||||
@ -1012,9 +1041,6 @@ class processor(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.register_buffer("std-of-means", torch.empty(128))
|
self.register_buffer("std-of-means", torch.empty(128))
|
||||||
self.register_buffer("mean-of-means", torch.empty(128))
|
self.register_buffer("mean-of-means", torch.empty(128))
|
||||||
self.register_buffer("mean-of-stds", torch.empty(128))
|
|
||||||
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
|
|
||||||
self.register_buffer("channel", torch.empty(128))
|
|
||||||
|
|
||||||
def un_normalize(self, x):
|
def un_normalize(self, x):
|
||||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||||
@ -1027,9 +1053,12 @@ class VideoVAE(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = self.guess_config(version)
|
config = self.get_default_config(version)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||||
|
self.decode_noise_scale = config.get("decode_noise_scale", 0.025)
|
||||||
|
self.decode_timestep = config.get("decode_timestep", 0.05)
|
||||||
double_z = config.get("double_z", True)
|
double_z = config.get("double_z", True)
|
||||||
latent_log_var = config.get(
|
latent_log_var = config.get(
|
||||||
"latent_log_var", "per_channel" if double_z else "none"
|
"latent_log_var", "per_channel" if double_z else "none"
|
||||||
@ -1044,6 +1073,7 @@ class VideoVAE(nn.Module):
|
|||||||
latent_log_var=latent_log_var,
|
latent_log_var=latent_log_var,
|
||||||
norm_layer=config.get("norm_layer", "group_norm"),
|
norm_layer=config.get("norm_layer", "group_norm"),
|
||||||
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
||||||
|
base_channels=config.get("encoder_base_channels", 128),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.decoder = Decoder(
|
self.decoder = Decoder(
|
||||||
@ -1051,6 +1081,7 @@ class VideoVAE(nn.Module):
|
|||||||
in_channels=config["latent_channels"],
|
in_channels=config["latent_channels"],
|
||||||
out_channels=config.get("out_channels", 3),
|
out_channels=config.get("out_channels", 3),
|
||||||
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||||
|
base_channels=config.get("decoder_base_channels", 128),
|
||||||
patch_size=config.get("patch_size", 1),
|
patch_size=config.get("patch_size", 1),
|
||||||
norm_layer=config.get("norm_layer", "group_norm"),
|
norm_layer=config.get("norm_layer", "group_norm"),
|
||||||
causal=config.get("causal_decoder", False),
|
causal=config.get("causal_decoder", False),
|
||||||
@ -1060,7 +1091,7 @@ class VideoVAE(nn.Module):
|
|||||||
|
|
||||||
self.per_channel_statistics = processor()
|
self.per_channel_statistics = processor()
|
||||||
|
|
||||||
def guess_config(self, version):
|
def get_default_config(self, version):
|
||||||
if version == 0:
|
if version == 0:
|
||||||
config = {
|
config = {
|
||||||
"_class_name": "CausalVideoAutoencoder",
|
"_class_name": "CausalVideoAutoencoder",
|
||||||
@ -1167,8 +1198,7 @@ class VideoVAE(nn.Module):
|
|||||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||||
return self.per_channel_statistics.normalize(means)
|
return self.per_channel_statistics.normalize(means)
|
||||||
|
|
||||||
def decode(self, x, timestep=0.05, noise_scale=0.025):
|
def decode(self, x):
|
||||||
if self.timestep_conditioning: #TODO: seed
|
if self.timestep_conditioning: #TODO: seed
|
||||||
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
|
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
|
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import torch.nn.functional as F
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import math
|
||||||
|
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
@ -12,6 +13,307 @@ def get_padding(kernel_size, dilation=1):
|
|||||||
return int((kernel_size * dilation - dilation) / 2)
|
return int((kernel_size * dilation - dilation) / 2)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
|
||||||
|
# Adopted from https://github.com/NVIDIA/BigVGAN
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _sinc(x: torch.Tensor):
|
||||||
|
return torch.where(
|
||||||
|
x == 0,
|
||||||
|
torch.tensor(1.0, device=x.device, dtype=x.dtype),
|
||||||
|
torch.sin(math.pi * x) / math.pi / x,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
|
||||||
|
even = kernel_size % 2 == 0
|
||||||
|
half_size = kernel_size // 2
|
||||||
|
delta_f = 4 * half_width
|
||||||
|
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||||
|
if A > 50.0:
|
||||||
|
beta = 0.1102 * (A - 8.7)
|
||||||
|
elif A >= 21.0:
|
||||||
|
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
|
||||||
|
else:
|
||||||
|
beta = 0.0
|
||||||
|
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||||
|
if even:
|
||||||
|
time = torch.arange(-half_size, half_size) + 0.5
|
||||||
|
else:
|
||||||
|
time = torch.arange(kernel_size) - half_size
|
||||||
|
if cutoff == 0:
|
||||||
|
filter_ = torch.zeros_like(time)
|
||||||
|
else:
|
||||||
|
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
|
||||||
|
filter_ /= filter_.sum()
|
||||||
|
filter = filter_.view(1, 1, kernel_size)
|
||||||
|
return filter
|
||||||
|
|
||||||
|
|
||||||
|
class LowPassFilter1d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cutoff=0.5,
|
||||||
|
half_width=0.6,
|
||||||
|
stride=1,
|
||||||
|
padding=True,
|
||||||
|
padding_mode="replicate",
|
||||||
|
kernel_size=12,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if cutoff < -0.0:
|
||||||
|
raise ValueError("Minimum cutoff must be larger than zero.")
|
||||||
|
if cutoff > 0.5:
|
||||||
|
raise ValueError("A cutoff above 0.5 does not make sense.")
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.even = kernel_size % 2 == 0
|
||||||
|
self.pad_left = kernel_size // 2 - int(self.even)
|
||||||
|
self.pad_right = kernel_size // 2
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
self.padding_mode = padding_mode
|
||||||
|
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
|
||||||
|
self.register_buffer("filter", filter)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
_, C, _ = x.shape
|
||||||
|
if self.padding:
|
||||||
|
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||||
|
return F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||||
|
|
||||||
|
|
||||||
|
class UpSample1d(nn.Module):
|
||||||
|
def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"):
|
||||||
|
super().__init__()
|
||||||
|
self.ratio = ratio
|
||||||
|
self.stride = ratio
|
||||||
|
|
||||||
|
if window_type == "hann":
|
||||||
|
# Hann-windowed sinc filter — identical to torchaudio.functional.resample
|
||||||
|
# with its default parameters (rolloff=0.99, lowpass_filter_width=6).
|
||||||
|
# Uses replicate boundary padding, matching the reference resampler exactly.
|
||||||
|
rolloff = 0.99
|
||||||
|
lowpass_filter_width = 6
|
||||||
|
width = math.ceil(lowpass_filter_width / rolloff)
|
||||||
|
self.kernel_size = 2 * width * ratio + 1
|
||||||
|
self.pad = width
|
||||||
|
self.pad_left = 2 * width * ratio
|
||||||
|
self.pad_right = self.kernel_size - ratio
|
||||||
|
t = (torch.arange(self.kernel_size) / ratio - width) * rolloff
|
||||||
|
t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width)
|
||||||
|
window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
||||||
|
filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1)
|
||||||
|
else:
|
||||||
|
# Kaiser-windowed sinc filter (BigVGAN default).
|
||||||
|
self.kernel_size = (
|
||||||
|
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||||
|
)
|
||||||
|
self.pad = self.kernel_size // ratio - 1
|
||||||
|
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||||
|
self.pad_right = (
|
||||||
|
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||||
|
)
|
||||||
|
filter = kaiser_sinc_filter1d(
|
||||||
|
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
||||||
|
)
|
||||||
|
|
||||||
|
self.register_buffer("filter", filter, persistent=persistent)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
_, C, _ = x.shape
|
||||||
|
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||||
|
x = self.ratio * F.conv_transpose1d(
|
||||||
|
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
||||||
|
)
|
||||||
|
x = x[..., self.pad_left : -self.pad_right]
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DownSample1d(nn.Module):
|
||||||
|
def __init__(self, ratio=2, kernel_size=None):
|
||||||
|
super().__init__()
|
||||||
|
self.ratio = ratio
|
||||||
|
self.kernel_size = (
|
||||||
|
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||||
|
)
|
||||||
|
self.lowpass = LowPassFilter1d(
|
||||||
|
cutoff=0.5 / ratio,
|
||||||
|
half_width=0.6 / ratio,
|
||||||
|
stride=ratio,
|
||||||
|
kernel_size=self.kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.lowpass(x)
|
||||||
|
|
||||||
|
|
||||||
|
class Activation1d(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation,
|
||||||
|
up_ratio=2,
|
||||||
|
down_ratio=2,
|
||||||
|
up_kernel_size=12,
|
||||||
|
down_kernel_size=12,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.act = activation
|
||||||
|
self.upsample = UpSample1d(up_ratio, up_kernel_size)
|
||||||
|
self.downsample = DownSample1d(down_ratio, down_kernel_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.upsample(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.downsample(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# BigVGAN v2 activations (Snake / SnakeBeta)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class Snake(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.alpha_logscale = alpha_logscale
|
||||||
|
self.alpha = nn.Parameter(
|
||||||
|
torch.zeros(in_features)
|
||||||
|
if alpha_logscale
|
||||||
|
else torch.ones(in_features) * alpha
|
||||||
|
)
|
||||||
|
self.alpha.requires_grad = alpha_trainable
|
||||||
|
self.eps = 1e-9
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
a = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||||
|
if self.alpha_logscale:
|
||||||
|
a = torch.exp(a)
|
||||||
|
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
|
||||||
|
|
||||||
|
|
||||||
|
class SnakeBeta(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.alpha_logscale = alpha_logscale
|
||||||
|
self.alpha = nn.Parameter(
|
||||||
|
torch.zeros(in_features)
|
||||||
|
if alpha_logscale
|
||||||
|
else torch.ones(in_features) * alpha
|
||||||
|
)
|
||||||
|
self.alpha.requires_grad = alpha_trainable
|
||||||
|
self.beta = nn.Parameter(
|
||||||
|
torch.zeros(in_features)
|
||||||
|
if alpha_logscale
|
||||||
|
else torch.ones(in_features) * alpha
|
||||||
|
)
|
||||||
|
self.beta.requires_grad = alpha_trainable
|
||||||
|
self.eps = 1e-9
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
a = self.alpha.unsqueeze(0).unsqueeze(-1)
|
||||||
|
b = self.beta.unsqueeze(0).unsqueeze(-1)
|
||||||
|
if self.alpha_logscale:
|
||||||
|
a = torch.exp(a)
|
||||||
|
b = torch.exp(b)
|
||||||
|
return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class AMPBlock1(torch.nn.Module):
|
||||||
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"):
|
||||||
|
super().__init__()
|
||||||
|
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||||
|
self.convs1 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[0],
|
||||||
|
padding=get_padding(kernel_size, dilation[0]),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[1],
|
||||||
|
padding=get_padding(kernel_size, dilation[1]),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=dilation[2],
|
||||||
|
padding=get_padding(kernel_size, dilation[2]),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.convs2 = nn.ModuleList(
|
||||||
|
[
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
),
|
||||||
|
ops.Conv1d(
|
||||||
|
channels,
|
||||||
|
channels,
|
||||||
|
kernel_size,
|
||||||
|
1,
|
||||||
|
dilation=1,
|
||||||
|
padding=get_padding(kernel_size, 1),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.acts1 = nn.ModuleList(
|
||||||
|
[Activation1d(act_cls(channels)) for _ in range(len(self.convs1))]
|
||||||
|
)
|
||||||
|
self.acts2 = nn.ModuleList(
|
||||||
|
[Activation1d(act_cls(channels)) for _ in range(len(self.convs2))]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2):
|
||||||
|
xt = a1(x)
|
||||||
|
xt = c1(xt)
|
||||||
|
xt = a2(xt)
|
||||||
|
xt = c2(xt)
|
||||||
|
x = x + xt
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# HiFi-GAN residual blocks
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class ResBlock1(torch.nn.Module):
|
class ResBlock1(torch.nn.Module):
|
||||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||||
super(ResBlock1, self).__init__()
|
super(ResBlock1, self).__init__()
|
||||||
@ -119,6 +421,7 @@ class Vocoder(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
|
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
|
||||||
|
|
||||||
|
Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1").
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config=None):
|
def __init__(self, config=None):
|
||||||
@ -128,19 +431,39 @@ class Vocoder(torch.nn.Module):
|
|||||||
config = self.get_default_config()
|
config = self.get_default_config()
|
||||||
|
|
||||||
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
|
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
|
||||||
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
|
upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2])
|
||||||
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
|
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4])
|
||||||
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
|
||||||
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
|
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
|
||||||
stereo = config.get("stereo", True)
|
stereo = config.get("stereo", True)
|
||||||
resblock = config.get("resblock", "1")
|
activation = config.get("activation", "snake")
|
||||||
|
use_bias_at_final = config.get("use_bias_at_final", True)
|
||||||
|
|
||||||
|
|
||||||
|
# "output_sample_rate" is not present in recent checkpoint configs.
|
||||||
|
# When absent (None), AudioVAE.output_sample_rate computes it as:
|
||||||
|
# sample_rate * vocoder.upsample_factor / mel_hop_length
|
||||||
|
# where upsample_factor = product of all upsample stride lengths,
|
||||||
|
# and mel_hop_length is loaded from the autoencoder config at
|
||||||
|
# preprocessing.stft.hop_length (see CausalAudioAutoencoder).
|
||||||
self.output_sample_rate = config.get("output_sample_rate")
|
self.output_sample_rate = config.get("output_sample_rate")
|
||||||
|
self.resblock = config.get("resblock", "1")
|
||||||
|
self.use_tanh_at_final = config.get("use_tanh_at_final", True)
|
||||||
|
self.apply_final_activation = config.get("apply_final_activation", True)
|
||||||
self.num_kernels = len(resblock_kernel_sizes)
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
self.num_upsamples = len(upsample_rates)
|
self.num_upsamples = len(upsample_rates)
|
||||||
|
|
||||||
in_channels = 128 if stereo else 64
|
in_channels = 128 if stereo else 64
|
||||||
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
|
||||||
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
|
|
||||||
|
if self.resblock == "1":
|
||||||
|
resblock_cls = ResBlock1
|
||||||
|
elif self.resblock == "2":
|
||||||
|
resblock_cls = ResBlock2
|
||||||
|
elif self.resblock == "AMP1":
|
||||||
|
resblock_cls = AMPBlock1
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown resblock type: {self.resblock}")
|
||||||
|
|
||||||
self.ups = nn.ModuleList()
|
self.ups = nn.ModuleList()
|
||||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||||
@ -157,25 +480,40 @@ class Vocoder(torch.nn.Module):
|
|||||||
self.resblocks = nn.ModuleList()
|
self.resblocks = nn.ModuleList()
|
||||||
for i in range(len(self.ups)):
|
for i in range(len(self.ups)):
|
||||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
|
||||||
self.resblocks.append(resblock_class(ch, k, d))
|
if self.resblock == "AMP1":
|
||||||
|
self.resblocks.append(resblock_cls(ch, k, d, activation=activation))
|
||||||
|
else:
|
||||||
|
self.resblocks.append(resblock_cls(ch, k, d))
|
||||||
|
|
||||||
out_channels = 2 if stereo else 1
|
out_channels = 2 if stereo else 1
|
||||||
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
|
if self.resblock == "AMP1":
|
||||||
|
act_cls = SnakeBeta if activation == "snakebeta" else Snake
|
||||||
|
self.act_post = Activation1d(act_cls(ch))
|
||||||
|
else:
|
||||||
|
self.act_post = nn.LeakyReLU()
|
||||||
|
|
||||||
|
self.conv_post = ops.Conv1d(
|
||||||
|
ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final
|
||||||
|
)
|
||||||
|
|
||||||
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
|
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
|
||||||
|
|
||||||
|
|
||||||
def get_default_config(self):
|
def get_default_config(self):
|
||||||
"""Generate default configuration for the vocoder."""
|
"""Generate default configuration for the vocoder."""
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"resblock_kernel_sizes": [3, 7, 11],
|
"resblock_kernel_sizes": [3, 7, 11],
|
||||||
"upsample_rates": [6, 5, 2, 2, 2],
|
"upsample_rates": [5, 4, 2, 2, 2],
|
||||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
|
||||||
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||||
"upsample_initial_channel": 1024,
|
"upsample_initial_channel": 1024,
|
||||||
"stereo": True,
|
"stereo": True,
|
||||||
"resblock": "1",
|
"resblock": "1",
|
||||||
|
"activation": "snake",
|
||||||
|
"use_bias_at_final": True,
|
||||||
|
"use_tanh_at_final": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
return config
|
return config
|
||||||
@ -196,8 +534,10 @@ class Vocoder(torch.nn.Module):
|
|||||||
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
|
||||||
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
|
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
|
||||||
x = self.conv_pre(x)
|
x = self.conv_pre(x)
|
||||||
|
|
||||||
for i in range(self.num_upsamples):
|
for i in range(self.num_upsamples):
|
||||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
if self.resblock != "AMP1":
|
||||||
|
x = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
x = self.ups[i](x)
|
x = self.ups[i](x)
|
||||||
xs = None
|
xs = None
|
||||||
for j in range(self.num_kernels):
|
for j in range(self.num_kernels):
|
||||||
@ -206,8 +546,167 @@ class Vocoder(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||||
x = xs / self.num_kernels
|
x = xs / self.num_kernels
|
||||||
x = F.leaky_relu(x)
|
|
||||||
|
x = self.act_post(x)
|
||||||
x = self.conv_post(x)
|
x = self.conv_post(x)
|
||||||
x = torch.tanh(x)
|
|
||||||
|
if self.apply_final_activation:
|
||||||
|
if self.use_tanh_at_final:
|
||||||
|
x = torch.tanh(x)
|
||||||
|
else:
|
||||||
|
x = torch.clamp(x, -1, 1)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class _STFTFn(nn.Module):
|
||||||
|
"""Implements STFT as a convolution with precomputed DFT × Hann-window bases.
|
||||||
|
|
||||||
|
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
|
||||||
|
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
|
||||||
|
bfloat16 bases from training ensures the mel values fed to the BWE generator are
|
||||||
|
bit-identical to what it was trained on.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, filter_length: int, hop_length: int, win_length: int):
|
||||||
|
super().__init__()
|
||||||
|
self.hop_length = hop_length
|
||||||
|
self.win_length = win_length
|
||||||
|
n_freqs = filter_length // 2 + 1
|
||||||
|
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||||
|
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
|
||||||
|
|
||||||
|
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute magnitude and phase spectrogram from a batch of waveforms.
|
||||||
|
|
||||||
|
Applies causal (left-only) padding of win_length - hop_length samples so that
|
||||||
|
each output frame depends only on past and present input — no lookahead.
|
||||||
|
The STFT is computed by convolving the padded signal with forward_basis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: Waveform tensor of shape (B, T).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||||
|
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||||
|
Computed in float32 for numerical stability, then cast back to
|
||||||
|
the input dtype.
|
||||||
|
"""
|
||||||
|
if y.dim() == 2:
|
||||||
|
y = y.unsqueeze(1) # (B, 1, T)
|
||||||
|
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
|
||||||
|
y = F.pad(y, (left_pad, 0))
|
||||||
|
spec = F.conv1d(y, self.forward_basis, stride=self.hop_length, padding=0)
|
||||||
|
n_freqs = spec.shape[1] // 2
|
||||||
|
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
||||||
|
magnitude = torch.sqrt(real ** 2 + imag ** 2)
|
||||||
|
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
|
||||||
|
return magnitude, phase
|
||||||
|
|
||||||
|
|
||||||
|
class MelSTFT(nn.Module):
|
||||||
|
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
|
||||||
|
|
||||||
|
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
|
||||||
|
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
|
||||||
|
|
||||||
|
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
|
||||||
|
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
filter_length: int,
|
||||||
|
hop_length: int,
|
||||||
|
win_length: int,
|
||||||
|
n_mel_channels: int,
|
||||||
|
sampling_rate: int,
|
||||||
|
mel_fmin: float,
|
||||||
|
mel_fmax: float,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
|
||||||
|
|
||||||
|
n_freqs = filter_length // 2 + 1
|
||||||
|
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
|
||||||
|
|
||||||
|
def mel_spectrogram(
|
||||||
|
self, y: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute log-mel spectrogram and auxiliary spectral quantities.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y: Waveform tensor of shape (B, T).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
|
||||||
|
Computed as log(clamp(mel_basis @ magnitude, min=1e-5)).
|
||||||
|
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
|
||||||
|
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
|
||||||
|
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
|
||||||
|
"""
|
||||||
|
magnitude, phase = self.stft_fn(y)
|
||||||
|
energy = torch.norm(magnitude, dim=1)
|
||||||
|
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
||||||
|
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||||
|
return log_mel, magnitude, phase, energy
|
||||||
|
|
||||||
|
|
||||||
|
class VocoderWithBWE(torch.nn.Module):
|
||||||
|
"""Vocoder with bandwidth extension (BWE) for higher sample rate output.
|
||||||
|
|
||||||
|
Chains a base vocoder (mel → low-rate waveform) with a BWE stage that upsamples
|
||||||
|
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
vocoder_config = config["vocoder"]
|
||||||
|
bwe_config = config["bwe"]
|
||||||
|
|
||||||
|
self.vocoder = Vocoder(config=vocoder_config)
|
||||||
|
self.bwe_generator = Vocoder(
|
||||||
|
config={**bwe_config, "apply_final_activation": False}
|
||||||
|
)
|
||||||
|
|
||||||
|
self.input_sample_rate = bwe_config["input_sampling_rate"]
|
||||||
|
self.output_sample_rate = bwe_config["output_sampling_rate"]
|
||||||
|
self.hop_length = bwe_config["hop_length"]
|
||||||
|
|
||||||
|
self.mel_stft = MelSTFT(
|
||||||
|
filter_length=bwe_config["n_fft"],
|
||||||
|
hop_length=bwe_config["hop_length"],
|
||||||
|
win_length=bwe_config["n_fft"],
|
||||||
|
n_mel_channels=bwe_config["num_mels"],
|
||||||
|
sampling_rate=bwe_config["input_sampling_rate"],
|
||||||
|
mel_fmin=0.0,
|
||||||
|
mel_fmax=bwe_config["input_sampling_rate"] / 2.0,
|
||||||
|
)
|
||||||
|
self.resampler = UpSample1d(
|
||||||
|
ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"],
|
||||||
|
persistent=False,
|
||||||
|
window_type="hann",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _compute_mel(self, audio):
|
||||||
|
"""Compute log-mel spectrogram from waveform using causal STFT bases."""
|
||||||
|
B, C, T = audio.shape
|
||||||
|
flat = audio.reshape(B * C, -1) # (B*C, T)
|
||||||
|
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
|
||||||
|
return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
|
||||||
|
|
||||||
|
def forward(self, mel_spec):
|
||||||
|
x = self.vocoder(mel_spec)
|
||||||
|
_, _, T_low = x.shape
|
||||||
|
T_out = T_low * self.output_sample_rate // self.input_sample_rate
|
||||||
|
|
||||||
|
remainder = T_low % self.hop_length
|
||||||
|
if remainder != 0:
|
||||||
|
x = F.pad(x, (0, self.hop_length - remainder))
|
||||||
|
|
||||||
|
mel = self._compute_mel(x)
|
||||||
|
residual = self.bwe_generator(mel)
|
||||||
|
skip = self.resampler(x)
|
||||||
|
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
|
||||||
|
|
||||||
|
return torch.clamp(residual + skip, -1, 1)[..., :T_out]
|
||||||
|
|||||||
@ -1021,7 +1021,7 @@ class LTXAV(BaseModel):
|
|||||||
cross_attn = kwargs.get("cross_attn", None)
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
if cross_attn is not None:
|
if cross_attn is not None:
|
||||||
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
|
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
|
||||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()))
|
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False))
|
||||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||||
|
|||||||
@ -1666,12 +1666,16 @@ def lora_compute_dtype(device):
|
|||||||
return dtype
|
return dtype
|
||||||
|
|
||||||
def synchronize():
|
def synchronize():
|
||||||
|
if cpu_mode():
|
||||||
|
return
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
torch.xpu.synchronize()
|
torch.xpu.synchronize()
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def soft_empty_cache(force=False):
|
def soft_empty_cache(force=False):
|
||||||
|
if cpu_mode():
|
||||||
|
return
|
||||||
global cpu_state
|
global cpu_state
|
||||||
if cpu_state == CPUState.MPS:
|
if cpu_state == CPUState.MPS:
|
||||||
torch.mps.empty_cache()
|
torch.mps.empty_cache()
|
||||||
|
|||||||
18
comfy/ops.py
18
comfy/ops.py
@ -660,23 +660,29 @@ class fp8_ops(manual_cast):
|
|||||||
|
|
||||||
CUBLAS_IS_AVAILABLE = False
|
CUBLAS_IS_AVAILABLE = False
|
||||||
try:
|
try:
|
||||||
from cublas_ops import CublasLinear
|
from cublas_ops import CublasLinear, cublas_half_matmul
|
||||||
CUBLAS_IS_AVAILABLE = True
|
CUBLAS_IS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if CUBLAS_IS_AVAILABLE:
|
if CUBLAS_IS_AVAILABLE:
|
||||||
class cublas_ops(disable_weight_init):
|
class cublas_ops(manual_cast):
|
||||||
class Linear(CublasLinear, disable_weight_init.Linear):
|
class Linear(CublasLinear, manual_cast.Linear):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input):
|
def forward_comfy_cast_weights(self, input):
|
||||||
return super().forward(input)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
|
x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return super().forward(*args, **kwargs)
|
run_every_op()
|
||||||
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Mixed Precision Operations
|
# Mixed Precision Operations
|
||||||
|
|||||||
@ -1490,7 +1490,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
||||||
elif clip_type == CLIPType.LTXV:
|
elif clip_type == CLIPType.LTXV:
|
||||||
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data), **comfy.text_encoders.lt.sd_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
|
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
|
||||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||||
elif clip_type == CLIPType.NEWBIE:
|
elif clip_type == CLIPType.NEWBIE:
|
||||||
|
|||||||
@ -97,18 +97,39 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
|
|||||||
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
|
||||||
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
|
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
|
||||||
|
|
||||||
|
class DualLinearProjection(torch.nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.audio_aggregate_embed = operations.Linear(in_dim, out_dim_audio, bias=True, dtype=dtype, device=device)
|
||||||
|
self.video_aggregate_embed = operations.Linear(in_dim, out_dim_video, bias=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
source_dim = x.shape[-1]
|
||||||
|
x = x.movedim(1, -1)
|
||||||
|
x = (x * torch.rsqrt(torch.mean(x**2, dim=2, keepdim=True) + 1e-6)).flatten(start_dim=2)
|
||||||
|
|
||||||
|
video = self.video_aggregate_embed(x * math.sqrt(self.video_aggregate_embed.out_features / source_dim))
|
||||||
|
audio = self.audio_aggregate_embed(x * math.sqrt(self.audio_aggregate_embed.out_features / source_dim))
|
||||||
|
return torch.cat((video, audio), dim=-1)
|
||||||
|
|
||||||
class LTXAVTEModel(torch.nn.Module):
|
class LTXAVTEModel(torch.nn.Module):
|
||||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
def __init__(self, dtype_llama=None, device="cpu", dtype=None, text_projection_type="single_linear", model_options={}):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
self.dtypes.add(dtype)
|
self.dtypes.add(dtype)
|
||||||
self.compat_mode = False
|
self.compat_mode = False
|
||||||
|
self.text_projection_type = text_projection_type
|
||||||
|
|
||||||
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
|
||||||
self.dtypes.add(dtype_llama)
|
self.dtypes.add(dtype_llama)
|
||||||
|
|
||||||
operations = self.gemma3_12b.operations # TODO
|
operations = self.gemma3_12b.operations # TODO
|
||||||
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
|
||||||
|
if self.text_projection_type == "single_linear":
|
||||||
|
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
|
||||||
|
elif self.text_projection_type == "dual_linear":
|
||||||
|
self.text_embedding_projection = DualLinearProjection(3840 * 49, 4096, 2048, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
|
||||||
def enable_compat_mode(self): # TODO: remove
|
def enable_compat_mode(self): # TODO: remove
|
||||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||||
@ -148,18 +169,25 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
out_device = out.device
|
out_device = out.device
|
||||||
if comfy.model_management.should_use_bf16(self.execution_device):
|
if comfy.model_management.should_use_bf16(self.execution_device):
|
||||||
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
||||||
out = out.movedim(1, -1).to(self.execution_device)
|
|
||||||
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
|
||||||
out = out.reshape((out.shape[0], out.shape[1], -1))
|
|
||||||
out = self.text_embedding_projection(out)
|
|
||||||
out = out.float()
|
|
||||||
|
|
||||||
if self.compat_mode:
|
if self.text_projection_type == "single_linear":
|
||||||
out_vid = self.video_embeddings_connector(out)[0]
|
out = out.movedim(1, -1).to(self.execution_device)
|
||||||
out_audio = self.audio_embeddings_connector(out)[0]
|
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
|
||||||
out = torch.concat((out_vid, out_audio), dim=-1)
|
out = out.reshape((out.shape[0], out.shape[1], -1))
|
||||||
|
out = self.text_embedding_projection(out)
|
||||||
|
|
||||||
return out.to(out_device), pooled
|
if self.compat_mode:
|
||||||
|
out_vid = self.video_embeddings_connector(out)[0]
|
||||||
|
out_audio = self.audio_embeddings_connector(out)[0]
|
||||||
|
out = torch.concat((out_vid, out_audio), dim=-1)
|
||||||
|
extra = {}
|
||||||
|
else:
|
||||||
|
extra = {"unprocessed_ltxav_embeds": True}
|
||||||
|
elif self.text_projection_type == "dual_linear":
|
||||||
|
out = self.text_embedding_projection(out)
|
||||||
|
extra = {"unprocessed_ltxav_embeds": True}
|
||||||
|
|
||||||
|
return out.to(device=out_device, dtype=torch.float), pooled, extra
|
||||||
|
|
||||||
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
|
||||||
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
|
||||||
@ -168,7 +196,7 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
if "model.layers.47.self_attn.q_norm.weight" in sd:
|
||||||
return self.gemma3_12b.load_sd(sd)
|
return self.gemma3_12b.load_sd(sd)
|
||||||
else:
|
else:
|
||||||
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True)
|
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "text_embedding_projection.": "text_embedding_projection."}, filter_keys=True)
|
||||||
if len(sdo) == 0:
|
if len(sdo) == 0:
|
||||||
sdo = sd
|
sdo = sd
|
||||||
|
|
||||||
@ -206,7 +234,7 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
num_tokens = max(num_tokens, 642)
|
num_tokens = max(num_tokens, 642)
|
||||||
return num_tokens * constant * 1024 * 1024
|
return num_tokens * constant * 1024 * 1024
|
||||||
|
|
||||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None, text_projection_type="single_linear"):
|
||||||
class LTXAVTEModel_(LTXAVTEModel):
|
class LTXAVTEModel_(LTXAVTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
if llama_quantization_metadata is not None:
|
if llama_quantization_metadata is not None:
|
||||||
@ -214,9 +242,19 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
|||||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||||
if dtype_llama is not None:
|
if dtype_llama is not None:
|
||||||
dtype = dtype_llama
|
dtype = dtype_llama
|
||||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, text_projection_type=text_projection_type, model_options=model_options)
|
||||||
return LTXAVTEModel_
|
return LTXAVTEModel_
|
||||||
|
|
||||||
|
|
||||||
|
def sd_detect(state_dict_list, prefix=""):
|
||||||
|
for sd in state_dict_list:
|
||||||
|
if "{}text_embedding_projection.audio_aggregate_embed.bias".format(prefix) in sd:
|
||||||
|
return {"text_projection_type": "dual_linear"}
|
||||||
|
if "{}text_embedding_projection.weight".format(prefix) in sd or "{}text_embedding_projection.aggregate_embed.weight".format(prefix) in sd:
|
||||||
|
return {"text_projection_type": "single_linear"}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
|
def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
class Gemma3_12BModel_(Gemma3_12BModel):
|
class Gemma3_12BModel_(Gemma3_12BModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
|||||||
@ -7,7 +7,8 @@ class ImageGenerationRequest(BaseModel):
|
|||||||
aspect_ratio: str = Field(...)
|
aspect_ratio: str = Field(...)
|
||||||
n: int = Field(...)
|
n: int = Field(...)
|
||||||
seed: int = Field(...)
|
seed: int = Field(...)
|
||||||
response_for: str = Field("url")
|
response_format: str = Field("url")
|
||||||
|
resolution: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class InputUrlObject(BaseModel):
|
class InputUrlObject(BaseModel):
|
||||||
@ -16,12 +17,13 @@ class InputUrlObject(BaseModel):
|
|||||||
|
|
||||||
class ImageEditRequest(BaseModel):
|
class ImageEditRequest(BaseModel):
|
||||||
model: str = Field(...)
|
model: str = Field(...)
|
||||||
image: InputUrlObject = Field(...)
|
images: list[InputUrlObject] = Field(...)
|
||||||
prompt: str = Field(...)
|
prompt: str = Field(...)
|
||||||
resolution: str = Field(...)
|
resolution: str = Field(...)
|
||||||
n: int = Field(...)
|
n: int = Field(...)
|
||||||
seed: int = Field(...)
|
seed: int = Field(...)
|
||||||
response_for: str = Field("url")
|
response_format: str = Field("url")
|
||||||
|
aspect_ratio: str | None = Field(...)
|
||||||
|
|
||||||
|
|
||||||
class VideoGenerationRequest(BaseModel):
|
class VideoGenerationRequest(BaseModel):
|
||||||
@ -47,8 +49,13 @@ class ImageResponseObject(BaseModel):
|
|||||||
revised_prompt: str | None = Field(None)
|
revised_prompt: str | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class UsageObject(BaseModel):
|
||||||
|
cost_in_usd_ticks: int | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class ImageGenerationResponse(BaseModel):
|
class ImageGenerationResponse(BaseModel):
|
||||||
data: list[ImageResponseObject] = Field(...)
|
data: list[ImageResponseObject] = Field(...)
|
||||||
|
usage: UsageObject | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class VideoGenerationResponse(BaseModel):
|
class VideoGenerationResponse(BaseModel):
|
||||||
@ -65,3 +72,4 @@ class VideoStatusResponse(BaseModel):
|
|||||||
status: str | None = Field(None)
|
status: str | None = Field(None)
|
||||||
video: VideoResponseObject | None = Field(None)
|
video: VideoResponseObject | None = Field(None)
|
||||||
model: str | None = Field(None)
|
model: str | None = Field(None)
|
||||||
|
usage: UsageObject | None = Field(None)
|
||||||
|
|||||||
@ -27,6 +27,12 @@ from comfy_api_nodes.util import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_grok_price(response) -> float | None:
|
||||||
|
if response.usage and response.usage.cost_in_usd_ticks is not None:
|
||||||
|
return response.usage.cost_in_usd_ticks / 10_000_000_000
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class GrokImageNode(IO.ComfyNode):
|
class GrokImageNode(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -37,7 +43,10 @@ class GrokImageNode(IO.ComfyNode):
|
|||||||
category="api node/image/Grok",
|
category="api node/image/Grok",
|
||||||
description="Generate images using Grok based on a text prompt",
|
description="Generate images using Grok based on a text prompt",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input("model", options=["grok-imagine-image-beta"]),
|
IO.Combo.Input(
|
||||||
|
"model",
|
||||||
|
options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"],
|
||||||
|
),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
@ -81,6 +90,7 @@ class GrokImageNode(IO.ComfyNode):
|
|||||||
tooltip="Seed to determine if node should re-run; "
|
tooltip="Seed to determine if node should re-run; "
|
||||||
"actual results are nondeterministic regardless of seed.",
|
"actual results are nondeterministic regardless of seed.",
|
||||||
),
|
),
|
||||||
|
IO.Combo.Input("resolution", options=["1K", "2K"], optional=True),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Image.Output(),
|
IO.Image.Output(),
|
||||||
@ -92,8 +102,13 @@ class GrokImageNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]),
|
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]),
|
||||||
expr="""{"type":"usd","usd":0.033 * widgets.number_of_images}""",
|
expr="""
|
||||||
|
(
|
||||||
|
$rate := $contains(widgets.model, "pro") ? 0.07 : 0.02;
|
||||||
|
{"type":"usd","usd": $rate * widgets.number_of_images}
|
||||||
|
)
|
||||||
|
""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -105,6 +120,7 @@ class GrokImageNode(IO.ComfyNode):
|
|||||||
aspect_ratio: str,
|
aspect_ratio: str,
|
||||||
number_of_images: int,
|
number_of_images: int,
|
||||||
seed: int,
|
seed: int,
|
||||||
|
resolution: str = "1K",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
@ -116,8 +132,10 @@ class GrokImageNode(IO.ComfyNode):
|
|||||||
aspect_ratio=aspect_ratio,
|
aspect_ratio=aspect_ratio,
|
||||||
n=number_of_images,
|
n=number_of_images,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
resolution=resolution.lower(),
|
||||||
),
|
),
|
||||||
response_model=ImageGenerationResponse,
|
response_model=ImageGenerationResponse,
|
||||||
|
price_extractor=_extract_grok_price,
|
||||||
)
|
)
|
||||||
if len(response.data) == 1:
|
if len(response.data) == 1:
|
||||||
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
|
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
|
||||||
@ -138,14 +156,17 @@ class GrokImageEditNode(IO.ComfyNode):
|
|||||||
category="api node/image/Grok",
|
category="api node/image/Grok",
|
||||||
description="Modify an existing image based on a text prompt",
|
description="Modify an existing image based on a text prompt",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input("model", options=["grok-imagine-image-beta"]),
|
IO.Combo.Input(
|
||||||
IO.Image.Input("image"),
|
"model",
|
||||||
|
options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"],
|
||||||
|
),
|
||||||
|
IO.Image.Input("image", display_name="images"),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
tooltip="The text prompt used to generate the image",
|
tooltip="The text prompt used to generate the image",
|
||||||
),
|
),
|
||||||
IO.Combo.Input("resolution", options=["1K"]),
|
IO.Combo.Input("resolution", options=["1K", "2K"]),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
"number_of_images",
|
"number_of_images",
|
||||||
default=1,
|
default=1,
|
||||||
@ -166,6 +187,27 @@ class GrokImageEditNode(IO.ComfyNode):
|
|||||||
tooltip="Seed to determine if node should re-run; "
|
tooltip="Seed to determine if node should re-run; "
|
||||||
"actual results are nondeterministic regardless of seed.",
|
"actual results are nondeterministic regardless of seed.",
|
||||||
),
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
options=[
|
||||||
|
"auto",
|
||||||
|
"1:1",
|
||||||
|
"2:3",
|
||||||
|
"3:2",
|
||||||
|
"3:4",
|
||||||
|
"4:3",
|
||||||
|
"9:16",
|
||||||
|
"16:9",
|
||||||
|
"9:19.5",
|
||||||
|
"19.5:9",
|
||||||
|
"9:20",
|
||||||
|
"20:9",
|
||||||
|
"1:2",
|
||||||
|
"2:1",
|
||||||
|
],
|
||||||
|
optional=True,
|
||||||
|
tooltip="Only allowed when multiple images are connected to the image input.",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Image.Output(),
|
IO.Image.Output(),
|
||||||
@ -177,8 +219,13 @@ class GrokImageEditNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]),
|
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]),
|
||||||
expr="""{"type":"usd","usd":0.002 + 0.033 * widgets.number_of_images}""",
|
expr="""
|
||||||
|
(
|
||||||
|
$rate := $contains(widgets.model, "pro") ? 0.07 : 0.02;
|
||||||
|
{"type":"usd","usd": 0.002 + $rate * widgets.number_of_images}
|
||||||
|
)
|
||||||
|
""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -191,22 +238,32 @@ class GrokImageEditNode(IO.ComfyNode):
|
|||||||
resolution: str,
|
resolution: str,
|
||||||
number_of_images: int,
|
number_of_images: int,
|
||||||
seed: int,
|
seed: int,
|
||||||
|
aspect_ratio: str = "auto",
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||||
if get_number_of_images(image) != 1:
|
if model == "grok-imagine-image-pro":
|
||||||
raise ValueError("Only one input image is supported.")
|
if get_number_of_images(image) > 1:
|
||||||
|
raise ValueError("The pro model supports only 1 input image.")
|
||||||
|
elif get_number_of_images(image) > 3:
|
||||||
|
raise ValueError("A maximum of 3 input images is supported.")
|
||||||
|
if aspect_ratio != "auto" and get_number_of_images(image) == 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Custom aspect ratio is only allowed when multiple images are connected to the image input."
|
||||||
|
)
|
||||||
response = await sync_op(
|
response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"),
|
ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"),
|
||||||
data=ImageEditRequest(
|
data=ImageEditRequest(
|
||||||
model=model,
|
model=model,
|
||||||
image=InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}"),
|
images=[InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in image],
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
resolution=resolution.lower(),
|
resolution=resolution.lower(),
|
||||||
n=number_of_images,
|
n=number_of_images,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio,
|
||||||
),
|
),
|
||||||
response_model=ImageGenerationResponse,
|
response_model=ImageGenerationResponse,
|
||||||
|
price_extractor=_extract_grok_price,
|
||||||
)
|
)
|
||||||
if len(response.data) == 1:
|
if len(response.data) == 1:
|
||||||
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
|
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
|
||||||
@ -227,7 +284,7 @@ class GrokVideoNode(IO.ComfyNode):
|
|||||||
category="api node/video/Grok",
|
category="api node/video/Grok",
|
||||||
description="Generate video from a prompt or an image",
|
description="Generate video from a prompt or an image",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input("model", options=["grok-imagine-video-beta"]),
|
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
@ -275,10 +332,11 @@ class GrokVideoNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
depends_on=IO.PriceBadgeDepends(widgets=["duration"], inputs=["image"]),
|
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]),
|
||||||
expr="""
|
expr="""
|
||||||
(
|
(
|
||||||
$base := 0.181 * widgets.duration;
|
$rate := widgets.resolution = "720p" ? 0.07 : 0.05;
|
||||||
|
$base := $rate * widgets.duration;
|
||||||
{"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base}
|
{"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base}
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
@ -321,6 +379,7 @@ class GrokVideoNode(IO.ComfyNode):
|
|||||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||||
response_model=VideoStatusResponse,
|
response_model=VideoStatusResponse,
|
||||||
|
price_extractor=_extract_grok_price,
|
||||||
)
|
)
|
||||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||||
|
|
||||||
@ -335,7 +394,7 @@ class GrokVideoEditNode(IO.ComfyNode):
|
|||||||
category="api node/video/Grok",
|
category="api node/video/Grok",
|
||||||
description="Edit an existing video based on a text prompt.",
|
description="Edit an existing video based on a text prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input("model", options=["grok-imagine-video-beta"]),
|
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
multiline=True,
|
multiline=True,
|
||||||
@ -364,7 +423,7 @@ class GrokVideoEditNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd": 0.191, "format": {"suffix": "/sec", "approximate": true}}""",
|
expr="""{"type":"usd","usd": 0.06, "format": {"suffix": "/sec", "approximate": true}}""",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -398,6 +457,7 @@ class GrokVideoEditNode(IO.ComfyNode):
|
|||||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||||
response_model=VideoStatusResponse,
|
response_model=VideoStatusResponse,
|
||||||
|
price_extractor=_extract_grok_price,
|
||||||
)
|
)
|
||||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.15.1"
|
__version__ = "0.16.0"
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.15.1"
|
version = "0.16.0"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.39.19
|
comfyui-frontend-package==1.39.19
|
||||||
comfyui-workflow-templates==0.9.5
|
comfyui-workflow-templates==0.9.7
|
||||||
comfyui-embedded-docs==0.4.3
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user