ComfyUI/comfy/ldm/trellis2/model.py
Yousef R. Gamaleldin 6ea2e5b288 init
2026-01-30 23:34:48 +02:00

500 lines
19 KiB
Python

import torch
import torch.nn.functional as F
import torch.nn as nn
from vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor
from typing import Optional, Tuple, Literal, Union, List
from attention import sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
class SparseGELU(nn.GELU):
def forward(self, input: VarLenTensor) -> VarLenTensor:
return input.replace(super().forward(input.feats))
class SparseFeedForwardNet(nn.Module):
def __init__(self, channels: int, mlp_ratio: float = 4.0):
super().__init__()
self.mlp = nn.Sequential(
SparseLinear(channels, int(channels * mlp_ratio)),
SparseGELU(approximate="tanh"),
SparseLinear(int(channels * mlp_ratio), channels),
)
def forward(self, x: VarLenTensor) -> VarLenTensor:
return self.mlp(x)
def manual_cast(tensor, dtype):
if not torch.is_autocast_enabled():
return tensor.type(dtype)
return tensor
class LayerNorm32(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_dtype = x.dtype
x = manual_cast(x, torch.float32)
o = super().forward(x)
return manual_cast(o, x_dtype)
class SparseMultiHeadRMSNorm(nn.Module):
def __init__(self, dim: int, heads: int):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(heads, dim))
def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
x_type = x.dtype
x = x.float()
if isinstance(x, VarLenTensor):
x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale)
else:
x = F.normalize(x, dim=-1) * self.gamma * self.scale
return x.to(x_type)
# TODO: replace with apply_rope1
class SparseRotaryPositionEmbedder(nn.Module):
def __init__(
self,
head_dim: int,
dim: int = 3,
rope_freq: Tuple[float, float] = (1.0, 10000.0)
):
super().__init__()
assert head_dim % 2 == 0, "Head dim must be divisible by 2"
self.head_dim = head_dim
self.dim = dim
self.rope_freq = rope_freq
self.freq_dim = head_dim // 2 // dim
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim
self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
self.freqs = self.freqs.to(indices.device)
phases = torch.outer(indices, self.freqs)
phases = torch.polar(torch.ones_like(phases), phases)
return phases
def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
x_rotated = x_complex * phases.unsqueeze(-2)
x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
return x_embed
def forward(self, q: SparseTensor, k: Optional[SparseTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
q (SparseTensor): [..., N, H, D] tensor of queries
k (SparseTensor): [..., N, H, D] tensor of keys
"""
assert q.coords.shape[-1] == self.dim + 1, "Last dimension of coords must be equal to dim+1"
phases_cache_name = f'rope_phase_{self.dim}d_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}'
phases = q.get_spatial_cache(phases_cache_name)
if phases is None:
coords = q.coords[..., 1:]
phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1)
if phases.shape[-1] < self.head_dim // 2:
padn = self.head_dim // 2 - phases.shape[-1]
phases = torch.cat([phases, torch.polar(
torch.ones(*phases.shape[:-1], padn, device=phases.device),
torch.zeros(*phases.shape[:-1], padn, device=phases.device)
)], dim=-1)
q.register_spatial_cache(phases_cache_name, phases)
q_embed = q.replace(self._rotary_embedding(q.feats, phases))
if k is None:
return q_embed
k_embed = k.replace(self._rotary_embedding(k.feats, phases))
return q_embed, k_embed
class SparseMultiHeadAttention(nn.Module):
def __init__(
self,
channels: int,
num_heads: int,
ctx_channels: Optional[int] = None,
type: Literal["self", "cross"] = "self",
attn_mode: Literal["full", "windowed", "double_windowed"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
qkv_bias: bool = True,
use_rope: bool = False,
rope_freq: Tuple[int, int] = (1.0, 10000.0),
qk_rms_norm: bool = False,
):
super().__init__()
self.channels = channels
self.head_dim = channels // num_heads
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
self.num_heads = num_heads
self._type = type
self.attn_mode = attn_mode
self.window_size = window_size
self.shift_window = shift_window
self.use_rope = use_rope
self.qk_rms_norm = qk_rms_norm
if self._type == "self":
self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias)
else:
self.to_q = nn.Linear(channels, channels, bias=qkv_bias)
self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias)
if self.qk_rms_norm:
self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads)
self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads)
self.to_out = nn.Linear(channels, channels)
if use_rope:
self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq)
@staticmethod
def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
if isinstance(x, VarLenTensor):
return x.replace(module(x.feats))
else:
return module(x)
@staticmethod
def _reshape_chs(x: Union[VarLenTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[VarLenTensor, torch.Tensor]:
if isinstance(x, VarLenTensor):
return x.reshape(*shape)
else:
return x.reshape(*x.shape[:2], *shape)
def _fused_pre(self, x: Union[VarLenTensor, torch.Tensor], num_fused: int) -> Union[VarLenTensor, torch.Tensor]:
if isinstance(x, VarLenTensor):
x_feats = x.feats.unsqueeze(0)
else:
x_feats = x
x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats
def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor:
if self._type == "self":
qkv = self._linear(self.to_qkv, x)
qkv = self._fused_pre(qkv, num_fused=3)
if self.qk_rms_norm or self.use_rope:
q, k, v = qkv.unbind(dim=-3)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k = self.k_rms_norm(k)
if self.use_rope:
q, k = self.rope(q, k)
qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
if self.attn_mode == "full":
h = sparse_scaled_dot_product_attention(qkv)
elif self.attn_mode == "windowed":
h = sparse_windowed_scaled_dot_product_self_attention(
qkv, self.window_size, shift_window=self.shift_window
)
elif self.attn_mode == "double_windowed":
qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:])
qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2])
h0 = sparse_windowed_scaled_dot_product_self_attention(
qkv0, self.window_size, shift_window=(0, 0, 0)
)
h1 = sparse_windowed_scaled_dot_product_self_attention(
qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3)
)
h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1))
else:
q = self._linear(self.to_q, x)
q = self._reshape_chs(q, (self.num_heads, -1))
kv = self._linear(self.to_kv, context)
kv = self._fused_pre(kv, num_fused=2)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k, v = kv.unbind(dim=-3)
k = self.k_rms_norm(k)
h = sparse_scaled_dot_product_attention(q, k, v)
else:
h = sparse_scaled_dot_product_attention(q, kv)
h = self._reshape_chs(h, (-1,))
h = self._linear(self.to_out, h)
return h
class ModulatedSparseTransformerBlock(nn.Module):
"""
Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning.
"""
def __init__(
self,
channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "swin"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
rope_freq: Tuple[float, float] = (1.0, 10000.0),
qk_rms_norm: bool = False,
qkv_bias: bool = True,
share_mod: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
self.attn = SparseMultiHeadAttention(
channels,
num_heads=num_heads,
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias,
use_rope=use_rope,
rope_freq=rope_freq,
qk_rms_norm=qk_rms_norm,
)
self.mlp = SparseFeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(channels, 6 * channels, bias=True)
)
else:
self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
h = x.replace(self.norm1(x.feats))
h = h * (1 + scale_msa) + shift_msa
h = self.attn(h)
h = h * gate_msa
x = x + h
h = x.replace(self.norm2(x.feats))
h = h * (1 + scale_mlp) + shift_mlp
h = self.mlp(h)
h = h * gate_mlp
x = x + h
return x
def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor:
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False)
else:
return self._forward(x, mod)
class ModulatedSparseTransformerCrossBlock(nn.Module):
"""
Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
"""
def __init__(
self,
channels: int,
ctx_channels: int,
num_heads: int,
mlp_ratio: float = 4.0,
attn_mode: Literal["full", "swin"] = "full",
window_size: Optional[int] = None,
shift_window: Optional[Tuple[int, int, int]] = None,
use_checkpoint: bool = False,
use_rope: bool = False,
rope_freq: Tuple[float, float] = (1.0, 10000.0),
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
qkv_bias: bool = True,
share_mod: bool = False,
):
super().__init__()
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6)
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6)
self.self_attn = SparseMultiHeadAttention(
channels,
num_heads=num_heads,
type="self",
attn_mode=attn_mode,
window_size=window_size,
shift_window=shift_window,
qkv_bias=qkv_bias,
use_rope=use_rope,
rope_freq=rope_freq,
qk_rms_norm=qk_rms_norm,
)
self.cross_attn = SparseMultiHeadAttention(
channels,
ctx_channels=ctx_channels,
num_heads=num_heads,
type="cross",
attn_mode="full",
qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross,
)
self.mlp = SparseFeedForwardNet(
channels,
mlp_ratio=mlp_ratio,
)
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(channels, 6 * channels, bias=True)
)
else:
self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5)
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
h = x.replace(self.norm1(x.feats))
h = h * (1 + scale_msa) + shift_msa
h = self.self_attn(h)
h = h * gate_msa
x = x + h
h = x.replace(self.norm2(x.feats))
h = self.cross_attn(h, context)
x = x + h
h = x.replace(self.norm3(x.feats))
h = h * (1 + scale_mlp) + shift_mlp
h = self.mlp(h)
h = h * gate_mlp
x = x + h
return x
def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
return self._forward(x, mod, context)
class SLatFlowModel(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
model_channels: int,
cond_channels: int,
out_channels: int,
num_blocks: int,
num_heads: Optional[int] = None,
num_head_channels: Optional[int] = 64,
mlp_ratio: float = 4,
pe_mode: Literal["ape", "rope"] = "rope",
rope_freq: Tuple[float, float] = (1.0, 10000.0),
use_checkpoint: bool = False,
share_mod: bool = False,
initialization: str = 'vanilla',
qk_rms_norm: bool = False,
qk_rms_norm_cross: bool = False,
dtype = None,
device = None,
operations = None,
):
super().__init__()
self.resolution = resolution
self.in_channels = in_channels
self.model_channels = model_channels
self.cond_channels = cond_channels
self.out_channels = out_channels
self.num_blocks = num_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.pe_mode = pe_mode
self.use_checkpoint = use_checkpoint
self.share_mod = share_mod
self.initialization = initialization
self.qk_rms_norm = qk_rms_norm
self.qk_rms_norm_cross = qk_rms_norm_cross
self.dtype = dtype
self.t_embedder = TimestepEmbedder(model_channels)
if share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(model_channels, 6 * model_channels, bias=True)
)
self.input_layer = SparseLinear(in_channels, model_channels)
self.blocks = nn.ModuleList([
ModulatedSparseTransformerCrossBlock(
model_channels,
cond_channels,
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
attn_mode='full',
use_checkpoint=self.use_checkpoint,
use_rope=(pe_mode == "rope"),
rope_freq=rope_freq,
share_mod=self.share_mod,
qk_rms_norm=self.qk_rms_norm,
qk_rms_norm_cross=self.qk_rms_norm_cross,
)
for _ in range(num_blocks)
])
self.out_layer = SparseLinear(model_channels, out_channels)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
def forward(
self,
x: SparseTensor,
t: torch.Tensor,
cond: Union[torch.Tensor, List[torch.Tensor]],
concat_cond: Optional[SparseTensor] = None,
**kwargs
) -> SparseTensor:
if concat_cond is not None:
x = sparse_cat([x, concat_cond], dim=-1)
if isinstance(cond, list):
cond = VarLenTensor.from_tensor_list(cond)
h = self.input_layer(x)
h = manual_cast(h, self.dtype)
t_emb = self.t_embedder(t)
if self.share_mod:
t_emb = self.adaLN_modulation(t_emb)
t_emb = manual_cast(t_emb, self.dtype)
cond = manual_cast(cond, self.dtype)
if self.pe_mode == "ape":
pe = self.pos_embedder(h.coords[:, 1:])
h = h + manual_cast(pe, self.dtype)
for block in self.blocks:
h = block(h, t_emb, cond)
h = manual_cast(h, x.dtype)
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
h = self.out_layer(h)
return h
class Trellis2(nn.Module):
def __init__(self, resolution,
in_channels = 32,
out_channels = 32,
model_channels = 1536,
cond_channels = 1024,
num_blocks = 30,
num_heads = 12,
mlp_ratio = 5.3334,
share_mod = True,
qk_rms_norm = True,
qk_rms_norm_cross = True,
dtype=None, device=None, operations=None):
args = {
"out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels,
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
}
# TODO: update the names/checkpoints
self.img2shape = SLatFlowModel(resolution, in_channels=in_channels, *args)
self.shape2txt = SLatFlowModel(resolution, in_channels=in_channels*2, *args)
self.shape_generation = True
def forward(self, x, timestep, context):
pass