From 4b48535a7d66b89a4314e087e70fb7051e54eaaa Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 1 Jun 2026 18:08:20 -0700 Subject: [PATCH] Do tripo dinov3 inference in fp32. (#14221) --- comfy/image_encoders/dino3.py | 7 ++++--- comfy_extras/nodes_triposplat.py | 3 +-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index 014d1d29a..ad29b06f8 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import comfy.ops from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale @@ -171,11 +172,11 @@ class DINOv3ViTEmbeddings(nn.Module): patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2) if bool_masked_pos is not None: - mask_token = self.mask_token.to(patch_embeddings.dtype) + mask_token = comfy.ops.cast_to_input(self.mask_token, patch_embeddings) patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings) - cls_token = self.cls_token.expand(batch_size, -1, -1).to(patch_embeddings.device) - register_tokens = self.register_tokens.expand(batch_size, -1, -1).to(patch_embeddings.device) + cls_token = comfy.ops.cast_to_input(self.cls_token.expand(batch_size, -1, -1), patch_embeddings) + register_tokens = comfy.ops.cast_to_input(self.register_tokens.expand(batch_size, -1, -1), patch_embeddings) embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) return embeddings diff --git a/comfy_extras/nodes_triposplat.py b/comfy_extras/nodes_triposplat.py index 5646d611b..1848ad31a 100644 --- a/comfy_extras/nodes_triposplat.py +++ b/comfy_extras/nodes_triposplat.py @@ -115,12 +115,11 @@ class TripoSplatConditioning(IO.ComfyNode): # feature1: DINOv3 token sequence (cls + registers + patches), ImageNet-normalized, with a final non-affine layer norm on top comfy.model_management.load_model_gpu(clip_vision.patcher) device = clip_vision.load_device - model_dtype = next(clip_vision.model.parameters()).dtype img = image.movedim(-1, 1).to(device) # (B,3,H,W) in [0,1] mean = torch.tensor(_DINOV3_MEAN, device=device).view(1, 3, 1, 1) std = torch.tensor(_DINOV3_STD, device=device).view(1, 3, 1, 1) img = (img - mean) / std - seq = clip_vision.model(pixel_values=img.to(model_dtype))[0] + seq = clip_vision.model(pixel_values=img.float())[0] feature1 = F.layer_norm(seq.float(), seq.shape[-1:]).to(comfy.model_management.intermediate_device()) # Second conditioning: the Flux2 VAE latent of the image, carried as a standard reference_latents entry