From f3d4125e4904c427d26aa86c84d26b8dba48fe22 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 20 Feb 2026 20:16:49 +0200 Subject: [PATCH] code rabbit suggestions --- comfy/image_encoders/dino3.py | 6 +++--- comfy/image_encoders/dino3_large.json | 4 ++-- comfy/ldm/trellis2/cumesh.py | 2 +- comfy/ldm/trellis2/model.py | 8 +++----- comfy/model_detection.py | 4 ++-- 5 files changed, 11 insertions(+), 13 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index ce6b2edd9..e009a7291 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -188,7 +188,7 @@ class DINOv3ViTLayer(nn.Module): device, dtype, operations): super().__init__() - self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps) + self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations) self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) @@ -273,8 +273,8 @@ class DINOv3ViTModel(nn.Module): position_embeddings=position_embeddings, ) - self.norm = self.norm.to(hidden_states.device) - sequence_output = self.norm(hidden_states) + norm = self.norm.to(hidden_states.device) + sequence_output = norm(hidden_states) pooled_output = sequence_output[:, 0, :] return sequence_output, None, pooled_output, None diff --git a/comfy/image_encoders/dino3_large.json b/comfy/image_encoders/dino3_large.json index 53f761a25..b37b61dc8 100644 --- a/comfy/image_encoders/dino3_large.json +++ b/comfy/image_encoders/dino3_large.json @@ -18,6 +18,6 @@ "rope_theta": 100.0, "use_gated_mlp": false, "value_bias": true, - "mean": [0.485, 0.456, 0.406], - "std": [0.229, 0.224, 0.225] + "image_mean": [0.485, 0.456, 0.406], + "image_std": [0.229, 0.224, 0.225] } diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index 972fb13c3..cb067a32f 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -6,7 +6,7 @@ from typing import Dict, Callable NO_TRITION = False try: - allow_tf32 = torch.cuda.is_tf32_supported + allow_tf32 = torch.cuda.is_tf32_supported() except Exception: allow_tf32 = False try: diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 5ff2a1ce0..07cf86d30 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -306,10 +306,7 @@ class ModulatedSparseTransformerBlock(nn.Module): return x def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: - if self.use_checkpoint: - return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) - else: - return self._forward(x, mod) + return self._forward(x, mod) class ModulatedSparseTransformerCrossBlock(nn.Module): @@ -486,6 +483,7 @@ class SLatFlowModel(nn.Module): x = x.to(dtype) h = self.input_layer(x) h = manual_cast(h, self.dtype) + t = t.to(dtype) t_emb = self.t_embedder(t, out_dtype = t.dtype) if self.share_mod: t_emb = self.adaLN_modulation(t_emb) @@ -790,7 +788,7 @@ class SparseStructureFlowModel(nn.Module): return h def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0): - t_shifted /= 1000.0 + t_shifted = t_shifted / 1000.0 t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1)) t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear) t_new *= 1000.0 diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 375cb87b1..6cadc8af6 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -117,12 +117,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config["image_model"] = "trellis2" unet_config["init_txt_model"] = False - if '{}model.shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: + if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: unet_config["init_txt_model"] = True unet_config["resolution"] = 64 if metadata is not None: - if "is_512" in metadata and metadata["metadata"]: + if "is_512" in metadata: unet_config["resolution"] = 32 unet_config["num_heads"] = 12