import torch import torch.nn.functional as F import torch.nn as nn from vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor from typing import Optional, Tuple, Literal, Union, List from attention import sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder class SparseGELU(nn.GELU): def forward(self, input: VarLenTensor) -> VarLenTensor: return input.replace(super().forward(input.feats)) class SparseFeedForwardNet(nn.Module): def __init__(self, channels: int, mlp_ratio: float = 4.0): super().__init__() self.mlp = nn.Sequential( SparseLinear(channels, int(channels * mlp_ratio)), SparseGELU(approximate="tanh"), SparseLinear(int(channels * mlp_ratio), channels), ) def forward(self, x: VarLenTensor) -> VarLenTensor: return self.mlp(x) def manual_cast(tensor, dtype): if not torch.is_autocast_enabled(): return tensor.type(dtype) return tensor class LayerNorm32(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: x_dtype = x.dtype x = manual_cast(x, torch.float32) o = super().forward(x) return manual_cast(o, x_dtype) class SparseMultiHeadRMSNorm(nn.Module): def __init__(self, dim: int, heads: int): super().__init__() self.scale = dim ** 0.5 self.gamma = nn.Parameter(torch.ones(heads, dim)) def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: x_type = x.dtype x = x.float() if isinstance(x, VarLenTensor): x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale) else: x = F.normalize(x, dim=-1) * self.gamma * self.scale return x.to(x_type) # TODO: replace with apply_rope1 class SparseRotaryPositionEmbedder(nn.Module): def __init__( self, head_dim: int, dim: int = 3, rope_freq: Tuple[float, float] = (1.0, 10000.0) ): super().__init__() assert head_dim % 2 == 0, "Head dim must be divisible by 2" self.head_dim = head_dim self.dim = dim self.rope_freq = rope_freq self.freq_dim = head_dim // 2 // dim self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: self.freqs = self.freqs.to(indices.device) phases = torch.outer(indices, self.freqs) phases = torch.polar(torch.ones_like(phases), phases) return phases def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) x_rotated = x_complex * phases.unsqueeze(-2) x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) return x_embed def forward(self, q: SparseTensor, k: Optional[SparseTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: q (SparseTensor): [..., N, H, D] tensor of queries k (SparseTensor): [..., N, H, D] tensor of keys """ assert q.coords.shape[-1] == self.dim + 1, "Last dimension of coords must be equal to dim+1" phases_cache_name = f'rope_phase_{self.dim}d_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}' phases = q.get_spatial_cache(phases_cache_name) if phases is None: coords = q.coords[..., 1:] phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1) if phases.shape[-1] < self.head_dim // 2: padn = self.head_dim // 2 - phases.shape[-1] phases = torch.cat([phases, torch.polar( torch.ones(*phases.shape[:-1], padn, device=phases.device), torch.zeros(*phases.shape[:-1], padn, device=phases.device) )], dim=-1) q.register_spatial_cache(phases_cache_name, phases) q_embed = q.replace(self._rotary_embedding(q.feats, phases)) if k is None: return q_embed k_embed = k.replace(self._rotary_embedding(k.feats, phases)) return q_embed, k_embed class SparseMultiHeadAttention(nn.Module): def __init__( self, channels: int, num_heads: int, ctx_channels: Optional[int] = None, type: Literal["self", "cross"] = "self", attn_mode: Literal["full", "windowed", "double_windowed"] = "full", window_size: Optional[int] = None, shift_window: Optional[Tuple[int, int, int]] = None, qkv_bias: bool = True, use_rope: bool = False, rope_freq: Tuple[int, int] = (1.0, 10000.0), qk_rms_norm: bool = False, ): super().__init__() self.channels = channels self.head_dim = channels // num_heads self.ctx_channels = ctx_channels if ctx_channels is not None else channels self.num_heads = num_heads self._type = type self.attn_mode = attn_mode self.window_size = window_size self.shift_window = shift_window self.use_rope = use_rope self.qk_rms_norm = qk_rms_norm if self._type == "self": self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) else: self.to_q = nn.Linear(channels, channels, bias=qkv_bias) self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) if self.qk_rms_norm: self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) self.to_out = nn.Linear(channels, channels) if use_rope: self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq) @staticmethod def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: if isinstance(x, VarLenTensor): return x.replace(module(x.feats)) else: return module(x) @staticmethod def _reshape_chs(x: Union[VarLenTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[VarLenTensor, torch.Tensor]: if isinstance(x, VarLenTensor): return x.reshape(*shape) else: return x.reshape(*x.shape[:2], *shape) def _fused_pre(self, x: Union[VarLenTensor, torch.Tensor], num_fused: int) -> Union[VarLenTensor, torch.Tensor]: if isinstance(x, VarLenTensor): x_feats = x.feats.unsqueeze(0) else: x_feats = x x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor: if self._type == "self": qkv = self._linear(self.to_qkv, x) qkv = self._fused_pre(qkv, num_fused=3) if self.qk_rms_norm or self.use_rope: q, k, v = qkv.unbind(dim=-3) if self.qk_rms_norm: q = self.q_rms_norm(q) k = self.k_rms_norm(k) if self.use_rope: q, k = self.rope(q, k) qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) if self.attn_mode == "full": h = sparse_scaled_dot_product_attention(qkv) elif self.attn_mode == "windowed": h = sparse_windowed_scaled_dot_product_self_attention( qkv, self.window_size, shift_window=self.shift_window ) elif self.attn_mode == "double_windowed": qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:]) qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2]) h0 = sparse_windowed_scaled_dot_product_self_attention( qkv0, self.window_size, shift_window=(0, 0, 0) ) h1 = sparse_windowed_scaled_dot_product_self_attention( qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3) ) h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1)) else: q = self._linear(self.to_q, x) q = self._reshape_chs(q, (self.num_heads, -1)) kv = self._linear(self.to_kv, context) kv = self._fused_pre(kv, num_fused=2) if self.qk_rms_norm: q = self.q_rms_norm(q) k, v = kv.unbind(dim=-3) k = self.k_rms_norm(k) h = sparse_scaled_dot_product_attention(q, k, v) else: h = sparse_scaled_dot_product_attention(q, kv) h = self._reshape_chs(h, (-1,)) h = self._linear(self.to_out, h) return h class ModulatedSparseTransformerBlock(nn.Module): """ Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. """ def __init__( self, channels: int, num_heads: int, mlp_ratio: float = 4.0, attn_mode: Literal["full", "swin"] = "full", window_size: Optional[int] = None, shift_window: Optional[Tuple[int, int, int]] = None, use_checkpoint: bool = False, use_rope: bool = False, rope_freq: Tuple[float, float] = (1.0, 10000.0), qk_rms_norm: bool = False, qkv_bias: bool = True, share_mod: bool = False, ): super().__init__() self.use_checkpoint = use_checkpoint self.share_mod = share_mod self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) self.attn = SparseMultiHeadAttention( channels, num_heads=num_heads, attn_mode=attn_mode, window_size=window_size, shift_window=shift_window, qkv_bias=qkv_bias, use_rope=use_rope, rope_freq=rope_freq, qk_rms_norm=qk_rms_norm, ) self.mlp = SparseFeedForwardNet( channels, mlp_ratio=mlp_ratio, ) if not share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True) ) else: self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: if self.share_mod: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) else: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) h = x.replace(self.norm1(x.feats)) h = h * (1 + scale_msa) + shift_msa h = self.attn(h) h = h * gate_msa x = x + h h = x.replace(self.norm2(x.feats)) h = h * (1 + scale_mlp) + shift_mlp h = self.mlp(h) h = h * gate_mlp x = x + h return x def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: if self.use_checkpoint: return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) else: return self._forward(x, mod) class ModulatedSparseTransformerCrossBlock(nn.Module): """ Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. """ def __init__( self, channels: int, ctx_channels: int, num_heads: int, mlp_ratio: float = 4.0, attn_mode: Literal["full", "swin"] = "full", window_size: Optional[int] = None, shift_window: Optional[Tuple[int, int, int]] = None, use_checkpoint: bool = False, use_rope: bool = False, rope_freq: Tuple[float, float] = (1.0, 10000.0), qk_rms_norm: bool = False, qk_rms_norm_cross: bool = False, qkv_bias: bool = True, share_mod: bool = False, ): super().__init__() self.use_checkpoint = use_checkpoint self.share_mod = share_mod self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) self.self_attn = SparseMultiHeadAttention( channels, num_heads=num_heads, type="self", attn_mode=attn_mode, window_size=window_size, shift_window=shift_window, qkv_bias=qkv_bias, use_rope=use_rope, rope_freq=rope_freq, qk_rms_norm=qk_rms_norm, ) self.cross_attn = SparseMultiHeadAttention( channels, ctx_channels=ctx_channels, num_heads=num_heads, type="cross", attn_mode="full", qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, ) self.mlp = SparseFeedForwardNet( channels, mlp_ratio=mlp_ratio, ) if not share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(channels, 6 * channels, bias=True) ) else: self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: if self.share_mod: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) else: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) h = x.replace(self.norm1(x.feats)) h = h * (1 + scale_msa) + shift_msa h = self.self_attn(h) h = h * gate_msa x = x + h h = x.replace(self.norm2(x.feats)) h = self.cross_attn(h, context) x = x + h h = x.replace(self.norm3(x.feats)) h = h * (1 + scale_mlp) + shift_mlp h = self.mlp(h) h = h * gate_mlp x = x + h return x def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: return self._forward(x, mod, context) class SLatFlowModel(nn.Module): def __init__( self, resolution: int, in_channels: int, model_channels: int, cond_channels: int, out_channels: int, num_blocks: int, num_heads: Optional[int] = None, num_head_channels: Optional[int] = 64, mlp_ratio: float = 4, pe_mode: Literal["ape", "rope"] = "rope", rope_freq: Tuple[float, float] = (1.0, 10000.0), use_checkpoint: bool = False, share_mod: bool = False, initialization: str = 'vanilla', qk_rms_norm: bool = False, qk_rms_norm_cross: bool = False, dtype = None, device = None, operations = None, ): super().__init__() self.resolution = resolution self.in_channels = in_channels self.model_channels = model_channels self.cond_channels = cond_channels self.out_channels = out_channels self.num_blocks = num_blocks self.num_heads = num_heads or model_channels // num_head_channels self.mlp_ratio = mlp_ratio self.pe_mode = pe_mode self.use_checkpoint = use_checkpoint self.share_mod = share_mod self.initialization = initialization self.qk_rms_norm = qk_rms_norm self.qk_rms_norm_cross = qk_rms_norm_cross self.dtype = dtype self.t_embedder = TimestepEmbedder(model_channels) if share_mod: self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(model_channels, 6 * model_channels, bias=True) ) self.input_layer = SparseLinear(in_channels, model_channels) self.blocks = nn.ModuleList([ ModulatedSparseTransformerCrossBlock( model_channels, cond_channels, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, attn_mode='full', use_checkpoint=self.use_checkpoint, use_rope=(pe_mode == "rope"), rope_freq=rope_freq, share_mod=self.share_mod, qk_rms_norm=self.qk_rms_norm, qk_rms_norm_cross=self.qk_rms_norm_cross, ) for _ in range(num_blocks) ]) self.out_layer = SparseLinear(model_channels, out_channels) @property def device(self) -> torch.device: return next(self.parameters()).device def forward( self, x: SparseTensor, t: torch.Tensor, cond: Union[torch.Tensor, List[torch.Tensor]], concat_cond: Optional[SparseTensor] = None, **kwargs ) -> SparseTensor: if concat_cond is not None: x = sparse_cat([x, concat_cond], dim=-1) if isinstance(cond, list): cond = VarLenTensor.from_tensor_list(cond) h = self.input_layer(x) h = manual_cast(h, self.dtype) t_emb = self.t_embedder(t) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) t_emb = manual_cast(t_emb, self.dtype) cond = manual_cast(cond, self.dtype) if self.pe_mode == "ape": pe = self.pos_embedder(h.coords[:, 1:]) h = h + manual_cast(pe, self.dtype) for block in self.blocks: h = block(h, t_emb, cond) h = manual_cast(h, x.dtype) h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) h = self.out_layer(h) return h class Trellis2(nn.Module): def __init__(self, resolution, in_channels = 32, out_channels = 32, model_channels = 1536, cond_channels = 1024, num_blocks = 30, num_heads = 12, mlp_ratio = 5.3334, share_mod = True, qk_rms_norm = True, qk_rms_norm_cross = True, dtype=None, device=None, operations=None): args = { "out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels, "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations } # TODO: update the names/checkpoints self.img2shape = SLatFlowModel(resolution, in_channels=in_channels, *args) self.shape2txt = SLatFlowModel(resolution, in_channels=in_channels*2, *args) self.shape_generation = True def forward(self, x, timestep, context): pass