mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-08 13:20:50 +08:00
955 lines
35 KiB
Python
955 lines
35 KiB
Python
from abc import ABC, abstractmethod
|
|
from enum import Enum
|
|
import functools
|
|
import math
|
|
from typing import Dict, Optional, Tuple
|
|
|
|
from einops import rearrange
|
|
import numpy as np
|
|
import torch
|
|
from torch import nn
|
|
import comfy.patcher_extension
|
|
import comfy.ldm.modules.attention
|
|
import comfy.ldm.common_dit
|
|
|
|
from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
|
|
|
|
def _log_base(x, base):
|
|
return np.log(x) / np.log(base)
|
|
|
|
class LTXRopeType(str, Enum):
|
|
INTERLEAVED = "interleaved"
|
|
SPLIT = "split"
|
|
|
|
KEY = "rope_type"
|
|
|
|
@classmethod
|
|
def from_dict(cls, kwargs, default=None):
|
|
if default is None:
|
|
default = cls.INTERLEAVED
|
|
return cls(kwargs.get(cls.KEY, default))
|
|
|
|
|
|
class LTXFrequenciesPrecision(str, Enum):
|
|
FLOAT32 = "float32"
|
|
FLOAT64 = "float64"
|
|
|
|
KEY = "frequencies_precision"
|
|
|
|
@classmethod
|
|
def from_dict(cls, kwargs, default=None):
|
|
if default is None:
|
|
default = cls.FLOAT32
|
|
return cls(kwargs.get(cls.KEY, default))
|
|
|
|
|
|
def get_timestep_embedding(
|
|
timesteps: torch.Tensor,
|
|
embedding_dim: int,
|
|
flip_sin_to_cos: bool = False,
|
|
downscale_freq_shift: float = 1,
|
|
scale: float = 1,
|
|
max_period: int = 10000,
|
|
):
|
|
"""
|
|
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
|
|
|
Args
|
|
timesteps (torch.Tensor):
|
|
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
|
embedding_dim (int):
|
|
the dimension of the output.
|
|
flip_sin_to_cos (bool):
|
|
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
|
downscale_freq_shift (float):
|
|
Controls the delta between frequencies between dimensions
|
|
scale (float):
|
|
Scaling factor applied to the embeddings.
|
|
max_period (int):
|
|
Controls the maximum frequency of the embeddings
|
|
Returns
|
|
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
|
"""
|
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
|
|
|
half_dim = embedding_dim // 2
|
|
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
|
exponent = exponent / (half_dim - downscale_freq_shift)
|
|
|
|
emb = torch.exp(exponent)
|
|
emb = timesteps[:, None].float() * emb[None, :]
|
|
|
|
# scale embeddings
|
|
emb = scale * emb
|
|
|
|
# concat sine and cosine embeddings
|
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
|
|
|
# flip sine and cosine embeddings
|
|
if flip_sin_to_cos:
|
|
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
|
|
|
# zero pad
|
|
if embedding_dim % 2 == 1:
|
|
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
|
return emb
|
|
|
|
|
|
class TimestepEmbedding(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
time_embed_dim: int,
|
|
act_fn: str = "silu",
|
|
out_dim: int = None,
|
|
post_act_fn: Optional[str] = None,
|
|
cond_proj_dim=None,
|
|
sample_proj_bias=True,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device)
|
|
|
|
if cond_proj_dim is not None:
|
|
self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device)
|
|
else:
|
|
self.cond_proj = None
|
|
|
|
self.act = nn.SiLU()
|
|
|
|
if out_dim is not None:
|
|
time_embed_dim_out = out_dim
|
|
else:
|
|
time_embed_dim_out = time_embed_dim
|
|
self.linear_2 = operations.Linear(
|
|
time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device
|
|
)
|
|
|
|
if post_act_fn is None:
|
|
self.post_act = None
|
|
# else:
|
|
# self.post_act = get_activation(post_act_fn)
|
|
|
|
def forward(self, sample, condition=None):
|
|
if condition is not None:
|
|
sample = sample + self.cond_proj(condition)
|
|
sample = self.linear_1(sample)
|
|
|
|
if self.act is not None:
|
|
sample = self.act(sample)
|
|
|
|
sample = self.linear_2(sample)
|
|
|
|
if self.post_act is not None:
|
|
sample = self.post_act(sample)
|
|
return sample
|
|
|
|
|
|
class Timesteps(nn.Module):
|
|
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
|
super().__init__()
|
|
self.num_channels = num_channels
|
|
self.flip_sin_to_cos = flip_sin_to_cos
|
|
self.downscale_freq_shift = downscale_freq_shift
|
|
self.scale = scale
|
|
|
|
def forward(self, timesteps):
|
|
t_emb = get_timestep_embedding(
|
|
timesteps,
|
|
self.num_channels,
|
|
flip_sin_to_cos=self.flip_sin_to_cos,
|
|
downscale_freq_shift=self.downscale_freq_shift,
|
|
scale=self.scale,
|
|
)
|
|
return t_emb
|
|
|
|
|
|
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
|
"""
|
|
For PixArt-Alpha.
|
|
|
|
Reference:
|
|
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
embedding_dim,
|
|
size_emb_dim,
|
|
use_additional_conditions: bool = False,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.outdim = size_emb_dim
|
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
|
self.timestep_embedder = TimestepEmbedding(
|
|
in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations
|
|
)
|
|
|
|
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
|
|
timesteps_proj = self.time_proj(timestep)
|
|
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
|
return timesteps_emb
|
|
|
|
|
|
class AdaLayerNormSingle(nn.Module):
|
|
r"""
|
|
Norm layer adaptive layer norm single (adaLN-single).
|
|
|
|
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
|
|
|
Parameters:
|
|
embedding_dim (`int`): The size of each embedding vector.
|
|
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
|
"""
|
|
|
|
def __init__(
|
|
self, embedding_dim: int, embedding_coefficient: int = 6, use_additional_conditions: bool = False, dtype=None, device=None, operations=None
|
|
):
|
|
super().__init__()
|
|
|
|
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
|
embedding_dim,
|
|
size_emb_dim=embedding_dim // 3,
|
|
use_additional_conditions=use_additional_conditions,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
self.silu = nn.SiLU()
|
|
self.linear = operations.Linear(embedding_dim, embedding_coefficient * embedding_dim, bias=True, dtype=dtype, device=device)
|
|
|
|
def forward(
|
|
self,
|
|
timestep: torch.Tensor,
|
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
|
batch_size: Optional[int] = None,
|
|
hidden_dtype: Optional[torch.dtype] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
# No modulation happening here.
|
|
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
|
|
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
|
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
|
|
|
|
|
class PixArtAlphaTextProjection(nn.Module):
|
|
"""
|
|
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
|
|
|
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
|
"""
|
|
|
|
def __init__(
|
|
self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None
|
|
):
|
|
super().__init__()
|
|
if out_features is None:
|
|
out_features = hidden_size
|
|
self.linear_1 = operations.Linear(
|
|
in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device
|
|
)
|
|
if act_fn == "gelu_tanh":
|
|
self.act_1 = nn.GELU(approximate="tanh")
|
|
elif act_fn == "silu":
|
|
self.act_1 = nn.SiLU()
|
|
else:
|
|
raise ValueError(f"Unknown activation function: {act_fn}")
|
|
self.linear_2 = operations.Linear(
|
|
in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device
|
|
)
|
|
|
|
def forward(self, caption):
|
|
hidden_states = self.linear_1(caption)
|
|
hidden_states = self.act_1(hidden_states)
|
|
hidden_states = self.linear_2(hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class GELU_approx(nn.Module):
|
|
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device)
|
|
|
|
def forward(self, x):
|
|
return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
|
|
|
|
|
|
class FeedForward(nn.Module):
|
|
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0.0, dtype=None, device=None, operations=None):
|
|
super().__init__()
|
|
inner_dim = int(dim * mult)
|
|
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
|
|
|
|
self.net = nn.Sequential(
|
|
project_in, nn.Dropout(dropout), operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
def apply_rotary_emb(input_tensor, freqs_cis):
|
|
cos_freqs, sin_freqs = freqs_cis[0], freqs_cis[1]
|
|
split_pe = freqs_cis[2] if len(freqs_cis) > 2 else False
|
|
return (
|
|
apply_split_rotary_emb(input_tensor, cos_freqs, sin_freqs)
|
|
if split_pe else
|
|
apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs)
|
|
)
|
|
|
|
def apply_interleaved_rotary_emb(input_tensor, cos_freqs, sin_freqs): # TODO: remove duplicate funcs and pick the best/fastest one
|
|
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
|
t1, t2 = t_dup.unbind(dim=-1)
|
|
t_dup = torch.stack((-t2, t1), dim=-1)
|
|
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
|
|
|
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
|
|
|
return out
|
|
|
|
def apply_split_rotary_emb(input_tensor, cos, sin):
|
|
needs_reshape = False
|
|
if input_tensor.ndim != 4 and cos.ndim == 4:
|
|
B, H, T, _ = cos.shape
|
|
input_tensor = input_tensor.reshape(B, T, H, -1).swapaxes(1, 2)
|
|
needs_reshape = True
|
|
split_input = rearrange(input_tensor, "... (d r) -> ... d r", d=2)
|
|
first_half_input = split_input[..., :1, :]
|
|
second_half_input = split_input[..., 1:, :]
|
|
output = split_input * cos.unsqueeze(-2)
|
|
first_half_output = output[..., :1, :]
|
|
second_half_output = output[..., 1:, :]
|
|
first_half_output.addcmul_(-sin.unsqueeze(-2), second_half_input)
|
|
second_half_output.addcmul_(sin.unsqueeze(-2), first_half_input)
|
|
output = rearrange(output, "... d r -> ... (d r)")
|
|
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
|
|
|
|
|
|
class CrossAttention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
query_dim,
|
|
context_dim=None,
|
|
heads=8,
|
|
dim_head=64,
|
|
dropout=0.0,
|
|
attn_precision=None,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
):
|
|
super().__init__()
|
|
inner_dim = dim_head * heads
|
|
context_dim = query_dim if context_dim is None else context_dim
|
|
self.attn_precision = attn_precision
|
|
|
|
self.heads = heads
|
|
self.dim_head = dim_head
|
|
|
|
self.q_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
|
|
self.k_norm = operations.RMSNorm(inner_dim, eps=1e-5, dtype=dtype, device=device)
|
|
|
|
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
|
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
|
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
|
|
|
|
self.to_out = nn.Sequential(
|
|
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
|
|
)
|
|
|
|
def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_options={}):
|
|
q = self.to_q(x)
|
|
context = x if context is None else context
|
|
k = self.to_k(context)
|
|
v = self.to_v(context)
|
|
|
|
q = self.q_norm(q)
|
|
k = self.k_norm(k)
|
|
|
|
if pe is not None:
|
|
q = apply_rotary_emb(q, pe)
|
|
k = apply_rotary_emb(k, pe if k_pe is None else k_pe)
|
|
|
|
if mask is None:
|
|
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
|
else:
|
|
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
|
|
return self.to_out(out)
|
|
|
|
|
|
class BasicTransformerBlock(nn.Module):
|
|
def __init__(
|
|
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
|
|
):
|
|
super().__init__()
|
|
|
|
self.attn_precision = attn_precision
|
|
self.attn1 = CrossAttention(
|
|
query_dim=dim,
|
|
heads=n_heads,
|
|
dim_head=d_head,
|
|
context_dim=None,
|
|
attn_precision=self.attn_precision,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
|
|
|
|
self.attn2 = CrossAttention(
|
|
query_dim=dim,
|
|
context_dim=context_dim,
|
|
heads=n_heads,
|
|
dim_head=d_head,
|
|
attn_precision=self.attn_precision,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
)
|
|
|
|
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
|
|
|
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
|
|
|
attn1_input = comfy.ldm.common_dit.rms_norm(x)
|
|
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
|
|
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
|
|
x.addcmul_(attn1_input, gate_msa)
|
|
del attn1_input
|
|
|
|
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
|
|
|
|
y = comfy.ldm.common_dit.rms_norm(x)
|
|
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
|
|
x.addcmul_(self.ff(y), gate_mlp)
|
|
|
|
return x
|
|
|
|
def get_fractional_positions(indices_grid, max_pos):
|
|
n_pos_dims = indices_grid.shape[1]
|
|
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
|
|
fractional_positions = torch.stack(
|
|
[indices_grid[:, i] / max_pos[i] for i in range(n_pos_dims)],
|
|
axis=-1,
|
|
)
|
|
return fractional_positions
|
|
|
|
|
|
@functools.lru_cache(maxsize=5)
|
|
def generate_freq_grid_np(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, _ = None):
|
|
theta = positional_embedding_theta
|
|
start = 1
|
|
end = theta
|
|
|
|
n_elem = 2 * positional_embedding_max_pos_count
|
|
pow_indices = np.power(
|
|
theta,
|
|
np.linspace(
|
|
_log_base(start, theta),
|
|
_log_base(end, theta),
|
|
inner_dim // n_elem,
|
|
dtype=np.float64,
|
|
),
|
|
)
|
|
return torch.tensor(pow_indices * math.pi / 2, dtype=torch.float32)
|
|
|
|
def generate_freq_grid_pytorch(positional_embedding_theta, positional_embedding_max_pos_count, inner_dim, device):
|
|
theta = positional_embedding_theta
|
|
start = 1
|
|
end = theta
|
|
n_elem = 2 * positional_embedding_max_pos_count
|
|
|
|
indices = theta ** (
|
|
torch.linspace(
|
|
math.log(start, theta),
|
|
math.log(end, theta),
|
|
inner_dim // n_elem,
|
|
device=device,
|
|
dtype=torch.float32,
|
|
)
|
|
)
|
|
indices = indices.to(dtype=torch.float32)
|
|
|
|
indices = indices * math.pi / 2
|
|
|
|
return indices
|
|
|
|
def generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid):
|
|
if use_middle_indices_grid:
|
|
assert(len(indices_grid.shape) == 4 and indices_grid.shape[-1] ==2)
|
|
indices_grid_start, indices_grid_end = indices_grid[..., 0], indices_grid[..., 1]
|
|
indices_grid = (indices_grid_start + indices_grid_end) / 2.0
|
|
elif len(indices_grid.shape) == 4:
|
|
indices_grid = indices_grid[..., 0]
|
|
|
|
# Get fractional positions and compute frequency indices
|
|
fractional_positions = get_fractional_positions(indices_grid, max_pos)
|
|
indices = indices.to(device=fractional_positions.device)
|
|
|
|
freqs = (
|
|
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
|
.transpose(-1, -2)
|
|
.flatten(2)
|
|
)
|
|
return freqs
|
|
|
|
def interleaved_freqs_cis(freqs, pad_size):
|
|
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
|
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
|
if pad_size != 0:
|
|
cos_padding = torch.ones_like(cos_freq[:, :, : pad_size])
|
|
sin_padding = torch.zeros_like(cos_freq[:, :, : pad_size])
|
|
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
|
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
|
return cos_freq, sin_freq
|
|
|
|
def split_freqs_cis(freqs, pad_size, num_attention_heads):
|
|
cos_freq = freqs.cos()
|
|
sin_freq = freqs.sin()
|
|
|
|
if pad_size != 0:
|
|
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
|
|
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
|
|
|
|
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
|
|
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
|
|
|
|
# Reshape freqs to be compatible with multi-head attention
|
|
B , T, half_HD = cos_freq.shape
|
|
|
|
cos_freq = cos_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
|
|
sin_freq = sin_freq.reshape(B, T, num_attention_heads, half_HD // num_attention_heads)
|
|
|
|
cos_freq = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
|
|
sin_freq = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
|
|
return cos_freq, sin_freq
|
|
|
|
class LTXBaseModel(torch.nn.Module, ABC):
|
|
"""
|
|
Abstract base class for LTX models (Lightricks Transformer models).
|
|
|
|
This class defines the common interface and shared functionality for all LTX models,
|
|
including LTXV (video) and LTXAV (audio-video) variants.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
cross_attention_dim: int,
|
|
attention_head_dim: int,
|
|
num_attention_heads: int,
|
|
caption_channels: int,
|
|
num_layers: int,
|
|
positional_embedding_theta: float = 10000.0,
|
|
positional_embedding_max_pos: list = [20, 2048, 2048],
|
|
causal_temporal_positioning: bool = False,
|
|
vae_scale_factors: tuple = (8, 32, 32),
|
|
use_middle_indices_grid=False,
|
|
timestep_scale_multiplier = 1000.0,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.generator = None
|
|
self.vae_scale_factors = vae_scale_factors
|
|
self.use_middle_indices_grid = use_middle_indices_grid
|
|
self.dtype = dtype
|
|
self.in_channels = in_channels
|
|
self.cross_attention_dim = cross_attention_dim
|
|
self.attention_head_dim = attention_head_dim
|
|
self.num_attention_heads = num_attention_heads
|
|
self.caption_channels = caption_channels
|
|
self.num_layers = num_layers
|
|
self.positional_embedding_theta = positional_embedding_theta
|
|
self.positional_embedding_max_pos = positional_embedding_max_pos
|
|
self.split_positional_embedding = LTXRopeType.from_dict(kwargs)
|
|
self.freq_grid_generator = (
|
|
generate_freq_grid_np if LTXFrequenciesPrecision.from_dict(kwargs) == LTXFrequenciesPrecision.FLOAT64
|
|
else generate_freq_grid_pytorch
|
|
)
|
|
self.causal_temporal_positioning = causal_temporal_positioning
|
|
self.operations = operations
|
|
self.timestep_scale_multiplier = timestep_scale_multiplier
|
|
|
|
# Common dimensions
|
|
self.inner_dim = num_attention_heads * attention_head_dim
|
|
self.out_channels = in_channels
|
|
|
|
# Initialize common components
|
|
self._init_common_components(device, dtype)
|
|
|
|
# Initialize model-specific components
|
|
self._init_model_components(device, dtype, **kwargs)
|
|
|
|
# Initialize transformer blocks
|
|
self._init_transformer_blocks(device, dtype, **kwargs)
|
|
|
|
# Initialize output components
|
|
self._init_output_components(device, dtype)
|
|
|
|
def _init_common_components(self, device, dtype):
|
|
"""Initialize components common to all LTX models
|
|
- patchify_proj: Linear projection for patchifying input
|
|
- adaln_single: AdaLN layer for timestep embedding
|
|
- caption_projection: Linear projection for caption embedding
|
|
"""
|
|
self.patchify_proj = self.operations.Linear(
|
|
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
|
|
)
|
|
|
|
self.adaln_single = AdaLayerNormSingle(
|
|
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
|
|
)
|
|
|
|
self.caption_projection = PixArtAlphaTextProjection(
|
|
in_features=self.caption_channels,
|
|
hidden_size=self.inner_dim,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
|
|
@abstractmethod
|
|
def _init_model_components(self, device, dtype, **kwargs):
|
|
"""Initialize model-specific components. Must be implemented by subclasses."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
|
"""Initialize transformer blocks. Must be implemented by subclasses."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _init_output_components(self, device, dtype):
|
|
"""Initialize output components. Must be implemented by subclasses."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
|
"""Process input data. Must be implemented by subclasses."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, **kwargs):
|
|
"""Process transformer blocks. Must be implemented by subclasses."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
|
"""Process output data. Must be implemented by subclasses."""
|
|
pass
|
|
|
|
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
|
|
"""Prepare timestep embeddings."""
|
|
grid_mask = kwargs.get("grid_mask", None)
|
|
if grid_mask is not None:
|
|
timestep = timestep[:, grid_mask]
|
|
|
|
timestep = timestep * self.timestep_scale_multiplier
|
|
timestep, embedded_timestep = self.adaln_single(
|
|
timestep.flatten(),
|
|
{"resolution": None, "aspect_ratio": None},
|
|
batch_size=batch_size,
|
|
hidden_dtype=hidden_dtype,
|
|
)
|
|
|
|
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
|
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
|
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
|
|
|
|
return timestep, embedded_timestep
|
|
|
|
def _prepare_context(self, context, batch_size, x, attention_mask=None):
|
|
"""Prepare context for transformer blocks."""
|
|
if self.caption_projection is not None:
|
|
context = self.caption_projection(context)
|
|
context = context.view(batch_size, -1, x.shape[-1])
|
|
|
|
return context, attention_mask
|
|
|
|
def _precompute_freqs_cis(
|
|
self,
|
|
indices_grid,
|
|
dim,
|
|
out_dtype,
|
|
theta=10000.0,
|
|
max_pos=[20, 2048, 2048],
|
|
use_middle_indices_grid=False,
|
|
num_attention_heads=32,
|
|
):
|
|
split_mode = self.split_positional_embedding == LTXRopeType.SPLIT
|
|
indices = self.freq_grid_generator(theta, indices_grid.shape[1], dim, indices_grid.device)
|
|
freqs = generate_freqs(indices, indices_grid, max_pos, use_middle_indices_grid)
|
|
|
|
if split_mode:
|
|
expected_freqs = dim // 2
|
|
current_freqs = freqs.shape[-1]
|
|
pad_size = expected_freqs - current_freqs
|
|
cos_freq, sin_freq = split_freqs_cis(freqs, pad_size, num_attention_heads)
|
|
else:
|
|
# 2 because of cos and sin by 3 for (t, x, y), 1 for temporal only
|
|
n_elem = 2 * indices_grid.shape[1]
|
|
cos_freq, sin_freq = interleaved_freqs_cis(freqs, dim % n_elem)
|
|
return cos_freq.to(out_dtype), sin_freq.to(out_dtype), split_mode
|
|
|
|
def _prepare_positional_embeddings(self, pixel_coords, frame_rate, x_dtype):
|
|
"""Prepare positional embeddings."""
|
|
fractional_coords = pixel_coords.to(torch.float32)
|
|
fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
|
|
pe = self._precompute_freqs_cis(
|
|
fractional_coords,
|
|
dim=self.inner_dim,
|
|
out_dtype=x_dtype,
|
|
max_pos=self.positional_embedding_max_pos,
|
|
use_middle_indices_grid=self.use_middle_indices_grid,
|
|
num_attention_heads=self.num_attention_heads,
|
|
)
|
|
return pe
|
|
|
|
def _prepare_attention_mask(self, attention_mask, x_dtype):
|
|
"""Prepare attention mask."""
|
|
if attention_mask is not None and not torch.is_floating_point(attention_mask):
|
|
attention_mask = (attention_mask - 1).to(x_dtype).reshape(
|
|
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
|
) * torch.finfo(x_dtype).max
|
|
return attention_mask
|
|
|
|
def forward(
|
|
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
|
):
|
|
"""
|
|
Forward pass for LTX models.
|
|
|
|
Args:
|
|
x: Input tensor
|
|
timestep: Timestep tensor
|
|
context: Context tensor (e.g., text embeddings)
|
|
attention_mask: Attention mask tensor
|
|
frame_rate: Frame rate for temporal processing
|
|
transformer_options: Additional options for transformer blocks
|
|
keyframe_idxs: Keyframe indices for temporal processing
|
|
**kwargs: Additional keyword arguments
|
|
|
|
Returns:
|
|
Processed output tensor
|
|
"""
|
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
|
self._forward,
|
|
self,
|
|
comfy.patcher_extension.get_all_wrappers(
|
|
comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options
|
|
),
|
|
).execute(x, timestep, context, attention_mask, frame_rate, transformer_options, keyframe_idxs, denoise_mask=denoise_mask, **kwargs)
|
|
|
|
def _forward(
|
|
self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, denoise_mask=None, **kwargs
|
|
):
|
|
"""
|
|
Internal forward pass for LTX models.
|
|
|
|
Args:
|
|
x: Input tensor
|
|
timestep: Timestep tensor
|
|
context: Context tensor (e.g., text embeddings)
|
|
attention_mask: Attention mask tensor
|
|
frame_rate: Frame rate for temporal processing
|
|
transformer_options: Additional options for transformer blocks
|
|
keyframe_idxs: Keyframe indices for temporal processing
|
|
**kwargs: Additional keyword arguments
|
|
|
|
Returns:
|
|
Processed output tensor
|
|
"""
|
|
if isinstance(x, list):
|
|
input_dtype = x[0].dtype
|
|
batch_size = x[0].shape[0]
|
|
else:
|
|
input_dtype = x.dtype
|
|
batch_size = x.shape[0]
|
|
# Process input
|
|
merged_args = {**transformer_options, **kwargs}
|
|
x, pixel_coords, additional_args = self._process_input(x, keyframe_idxs, denoise_mask, **merged_args)
|
|
merged_args.update(additional_args)
|
|
|
|
# Prepare timestep and context
|
|
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
|
|
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
|
|
|
|
# Prepare attention mask and positional embeddings
|
|
attention_mask = self._prepare_attention_mask(attention_mask, input_dtype)
|
|
pe = self._prepare_positional_embeddings(pixel_coords, frame_rate, input_dtype)
|
|
|
|
# Process transformer blocks
|
|
x = self._process_transformer_blocks(
|
|
x, context, attention_mask, timestep, pe, transformer_options=transformer_options, **merged_args
|
|
)
|
|
|
|
# Process output
|
|
x = self._process_output(x, embedded_timestep, keyframe_idxs, **merged_args)
|
|
return x
|
|
|
|
|
|
class LTXVModel(LTXBaseModel):
|
|
"""LTXV model for video generation."""
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels=128,
|
|
cross_attention_dim=2048,
|
|
attention_head_dim=64,
|
|
num_attention_heads=32,
|
|
caption_channels=4096,
|
|
num_layers=28,
|
|
positional_embedding_theta=10000.0,
|
|
positional_embedding_max_pos=[20, 2048, 2048],
|
|
causal_temporal_positioning=False,
|
|
vae_scale_factors=(8, 32, 32),
|
|
use_middle_indices_grid=False,
|
|
timestep_scale_multiplier = 1000.0,
|
|
dtype=None,
|
|
device=None,
|
|
operations=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__(
|
|
in_channels=in_channels,
|
|
cross_attention_dim=cross_attention_dim,
|
|
attention_head_dim=attention_head_dim,
|
|
num_attention_heads=num_attention_heads,
|
|
caption_channels=caption_channels,
|
|
num_layers=num_layers,
|
|
positional_embedding_theta=positional_embedding_theta,
|
|
positional_embedding_max_pos=positional_embedding_max_pos,
|
|
causal_temporal_positioning=causal_temporal_positioning,
|
|
vae_scale_factors=vae_scale_factors,
|
|
use_middle_indices_grid=use_middle_indices_grid,
|
|
timestep_scale_multiplier=timestep_scale_multiplier,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=operations,
|
|
**kwargs,
|
|
)
|
|
|
|
def _init_model_components(self, device, dtype, **kwargs):
|
|
"""Initialize LTXV-specific components."""
|
|
# No additional components needed for LTXV beyond base class
|
|
pass
|
|
|
|
def _init_transformer_blocks(self, device, dtype, **kwargs):
|
|
"""Initialize transformer blocks for LTXV."""
|
|
self.transformer_blocks = nn.ModuleList(
|
|
[
|
|
BasicTransformerBlock(
|
|
self.inner_dim,
|
|
self.num_attention_heads,
|
|
self.attention_head_dim,
|
|
context_dim=self.cross_attention_dim,
|
|
dtype=dtype,
|
|
device=device,
|
|
operations=self.operations,
|
|
)
|
|
for _ in range(self.num_layers)
|
|
]
|
|
)
|
|
|
|
def _init_output_components(self, device, dtype):
|
|
"""Initialize output components for LTXV."""
|
|
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
|
|
self.norm_out = self.operations.LayerNorm(
|
|
self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device
|
|
)
|
|
self.proj_out = self.operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
|
|
self.patchifier = SymmetricPatchifier(1, start_end=True)
|
|
|
|
def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
|
|
"""Process input for LTXV."""
|
|
additional_args = {"orig_shape": list(x.shape)}
|
|
x, latent_coords = self.patchifier.patchify(x)
|
|
pixel_coords = latent_to_pixel_coords(
|
|
latent_coords=latent_coords,
|
|
scale_factors=self.vae_scale_factors,
|
|
causal_fix=self.causal_temporal_positioning,
|
|
)
|
|
|
|
grid_mask = None
|
|
if keyframe_idxs is not None:
|
|
additional_args.update({ "orig_patchified_shape": list(x.shape)})
|
|
denoise_mask = self.patchifier.patchify(denoise_mask)[0]
|
|
grid_mask = ~torch.any(denoise_mask < 0, dim=-1)[0]
|
|
additional_args.update({"grid_mask": grid_mask})
|
|
x = x[:, grid_mask, :]
|
|
pixel_coords = pixel_coords[:, :, grid_mask, ...]
|
|
|
|
kf_grid_mask = grid_mask[-keyframe_idxs.shape[2]:]
|
|
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
|
|
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
|
|
|
|
x = self.patchify_proj(x)
|
|
return x, pixel_coords, additional_args
|
|
|
|
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, **kwargs):
|
|
"""Process transformer blocks for LTXV."""
|
|
patches_replace = transformer_options.get("patches_replace", {})
|
|
blocks_replace = patches_replace.get("dit", {})
|
|
|
|
for i, block in enumerate(self.transformer_blocks):
|
|
if ("double_block", i) in blocks_replace:
|
|
|
|
def block_wrap(args):
|
|
out = {}
|
|
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"])
|
|
return out
|
|
|
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
|
x = out["img"]
|
|
else:
|
|
x = block(
|
|
x,
|
|
context=context,
|
|
attention_mask=attention_mask,
|
|
timestep=timestep,
|
|
pe=pe,
|
|
transformer_options=transformer_options,
|
|
)
|
|
|
|
return x
|
|
|
|
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
|
"""Process output for LTXV."""
|
|
# Apply scale-shift modulation
|
|
scale_shift_values = (
|
|
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
|
)
|
|
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
|
|
|
x = self.norm_out(x)
|
|
x = x * (1 + scale) + shift
|
|
x = self.proj_out(x)
|
|
|
|
if keyframe_idxs is not None:
|
|
grid_mask = kwargs["grid_mask"]
|
|
orig_patchified_shape = kwargs["orig_patchified_shape"]
|
|
full_x = torch.zeros(orig_patchified_shape, dtype=x.dtype, device=x.device)
|
|
full_x[:, grid_mask, :] = x
|
|
x = full_x
|
|
# Unpatchify to restore original dimensions
|
|
orig_shape = kwargs["orig_shape"]
|
|
x = self.patchifier.unpatchify(
|
|
latents=x,
|
|
output_height=orig_shape[3],
|
|
output_width=orig_shape[4],
|
|
output_num_frames=orig_shape[2],
|
|
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
|
|
)
|
|
|
|
return x
|