multiple fixes

This commit is contained in:
Yousef Rafat 2026-02-09 22:47:50 +02:00
parent 704e1b5462
commit 2eef826def
5 changed files with 60 additions and 30 deletions

View File

@ -9,6 +9,7 @@ import comfy.model_management
import comfy.utils
import comfy.clip_model
import comfy.image_encoders.dino2
import comfy.image_encoders.dino3
class Output:
def __getitem__(self, key):
@ -23,6 +24,7 @@ IMAGE_ENCODERS = {
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
"dinov3": comfy.image_encoders.dino3.DINOv3ViTModel
}
class ClipVisionModel():
@ -134,6 +136,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
elif 'layer.9.attention.o_proj.bias' in sd: # dinov3
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino3_large.json")
else:
return None

View File

@ -4,7 +4,19 @@ import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.ldm.flux.math import apply_rope
from dino2 import Dinov2MLP as DINOv3ViTMLP, LayerScale as DINOv3ViTLayerScale
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
class DINOv3ViTMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
self.act_fn = torch.nn.GELU()
def forward(self, x):
return self.down_proj(self.act_fn(self.up_proj(x)))
class DINOv3ViTAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, device, dtype, operations):
@ -90,6 +102,7 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
self.head_dim = hidden_size // num_attention_heads
self.num_patches_h = image_size // patch_size
self.num_patches_w = image_size // patch_size
self.patch_size = patch_size
inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
@ -106,6 +119,7 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
num_patches_h, num_patches_w, dtype=torch.float32, device=device
)
self.inv_freq = self.inv_freq.to(device)
angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :]
angles = angles.flatten(1, 2)
angles = angles.tile(2)
@ -140,27 +154,30 @@ class DINOv3ViTEmbeddings(nn.Module):
cls_token = self.cls_token.expand(batch_size, -1, -1)
register_tokens = self.register_tokens.expand(batch_size, -1, -1)
device = patch_embeddings
cls_token = cls_token.to(device)
register_tokens = register_tokens.to(device)
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
return embeddings
class DINOv3ViTLayer(nn.Module):
def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, layerscale_value, mlp_bias, intermediate_size, num_attention_heads,
def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size, num_attention_heads,
device, dtype, operations):
super().__init__()
self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps)
self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations)
self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, layerscale_value, device=device, dtype=dtype)
self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
if use_gated_mlp:
self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations)
else:
self.mlp = DINOv3ViTMLP(hidden_size, device=device, dtype=dtype, operations=operations)
self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, layerscale_value, device=device, dtype=dtype)
self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations)
self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
def forward(
self,
@ -188,7 +205,7 @@ class DINOv3ViTLayer(nn.Module):
class DINOv3ViTModel(nn.Module):
def __init__(self, config, device, dtype, operations):
def __init__(self, config, dtype, device, operations):
super().__init__()
num_hidden_layers = config["num_hidden_layers"]
hidden_size = config["hidden_size"]
@ -196,7 +213,6 @@ class DINOv3ViTModel(nn.Module):
num_register_tokens = config["num_register_tokens"]
intermediate_size = config["intermediate_size"]
layer_norm_eps = config["layer_norm_eps"]
layerscale_value = config["layerscale_value"]
num_channels = config["num_channels"]
patch_size = config["patch_size"]
rope_theta = config["rope_theta"]
@ -208,7 +224,7 @@ class DINOv3ViTModel(nn.Module):
rope_theta, hidden_size, num_attention_heads, image_size=512, patch_size=patch_size, dtype=dtype, device=device
)
self.layer = nn.ModuleList(
[DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=False, layerscale_value=layerscale_value, mlp_bias=True,
[DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=False, mlp_bias=True,
intermediate_size=intermediate_size,num_attention_heads = num_attention_heads,
dtype=dtype, device=device, operations=operations)
for _ in range(num_hidden_layers)])

View File

@ -1,16 +1,15 @@
{
"hidden_size": 384,
"model_type": "dinov3",
"hidden_size": 1024,
"image_size": 224,
"initializer_range": 0.02,
"intermediate_size": 1536,
"intermediate_size": 4096,
"key_bias": false,
"layer_norm_eps": 1e-05,
"layerscale_value": 1.0,
"mlp_bias": true,
"num_attention_heads": 6,
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 12,
"num_hidden_layers": 24,
"num_register_tokens": 4,
"patch_size": 16,
"pos_embed_rescale": 2.0,

View File

@ -1251,12 +1251,18 @@ class Trellis2(supported_models_base.BASE):
"shift": 3.0,
}
memory_usage_factor = 3.5
latent_format = latent_formats.Trellis2
vae_key_prefix = ["vae."]
clip_vision_prefix = "conditioner.main_image_encoder.model."
def get_model(self, state_dict, prefix="", device=None):
return model_base.Trellis2(self, device=device)
def clip_target(self, state_dict={}):
return None
class Hunyuan3Dv2(supported_models_base.BASE):
unet_config = {
"image_model": "hunyuan3d2",

View File

@ -3,10 +3,8 @@ from comfy_api.latest import ComfyExtension, IO
import torch
from comfy.ldm.trellis2.model import SparseTensor
import comfy.model_management
from PIL import Image
import PIL
import numpy as np
from comfy.nested_tensor import NestedTensor
from torchvision.transforms import ToPILImage, ToTensor, Resize, InterpolationMode
shape_slat_normalization = {
"mean": torch.tensor([
@ -76,23 +74,30 @@ def run_conditioning(
# Convert image to PIL
if image.dim() == 4:
pil_image = (image[0] * 255).clip(0, 255).astype(torch.uint8)
pil_image = (image[0] * 255).clip(0, 255).to(torch.uint8)
else:
pil_image = (image * 255).clip(0, 255).astype(torch.uint8)
pil_image = (image * 255).clip(0, 255).to(torch.uint8)
pil_image = pil_image.movedim(-1, 0)
pil_image = smart_crop_square(pil_image, background_color=bg_color)
model.image_size = 512
def set_image_size(image, image_size=512):
image = PIL.from_array(image)
image = [i.resize((image_size, image_size), Image.LANCZOS) for i in image]
image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image]
image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image]
image = torch.stack(image).to(torch_device)
return image
if image.ndim == 3:
image = image.unsqueeze(0)
pil_image = set_image_size(image, 512)
cond_512 = model([pil_image])
to_pil = ToPILImage()
to_tensor = ToTensor()
resizer = Resize((image_size, image_size), interpolation=InterpolationMode.LANCZOS)
pil_img = to_pil(image.squeeze(0))
resized_pil = resizer(pil_img)
image = to_tensor(resized_pil).unsqueeze(0)
return image.to(torch_device).float()
pil_image = set_image_size(pil_image, 512)
cond_512 = model(pil_image)
cond_1024 = None
if include_1024:
@ -267,7 +272,7 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
node_id="EmptyStructureLatentTrellis2",
category="latent/3d",
inputs=[
IO.Int.Input("resolution", default=3072, min=1, max=8192),
IO.Int.Input("resolution", default=256, min=1, max=8192),
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
],
outputs=[
@ -275,9 +280,9 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
]
)
@classmethod
def execute(cls, res, batch_size):
def execute(cls, resolution, batch_size):
in_channels = 32
latent = torch.randn(batch_size, in_channels, res, res, res)
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
latent = NestedTensor([latent])
return IO.NodeOutput({"samples": latent, "type": "trellis2"})