mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-06 14:27:24 +08:00
Some cast/dtype fixes for the birefnet and dino3 models. (#14217)
This commit is contained in:
parent
a88e02b185
commit
e785f0d212
@ -105,7 +105,7 @@ class WindowAttention(nn.Module):
|
|||||||
|
|
||||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
|
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
|
||||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
||||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
relative_position_bias = comfy.ops.cast_to_input(relative_position_bias.permute(2, 0, 1).contiguous(), attn) # nH, Wh*Ww, Wh*Ww
|
||||||
attn = attn + relative_position_bias.unsqueeze(0)
|
attn = attn + relative_position_bias.unsqueeze(0)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
|
|||||||
@ -2,7 +2,6 @@ from .utils import load_torch_file, transformers_convert, state_dict_prefix_repl
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import torch
|
|
||||||
|
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
@ -50,10 +49,6 @@ class ClipVisionModel():
|
|||||||
self.load_device = comfy.model_management.text_encoder_device()
|
self.load_device = comfy.model_management.text_encoder_device()
|
||||||
offload_device = comfy.model_management.text_encoder_offload_device()
|
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||||
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||||
if self.model_type == "dinov3" and self.dtype == torch.float16:
|
|
||||||
# DINOv3's activations borderline fits fp16, preferring bf16 if available for better stability #TODO: further fp16 tests in practice
|
|
||||||
if comfy.model_management.should_use_bf16(self.load_device, prioritize_performance=True):
|
|
||||||
self.dtype = torch.bfloat16
|
|
||||||
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
|
|||||||
@ -166,9 +166,8 @@ class DINOv3ViTEmbeddings(nn.Module):
|
|||||||
|
|
||||||
def forward(self, pixel_values, bool_masked_pos=None):
|
def forward(self, pixel_values, bool_masked_pos=None):
|
||||||
batch_size = pixel_values.shape[0]
|
batch_size = pixel_values.shape[0]
|
||||||
target_dtype = self.patch_embeddings.weight.dtype
|
|
||||||
|
|
||||||
patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
|
patch_embeddings = self.patch_embeddings(pixel_values)
|
||||||
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
|
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
if bool_masked_pos is not None:
|
if bool_masked_pos is not None:
|
||||||
@ -244,7 +243,6 @@ class DINOv3ViTModel(nn.Module):
|
|||||||
return self.embeddings.patch_embeddings
|
return self.embeddings.patch_embeddings
|
||||||
|
|
||||||
def forward(self, pixel_values, bool_masked_pos=None, **kwargs):
|
def forward(self, pixel_values, bool_masked_pos=None, **kwargs):
|
||||||
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
|
||||||
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||||
position_embeddings = self.rope_embeddings(pixel_values)
|
position_embeddings = self.rope_embeddings(pixel_values)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user