From 1fde60b2bc67d39ab4177c1aabc828350245d5f9 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 20 Feb 2026 22:04:37 +0200 Subject: [PATCH] debugging --- comfy/image_encoders/dino3.py | 2 +- comfy/ldm/trellis2/model.py | 94 +++++++++------------------------- comfy_extras/nodes_trellis2.py | 4 +- 3 files changed, 27 insertions(+), 73 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index e009a7291..3ec7f8a04 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -175,7 +175,7 @@ class DINOv3ViTEmbeddings(nn.Module): cls_token = self.cls_token.expand(batch_size, -1, -1) register_tokens = self.register_tokens.expand(batch_size, -1, -1) - device = patch_embeddings + device = patch_embeddings.device cls_token = cls_token.to(device) register_tokens = register_tokens.to(device) embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index ef1c25d33..7f16c4d41 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -201,6 +201,8 @@ class SparseMultiHeadAttention(nn.Module): def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor: if self._type == "self": + dtype = next(self.to_qkv.parameters()).dtype + x = x.to(dtype) qkv = self._linear(self.to_qkv, x) qkv = self._fused_pre(qkv, num_fused=3) if self.qk_rms_norm or self.use_rope: @@ -243,71 +245,6 @@ class SparseMultiHeadAttention(nn.Module): h = self._linear(self.to_out, h) return h -class ModulatedSparseTransformerBlock(nn.Module): - def __init__( - self, - channels: int, - num_heads: int, - mlp_ratio: float = 4.0, - attn_mode: Literal["full", "swin"] = "full", - window_size: Optional[int] = None, - shift_window: Optional[Tuple[int, int, int]] = None, - use_checkpoint: bool = False, - use_rope: bool = False, - rope_freq: Tuple[float, float] = (1.0, 10000.0), - qk_rms_norm: bool = False, - qkv_bias: bool = True, - share_mod: bool = False, - ): - super().__init__() - self.use_checkpoint = use_checkpoint - self.share_mod = share_mod - self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) - self.attn = SparseMultiHeadAttention( - channels, - num_heads=num_heads, - attn_mode=attn_mode, - window_size=window_size, - shift_window=shift_window, - qkv_bias=qkv_bias, - use_rope=use_rope, - rope_freq=rope_freq, - qk_rms_norm=qk_rms_norm, - ) - self.mlp = SparseFeedForwardNet( - channels, - mlp_ratio=mlp_ratio, - ) - if not share_mod: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(channels, 6 * channels, bias=True) - ) - else: - self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) - - def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: - if self.share_mod: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) - else: - shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) - h = x.replace(self.norm1(x.feats)) - h = h * (1 + scale_msa) + shift_msa - h = self.attn(h) - h = h * gate_msa - x = x + h - h = x.replace(self.norm2(x.feats)) - h = h * (1 + scale_mlp) + shift_mlp - h = self.mlp(h) - h = h * gate_mlp - x = x + h - return x - - def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: - return self._forward(x, mod) - - class ModulatedSparseTransformerCrossBlock(nn.Module): """ Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. @@ -483,15 +420,13 @@ class SLatFlowModel(nn.Module): h = self.input_layer(x) h = manual_cast(h, self.dtype) t = t.to(dtype) - t_emb = self.t_embedder(t, out_dtype = t.dtype) + t_embedder = self.t_embedder.to(dtype) + t_emb = t_embedder(t, out_dtype = t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) t_emb = manual_cast(t_emb, self.dtype) cond = manual_cast(cond, self.dtype) - if self.pe_mode == "ape": - pe = self.pos_embedder(h.coords[:, 1:]) - h = h + manual_cast(pe, self.dtype) for block in self.blocks: h = block(h, t_emb, cond) @@ -849,7 +784,24 @@ class Trellis2(nn.Module): txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] if not_struct_mode: - x = SparseTensor(feats=x, coords=coords) + B, N, C = x.shape + + if mode == "shape_generation": + feats_flat = x.reshape(-1, C) + + # 3. inflate coords [N, 4] -> [B*N, 4] + coords_list = [] + for i in range(B): + c = coords.clone() + c[:, 0] = i + coords_list.append(c) + + batched_coords = torch.cat(coords_list, dim=0) + else: # TODO: texture + # may remove the else if texture doesn't require special handling + batched_coords = coords + feats_flat = x + x = SparseTensor(feats=feats_flat, coords=batched_coords) if mode == "shape_generation": # TODO @@ -868,4 +820,6 @@ class Trellis2(nn.Module): if not_struct_mode: out = out.feats + if mode == "shape_generation": + out = out.view(B, N, -1) return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 14f5484d6..623430b9e 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -238,9 +238,9 @@ class Trellis2Conditioning(IO.ComfyNode): max_size = max(image.size) scale = min(1, 1024 / max_size) if scale < 1: - image = image.resize((int(image.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) + image = image.resize((int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS) - image = torch.tensor(np.array(image)).unsqueeze(0) + image = torch.tensor(np.array(image)).unsqueeze(0).float() / 255 # could make 1024 an option conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color)