diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index 3ec7f8a04..40ece19ed 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -44,9 +44,6 @@ class DINOv3ViTAttention(nn.Module): self.num_heads = num_attention_heads self.head_dim = self.embed_dim // self.num_heads - self.scaling = self.head_dim**-0.5 - self.is_causal = False - self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) @@ -251,7 +248,7 @@ class DINOv3ViTModel(nn.Module): intermediate_size=intermediate_size,num_attention_heads = num_attention_heads, dtype=dtype, device=device, operations=operations) for _ in range(num_hidden_layers)]) - self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device) + self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device) def get_input_embeddings(self): return self.embeddings.patch_embeddings diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index 1be8408c6..8f677ce24 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -4,7 +4,7 @@ import math import torch from typing import Dict, Callable -NO_TRITION = False +NO_TRITON = False try: allow_tf32 = torch.cuda.is_tf32_supported() except Exception: @@ -115,8 +115,8 @@ try: allow_tf32=allow_tf32, ) return output -except: - NO_TRITION = True +except Exception: + NO_TRITON = True def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device): # offsets in same order as CUDA kernel @@ -364,6 +364,8 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2( def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation): + if NO_TRITON: # TODO + raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.") if len(shape) == 5: N, C, W, H, D = shape else: diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 7f16c4d41..7cc1c1678 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -697,8 +697,6 @@ class SparseStructureFlowModel(nn.Module): def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3) - assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \ - f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}" h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous() @@ -746,7 +744,8 @@ class Trellis2(nn.Module): super().__init__() self.dtype = dtype # for some reason it passes num_heads = -1 - num_heads = 12 + if num_heads == -1: + num_heads = 12 args = { "out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels, "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, @@ -763,8 +762,10 @@ class Trellis2(nn.Module): def forward(self, x, timestep, context, **kwargs): # FIXME: should find a way to distinguish between 512/1024 models # currently assumes 1024 - transformer_options = kwargs.get("transformer_options") + transformer_options = kwargs.get("transformer_options", {}) embeds = kwargs.get("embeds") + if embeds is None: + raise ValueError("Trellis2.forward requires 'embeds' in kwargs") #_, cond = context.chunk(2) # TODO cond = embeds.chunk(2)[0] context = torch.cat([torch.zeros_like(cond), cond]) @@ -807,6 +808,8 @@ class Trellis2(nn.Module): # TODO out = self.img2shape(x, timestep, context) elif mode == "texture_generation": + if self.shape2txt is None: + raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!") out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure timestep = timestep_reshift(timestep) diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 57bf78346..c6ea5deb2 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -1522,6 +1522,8 @@ class Vae(nn.Module): return self.shape_dec(slat, return_subs=True) def decode_tex_slat(self, slat, subs): + if self.txt_dec is None: + raise ValueError("Checkpoint doesn't include texture model") return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5 @torch.no_grad() diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 623430b9e..c688c343d 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,9 +1,11 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types -import torch +import torch.nn.functional as TF import comfy.model_management from PIL import Image import numpy as np +import torch +import copy shape_slat_normalization = { "mean": torch.tensor([ @@ -145,11 +147,11 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): def execute(cls, samples, vae, resolution): vae = vae.first_stage_model samples = samples["samples"] - std = shape_slat_normalization["std"] - mean = shape_slat_normalization["mean"] + std = shape_slat_normalization["std"].to(samples) + mean = shape_slat_normalization["mean"].to(samples) samples = samples * std + mean - mesh, subs = vae.decode_shape_slat(resolution, samples) + mesh, subs = vae.decode_shape_slat(samples, resolution) return IO.NodeOutput(mesh, subs) class VaeDecodeTextureTrellis(IO.ComfyNode): @@ -172,8 +174,8 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): def execute(cls, samples, vae, shape_subs): vae = vae.first_stage_model samples = samples["samples"] - std = tex_slat_normalization["std"] - mean = tex_slat_normalization["mean"] + std = tex_slat_normalization["std"].to(samples) + mean = tex_slat_normalization["mean"].to(samples) samples = samples * std + mean mesh = vae.decode_tex_slat(samples, shape_subs) @@ -239,6 +241,8 @@ class Trellis2Conditioning(IO.ComfyNode): scale = min(1, 1024 / max_size) if scale < 1: image = image.resize((int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS) + new_h, new_w = int(mask.shape[-2] * scale), int(mask.shape[-1] * scale) + mask = TF.interpolate(mask.unsqueeze(0).float(), size=(new_h, new_w), mode='nearest').squeeze(0) image = torch.tensor(np.array(image)).unsqueeze(0).float() / 255 @@ -510,6 +514,7 @@ class PostProcessMesh(IO.ComfyNode): ) @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): + mesh = copy.deepcopy(mesh) verts, faces = mesh.vertices, mesh.faces if fill_holes_perimeter != 0.0: