From 66249395056ab311f4459fdec138fc49e080702d Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 3 Feb 2026 22:40:54 +0200 Subject: [PATCH] structure model --- comfy/ldm/trellis2/attention.py | 67 +++++++ comfy/ldm/trellis2/model.py | 316 +++++++++++++++++++++++++++++++- comfy/supported_models.py | 4 + 3 files changed, 382 insertions(+), 5 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 9cd7d4995..6c912c8d9 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -4,6 +4,73 @@ from comfy.ldm.modules.attention import optimized_attention from typing import Tuple, Union, List from vae import VarLenTensor +FLASH_ATTN_3_AVA = True +try: + import flash_attn_interface as flash_attn_3 +except: + FLASH_ATTN_3_AVA = False + +# TODO repalce with optimized attention +def scaled_dot_product_attention(*args, **kwargs): + num_all_args = len(args) + len(kwargs) + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + + if optimized_attention.__name__ == 'attention_xformers': + if 'xops' not in globals(): + import xformers.ops as xops + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + out = xops.memory_efficient_attention(q, k, v) + elif optimized_attention.__name__ == 'attention_flash' and not FLASH_ATTN_3_AVA: + if 'flash_attn' not in globals(): + import flash_attn + if num_all_args == 2: + out = flash_attn.flash_attn_kvpacked_func(q, kv) + elif num_all_args == 3: + out = flash_attn.flash_attn_func(q, k, v) + elif optimized_attention.__name__ == 'attention_flash': # TODO + if 'flash_attn_3' not in globals(): + import flash_attn_interface as flash_attn_3 + if num_all_args == 2: + k, v = kv.unbind(dim=2) + out = flash_attn_3.flash_attn_func(q, k, v) + elif num_all_args == 3: + out = flash_attn_3.flash_attn_func(q, k, v) + elif optimized_attention.__name__ == 'attention_pytorch': + if 'sdpa' not in globals(): + from torch.nn.functional import scaled_dot_product_attention as sdpa + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.permute(0, 2, 1, 3) # [N, H, L, C] + k = k.permute(0, 2, 1, 3) # [N, H, L, C] + v = v.permute(0, 2, 1, 3) # [N, H, L, C] + out = sdpa(q, k, v) # [N, H, L, C] + out = out.permute(0, 2, 1, 3) # [N, L, H, C] + elif optimized_attention.__name__ == 'attention_basic': + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + q = q.shape[2] # TODO + out = optimized_attention(q, k, v) + + return out + def sparse_windowed_scaled_dot_product_self_attention( qkv, window_size: int, diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 9d1a8fdb4..add3f21ce 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -3,7 +3,9 @@ import torch.nn.functional as F import torch.nn as nn from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor from typing import Optional, Tuple, Literal, Union, List -from comfy.ldm.trellis2.attention import sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention +from comfy.ldm.trellis2.attention import ( + sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention +) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder class SparseGELU(nn.GELU): @@ -103,6 +105,18 @@ class SparseRotaryPositionEmbedder(nn.Module): k_embed = k.replace(self._rotary_embedding(k.feats, phases)) return q_embed, k_embed +class RotaryPositionEmbedder(SparseRotaryPositionEmbedder): + def forward(self, indices: torch.Tensor) -> torch.Tensor: + assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}" + phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], padn, device=phases.device), + torch.zeros(*phases.shape[:-1], padn, device=phases.device) + )], dim=-1) + return phases + class SparseMultiHeadAttention(nn.Module): def __init__( self, @@ -472,6 +486,292 @@ class SLatFlowModel(nn.Module): h = self.out_layer(h) return h +class FeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.GELU(approximate="tanh"), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int]=None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + ): + super().__init__() + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + B, L, C = x.shape + if self._type == "self": + qkv = self.to_qkv(x) + qkv = qkv.reshape(B, L, 3, self.num_heads, -1) + + if self.attn_mode == "full": + if self.qk_rms_norm or self.use_rope: + q, k, v = qkv.unbind(dim=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + if self.use_rope: + assert phases is not None, "Phases must be provided for RoPE" + q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases) + k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(qkv) + else: + Lkv = context.shape[1] + q = self.to_q(x) + kv = self.to_kv(context) + q = q.reshape(B, L, self.num_heads, -1) + kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=2) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(q, kv) + h = h.reshape(B, L, -1) + h = self.to_out(h) + return h + +class ModulatedTransformerCrossBlock(nn.Module): + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.self_attn(h, phases=phases) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, context, phases, use_reentrant=False) + else: + return self._forward(x, mod, context, phases) + + +class SparseStructureFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + pe_mode: Literal["ape", "rope"] = "ape", + rope_freq: Tuple[float, float] = (1.0, 10000.0), + dtype: str = 'float32', + use_checkpoint: bool = False, + share_mod: bool = False, + initialization: str = 'vanilla', + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + **kwargs + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.initialization = initialization + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = dtype + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3) + coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij') + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + rope_phases = pos_embedder(coords) + self.register_buffer("rope_phases", rope_phases) + + if pe_mode != "rope": + self.rope_phases = None + + self.input_layer = nn.Linear(in_channels, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + rope_freq=rope_freq, + share_mod=share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = nn.Linear(model_channels, out_channels) + + self.initialize_weights() + self.convert_to(self.dtype) + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ + f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" + + h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous() + + h = self.input_layer(h) + if self.pe_mode == "ape": + h = h + self.pos_emb[None] + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = manual_cast(t_emb, self.dtype) + h = manual_cast(h, self.dtype) + cond = manual_cast(cond, self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond, self.rope_phases) + h = manual_cast(h, x.dtype) + h = F.layer_norm(h, h.shape[-1:]) + h = self.out_layer(h) + + h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous() + + return h + class Trellis2(nn.Module): def __init__(self, resolution, in_channels = 32, @@ -492,18 +792,24 @@ class Trellis2(nn.Module): "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations } - # TODO: update the names/checkpoints self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) - self.shape_generation = True + args.pop("out_channels") + args.pop("in_channels") + self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) def forward(self, x, timestep, context, **kwargs): # TODO add mode mode = kwargs.get("mode", "shape_generation") - mode = "texture_generation" if mode == 1 else "shape_generation" + if mode != 0: + mode = "texture_generation" if mode == 2 else "shape_generation" + else: + mode = "structure_generation" if mode == "shape_generation": out = self.img2shape(x, timestep, context) - if mode == "texture_generation": + elif mode == "texture_generation": out = self.shape2txt(x, timestep, context) + else: + out = self.structure_model(x, timestep, context) return out diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 6c2725b9f..9e2f17149 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1247,6 +1247,10 @@ class Trellis2(supported_models_base.BASE): "image_model": "trellis2" } + sampling_settings = { + "shift": 3.0, + } + latent_format = latent_formats.Trellis2 vae_key_prefix = ["vae."]