diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 68f2d7989..4d5ce024f 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -163,6 +163,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): naf_sd = {k[len("naf."):]: sd.pop(k) for k in naf_keys} naf = NAF().eval() naf.load_state_dict(naf_sd, strict=False) + naf.to(comfy.model_management.text_encoder_dtype(clip.load_device)) clip.naf = comfy.model_patcher.CoreModelPatcher(naf, load_device=clip.load_device, offload_device=comfy.model_management.text_encoder_offload_device()) return clip diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 2beb389e7..4a0deeacc 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -55,10 +55,12 @@ class SparseFeedForwardNet(nn.Module): def forward(self, x: VarLenTensor) -> VarLenTensor: return self.mlp(x) -class SparseMultiHeadRMSNorm(nn.Module): - def __init__(self, dim: int, heads: int, device, dtype): +class MultiHeadRMSNorm(nn.Module): + # Per-head qk-norm for both sparse (VarLenTensor) and dense inputs. gamma is [heads, dim] + # (per-head), so it's a broadcast multiply rather than F.rms_norm's 1-D weight + def __init__(self, dim: int, heads: int, device=None, dtype=None): super().__init__() - self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype)) + self.gamma = nn.Parameter(torch.empty(heads, dim, device=device, dtype=dtype)) def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: if isinstance(x, VarLenTensor): @@ -147,8 +149,8 @@ class SparseMultiHeadAttention(nn.Module): self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype) if self.qk_rms_norm: - self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) - self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype) @@ -307,7 +309,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module): operations.Linear(channels, 6 * channels, bias=True, device=device, dtype=dtype) ) else: - self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5) + self.modulation = nn.Parameter(torch.empty(6 * channels, device=device, dtype=dtype)) def _forward(self, x: SparseTensor, mod: torch.Tensor, context, transformer_options=None) -> SparseTensor: if self.share_mod: @@ -444,24 +446,6 @@ class FeedForwardNet(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: return self.mlp(x) -# class MultiHeadRMSNorm(nn.Module): -# def __init__(self, dim: int, heads: int, device=None, dtype=None): -# super().__init__() -# self.scale = dim ** 0.5 -# self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype)) - -# def forward(self, x: torch.Tensor) -> torch.Tensor: -# return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) - -class MultiHeadRMSNorm(nn.Module): - def __init__(self, dim: int, heads: int, device=None, dtype=None): - super().__init__() - self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return (F.rms_norm(x.float(), (x.shape[-1],)) * self.gamma).to(x.dtype) - - class MultiHeadAttention(nn.Module): def __init__( self, @@ -580,7 +564,7 @@ class ModulatedTransformerCrossBlock(nn.Module): if not share_mod: self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device)) else: - self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5) + self.modulation = nn.Parameter(torch.empty(6 * channels, device=device, dtype=dtype)) def _forward(self, x: torch.Tensor, mod: torch.Tensor, context, phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor: @@ -631,7 +615,7 @@ class SparseStructureFlowModel(nn.Module): proj_in_channels: Optional[int] = None, operations=None, device = None, - dtype = torch.float32, + dtype = None, **kwargs ): super().__init__() diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 7f765fc6e..d43776ce2 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -918,17 +918,14 @@ def flexible_dual_grid_to_mesh( ): device = coords.device - if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset") \ - or flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset.device != device: - flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset = torch.tensor([ - [[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis - [[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis - [[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis - ], dtype=torch.int, device=device).unsqueeze(0) - if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1") or flexible_dual_grid_to_mesh.quad_split_1.device != device: - flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) - if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2") or flexible_dual_grid_to_mesh.quad_split_2.device != device: - flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) + # Small constant index tables — built per call (stateless), not cached on the function. + edge_neighbor_voxel_offset = torch.tensor([ + [[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis + [[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis + [[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis + ], dtype=torch.int, device=device).unsqueeze(0) + quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device) + quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device) aabb = torch.tensor(aabb, dtype=torch.float32, device=device) @@ -954,7 +951,7 @@ def flexible_dual_grid_to_mesh( # Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3] n_idx, axis_idx = intersected_flag.nonzero(as_tuple=True) # (M,), (M,) - offsets_per_axis = flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset[0] # (3, 4, 3) + offsets_per_axis = edge_neighbor_voxel_offset[0] # (3, 4, 3) connected_voxel = coords[n_idx].unsqueeze(1) + offsets_per_axis[axis_idx] # (M, 4, 3) M = connected_voxel.shape[0] # flatten connected voxel coords and lookup. In-place to avoid extra memory allocation. @@ -973,12 +970,12 @@ def flexible_dual_grid_to_mesh( mesh_vertices.add_(dual_vertices).mul_(voxel_size).add_(aabb[0].reshape(1, 3)) if split_weight is None: # if split 1 - atempt_triangles_0 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1] + atempt_triangles_0 = quad_indices[:, quad_split_1] normals0 = torch.cross(mesh_vertices[atempt_triangles_0[:, 1]] - mesh_vertices[atempt_triangles_0[:, 0]], mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 0]]) normals1 = torch.cross(mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 1]], mesh_vertices[atempt_triangles_0[:, 3]] - mesh_vertices[atempt_triangles_0[:, 1]]) align0 = (normals0 * normals1).sum(dim=1, keepdim=True).abs() # if split 2 - atempt_triangles_1 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2] + atempt_triangles_1 = quad_indices[:, quad_split_2] normals0 = torch.cross(mesh_vertices[atempt_triangles_1[:, 1]] - mesh_vertices[atempt_triangles_1[:, 0]], mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 0]]) normals1 = torch.cross(mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 1]], mesh_vertices[atempt_triangles_1[:, 3]] - mesh_vertices[atempt_triangles_1[:, 1]]) align1 = (normals0 * normals1).sum(dim=1, keepdim=True).abs() @@ -990,8 +987,8 @@ def flexible_dual_grid_to_mesh( split_weight_ws_13 = split_weight_ws[:, 1] * split_weight_ws[:, 3] mesh_triangles = torch.where( split_weight_ws_02 > split_weight_ws_13, - quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1], - quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2] + quad_indices[:, quad_split_1], + quad_indices[:, quad_split_2] ).reshape(-1, 3) return mesh_vertices, mesh_triangles diff --git a/comfy_api/latest/_util/geometry_types.py b/comfy_api/latest/_util/geometry_types.py index 22e3df3ac..64f49a97b 100644 --- a/comfy_api/latest/_util/geometry_types.py +++ b/comfy_api/latest/_util/geometry_types.py @@ -39,7 +39,12 @@ class MESH: vertex_counts: torch.Tensor | None = None, face_counts: torch.Tensor | None = None, unlit: bool = False, - normals: torch.Tensor | None = None): + normals: torch.Tensor | None = None, + tangents: torch.Tensor | None = None, + normal_map: torch.Tensor | None = None, + occlusion_in_mr: bool = False, + material: dict | None = None, + emissive: torch.Tensor | None = None): assert (vertex_counts is None) == (face_counts is None), \ "vertex_counts and face_counts must be provided together (both or neither)" @@ -59,6 +64,13 @@ class MESH: self.face_counts = face_counts # Render flat / emissive (no scene lighting) when saved, e.g. for gaussian-splat-derived meshes. self.unlit = unlit + # Extra maps / material overrides attached by bake, normal/AO, and SetMeshMaterial nodes; + # consumed by SaveGLB. Declared here (with defaults) so consumers read them directly. + self.tangents = tangents # (B, N, 4) per-vertex tangents for normal mapping + self.normal_map = normal_map # tangent-space normal map: (B, H, W, 3) + self.occlusion_in_mr = occlusion_in_mr # True = R channel of metallic_roughness holds AO (ORM) + self.material = material # SetMeshMaterial scalar/factor overrides + self.emissive = emissive # emissive map: (B, H, W, 3) class File3D: diff --git a/comfy_extras/mesh3d/postprocess/qem_decimate.py b/comfy_extras/mesh3d/postprocess/qem_decimate.py index eae272ef5..34b7c77fe 100644 --- a/comfy_extras/mesh3d/postprocess/qem_decimate.py +++ b/comfy_extras/mesh3d/postprocess/qem_decimate.py @@ -1282,25 +1282,17 @@ def qem_simplify( iteration = 0 total_collapses = 0 - # progress bars (tqdm + optional comfy ProgressBar), best-effort + # progress bars (tqdm + comfy ProgressBar) _start_faces = num_faces _prog_total = max(1, _start_faces - int(target_faces)) - try: - _qtq = _tqdm(total=100, desc="QEM simplify", leave=False) - except Exception: - _qtq = None - try: - _qpbar = _comfy_utils.ProgressBar(100) - except Exception: - _qpbar = None + _qtq = _tqdm(total=100, desc="QEM simplify", leave=False) + _qpbar = _comfy_utils.ProgressBar(100) def _qreport(): pct = min(100, max(0, int(100 * (_start_faces - py_n_faces) / _prog_total))) - if _qtq is not None: - _qtq.n = pct - _qtq.refresh() - if _qpbar is not None: - _qpbar.update_absolute(pct, 100) + _qtq.n = pct + _qtq.refresh() + _qpbar.update_absolute(pct, 100) while True: if py_n_faces <= target_faces: @@ -1523,8 +1515,7 @@ def qem_simplify( break _qreport() - if _qtq is not None: - _qtq.close() + _qtq.close() # finalize: compact verts and faces final_v = verts[v_alive] diff --git a/comfy_extras/mesh3d/postprocess/remesh.py b/comfy_extras/mesh3d/postprocess/remesh.py index 4cbb867fe..fc09b05ec 100644 --- a/comfy_extras/mesh3d/postprocess/remesh.py +++ b/comfy_extras/mesh3d/postprocess/remesh.py @@ -522,7 +522,6 @@ def _build_mdc_lut() -> Tuple[torch.Tensor, torch.Tensor]: return K, group -@functools.lru_cache(maxsize=None) def _mdc_lut(device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: K, g = _build_mdc_lut() return K.to(device), g.to(device) @@ -967,15 +966,11 @@ def remesh_narrow_band_dc( n_levels += 1 _total_ticks = n_levels + 3 + int(smooth_iters) _pbar = comfy.utils.ProgressBar(_total_ticks) - try: - _tq = _tqdm(total=_total_ticks, desc="Remesh DC", leave=False) - except Exception: - _tq = None + _tq = _tqdm(total=_total_ticks, desc="Remesh DC", leave=False) def tick(): _pbar.update(1) - if _tq is not None: - _tq.update(1) + _tq.update(1) # Step 1: sparse narrow-band voxel grid (coarse-to-fine) voxel_coords, _band_tree = _build_narrow_band_voxels( diff --git a/comfy_extras/mesh3d/uv_unwrap/parameterize.py b/comfy_extras/mesh3d/uv_unwrap/parameterize.py index 966eaf9e3..8b59575a0 100644 --- a/comfy_extras/mesh3d/uv_unwrap/parameterize.py +++ b/comfy_extras/mesh3d/uv_unwrap/parameterize.py @@ -366,8 +366,8 @@ def lscm_chart( x_free = solve_least_squares(A_free, b) if not np.all(np.isfinite(x_free)): fallback_to_ortho = True - except Exception: - fallback_to_ortho = True + except (sp.linalg.MatrixRankWarning, RuntimeError): + fallback_to_ortho = True # singular / under-constrained system if fallback_to_ortho: if pin_positions is not None and pin_positions.shape == (Vc, 2): diff --git a/comfy_extras/nodes_save_3d.py b/comfy_extras/nodes_save_3d.py index 47d4664b3..a7d307163 100644 --- a/comfy_extras/nodes_save_3d.py +++ b/comfy_extras/nodes_save_3d.py @@ -74,31 +74,22 @@ def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=Non ) packed_tangents[i, :tn.shape[0]] = tn - out = Types.MESH(packed_vertices, packed_faces, - uvs=packed_uvs, vertex_colors=packed_colors, texture=texture, - metallic_roughness=metallic_roughness, - vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit, - normals=packed_normals) - if packed_tangents is not None: - out.tangents = packed_tangents - if normal_map is not None: - out.normal_map = normal_map - if occlusion_in_mr: - out.occlusion_in_mr = True - if material is not None: - out.material = material - if emissive is not None: - out.emissive = emissive - return out + return Types.MESH(packed_vertices, packed_faces, + uvs=packed_uvs, vertex_colors=packed_colors, texture=texture, + metallic_roughness=metallic_roughness, + vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit, + normals=packed_normals, tangents=packed_tangents, + normal_map=normal_map, occlusion_in_mr=occlusion_in_mr, + material=material, emissive=emissive) def get_mesh_batch_item(mesh, index): # Returns (vertices, faces, colors, uvs) for batch index, slicing to real lengths # if the mesh carries per-item counts (variable-size batch). - v_colors = getattr(mesh, "vertex_colors", None) - v_uvs = getattr(mesh, "uvs", None) - v_normals = getattr(mesh, "normals", None) - if getattr(mesh, "vertex_counts", None) is not None: + v_colors = mesh.vertex_colors + v_uvs = mesh.uvs + v_normals = mesh.normals + if mesh.vertex_counts is not None: vertex_count = int(mesh.vertex_counts[index].item()) face_count = int(mesh.face_counts[index].item()) vertices = mesh.vertices[index, :vertex_count] @@ -482,6 +473,7 @@ def save_glb(vertices, faces, filepath=None, metadata=None, if (texture_png_bytes is not None and has_uv) or "COLOR_0" in primitive_attributes: pbr["baseColorFactor"] = [1.0, 1.0, 1.0, 1.0] + pbr["roughnessFactor"] = 1.0 if mr_png_bytes is not None and has_uv: mr_texture_index = add_image_texture(mr_byte_offset, len(mr_buffer)) @@ -599,7 +591,7 @@ def mesh_item_to_glb_bytes(mesh, index, metadata=None): assert a.ndim == 3 and a.shape[-1] == 3, f"{attr} must be (B, H, W, 3), got {tuple(t.shape)}" return Image.fromarray(a, mode="RGB") - tangents_b = getattr(mesh, "tangents", None) + tangents_b = mesh.tangents tangents_i = tangents_b[index, :vertices_i.shape[0]] if tangents_b is not None else None return save_glb( vertices_i, faces_i, None, metadata, @@ -607,12 +599,12 @@ def mesh_item_to_glb_bytes(mesh, index, metadata=None): vertex_colors=v_colors, texture_image=_img("texture"), metallic_roughness_image=_img("metallic_roughness"), - unlit=getattr(mesh, "unlit", False), + unlit=mesh.unlit, normals=normals_i, normal_map_image=_img("normal_map"), tangents=tangents_i, - occlusion_in_mr=getattr(mesh, "occlusion_in_mr", False), - material=getattr(mesh, "material", None), + occlusion_in_mr=mesh.occlusion_in_mr, + material=mesh.material, emissive_image=_img("emissive"), ) @@ -810,7 +802,7 @@ class RotateMesh(IO.ComfyNode): else: out.vertices = rotate(mesh.vertices) # Normals are directions; rotate them too (R is orthogonal) so they stay valid. - nrm = getattr(mesh, "normals", None) + nrm = mesh.normals if nrm is not None: out.normals = [rotate(n) for n in nrm] if isinstance(nrm, list) else rotate(nrm) return IO.NodeOutput(out) @@ -860,7 +852,7 @@ class MeshSmoothNormals(IO.ComfyNode): return IO.NodeOutput(out) # Crease split changes per-item vertex counts -> rebuild as a variable-size batch. - tangents_b = getattr(mesh, "tangents", None) + tangents_b = mesh.tangents v_list, f_list, n_list = [], [], [] c_list = [] if mesh.vertex_colors is not None else None u_list = [] if mesh.uvs is not None else None @@ -889,11 +881,11 @@ class MeshSmoothNormals(IO.ComfyNode): return IO.NodeOutput(mesh) out = pack_variable_mesh_batch( v_list, f_list, colors=c_list, uvs=u_list, - texture=mesh.texture, unlit=getattr(mesh, "unlit", False), - normals=n_list, metallic_roughness=getattr(mesh, "metallic_roughness", None), - tangents=t_list, normal_map=getattr(mesh, "normal_map", None), - occlusion_in_mr=getattr(mesh, "occlusion_in_mr", False), - material=getattr(mesh, "material", None), emissive=getattr(mesh, "emissive", None)) + texture=mesh.texture, unlit=mesh.unlit, + normals=n_list, metallic_roughness=mesh.metallic_roughness, + tangents=t_list, normal_map=mesh.normal_map, + occlusion_in_mr=mesh.occlusion_in_mr, + material=mesh.material, emissive=mesh.emissive) return IO.NodeOutput(out) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 2fad6cb81..3a77c475c 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -8,9 +8,7 @@ from server import PromptServer import comfy.latent_formats import comfy.model_management import comfy.utils -from PIL import Image import logging -import numpy as np import math import torch @@ -425,7 +423,6 @@ def _dinov3_encode(model, image_bchw, image_size, want_patches=False): mean = torch.tensor(model.image_mean or [0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1) std = torch.tensor(model.image_std or [0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1) img_t = (img_t - mean) / std - model_internal.image_size = image_size tokens = model_internal(img_t, skip_norm_elementwise=True)[0] if not want_patches: return tokens @@ -434,20 +431,6 @@ def _dinov3_encode(model, image_bchw, image_size, want_patches=False): return {"tokens": tokens[:, :1 + n_reg], "patches_2d": _dinov3_patches_to_2d(tokens, image_size)} -def run_conditioning(model, cropped_pil_img): - device = comfy.model_management.intermediate_device() - - img_np = np.array(cropped_pil_img).astype(np.float32) / 255.0 - image_bchw = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).contiguous() - - cond_512 = _dinov3_encode(model, image_bchw, 512) - cond_1024 = _dinov3_encode(model, image_bchw, 1024) - return { - "cond_512": cond_512.to(device), - "neg_cond": torch.zeros_like(cond_512).to(device), - "cond_1024": cond_1024.to(device), - } - class Trellis2Conditioning(IO.ComfyNode): @classmethod def define_schema(cls): @@ -467,124 +450,10 @@ class Trellis2Conditioning(IO.ComfyNode): @classmethod def execute(cls, clip_vision_model, image, mask) -> IO.NodeOutput: - # Normalize to batched form so per-image conditioning loop below is uniform. - if image.ndim == 3: - image = image.unsqueeze(0) - elif image.ndim == 4: - if image.shape[1] in [1, 3, 4] and image.shape[-1] not in [1, 3, 4]: - image = image.permute(0, 2, 3, 1) - - # normalize mask to standard [B, H, W] (handling 2D, 3D, and 4D variants) - if mask.ndim == 4: - if mask.shape[1] == 1: - mask = mask.squeeze(1) - elif mask.shape[-1] == 1: - mask = mask.squeeze(-1) - else: - mask = mask[:, :, :, 0] # take first channel as fallback - - if mask.ndim == 3: - if mask.shape[-1] == 1: - mask = mask.squeeze(-1).unsqueeze(0) - elif mask.ndim == 2: - mask = mask.unsqueeze(0) - - batch_size = image.shape[0] - if mask.shape[0] == 1 and batch_size > 1: - mask = mask.expand(batch_size, -1, -1) - elif mask.shape[0] != batch_size: - raise ValueError(f"Trellis2Conditioning mask batch {mask.shape[0]} does not match image batch {batch_size}") - - cond_512_list = [] - cond_1024_list = [] - - for b in range(batch_size): - item_image = image[b] - item_mask = mask[b] if mask.size(0) > 1 else mask[0] - - img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) - mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) - - # Ensure img_np is either 2D (grayscale) or 3D (RGB/RGBA) - if img_np.ndim == 3 and img_np.shape[-1] == 1: - img_np = img_np.squeeze(-1) - - mask_np = mask_np.squeeze() - - # detect inverted mask - border_pixels = np.concatenate([ - mask_np[0, :], mask_np[-1, :], mask_np[:, 0], mask_np[:, -1] - ]) - if np.mean(border_pixels) > 127: - mask_np = 255 - mask_np - - mask_np[mask_np < 35] = 0 - - border_shave = 4 - mask_np[:border_shave, :] = 0 - mask_np[-border_shave:, :] = 0 - mask_np[:, :border_shave] = 0 - mask_np[:, -border_shave:] = 0 - - pil_img = Image.fromarray(img_np) - pil_mask = Image.fromarray(mask_np) - - max_size = max(pil_img.size) - scale = min(1.0, 1024 / max_size) - if scale < 1.0: - new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale) - pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS) - pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST) - - rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8) - rgba_np[:, :, :3] = np.array(pil_img.convert("RGB")) - rgba_np[:, :, 3] = np.array(pil_mask) - - alpha = rgba_np[:, :, 3] - bbox_coords = np.argwhere(alpha > 0.8 * 255) - - if len(bbox_coords) > 0: - y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1]) - y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1]) - - center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0 - size = max(y_max - y_min, x_max - x_min) - - crop_x1 = int(center_x - size // 2) - crop_y1 = int(center_y - size // 2) - crop_x2 = int(center_x + size // 2) - crop_y2 = int(center_y + size // 2) - - rgba_pil = Image.fromarray(rgba_np) - cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)) - cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0 - else: - logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.") - cropped_np = rgba_np.astype(np.float32) / 255.0 - - bg_rgb = np.array([0.0, 0.0, 0.0], dtype=np.float32) - - fg = cropped_np[:, :, :3] - alpha_float = cropped_np[:, :, 3:4] - composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float) - - # Keep the image as 4-channel RGBA to force TRELLIS to bypass its internal background remover - rgb_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) - alpha_uint8 = (alpha_float.squeeze(-1) * 255.0).round().clip(0, 255).astype(np.uint8) - - rgba_composite = np.zeros((cropped_np.shape[0], cropped_np.shape[1], 4), dtype=np.uint8) - rgba_composite[:, :, :3] = rgb_uint8 - rgba_composite[:, :, 3] = alpha_uint8 - - cropped_pil = Image.fromarray(rgba_composite, mode="RGBA") - - # Convert to RGB to ensure the CLIP/DINO model receives a 3-channel image - item_conditioning = run_conditioning(clip_vision_model, cropped_pil.convert("RGB")) - cond_512_list.append(item_conditioning["cond_512"]) - cond_1024_list.append(item_conditioning["cond_1024"]) - - cond_512_batched = torch.cat(cond_512_list, dim=0) - cond_1024_batched = torch.cat(cond_1024_list, dim=0) + out_device = comfy.model_management.intermediate_device() + cond = _dino_condition_batch(clip_vision_model, image, mask, out_device, + pad_factor=1.0, mask_threshold=35.0 / 255.0, border_shave=4) + cond_512_batched, cond_1024_batched = cond["global_512"], cond["global_1024"] neg_cond_batched = torch.zeros_like(cond_512_batched) neg_embeds_batched = torch.zeros_like(cond_1024_batched) @@ -781,12 +650,11 @@ def _dinov3_patches_to_2d(tokens, image_size, patch_size=16): return patches.transpose(1, 2).reshape(tokens.shape[0], -1, h_p, w_p).contiguous() -def _crop_image_with_mask(item_image, item_mask, max_image_size=1024): - img = item_image.permute(2, 0, 1).unsqueeze(0).cpu().float() - mask = item_mask.unsqueeze(0).unsqueeze(0).cpu().float() - # Upstream went float→PIL uint8 implicitly; match that to keep composite bit-exact. - img = (img.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0 - mask = (mask.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0 +def _crop_image_with_mask(item_image, item_mask, max_image_size=1024, pad_factor=1.1, + mask_threshold=0.0, border_shave=0): + img = item_image[..., :3] if item_image.shape[-1] >= 3 else item_image[..., :1].repeat(1, 1, 3) + img = img.permute(2, 0, 1).unsqueeze(0).cpu().float().clamp(0, 1) + mask = item_mask.unsqueeze(0).unsqueeze(0).cpu().float().clamp(0, 1) # Detect & correct an inverted mask m2d = mask[0, 0] @@ -794,6 +662,15 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024): if float(border.mean()) > 0.5: mask = 1.0 - mask + if mask_threshold > 0.0: + mask = torch.where(mask < mask_threshold, torch.zeros_like(mask), mask) + if border_shave > 0: + bs = border_shave + mask[..., :bs, :] = 0 + mask[..., -bs:, :] = 0 + mask[..., :, :bs] = 0 + mask[..., :, -bs:] = 0 + H, W = img.shape[-2:] if max(H, W) > max_image_size: scale = max_image_size / max(H, W) @@ -809,14 +686,14 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024): y_min, x_min = fg_pixels.min(dim=0).values.tolist() y_max, x_max = fg_pixels.max(dim=0).values.tolist() center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0 - size = int(max(y_max - y_min, x_max - x_min) * 1.1) + size = int(max(y_max - y_min, x_max - x_min) * pad_factor) half = size // 2 crop_x1 = int(center_x - half) crop_y1 = int(center_y - half) crop_x2 = crop_x1 + 2 * half crop_y2 = crop_y1 + 2 * half else: - logging.warning("Mask for the image is empty. Pixal3D requires a clean foreground mask.") + logging.warning("Mask for the image is empty; a clean foreground mask is required for best quality.") crop_x1, crop_y1, crop_x2, crop_y2 = 0, 0, W, H crop_bbox = (crop_x1, crop_y1, crop_x2, crop_y2) @@ -836,9 +713,77 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024): cropped_mask = mask[..., crop_y1:crop_y2, crop_x1:crop_x2] composite = (cropped_img * cropped_mask).clamp(0, 1) - composite = (composite * 255.0).round().clamp(0, 255).to(torch.uint8).float() / 255.0 return composite, crop_bbox, scene_size + +def _dino_condition_batch(clip_vision_model, image, mask, out_device, *, + pad_factor, mask_threshold=0.0, border_shave=0, want_patches=False): + """Normalize image/mask to a batch, then per item: masked square crop + DINOv3 encode at + 512 and 1024. Returns batched global tokens; with want_patches also the 2D patch grids and + the per-item composites / crop bboxes / scene sizes that the Pixal3D NAF+projection path needs.""" + # Normalize to batched form so the per-image loop is uniform. + if image.ndim == 3: + image = image.unsqueeze(0) + elif image.ndim == 4: + if image.shape[1] in [1, 3, 4] and image.shape[-1] not in [1, 3, 4]: + image = image.permute(0, 2, 3, 1) + + if mask.ndim == 4: + if mask.shape[1] == 1: + mask = mask.squeeze(1) + elif mask.shape[-1] == 1: + mask = mask.squeeze(-1) + else: + mask = mask[:, :, :, 0] # take first channel as fallback + if mask.ndim == 3: + if mask.shape[-1] == 1: + mask = mask.squeeze(-1).unsqueeze(0) + elif mask.ndim == 2: + mask = mask.unsqueeze(0) + + batch_size = image.shape[0] + if mask.shape[0] == 1 and batch_size > 1: + mask = mask.expand(batch_size, -1, -1) + elif mask.shape[0] != batch_size: + raise ValueError(f"Conditioning mask batch {mask.shape[0]} does not match image batch {batch_size}") + + cond_512_list, cond_1024_list = [], [] + patches_512_list, patches_1024_list = [], [] + composite_list, crop_bbox_list, scene_size_list = [], [], [] + for b in range(batch_size): + item_image = image[b] + item_mask = mask[b] if mask.size(0) > 1 else mask[0] + composite, crop_bbox, scene_size = _crop_image_with_mask( + item_image, item_mask, max_image_size=1024, pad_factor=pad_factor, + mask_threshold=mask_threshold, border_shave=border_shave) + c512 = _dinov3_encode(clip_vision_model, composite, 512, want_patches=want_patches) + c1024 = _dinov3_encode(clip_vision_model, composite, 1024, want_patches=want_patches) + if want_patches: + cond_512_list.append(c512["tokens"].to(out_device)) + cond_1024_list.append(c1024["tokens"].to(out_device)) + patches_512_list.append(c512["patches_2d"].to(out_device)) + patches_1024_list.append(c1024["patches_2d"].to(out_device)) + composite_list.append(composite) + crop_bbox_list.append(crop_bbox) + scene_size_list.append(scene_size) + else: + cond_512_list.append(c512.to(out_device)) + cond_1024_list.append(c1024.to(out_device)) + + out = { + "batch_size": batch_size, + "global_512": torch.cat(cond_512_list, dim=0), + "global_1024": torch.cat(cond_1024_list, dim=0), + } + if want_patches: + out["patches_512"] = torch.cat(patches_512_list, dim=0) + out["patches_1024"] = torch.cat(patches_1024_list, dim=0) + out["composites"] = composite_list + out["crop_bboxes"] = crop_bbox_list + out["scene_sizes"] = scene_size_list + return out + + class Pixal3DConditioning(IO.ComfyNode): @classmethod @@ -867,45 +812,15 @@ class Pixal3DConditioning(IO.ComfyNode): @classmethod def execute(cls, clip_vision_model, image, mask, camera_angle_x) -> IO.NodeOutput: naf_model = clip_vision_model.naf - if image.ndim == 3: - image = image.unsqueeze(0) - if mask.ndim == 2: - mask = mask.unsqueeze(0) - batch_size = image.shape[0] - if mask.shape[0] == 1 and batch_size > 1: - mask = mask.expand(batch_size, -1, -1) - elif mask.shape[0] != batch_size: - raise ValueError(f"Pixal3DConditioning mask batch {mask.shape[0]} != image batch {batch_size}") + out_device = comfy.model_management.intermediate_device() + compute_device = comfy.model_management.get_torch_device() - device = comfy.model_management.intermediate_device() - - cond_512_list, cond_1024_list = [], [] - patches_512_list, patches_1024_list = [], [] - composite_list = [] - crop_bbox_list, scene_size_list = [], [] - - torch_device = comfy.model_management.get_torch_device() - for b in range(batch_size): - item_image = image[b] - item_mask = mask[b] if mask.size(0) > 1 else mask[0] - composite, crop_bbox, scene_size = _crop_image_with_mask( - item_image, item_mask, max_image_size=1024) - crop_bbox_list.append(crop_bbox) - scene_size_list.append(scene_size) - composite_list.append(composite) - - cond_512 = _dinov3_encode(clip_vision_model, composite, 512, want_patches=True) - cond_1024 = _dinov3_encode(clip_vision_model, composite, 1024, want_patches=True) - cond_512_list.append(cond_512["tokens"].to(device)) - cond_1024_list.append(cond_1024["tokens"].to(device)) - patches_512_list.append(cond_512["patches_2d"].to(device)) - patches_1024_list.append(cond_1024["patches_2d"].to(device)) - - global_512 = torch.cat(cond_512_list, dim=0) - global_1024 = torch.cat(cond_1024_list, dim=0) - - fm_512_dino = torch.cat(patches_512_list, dim=0) - fm_1024_dino = torch.cat(patches_1024_list, dim=0) + cond = _dino_condition_batch(clip_vision_model, image, mask, out_device, pad_factor=1.1, want_patches=True) + batch_size = cond["batch_size"] + global_512, global_1024 = cond["global_512"], cond["global_1024"] + fm_512_dino, fm_1024_dino = cond["patches_512"], cond["patches_1024"] + composite_list = cond["composites"] + crop_bbox_list, scene_size_list = cond["crop_bboxes"], cond["scene_sizes"] # The LR DINO grid AND the NAF HR grid are sampled separately # NAF targets per stage: shape_512=512, shape_1024=512, tex_1024=1024. @@ -914,15 +829,13 @@ class Pixal3DConditioning(IO.ComfyNode): return None comfy.model_management.load_model_gpu(naf_model) inner = naf_model.model - target_dtype = comfy.model_management.text_encoder_dtype(torch_device) - if next(inner.parameters()).dtype != target_dtype: - inner.to(dtype=target_dtype) + model_dtype = next(inner.parameters()).dtype # set at load time (see clip_vision NAF) hrs = [] for i, c in enumerate(composites): img_i = comfy.utils.common_upscale(c, image_size, image_size, "lanczos", "disabled")\ - .to(torch_device).to(target_dtype) - lr_i = lr_feat[i:i + 1].to(torch_device).to(target_dtype) - hr_i = inner(img_i, lr_i, naf_target, output_device=device) + .to(compute_device).to(model_dtype) + lr_i = lr_feat[i:i + 1].to(compute_device).to(model_dtype) + hr_i = inner(img_i, lr_i, naf_target, output_device=out_device) hrs.append(hr_i) return torch.cat(hrs, dim=0) @@ -934,10 +847,10 @@ class Pixal3DConditioning(IO.ComfyNode): # FOV widget is in degrees for UX; trig + downstream projection expect radians. camera_angle_x = math.radians(float(camera_angle_x)) distance = 0.5 / math.tan(camera_angle_x / 2.0) - cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=device, dtype=torch.float32) - dist_t = torch.tensor([distance] * batch_size, device=device, dtype=torch.float32) - scale_t = torch.ones(batch_size, device=device, dtype=torch.float32) - T = build_proj_transform_matrix(dist_t, batch_size, device=device, dtype=torch.float32) + cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=out_device, dtype=torch.float32) + dist_t = torch.tensor([distance] * batch_size, device=out_device, dtype=torch.float32) + scale_t = torch.ones(batch_size, device=out_device, dtype=torch.float32) + T = build_proj_transform_matrix(dist_t, batch_size, device=out_device, dtype=torch.float32) proj_pack = { "stages": { @@ -958,7 +871,7 @@ class Pixal3DConditioning(IO.ComfyNode): # global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024. ss_proj_feats = compute_stage_proj_feats( proj_pack, "ss", dense_grid_resolution=16, batch_size=batch_size, - device=torch_device, + device=compute_device, ) neg_global = torch.zeros_like(global_512) neg_embeds = torch.zeros_like(global_1024)