From 2eef826def23599d395bb1854d6b95cb657c8385 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 9 Feb 2026 22:47:50 +0200 Subject: [PATCH] multiple fixes --- comfy/clip_vision.py | 4 +++ comfy/image_encoders/dino3.py | 32 +++++++++++++++++------ comfy/image_encoders/dino3_large.json | 11 ++++---- comfy/supported_models.py | 6 +++++ comfy_extras/nodes_trellis2.py | 37 +++++++++++++++------------ 5 files changed, 60 insertions(+), 30 deletions(-) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 1691fca81..71f2200b7 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -9,6 +9,7 @@ import comfy.model_management import comfy.utils import comfy.clip_model import comfy.image_encoders.dino2 +import comfy.image_encoders.dino3 class Output: def __getitem__(self, key): @@ -23,6 +24,7 @@ IMAGE_ENCODERS = { "siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection, "siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection, "dinov2": comfy.image_encoders.dino2.Dinov2Model, + "dinov3": comfy.image_encoders.dino3.DINOv3ViTModel } class ClipVisionModel(): @@ -134,6 +136,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json") elif 'encoder.layer.23.layer_scale2.lambda1' in sd: json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json") + elif 'layer.9.attention.o_proj.bias' in sd: # dinov3 + json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino3_large.json") else: return None diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index d07c2c5b8..b27b95b5f 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -4,7 +4,19 @@ import torch.nn as nn from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.flux.math import apply_rope -from dino2 import Dinov2MLP as DINOv3ViTMLP, LayerScale as DINOv3ViTLayerScale +from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale + +class DINOv3ViTMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype) + self.act_fn = torch.nn.GELU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) class DINOv3ViTAttention(nn.Module): def __init__(self, hidden_size, num_attention_heads, device, dtype, operations): @@ -90,6 +102,7 @@ class DINOv3ViTRopePositionEmbedding(nn.Module): self.head_dim = hidden_size // num_attention_heads self.num_patches_h = image_size // patch_size self.num_patches_w = image_size // patch_size + self.patch_size = patch_size inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device) self.register_buffer("inv_freq", inv_freq, persistent=False) @@ -106,6 +119,7 @@ class DINOv3ViTRopePositionEmbedding(nn.Module): num_patches_h, num_patches_w, dtype=torch.float32, device=device ) + self.inv_freq = self.inv_freq.to(device) angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] angles = angles.flatten(1, 2) angles = angles.tile(2) @@ -140,27 +154,30 @@ class DINOv3ViTEmbeddings(nn.Module): cls_token = self.cls_token.expand(batch_size, -1, -1) register_tokens = self.register_tokens.expand(batch_size, -1, -1) + device = patch_embeddings + cls_token = cls_token.to(device) + register_tokens = register_tokens.to(device) embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) return embeddings class DINOv3ViTLayer(nn.Module): - def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, layerscale_value, mlp_bias, intermediate_size, num_attention_heads, + def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size, num_attention_heads, device, dtype, operations): super().__init__() self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps) self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations) - self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, layerscale_value, device=device, dtype=dtype) + self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) if use_gated_mlp: self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations) else: - self.mlp = DINOv3ViTMLP(hidden_size, device=device, dtype=dtype, operations=operations) - self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, layerscale_value, device=device, dtype=dtype) + self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations) + self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) def forward( self, @@ -188,7 +205,7 @@ class DINOv3ViTLayer(nn.Module): class DINOv3ViTModel(nn.Module): - def __init__(self, config, device, dtype, operations): + def __init__(self, config, dtype, device, operations): super().__init__() num_hidden_layers = config["num_hidden_layers"] hidden_size = config["hidden_size"] @@ -196,7 +213,6 @@ class DINOv3ViTModel(nn.Module): num_register_tokens = config["num_register_tokens"] intermediate_size = config["intermediate_size"] layer_norm_eps = config["layer_norm_eps"] - layerscale_value = config["layerscale_value"] num_channels = config["num_channels"] patch_size = config["patch_size"] rope_theta = config["rope_theta"] @@ -208,7 +224,7 @@ class DINOv3ViTModel(nn.Module): rope_theta, hidden_size, num_attention_heads, image_size=512, patch_size=patch_size, dtype=dtype, device=device ) self.layer = nn.ModuleList( - [DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=False, layerscale_value=layerscale_value, mlp_bias=True, + [DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=False, mlp_bias=True, intermediate_size=intermediate_size,num_attention_heads = num_attention_heads, dtype=dtype, device=device, operations=operations) for _ in range(num_hidden_layers)]) diff --git a/comfy/image_encoders/dino3_large.json b/comfy/image_encoders/dino3_large.json index 96263f0d6..53f761a25 100644 --- a/comfy/image_encoders/dino3_large.json +++ b/comfy/image_encoders/dino3_large.json @@ -1,16 +1,15 @@ { - - "hidden_size": 384, + "model_type": "dinov3", + "hidden_size": 1024, "image_size": 224, "initializer_range": 0.02, - "intermediate_size": 1536, + "intermediate_size": 4096, "key_bias": false, "layer_norm_eps": 1e-05, - "layerscale_value": 1.0, "mlp_bias": true, - "num_attention_heads": 6, + "num_attention_heads": 16, "num_channels": 3, - "num_hidden_layers": 12, + "num_hidden_layers": 24, "num_register_tokens": 4, "patch_size": 16, "pos_embed_rescale": 2.0, diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 9e2f17149..3373f78a2 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1251,12 +1251,18 @@ class Trellis2(supported_models_base.BASE): "shift": 3.0, } + memory_usage_factor = 3.5 + latent_format = latent_formats.Trellis2 vae_key_prefix = ["vae."] + clip_vision_prefix = "conditioner.main_image_encoder.model." def get_model(self, state_dict, prefix="", device=None): return model_base.Trellis2(self, device=device) + def clip_target(self, state_dict={}): + return None + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 17ba94ec8..f53d36736 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -3,10 +3,8 @@ from comfy_api.latest import ComfyExtension, IO import torch from comfy.ldm.trellis2.model import SparseTensor import comfy.model_management -from PIL import Image -import PIL -import numpy as np from comfy.nested_tensor import NestedTensor +from torchvision.transforms import ToPILImage, ToTensor, Resize, InterpolationMode shape_slat_normalization = { "mean": torch.tensor([ @@ -76,23 +74,30 @@ def run_conditioning( # Convert image to PIL if image.dim() == 4: - pil_image = (image[0] * 255).clip(0, 255).astype(torch.uint8) + pil_image = (image[0] * 255).clip(0, 255).to(torch.uint8) else: - pil_image = (image * 255).clip(0, 255).astype(torch.uint8) + pil_image = (image * 255).clip(0, 255).to(torch.uint8) + pil_image = pil_image.movedim(-1, 0) pil_image = smart_crop_square(pil_image, background_color=bg_color) model.image_size = 512 def set_image_size(image, image_size=512): - image = PIL.from_array(image) - image = [i.resize((image_size, image_size), Image.LANCZOS) for i in image] - image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] - image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] - image = torch.stack(image).to(torch_device) - return image + if image.ndim == 3: + image = image.unsqueeze(0) - pil_image = set_image_size(image, 512) - cond_512 = model([pil_image]) + to_pil = ToPILImage() + to_tensor = ToTensor() + resizer = Resize((image_size, image_size), interpolation=InterpolationMode.LANCZOS) + + pil_img = to_pil(image.squeeze(0)) + resized_pil = resizer(pil_img) + image = to_tensor(resized_pil).unsqueeze(0) + + return image.to(torch_device).float() + + pil_image = set_image_size(pil_image, 512) + cond_512 = model(pil_image) cond_1024 = None if include_1024: @@ -267,7 +272,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): node_id="EmptyStructureLatentTrellis2", category="latent/3d", inputs=[ - IO.Int.Input("resolution", default=3072, min=1, max=8192), + IO.Int.Input("resolution", default=256, min=1, max=8192), IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), ], outputs=[ @@ -275,9 +280,9 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): ] ) @classmethod - def execute(cls, res, batch_size): + def execute(cls, resolution, batch_size): in_channels = 32 - latent = torch.randn(batch_size, in_channels, res, res, res) + latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"})