mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-03 04:47:29 +08:00
Do tripo dinov3 inference in fp32. (#14221)
This commit is contained in:
parent
06b710aa68
commit
4b48535a7d
@ -3,6 +3,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
|
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
|
||||||
|
|
||||||
@ -171,11 +172,11 @@ class DINOv3ViTEmbeddings(nn.Module):
|
|||||||
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
|
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
if bool_masked_pos is not None:
|
if bool_masked_pos is not None:
|
||||||
mask_token = self.mask_token.to(patch_embeddings.dtype)
|
mask_token = comfy.ops.cast_to_input(self.mask_token, patch_embeddings)
|
||||||
patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
|
patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
|
||||||
|
|
||||||
cls_token = self.cls_token.expand(batch_size, -1, -1).to(patch_embeddings.device)
|
cls_token = comfy.ops.cast_to_input(self.cls_token.expand(batch_size, -1, -1), patch_embeddings)
|
||||||
register_tokens = self.register_tokens.expand(batch_size, -1, -1).to(patch_embeddings.device)
|
register_tokens = comfy.ops.cast_to_input(self.register_tokens.expand(batch_size, -1, -1), patch_embeddings)
|
||||||
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
|
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|||||||
@ -115,12 +115,11 @@ class TripoSplatConditioning(IO.ComfyNode):
|
|||||||
# feature1: DINOv3 token sequence (cls + registers + patches), ImageNet-normalized, with a final non-affine layer norm on top
|
# feature1: DINOv3 token sequence (cls + registers + patches), ImageNet-normalized, with a final non-affine layer norm on top
|
||||||
comfy.model_management.load_model_gpu(clip_vision.patcher)
|
comfy.model_management.load_model_gpu(clip_vision.patcher)
|
||||||
device = clip_vision.load_device
|
device = clip_vision.load_device
|
||||||
model_dtype = next(clip_vision.model.parameters()).dtype
|
|
||||||
img = image.movedim(-1, 1).to(device) # (B,3,H,W) in [0,1]
|
img = image.movedim(-1, 1).to(device) # (B,3,H,W) in [0,1]
|
||||||
mean = torch.tensor(_DINOV3_MEAN, device=device).view(1, 3, 1, 1)
|
mean = torch.tensor(_DINOV3_MEAN, device=device).view(1, 3, 1, 1)
|
||||||
std = torch.tensor(_DINOV3_STD, device=device).view(1, 3, 1, 1)
|
std = torch.tensor(_DINOV3_STD, device=device).view(1, 3, 1, 1)
|
||||||
img = (img - mean) / std
|
img = (img - mean) / std
|
||||||
seq = clip_vision.model(pixel_values=img.to(model_dtype))[0]
|
seq = clip_vision.model(pixel_values=img.float())[0]
|
||||||
feature1 = F.layer_norm(seq.float(), seq.shape[-1:]).to(comfy.model_management.intermediate_device())
|
feature1 = F.layer_norm(seq.float(), seq.shape[-1:]).to(comfy.model_management.intermediate_device())
|
||||||
|
|
||||||
# Second conditioning: the Flux2 VAE latent of the image, carried as a standard reference_latents entry
|
# Second conditioning: the Flux2 VAE latent of the image, carried as a standard reference_latents entry
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user