mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
The LTXBaseModel.forward and _forward methods required attention_mask as a positional argument with no default value. However, LTXV.extra_conds only conditionally adds attention_mask to model_conds when it is present in kwargs. If attention_mask is not provided by the text encoder, the diffusion_model forward call fails with: TypeError: LTXBaseModel.forward() missing 1 required positional argument: 'attention_mask' The model already handles attention_mask=None correctly in both _prepare_attention_mask and _prepare_context, so making the parameter optional is the minimal safe fix. This also aligns with how LTXAVDoubleStreamBlock.forward handles the parameter.
1335 lines
52 KiB
Python
1335 lines
52 KiB
Python
from abc import ABC, abstractmethod
|
|
from enum import Enum
|
|
import functools
|
|
import logging
|
|
import math
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
from einops import rearrange
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
import comfy.patcher_extension
|
|
import comfy.ldm.modules.attention
|
|
import comfy.ldm.common_dit
|
|
|
|
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def _log_base(x, base):
|
|
return np.log(x) / np.log(base)
|
|
|
|
class LTXRopeType(str, Enum):
|
|
INTERLEAVED = "interleaved"
|
|
SPLIT = "split"
|
|
|
|
KEY = "rope_type"
|
|
|
|
@classmethod
|
|
def from_dict(cls, kwargs, default=None):
|
|
if default is None:
|
|
default = cls.INTERLEAVED
|
|
return cls(kwargs.get(cls.KEY, default))
|
|
|
|
|
|
class LTXFrequenciesPrecision(str, Enum):
|
|
FLOAT32 = "float32"
|
|
FLOAT64 = "float64"
|
|
|
|
KEY = "frequencies_precision"
|
|
|
|
@classmethod
|
|
def from_dict(cls, kwargs, default=None):
|
|
if default is None:
|
|
default = cls.FLOAT32
|
|
return cls(kwargs.get(cls.KEY, default))
|
|
|
|
|
|
def get_timestep_embedding(
|
|
timesteps: torch.Tensor,
|
|
embedding_dim: int,
|
|
flip_sin_to_cos: bool = False,
|
|
downscale_freq_shift: float = 1,
|
|
scale: float = 1,
|
|
max_period: int = 10000,
|
|
):
|
|
"""
|
|
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
|
|
|
Args
|
|
timesteps (torch.Tensor):
|
|
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
|
embedding_dim (int):
|
|
the dimension of the output.
|
|
flip_sin_to_cos (bool):
|
|
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
|
downscale_freq_shift (float):
|
|
Controls the delta between frequencies between dimensions
|
|
scale (float):
|
|
Scaling factor applied to the embeddings.
|
|
max_period (int):
|
|
Controls the maximum frequency of the embeddings
|
|
Returns
|
|
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
|
"""
|
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
|
|
|
half_dim = embedding_dim // 2
|
|
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
|
exponent = exponent / (half_dim - downscale_freq_shift)
|
|
|
|
emb = torch.exp(exponent)
|
|
emb = timesteps[:, None].float() * emb[None, :]
|
|
|
|
# scale embeddings
|
|
emb = scale * emb
|
|
|
|
# concat sine and cosine embeddings
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
|
|
|
# flip sine and cosine embeddings
|
|
if flip_sin_to_cos:
|
|
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
|
|
|
# zero pad
|
|
if embedding_dim % 2 == 1:
|
|
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
|
return emb
|
|
|
|
|
|
class TimestepEmbedding(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
time_embed_dim: int,
|
|
act_fn: str = "silu",
|
|
out_dim: int = None,
|
|
post_act_fn: Optional[str] = None,
|
|
cond_proj_dim=None,
|
|
sample_proj_bias=True,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device)
|
|
|
|
if cond_proj_dim is not None:
|
|
self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device)
|
|
else:
|
|
self.cond_proj = None
|
|
|
|
self.act = nn.SiLU()
|
|
|
|
if out_dim is not None:
|
|
time_embed_dim_out = out_dim
|
|
else:
|
|
time_embed_dim_out = time_embed_dim
|
|
self.linear_2 = operations.Linear(
|
|
time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device
|
|
)
|
|
|
|
if post_act_fn is None:
|
|
self.post_act = None
|
|
# else:
|
|
# self.post_act = get_activation(post_act_fn)
|
|
|
|
def forward(self, sample, condition=None):
|
|
if condition is not None:
|
|
sample = sample + self.cond_proj(condition)
|
|
sample = self.linear_1(sample)
|
|
|
|
if self.act is not None:
|
|
sample = self.act(sample)
|
|
|
|
sample = self.linear_2(sample)
|
|
|
|
if self.post_act is not None:
|
|
sample = self.post_act(sample)
|
|
return sample
|
|
|
|
|
|
class Timesteps(nn.Module):
|
|
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
|
super().__init__()
|
|
self.num_channels = num_channels
|
|
self.flip_sin_to_cos = flip_sin_to_cos
|
|
self.downscale_freq_shift = downscale_freq_shift
|
|
self.scale = scale
|
|
|
|
def forward(self, timesteps):
|
|
t_emb = get_timestep_embedding(
|
|
timesteps,
|
|
self.num_channels,
|
|
flip_sin_to_cos=self.flip_sin_to_cos,
|
|
downscale_freq_shift=self.downscale_freq_shift,
|
|
scale=self.scale,
|
|
)
|
|
return t_emb
|
|
|
|
|
|
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
|
"""
|
|
For PixArt-Alpha.
|
|
|
|
Reference:
|
|
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embedding_dim,
|
|
size_emb_dim,
|
|
use_additional_conditions: bool = False,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.outdim = size_emb_dim
|
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
|
self.timestep_embedder = TimestepEmbedding(
|
|
in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations
|
|
)
|
|
|
|
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
|
timesteps_proj = self.time_proj(timestep)
|
|
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
|
return timesteps_emb
|
|
|
|
|
|
class AdaLayerNormSingle(nn.Module):
|
|
r"""
|
|
Norm layer adaptive layer norm single (adaLN-single).
|
|
|
|
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
|
|
|
Parameters:
|
|
embedding_dim (`int`): The size of each embedding vector.
|
|
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
|
"""
|
|
|
|
def __init__(
|
|
self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None
|
|
):
|
|
super().__init__()
|
|
|
|
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
|
embedding_dim,
|
|
size_emb_dim=embedding_dim // 3,
|
|
use_additional_conditions=use_additional_conditions,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
self.silu = nn.SiLU()
|
|
self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device)
|
|
|
|
def forward(
|
|
self,
|
|
timestep: torch.Tensor,
|
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
|
batch_size: Optional[int] = None,
|
|
hidden_dtype: Optional[torch.dtype] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
# No modulation happening here.
|
|
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
|
|
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
|
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
|
|
|
|
|
class PixArtAlphaTextProjection(nn.Module):
|
|
"""
|
|
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
|
|
|
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
|
"""
|
|
|
|
def __init__(
|
|
self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None
|
|
):
|
|
super().__init__()
|
|
if out_features is None:
|
|
out_features = hidden_size
|
|
self.linear_1 = operations.Linear(
|
|
in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device
|
|
)
|
|
if act_fn == "gelu_tanh":
|
|
self.act_1 = nn.GELU(approximate="tanh")
|
|
elif act_fn == "silu":
|
|
self.act_1 = nn.SiLU()
|
|
else:
|
|
raise ValueError(f"Unknown activation function: {act_fn}")
|
|
self.linear_2 = operations.Linear(
|
|
in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device
|
|
)
|
|
|
|
def forward(self, caption):
|
|
hidden_states = self.linear_1(caption)
|
|
hidden_states = self.act_1(hidden_states)
|
|
hidden_states = self.linear_2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class NormSingleLinearTextProjection(nn.Module):
|
|
"""Text projection for 20B models - single linear with RMSNorm (no activation)."""
|
|
|
|
def __init__(
|
|
self, in_features, hidden_size, dtype=None, device=None, operations=None
|
|
):
|
|
super().__init__()
|
|
if operations is None:
|
|
operations = comfy.ops.disable_weight_init
|
|
self.in_norm = operations.RMSNorm(
|
|
in_features, eps=1e-6, elementwise_affine=False
|
|
)
|
|
self.linear_1 = operations.Linear(
|
|
in_features, hidden_size, bias=True, dtype=dtype, device=device
|
|
)
|
|
self.hidden_size = hidden_size
|
|
self.in_features = in_features
|
|
|
|
def forward(self, caption):
|
|
caption = self.in_norm(caption)
|
|
caption = caption * (self.hidden_size / self.in_features) ** 0.5
|
|
return self.linear_1(caption)
|
|
|
|
|
|
class GELU_approx(nn.Module):
|
|
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device)
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
inner_dim = int(dim * mult)
|
|
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
|
|
|
self.net = nn.Sequential(
|
|
project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
def apply_rotary_emb(input_tensor, freqs_cis):
|
|
cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1]
|
|
split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False
|
|
return (
|
|
apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs)
|
|
if split_pe else
|
|
apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs)
|
|
)
|
|
|
|
def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one
|
|
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
|
t1, t2 = t_dup.unbind(dim=-1)
|
|
t_dup = torch.stack((-t2, t1), dim=-1)
|
|
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
|
|
|
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
|
|
|
return out
|
|
|
|
def apply_split_rotary_emb(input_tensor, cos, sin):
|
|
needs_reshape = False
|
|
if input_tensor.ndim != 4 and cos.ndim == 4:
|
|
B, H, T, _ = cos.shape
|
|
input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2)
|
|
needs_reshape = True
|
|
split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2)
|
|
first_half_input = split_input[..., :1, :]
|
|
second_half_input = split_input[..., 1:, :]
|
|
output = split_input * cos.unsqueeze(-2)
|
|
first_half_output = output[..., :1, :]
|
|
second_half_output = output[..., 1:, :]
|
|
first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input)
|
|
second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input)
|
|
output = rearrange(output, "... d r -> ... (d r)")
|
|
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
|
|
|
|
|
|
class CrossAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
query_dim,
|
|
context_dim=None,
|
|
heads=8,
|
|
dim_head=64,
|
|
dropout=0.0,
|
|
attn_precision=None,
|
|
apply_gated_attention=False,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
inner_dim = dim_head * heads
|
|
context_dim = query_dim if context_dim is None else context_dim
|
|
self.attn_precision = attn_precision
|
|
|
|
self.heads = heads
|
|
self.dim_head = dim_head
|
|
|
|
self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
|
|
self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
|
|
|
|
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
|
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
|
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
|
|
|
# Optional per-head gating
|
|
if apply_gated_attention:
|
|
self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device)
|
|
else:
|
|
self.to_gate_logits = None
|
|
|
|
self.to_out = nn.Sequential(
|
|
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
|
)
|
|
|
|
def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}):
|
|
q = self.to_q(x)
|
|
context = x if context is None else context
|
|
k = self.to_k(context)
|
|
v = self.to_v(context)
|
|
|
|
q = self.q_norm(q)
|
|
k = self.k_norm(k)
|
|
|
|
if pe is not None:
|
|
q = apply_rotary_emb(q, pe)
|
|
k = apply_rotary_emb(k, pe if k_pe is None else k_pe)
|
|
|
|
if mask is None:
|
|
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
|
else:
|
|
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
|
|
|
# Apply per-head gating if enabled
|
|
if self.to_gate_logits is not None:
|
|
gate_logits = self.to_gate_logits(x) # (B, T, H)
|
|
b, t, _ = out.shape
|
|
out = out.view(b, t, self.heads, self.dim_head)
|
|
gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity
|
|
out = out * gates.unsqueeze(-1)
|
|
out = out.view(b, t, self.heads * self.dim_head)
|
|
|
|
return self.to_out(out)
|
|
|
|
# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate)
|
|
ADALN_BASE_PARAMS_COUNT = 6
|
|
ADALN_CROSS_ATTN_PARAMS_COUNT = 9
|
|
|
|
class BasicTransformerBlock(nn.Module):
|
|
def __init__(
|
|
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None
|
|
):
|
|
super().__init__()
|
|
|
|
self.attn_precision = attn_precision
|
|
self.cross_attention_adaln = cross_attention_adaln
|
|
self.attn1 = CrossAttention(
|
|
query_dim=dim,
|
|
heads=n_heads,
|
|
dim_head=d_head,
|
|
context_dim=None,
|
|
attn_precision=self.attn_precision,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
|
|
|
self.attn2 = CrossAttention(
|
|
query_dim=dim,
|
|
context_dim=context_dim,
|
|
heads=n_heads,
|
|
dim_head=d_head,
|
|
attn_precision=self.attn_precision,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
|
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype))
|
|
|
|
if cross_attention_adaln:
|
|
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype))
|
|
|
|
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None):
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2)
|
|
|
|
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa
|
|
|
|
if self.cross_attention_adaln:
|
|
shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2)
|
|
x += apply_cross_attention_adaln(
|
|
x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca,
|
|
self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options,
|
|
)
|
|
else:
|
|
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
|
|
|
y = comfy.ldm.common_dit.rms_norm(x)
|
|
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
|
x.addcmul_(self.ff(y), gate_mlp)
|
|
|
|
return x
|
|
|
|
def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype):
|
|
"""Compute a single global prompt timestep for cross-attention ADaLN.
|
|
|
|
Uses the max across tokens (matching JAX max_per_segment) and broadcasts
|
|
over text tokens. Returns None when *adaln_module* is None.
|
|
"""
|
|
if adaln_module is None:
|
|
return None
|
|
ts_input = (
|
|
timestep_scaled.max(dim=1, keepdim=True).values.flatten()
|
|
if timestep_scaled.dim() > 1
|
|
else timestep_scaled.flatten()
|
|
)
|
|
prompt_ts, _ = adaln_module(
|
|
ts_input,
|
|
{"resolution": None, "aspect_ratio": None},
|
|
batch_size=batch_size,
|
|
hidden_dtype=hidden_dtype,
|
|
)
|
|
return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1])
|
|
|
|
|
|
def apply_cross_attention_adaln(
|
|
x, context, attn, q_shift, q_scale, q_gate,
|
|
prompt_scale_shift_table, prompt_timestep,
|
|
attention_mask=None, transformer_options={},
|
|
):
|
|
"""Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV).
|
|
|
|
Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so
|
|
that both regular tensors and CompressedTimestep are supported.
|
|
"""
|
|
batch_size = x.shape[0]
|
|
shift_kv, scale_kv = (
|
|
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
|
|
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
|
|
).unbind(dim=2)
|
|
attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift
|
|
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
|
|
return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate
|
|
|
|
def get_fractional_positions(indices_grid, max_pos):
|
|
n_pos_dims = indices_grid.shape[1]
|
|
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
|
fractional_positions = torch.stack(
|
|
[indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)],
|
|
axis=-1,
|
|
)
|
|
return fractional_positions
|
|
|
|
|
|
@functools.lru_cache(maxsize=5)
|
|
def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None):
|
|
theta = positional_embedding_theta
|
|
start = 1
|
|
end = theta
|
|
|
|
n_elem = 2 * positional_embedding_max_pos_count
|
|
pow_indices = np.power(
|
|
theta,
|
|
np.linspace(
|
|
_log_base(start, theta),
|
|
_log_base(end, theta),
|
|
inner_dim // n_elem,
|
|
dtype=np.float64,
|
|
),
|
|
)
|
|
return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32)
|
|
|
|
def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device):
|
|
theta = positional_embedding_theta
|
|
start = 1
|
|
end = theta
|
|
n_elem = 2 * positional_embedding_max_pos_count
|
|
|
|
indices = theta ** (
|
|
torch.linspace(
|
|
math.log(start, theta),
|
|
math.log(end, theta),
|
|
inner_dim // n_elem,
|
|
device=device,
|
|
dtype=torch.float32,
|
|
)
|
|
)
|
|
indices = indices.to(dtype=torch.float32)
|
|
|
|
indices = indices * math.pi / 2
|
|
|
|
return indices
|
|
|
|
def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid):
|
|
if use_middle_indices_grid:
|
|
assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2)
|
|
indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1]
|
|
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
|
|
elif len(indices_grid.shape) == 4:
|
|
indices_grid = indices_grid[..., 0]
|
|
|
|
# Get fractional positions and compute frequency indices
|
|
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
|
indices = indices.to(device=fractional_positions.device)
|
|
|
|
freqs = (
|
|
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
|
.transpose(-1, -2)
|
|
.flatten(2)
|
|
)
|
|
return freqs
|
|
|
|
def interleaved_freqs_cis(freqs, pad_size):
|
|
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
|
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
|
if pad_size != 0:
|
|
cos_padding = torch.ones_like(cos_freq[:, :, : pad_size])
|
|
sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size])
|
|
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
|
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
|
return cos_freq, sin_freq
|
|
|
|
def split_freqs_cis(freqs, pad_size, num_attention_heads):
|
|
cos_freq = freqs.cos()
|
|
sin_freq = freqs.sin()
|
|
|
|
if pad_size != 0:
|
|
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
|
|
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
|
|
|
|
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
|
|
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
|
|
|
|
# Reshape freqs to be compatible with multi-head attention
|
|
B , T, half_HD = cos_freq.shape
|
|
|
|
cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
|
|
sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
|
|
|
|
cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
|
|
sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
|
|
return cos_freq, sin_freq
|
|
|
|
class LTXBaseModel(torch.nn.Module, ABC):
|
|
"""
|
|
Abstract base class for LTX models (Lightricks Transformer models).
|
|
|
|
This class defines the common interface and shared functionality for all LTX models,
|
|
including LTXV (video) and LTXAV (audio-video) variants.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
cross_attention_dim: int,
|
|
attention_head_dim: int,
|
|
num_attention_heads: int,
|
|
caption_channels: int,
|
|
num_layers: int,
|
|
positional_embedding_theta: float = 10000.0,
|
|
positional_embedding_max_pos: list = [20, 2048, 2048],
|
|
causal_temporal_positioning: bool = False,
|
|
vae_scale_factors: tuple = (8, 32, 32),
|
|
use_middle_indices_grid=False,
|
|
timestep_scale_multiplier = 1000.0,
|
|
caption_proj_before_connector=False,
|
|
cross_attention_adaln=False,
|
|
caption_projection_first_linear=True,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.generator = None
|
|
self.vae_scale_factors = vae_scale_factors
|
|
self.use_middle_indices_grid = use_middle_indices_grid
|
|
self.dtype = dtype
|
|
self.in_channels = in_channels
|
|
self.cross_attention_dim = cross_attention_dim
|
|
self.attention_head_dim = attention_head_dim
|
|
self.num_attention_heads = num_attention_heads
|
|
self.caption_channels = caption_channels
|
|
self.num_layers = num_layers
|
|
self.positional_embedding_theta = positional_embedding_theta
|
|
self.positional_embedding_max_pos = positional_embedding_max_pos
|
|
self.split_positional_embedding = LTXRopeType.from_dict(kwargs)
|
|
self.freq_grid_generator = (
|
|
generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64
|
|
else generate_freq_grid_pytorch
|
|
)
|
|
self.causal_temporal_positioning = causal_temporal_positioning
|
|
self.operations = operations
|
|
self.timestep_scale_multiplier = timestep_scale_multiplier
|
|
self.caption_proj_before_connector = caption_proj_before_connector
|
|
self.cross_attention_adaln = cross_attention_adaln
|
|
self.caption_projection_first_linear = caption_projection_first_linear
|
|
|
|
# Common dimensions
|
|
self.inner_dim = num_attention_heads * attention_head_dim
|
|
self.out_channels = in_channels
|
|
|
|
# Initialize common components
|
|
self._init_common_components(device, dtype)
|
|
|
|
# Initialize model-specific components
|
|
self._init_model_components(device, dtype, **kwargs)
|
|
|
|
# Initialize transformer blocks
|
|
self._init_transformer_blocks(device, dtype, **kwargs)
|
|
|
|
# Initialize output components
|
|
self._init_output_components(device, dtype)
|
|
|
|
def _init_common_components(self, device, dtype):
|
|
"""Initialize components common to all LTX models
|
|
- patchify_proj: Linear projection for patchifying input
|
|
- adaln_single: AdaLN layer for timestep embedding
|
|
- caption_projection: Linear projection for caption embedding
|
|
"""
|
|
self.patchify_proj = self.operations.Linear(
|
|
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
|
)
|
|
|
|
embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
|
|
self.adaln_single = AdaLayerNormSingle(
|
|
self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
|
)
|
|
|
|
if self.cross_attention_adaln:
|
|
self.prompt_adaln_single = AdaLayerNormSingle(
|
|
self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
|
)
|
|
else:
|
|
self.prompt_adaln_single = None
|
|
|
|
if self.caption_proj_before_connector:
|
|
if self.caption_projection_first_linear:
|
|
self.caption_projection = NormSingleLinearTextProjection(
|
|
in_features=self.caption_channels,
|
|
hidden_size=self.inner_dim,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
else:
|
|
self.caption_projection = lambda a: a
|
|
else:
|
|
self.caption_projection = PixArtAlphaTextProjection(
|
|
in_features=self.caption_channels,
|
|
hidden_size=self.inner_dim,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
|
|
@abstractmethod
|
|
def _init_model_components(self, device, dtype, **kwargs):
|
|
"""Initialize model-specific components. Must be implemented by subclasses."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
|
"""Initialize transformer blocks. Must be implemented by subclasses."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _init_output_components(self, device, dtype):
|
|
"""Initialize output components. Must be implemented by subclasses."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
|
"""Process input data. Must be implemented by subclasses."""
|
|
pass
|
|
|
|
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
|
|
"""Build self-attention mask for per-guide attention attenuation.
|
|
|
|
Base implementation returns None (no attenuation). Subclasses that
|
|
support guide-based attention control should override this.
|
|
"""
|
|
return None
|
|
|
|
@abstractmethod
|
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, self_attention_mask=None, **kwargs):
|
|
"""Process transformer blocks. Must be implemented by subclasses."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
|
"""Process output data. Must be implemented by subclasses."""
|
|
pass
|
|
|
|
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
|
"""Prepare timestep embeddings."""
|
|
grid_mask = kwargs.get("grid_mask", None)
|
|
if grid_mask is not None:
|
|
timestep = timestep[:, grid_mask]
|
|
|
|
timestep_scaled = timestep * self.timestep_scale_multiplier
|
|
timestep, embedded_timestep = self.adaln_single(
|
|
timestep_scaled.flatten(),
|
|
{"resolution": None, "aspect_ratio": None},
|
|
batch_size=batch_size,
|
|
hidden_dtype=hidden_dtype,
|
|
)
|
|
|
|
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
|
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
|
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
|
|
|
prompt_timestep = compute_prompt_timestep(
|
|
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
|
|
)
|
|
|
|
return timestep, embedded_timestep, prompt_timestep
|
|
|
|
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
|
"""Prepare context for transformer blocks."""
|
|
if self.caption_proj_before_connector is False:
|
|
context = self.caption_projection(context)
|
|
|
|
context = context.view(batch_size, -1, x.shape[-1])
|
|
return context, attention_mask
|
|
|
|
def _precompute_freqs_cis(
|
|
self,
|
|
indices_grid,
|
|
dim,
|
|
out_dtype,
|
|
theta=10000.0,
|
|
max_pos=[20, 2048, 2048],
|
|
use_middle_indices_grid=False,
|
|
num_attention_heads=32,
|
|
):
|
|
split_mode = self.split_positional_embedding == LTXRopeType.SPLIT
|
|
indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device)
|
|
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
|
|
|
|
if split_mode:
|
|
expected_freqs = dim // 2
|
|
current_freqs = freqs.shape[-1]
|
|
pad_size = expected_freqs - current_freqs
|
|
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
|
|
else:
|
|
# 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only
|
|
n_elem = 2 * indices_grid.shape[1]
|
|
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
|
return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode
|
|
|
|
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
|
"""Prepare positional embeddings."""
|
|
fractional_coords = pixel_coords.to(torch.float32)
|
|
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
|
pe = self._precompute_freqs_cis(
|
|
fractional_coords,
|
|
dim=self.inner_dim,
|
|
out_dtype=x_dtype,
|
|
max_pos=self.positional_embedding_max_pos,
|
|
use_middle_indices_grid=self.use_middle_indices_grid,
|
|
num_attention_heads=self.num_attention_heads,
|
|
)
|
|
return pe
|
|
|
|
def _prepare_attention_mask(self, attention_mask, x_dtype):
|
|
"""Prepare attention mask."""
|
|
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
|
attention_mask = (attention_mask - 1).to(x_dtype).reshape(
|
|
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
|
) * torch.finfo(x_dtype).max
|
|
return attention_mask
|
|
|
|
def forward(
|
|
self, x, timestep, context, attention_mask=None, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
|
):
|
|
"""
|
|
Forward pass for LTX models.
|
|
|
|
Args:
|
|
x: Input tensor
|
|
timestep: Timestep tensor
|
|
context: Context tensor (e.g., text embeddings)
|
|
attention_mask: Attention mask tensor (optional)
|
|
frame_rate: Frame rate for temporal processing
|
|
transformer_options: Additional options for transformer blocks
|
|
keyframe_idxs: Keyframe indices for temporal processing
|
|
**kwargs: Additional keyword arguments
|
|
|
|
Returns:
|
|
Processed output tensor
|
|
"""
|
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
|
self._forward,
|
|
self,
|
|
comfy.patcher_extension.get_all_wrappers(
|
|
comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options
|
|
),
|
|
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs)
|
|
|
|
def _forward(
|
|
self, x, timestep, context, attention_mask=None, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
|
):
|
|
"""
|
|
Internal forward pass for LTX models.
|
|
|
|
Args:
|
|
x: Input tensor
|
|
timestep: Timestep tensor
|
|
context: Context tensor (e.g., text embeddings)
|
|
attention_mask: Attention mask tensor
|
|
frame_rate: Frame rate for temporal processing
|
|
transformer_options: Additional options for transformer blocks
|
|
keyframe_idxs: Keyframe indices for temporal processing
|
|
**kwargs: Additional keyword arguments
|
|
|
|
Returns:
|
|
Processed output tensor
|
|
"""
|
|
if isinstance(x, list):
|
|
input_dtype = x[0].dtype
|
|
batch_size = x[0].shape[0]
|
|
else:
|
|
input_dtype = x.dtype
|
|
batch_size = x.shape[0]
|
|
# Process input
|
|
merged_args = {**transformer_options, **kwargs}
|
|
x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args)
|
|
merged_args.update(additional_args)
|
|
|
|
# Prepare timestep and context
|
|
timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
|
merged_args["prompt_timestep"] = prompt_timestep
|
|
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
|
|
|
# Prepare attention mask and positional embeddings
|
|
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
|
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
|
|
|
# Build self-attention mask for per-guide attenuation
|
|
self_attention_mask = self._build_guide_self_attention_mask(
|
|
x, transformer_options, merged_args
|
|
)
|
|
|
|
# Process transformer blocks
|
|
x = self._process_transformer_blocks(
|
|
x, context, attention_mask, timestep, pe,
|
|
transformer_options=transformer_options,
|
|
self_attention_mask=self_attention_mask,
|
|
**merged_args,
|
|
)
|
|
|
|
# Process output
|
|
x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args)
|
|
return x
|
|
|
|
|
|
class LTXVModel(LTXBaseModel):
|
|
"""LTXV model for video generation."""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels=128,
|
|
cross_attention_dim=2048,
|
|
attention_head_dim=64,
|
|
num_attention_heads=32,
|
|
caption_channels=4096,
|
|
num_layers=28,
|
|
positional_embedding_theta=10000.0,
|
|
positional_embedding_max_pos=[20, 2048, 2048],
|
|
causal_temporal_positioning=False,
|
|
vae_scale_factors=(8, 32, 32),
|
|
use_middle_indices_grid=False,
|
|
timestep_scale_multiplier=1000.0,
|
|
caption_proj_before_connector=False,
|
|
cross_attention_adaln=False,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__(
|
|
in_channels=in_channels,
|
|
cross_attention_dim=cross_attention_dim,
|
|
attention_head_dim=attention_head_dim,
|
|
num_attention_heads=num_attention_heads,
|
|
caption_channels=caption_channels,
|
|
num_layers=num_layers,
|
|
positional_embedding_theta=positional_embedding_theta,
|
|
positional_embedding_max_pos=positional_embedding_max_pos,
|
|
causal_temporal_positioning=causal_temporal_positioning,
|
|
vae_scale_factors=vae_scale_factors,
|
|
use_middle_indices_grid=use_middle_indices_grid,
|
|
timestep_scale_multiplier=timestep_scale_multiplier,
|
|
caption_proj_before_connector=caption_proj_before_connector,
|
|
cross_attention_adaln=cross_attention_adaln,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
**kwargs,
|
|
)
|
|
|
|
def _init_model_components(self, device, dtype, **kwargs):
|
|
"""Initialize LTXV-specific components."""
|
|
pass
|
|
|
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
|
"""Initialize transformer blocks for LTXV."""
|
|
self.transformer_blocks = nn.ModuleList(
|
|
[
|
|
BasicTransformerBlock(
|
|
self.inner_dim,
|
|
self.num_attention_heads,
|
|
self.attention_head_dim,
|
|
context_dim=self.cross_attention_dim,
|
|
cross_attention_adaln=self.cross_attention_adaln,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
for _ in range(self.num_layers)
|
|
]
|
|
)
|
|
|
|
def _init_output_components(self, device, dtype):
|
|
"""Initialize output components for LTXV."""
|
|
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
|
self.norm_out = self.operations.LayerNorm(
|
|
self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
|
)
|
|
self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
|
self.patchifier = SymmetricPatchifier(1, start_end=True)
|
|
|
|
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
|
"""Process input for LTXV."""
|
|
additional_args = {"orig_shape": list(x.shape)}
|
|
x, latent_coords = self.patchifier.patchify(x)
|
|
pixel_coords = latent_to_pixel_coords(
|
|
latent_coords=latent_coords,
|
|
scale_factors=self.vae_scale_factors,
|
|
causal_fix=self.causal_temporal_positioning,
|
|
)
|
|
|
|
grid_mask = None
|
|
if keyframe_idxs is not None:
|
|
additional_args.update({ "orig_patchified_shape": list(x.shape)})
|
|
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
|
|
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
|
|
additional_args.update({"grid_mask": grid_mask})
|
|
x = x[:, grid_mask, :]
|
|
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
|
|
|
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
|
|
|
# Compute per-guide surviving token counts from guide_attention_entries.
|
|
# Each entry tracks one guide reference; they are appended in order and
|
|
# their pre_filter_counts partition the kf_grid_mask.
|
|
guide_entries = kwargs.get("guide_attention_entries", None)
|
|
if guide_entries:
|
|
total_pfc = sum(e["pre_filter_count"] for e in guide_entries)
|
|
if total_pfc != len(kf_grid_mask):
|
|
raise ValueError(
|
|
f"guide pre_filter_counts ({total_pfc}) != "
|
|
f"keyframe grid mask length ({len(kf_grid_mask)})"
|
|
)
|
|
resolved_entries = []
|
|
offset = 0
|
|
for entry in guide_entries:
|
|
pfc = entry["pre_filter_count"]
|
|
entry_mask = kf_grid_mask[offset:offset + pfc]
|
|
surviving = int(entry_mask.sum().item())
|
|
resolved_entries.append({
|
|
**entry,
|
|
"surviving_count": surviving,
|
|
})
|
|
offset += pfc
|
|
additional_args["resolved_guide_entries"] = resolved_entries
|
|
|
|
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
|
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
|
|
|
# Total surviving guide tokens (all guides)
|
|
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
|
|
|
|
x = self.patchify_proj(x)
|
|
return x, pixel_coords, additional_args
|
|
|
|
def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
|
|
"""Build self-attention mask for per-guide attention attenuation.
|
|
|
|
Reads resolved_guide_entries from merged_args (computed in _process_input)
|
|
to build a log-space additive bias mask that attenuates noisy ↔ guide
|
|
attention for each guide reference independently.
|
|
|
|
Returns None if no attenuation is needed (all strengths == 1.0 and no
|
|
spatial masks, or no guide tokens).
|
|
"""
|
|
if isinstance(x, list):
|
|
# AV model: x = [vx, ax]; use vx for token count and device
|
|
total_tokens = x[0].shape[1]
|
|
device = x[0].device
|
|
dtype = x[0].dtype
|
|
else:
|
|
total_tokens = x.shape[1]
|
|
device = x.device
|
|
dtype = x.dtype
|
|
|
|
num_guide_tokens = merged_args.get("num_guide_tokens", 0)
|
|
if num_guide_tokens == 0:
|
|
return None
|
|
|
|
resolved_entries = merged_args.get("resolved_guide_entries", None)
|
|
if not resolved_entries:
|
|
return None
|
|
|
|
# Check if any attenuation is actually needed
|
|
needs_attenuation = any(
|
|
e["strength"] < 1.0 or e.get("pixel_mask") is not None
|
|
for e in resolved_entries
|
|
)
|
|
if not needs_attenuation:
|
|
return None
|
|
|
|
# Build per-guide-token weights for all tracked guide tokens.
|
|
# Guides are appended in order at the end of the sequence.
|
|
guide_start = total_tokens - num_guide_tokens
|
|
all_weights = []
|
|
total_tracked = 0
|
|
|
|
for entry in resolved_entries:
|
|
surviving = entry["surviving_count"]
|
|
if surviving == 0:
|
|
continue
|
|
|
|
strength = entry["strength"]
|
|
pixel_mask = entry.get("pixel_mask")
|
|
latent_shape = entry.get("latent_shape")
|
|
|
|
if pixel_mask is not None and latent_shape is not None:
|
|
f_lat, h_lat, w_lat = latent_shape
|
|
per_token = self._downsample_mask_to_latent(
|
|
pixel_mask.to(device=device, dtype=dtype),
|
|
f_lat, h_lat, w_lat,
|
|
)
|
|
# per_token shape: (B, f_lat*h_lat*w_lat).
|
|
# Collapse batch dim — the mask is assumed identical across the
|
|
# batch; validate and take the first element to get (1, tokens).
|
|
if per_token.shape[0] > 1:
|
|
ref = per_token[0]
|
|
for bi in range(1, per_token.shape[0]):
|
|
if not torch.equal(ref, per_token[bi]):
|
|
logger.warning(
|
|
"pixel_mask differs across batch elements; "
|
|
"using first element only."
|
|
)
|
|
break
|
|
per_token = per_token[:1]
|
|
# `surviving` is the post-grid_mask token count.
|
|
# Clamp to surviving to handle any mismatch safely.
|
|
n_weights = min(per_token.shape[1], surviving)
|
|
weights = per_token[:, :n_weights] * strength # (1, n_weights)
|
|
else:
|
|
weights = torch.full(
|
|
(1, surviving), strength, device=device, dtype=dtype
|
|
)
|
|
|
|
all_weights.append(weights)
|
|
total_tracked += weights.shape[1]
|
|
|
|
if not all_weights:
|
|
return None
|
|
|
|
# Concatenate per-token weights for all tracked guides
|
|
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
|
|
|
|
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
|
|
if (tracked_weights >= 1.0).all():
|
|
return None
|
|
|
|
# Build the mask: guide tokens are at the end of the sequence.
|
|
# Tracked guides come first (in order), untracked follow.
|
|
return self._build_self_attention_mask(
|
|
total_tokens, num_guide_tokens, total_tracked,
|
|
tracked_weights, guide_start, device, dtype,
|
|
)
|
|
|
|
@staticmethod
|
|
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
|
|
"""Downsample a pixel-space mask to per-token latent weights.
|
|
|
|
Args:
|
|
mask: (B, 1, F_pix, H_pix, W_pix) pixel-space mask with values in [0, 1].
|
|
f_lat: Number of latent frames (pre-dilation original count).
|
|
h_lat: Latent height (pre-dilation original height).
|
|
w_lat: Latent width (pre-dilation original width).
|
|
|
|
Returns:
|
|
(B, F_lat * H_lat * W_lat) flattened per-token weights.
|
|
"""
|
|
b = mask.shape[0]
|
|
f_pix = mask.shape[2]
|
|
|
|
# Spatial downsampling: area interpolation per frame
|
|
spatial_down = torch.nn.functional.interpolate(
|
|
rearrange(mask, "b 1 f h w -> (b f) 1 h w"),
|
|
size=(h_lat, w_lat),
|
|
mode="area",
|
|
)
|
|
spatial_down = rearrange(spatial_down, "(b f) 1 h w -> b 1 f h w", b=b)
|
|
|
|
# Temporal downsampling: first pixel frame maps to first latent frame,
|
|
# remaining pixel frames are averaged in groups for causal temporal structure.
|
|
first_frame = spatial_down[:, :, :1, :, :]
|
|
if f_pix > 1 and f_lat > 1:
|
|
remaining_pix = f_pix - 1
|
|
remaining_lat = f_lat - 1
|
|
t = remaining_pix // remaining_lat
|
|
if t < 1:
|
|
# Fewer pixel frames than latent frames — upsample by repeating
|
|
# the available pixel frames via nearest interpolation.
|
|
rest_flat = rearrange(
|
|
spatial_down[:, :, 1:, :, :],
|
|
"b 1 f h w -> (b h w) 1 f",
|
|
)
|
|
rest_up = torch.nn.functional.interpolate(
|
|
rest_flat, size=remaining_lat, mode="nearest",
|
|
)
|
|
rest = rearrange(
|
|
rest_up, "(b h w) 1 f -> b 1 f h w",
|
|
b=b, h=h_lat, w=w_lat,
|
|
)
|
|
else:
|
|
# Trim trailing pixel frames that don't fill a complete group
|
|
usable = remaining_lat * t
|
|
rest = rearrange(
|
|
spatial_down[:, :, 1:1 + usable, :, :],
|
|
"b 1 (f t) h w -> b 1 f t h w",
|
|
t=t,
|
|
)
|
|
rest = rest.mean(dim=3)
|
|
latent_mask = torch.cat([first_frame, rest], dim=2)
|
|
elif f_lat > 1:
|
|
# Single pixel frame but multiple latent frames — repeat the
|
|
# single frame across all latent frames.
|
|
latent_mask = first_frame.expand(-1, -1, f_lat, -1, -1)
|
|
else:
|
|
latent_mask = first_frame
|
|
|
|
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
|
|
|
|
@staticmethod
|
|
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
|
|
tracked_weights, guide_start, device, dtype):
|
|
"""Build a log-space additive self-attention bias mask.
|
|
|
|
Attenuates attention between noisy tokens and tracked guide tokens.
|
|
Untracked guide tokens (at the end of the guide portion) keep full attention.
|
|
|
|
Args:
|
|
total_tokens: Total sequence length.
|
|
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
|
|
tracked_count: Number of tracked guide tokens (first in the guide portion).
|
|
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
|
|
guide_start: Index where guide tokens begin in the sequence.
|
|
device: Target device.
|
|
dtype: Target dtype.
|
|
|
|
Returns:
|
|
(1, 1, total_tokens, total_tokens) additive bias mask.
|
|
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
|
|
"""
|
|
finfo = torch.finfo(dtype)
|
|
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
|
|
tracked_end = guide_start + tracked_count
|
|
|
|
# Convert weights to log-space bias
|
|
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
|
|
log_w = torch.full_like(w, finfo.min)
|
|
positive_mask = w > 0
|
|
if positive_mask.any():
|
|
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
|
|
|
|
# noisy → tracked guides: each noisy row gets the same per-guide weight
|
|
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
|
|
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
|
|
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
|
|
|
|
return mask
|
|
|
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
|
|
"""Process transformer blocks for LTXV."""
|
|
patches_replace = transformer_options.get("patches_replace", {})
|
|
blocks_replace = patches_replace.get("dit", {})
|
|
prompt_timestep = kwargs.get("prompt_timestep", None)
|
|
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
if ("double_block", i) in blocks_replace:
|
|
|
|
def block_wrap(args):
|
|
out = {}
|
|
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep"))
|
|
return out
|
|
|
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap})
|
|
x = out["img"]
|
|
else:
|
|
x = block(
|
|
x,
|
|
context=context,
|
|
attention_mask=attention_mask,
|
|
timestep=timestep,
|
|
pe=pe,
|
|
transformer_options=transformer_options,
|
|
self_attention_mask=self_attention_mask,
|
|
prompt_timestep=prompt_timestep,
|
|
)
|
|
|
|
return x
|
|
|
|
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
|
"""Process output for LTXV."""
|
|
# Apply scale-shift modulation
|
|
scale_shift_values = (
|
|
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
|
)
|
|
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
|
|
|
x = self.norm_out(x)
|
|
x = x * (1 + scale) + shift
|
|
x = self.proj_out(x)
|
|
|
|
if keyframe_idxs is not None:
|
|
grid_mask = kwargs["grid_mask"]
|
|
orig_patchified_shape = kwargs["orig_patchified_shape"]
|
|
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
|
|
full_x[:, grid_mask, :] = x
|
|
x = full_x
|
|
# Unpatchify to restore original dimensions
|
|
orig_shape = kwargs["orig_shape"]
|
|
x = self.patchifier.unpatchify(
|
|
latents=x,
|
|
output_height=orig_shape[3],
|
|
output_width=orig_shape[4],
|
|
output_num_frames=orig_shape[2],
|
|
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
|
|
)
|
|
|
|
return x
|