From e785f0d212731e7f0f4b8c1638c58ab7df6f16b7 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 1 Jun 2026 14:35:26 -0700 Subject: [PATCH] Some cast/dtype fixes for the birefnet and dino3 models. (#14217) --- comfy/background_removal/birefnet.py | 2 +- comfy/clip_vision.py | 5 ----- comfy/image_encoders/dino3.py | 4 +--- 3 files changed, 2 insertions(+), 9 deletions(-) diff --git a/comfy/background_removal/birefnet.py b/comfy/background_removal/birefnet.py index df54b2b90..78a80246e 100644 --- a/comfy/background_removal/birefnet.py +++ b/comfy/background_removal/birefnet.py @@ -105,7 +105,7 @@ class WindowAttention(nn.Module): relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = comfy.ops.cast_to_input(relative_position_bias.permute(2, 0, 1).contiguous(), attn) # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 337575191..ce8924a11 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -2,7 +2,6 @@ from .utils import load_torch_file, transformers_convert, state_dict_prefix_repl import os import json import logging -import torch import comfy.ops import comfy.model_patcher @@ -50,10 +49,6 @@ class ClipVisionModel(): self.load_device = comfy.model_management.text_encoder_device() offload_device = comfy.model_management.text_encoder_offload_device() self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) - if self.model_type == "dinov3" and self.dtype == torch.float16: - # DINOv3's activations borderline fits fp16, preferring bf16 if available for better stability #TODO: further fp16 tests in practice - if comfy.model_management.should_use_bf16(self.load_device, prioritize_performance=True): - self.dtype = torch.bfloat16 self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model.eval() diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index 9bd42a66b..014d1d29a 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -166,9 +166,8 @@ class DINOv3ViTEmbeddings(nn.Module): def forward(self, pixel_values, bool_masked_pos=None): batch_size = pixel_values.shape[0] - target_dtype = self.patch_embeddings.weight.dtype - patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + patch_embeddings = self.patch_embeddings(pixel_values) patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2) if bool_masked_pos is not None: @@ -244,7 +243,6 @@ class DINOv3ViTModel(nn.Module): return self.embeddings.patch_embeddings def forward(self, pixel_values, bool_masked_pos=None, **kwargs): - pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) position_embeddings = self.rope_embeddings(pixel_values)