From 44adb27782ea1df23ea43ca44fde808ac8b893d2 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 10 Mar 2026 22:27:10 +0200 Subject: [PATCH] working version --- comfy/ldm/trellis2/attention.py | 2 +- comfy/ldm/trellis2/model.py | 4 +- comfy_extras/nodes_trellis2.py | 430 +++++++++++++++++--------------- 3 files changed, 232 insertions(+), 204 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 0b9c12294..681666430 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -46,7 +46,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) if var_length: - return out.contiguous().transpose(1, 2).values() + return out.transpose(1, 2).values() if not skip_output_reshape: out = ( out.transpose(1, 2).reshape(b, -1, heads * dim_head) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index bd8309f2b..4bbfbff5f 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -767,8 +767,6 @@ class Trellis2(nn.Module): if embeds is None: raise ValueError("Trellis2.forward requires 'embeds' in kwargs") is_1024 = self.img2shape.resolution == 1024 - if is_1024: - context = embeds coords = transformer_options.get("coords", None) mode = transformer_options.get("generation_mode", "structure_generation") if coords is not None: @@ -777,6 +775,8 @@ class Trellis2(nn.Module): else: mode = "structure_generation" not_struct_mode = False + if is_1024 and mode == "shape_generation": + context = embeds sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: timestep *= 1000.0 diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 739233523..23b2f72bb 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,9 +1,8 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy.ldm.trellis2.vae import SparseTensor -from comfy.utils import ProgressBar, lanczos -import torch.nn.functional as TF import comfy.model_management +import logging from PIL import Image import numpy as np import torch @@ -39,93 +38,6 @@ tex_slat_normalization = { ])[None] } -dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) -dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) - -def smart_crop_square(image, mask, margin_ratio=0.1, bg_color=(128, 128, 128)): - nz = torch.nonzero(mask[0] > 0.5) - if nz.shape[0] == 0: - C, H, W = image.shape - side = max(H, W) - canvas = torch.full((C, side, side), 0.5, device=image.device) # Gray - canvas[:, (side-H)//2:(side-H)//2+H, (side-W)//2:(side-W)//2+W] = image - return canvas - - y_min, x_min = nz.min(dim=0)[0] - y_max, x_max = nz.max(dim=0)[0] - - obj_w, obj_h = x_max - x_min, y_max - y_min - center_x, center_y = (x_min + x_max) / 2, (y_min + y_max) / 2 - - side = int(max(obj_w, obj_h) * (1 + margin_ratio * 2)) - half_side = side / 2 - - x1, y1 = int(center_x - half_side), int(center_y - half_side) - x2, y2 = x1 + side, y1 + side - - C, H, W = image.shape - canvas = torch.ones((C, side, side), device=image.device) - for c in range(C): - canvas[c] *= (bg_color[c] / 255.0) - - src_x1, src_y1 = max(0, x1), max(0, y1) - src_x2, src_y2 = min(W, x2), min(H, y2) - - dst_x1, dst_y1 = max(0, -x1), max(0, -y1) - dst_x2 = dst_x1 + (src_x2 - src_x1) - dst_y2 = dst_y1 + (src_y2 - src_y1) - - img_crop = image[:, src_y1:src_y2, src_x1:src_x2] - mask_crop = mask[0, src_y1:src_y2, src_x1:src_x2] - - bg_val = torch.tensor(bg_color, device=image.device, dtype=image.dtype).view(-1, 1, 1) / 255.0 - - masked_crop = img_crop * mask_crop + bg_val * (1.0 - mask_crop) - - canvas[:, dst_y1:dst_y2, dst_x1:dst_x2] = masked_crop - - return canvas - -def run_conditioning(model, image, mask, include_1024 = True, background_color = "black"): - model_internal = model.model - device = comfy.model_management.intermediate_device() - torch_device = comfy.model_management.get_torch_device() - - bg_colors = {"black": (0, 0, 0), "gray": (128, 128, 128), "white": (255, 255, 255)} - bg_rgb = bg_colors.get(background_color, (128, 128, 128)) - - img_t = image[0].movedim(-1, 0).to(torch_device).float() - mask_t = mask[0].to(torch_device).float() - if mask_t.ndim == 2: - mask_t = mask_t.unsqueeze(0) - - cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb) - - def prepare_tensor(img, size): - resized = lanczos(img.unsqueeze(0), size, size) - return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device) - - model_internal.image_size = 512 - input_512 = prepare_tensor(cropped_img, 512) - cond_512 = model_internal(input_512)[0] - - cond_1024 = None - if include_1024: - model_internal.image_size = 1024 - input_1024 = prepare_tensor(cropped_img, 1024) - cond_1024 = model_internal(input_1024)[0] - - conditioning = { - 'cond_512': cond_512.to(device), - 'neg_cond': torch.zeros_like(cond_512).to(device), - } - if cond_1024 is not None: - conditioning['cond_1024'] = cond_1024.to(device) - - preprocessed_tensor = cropped_img.movedim(0, -1).unsqueeze(0).cpu() - - return conditioning, preprocessed_tensor - class VaeDecodeShapeTrellis(IO.ComfyNode): @classmethod def define_schema(cls): @@ -245,6 +157,39 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): out = Types.VOXEL(decoded.squeeze(1).float()) return IO.NodeOutput(out) +dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) +dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + +def run_conditioning(model, cropped_img_tensor, include_1024=True): + model_internal = model.model + device = comfy.model_management.intermediate_device() + torch_device = comfy.model_management.get_torch_device() + + img_t = cropped_img_tensor.to(torch_device) + + def prepare_tensor(img, size): + resized = torch.nn.functional.interpolate(img, size=(size, size), mode='bicubic', align_corners=False).clamp(0.0, 1.0) + return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device) + + model_internal.image_size = 512 + input_512 = prepare_tensor(img_t, 512) + cond_512 = model_internal(input_512)[0] + + cond_1024 = None + if include_1024: + model_internal.image_size = 1024 + input_1024 = prepare_tensor(img_t, 1024) + cond_1024 = model_internal(input_1024)[0] + + conditioning = { + 'cond_512': cond_512.to(device), + 'neg_cond': torch.zeros_like(cond_512).to(device), + } + if cond_1024 is not None: + conditioning['cond_1024'] = cond_1024.to(device) + + return conditioning + class Trellis2Conditioning(IO.ComfyNode): @classmethod def define_schema(cls): @@ -268,22 +213,60 @@ class Trellis2Conditioning(IO.ComfyNode): if image.ndim == 4: image = image[0] + if mask.ndim == 3: + mask = mask[0] - # TODO - image = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) - image = Image.fromarray(image) - max_size = max(image.size) - scale = min(1, 1024 / max_size) - if scale < 1: - image = image.resize((int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS) - new_h, new_w = int(mask.shape[-2] * scale), int(mask.shape[-1] * scale) - mask = TF.interpolate(mask.unsqueeze(0).float(), size=(new_h, new_w), mode='nearest').squeeze(0) + img_np = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + mask_np = (mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) - image = torch.tensor(np.array(image)).unsqueeze(0).float() / 255 + pil_img = Image.fromarray(img_np) + pil_mask = Image.fromarray(mask_np) - # could make 1024 an option - conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color) - embeds = conditioning["cond_1024"] # should add that + max_size = max(pil_img.size) + scale = min(1.0, 1024 / max_size) + if scale < 1.0: + new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale) + pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS) + pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST) + + rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8) + rgba_np[:, :, :3] = np.array(pil_img) + rgba_np[:, :, 3] = np.array(pil_mask) + + alpha = rgba_np[:, :, 3] + bbox_coords = np.argwhere(alpha > 0.8 * 255) + + if len(bbox_coords) > 0: + y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1]) + y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1]) + + center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0 + size = max(y_max - y_min, x_max - x_min) + + crop_x1 = int(center_x - size // 2) + crop_y1 = int(center_y - size // 2) + crop_x2 = int(center_x + size // 2) + crop_y2 = int(center_y + size // 2) + + rgba_pil = Image.fromarray(rgba_np, 'RGBA') + cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)) + cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0 + else: + logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.") + cropped_np = rgba_np.astype(np.float32) / 255.0 + + bg_colors = {"black": [0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]} + bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32) + + fg = cropped_np[:, :, :3] + alpha_float = cropped_np[:, :, 3:4] + composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float) + + cropped_img_tensor = torch.from_numpy(composite_np).movedim(-1, 0).unsqueeze(0).float() + + conditioning = run_conditioning(clip_vision_model, cropped_img_tensor, include_1024=True) + + embeds = conditioning["cond_1024"] positive = [[conditioning["cond_512"], {"embeds": embeds}]] negative = [[conditioning["neg_cond"], {"embeds": torch.zeros_like(embeds)}]] return IO.NodeOutput(positive, negative) @@ -417,118 +400,168 @@ def simplify_fn(vertices, faces, target=100000): return final_vertices, final_faces -def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2): +def fill_holes_fn(vertices, faces, max_perimeter=0.03): is_batched = vertices.ndim == 3 if is_batched: - batch_size = vertices.shape[0] - if batch_size > 1: - v_out, f_out = [], [] - for i in range(batch_size): - v, f = fill_holes_fn(vertices[i], faces[i], max_hole_perimeter) - v_out.append(v) - f_out.append(f) - return torch.stack(v_out), torch.stack(f_out) - - vertices = vertices.squeeze(0) - faces = faces.squeeze(0) + v_list, f_list = [],[] + for i in range(vertices.shape[0]): + v_i, f_i = fill_holes_fn(vertices[i], faces[i], max_perimeter) + v_list.append(v_i) + f_list.append(f_i) + return torch.stack(v_list), torch.stack(f_list) device = vertices.device - orig_vertices = vertices - orig_faces = faces + v = vertices + f = faces - edges = torch.cat([ - faces[:, [0, 1]], - faces[:, [1, 2]], - faces[:, [2, 0]] - ], dim=0) + if f.shape[0] == 0: + return v, f + edges = torch.cat([f[:, [0, 1]], f[:, [1, 2]], f[:, [2, 0]]], dim=0) edges_sorted, _ = torch.sort(edges, dim=1) - unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True) + + max_v = v.shape[0] + packed_undirected = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long() + + unique_packed, counts = torch.unique(packed_undirected, return_counts=True) boundary_mask = counts == 1 - boundary_edges_sorted = unique_edges[boundary_mask] + boundary_packed = unique_packed[boundary_mask] - if boundary_edges_sorted.shape[0] == 0: - if is_batched: - return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0) - return orig_vertices, orig_faces + if boundary_packed.numel() == 0: + return v, f - max_idx = vertices.shape[0] - - packed_edges_all = torch.sort(edges, dim=1).values - packed_edges_all = packed_edges_all[:, 0] * max_idx + packed_edges_all[:, 1] - - packed_boundary = boundary_edges_sorted[:, 0] * max_idx + boundary_edges_sorted[:, 1] - - is_boundary_edge = torch.isin(packed_edges_all, packed_boundary) - active_boundary_edges = edges[is_boundary_edge] + packed_directed_sorted = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long() + is_boundary = torch.isin(packed_directed_sorted, boundary_packed) + boundary_edges_directed = edges[is_boundary] adj = {} - edges_np = active_boundary_edges.cpu().numpy() - for u, v in edges_np: - adj[u] = v + in_deg = {} + out_deg = {} + + edges_list = boundary_edges_directed.tolist() + for u, v_idx in edges_list: + if u not in adj: adj[u] = [] + adj[u].append(v_idx) + out_deg[u] = out_deg.get(u, 0) + 1 + in_deg[v_idx] = in_deg.get(v_idx, 0) + 1 + + manifold_nodes = set() + for node in set(list(in_deg.keys()) + list(out_deg.keys())): + if in_deg.get(node, 0) == 1 and out_deg.get(node, 0) == 1: + manifold_nodes.add(node) + + loops =[] + visited_nodes = set() - loops = [] - visited_edges = set() - processed_nodes = set() for start_node in list(adj.keys()): - if start_node in processed_nodes: + if start_node not in manifold_nodes or start_node in visited_nodes: continue - current_loop, curr = [], start_node - while curr in adj: - next_node = adj[curr] - if (curr, next_node) in visited_edges: - break - visited_edges.add((curr, next_node)) - processed_nodes.add(curr) + + curr = start_node + current_loop =[] + + while True: current_loop.append(curr) + visited_nodes.add(curr) + + next_node = adj[curr][0] + + if next_node == start_node: + if len(current_loop) >= 3: + loops.append(current_loop) + break + + if next_node not in manifold_nodes or next_node in visited_nodes: + break + curr = next_node - if curr == start_node: - loops.append(current_loop) - break - if len(current_loop) > len(edges_np): + + if len(current_loop) > len(edges_list): break - if not loops: - if is_batched: - return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0) - return orig_vertices, orig_faces + new_faces =[] + new_verts = [] + curr_v_idx = v.shape[0] - new_faces = [] - v_offset = vertices.shape[0] - valid_new_verts = [] + for loop in loops: + loop_indices = torch.tensor(loop, device=device, dtype=torch.long) + loop_points = v[loop_indices] - for loop_indices in loops: - if len(loop_indices) < 3: - continue - loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device) - loop_verts = vertices[loop_tensor] - diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0) - perimeter = torch.norm(diffs, dim=1).sum() + # Calculate perimeter + p1 = loop_points + p2 = torch.roll(loop_points, shifts=-1, dims=0) + perimeter = torch.norm(p1 - p2, dim=1).sum().item() - if perimeter > max_hole_perimeter: - continue + if perimeter <= max_perimeter: + centroid = loop_points.mean(dim=0) + new_verts.append(centroid) + center_idx = curr_v_idx + curr_v_idx += 1 - center = loop_verts.mean(dim=0) - valid_new_verts.append(center) - c_idx = v_offset - v_offset += 1 + for i in range(len(loop)): + u_idx = loop[i] + v_next_idx = loop[(i + 1) % len(loop)] + new_faces.append([u_idx, v_next_idx, center_idx]) - num_v = len(loop_indices) - for i in range(num_v): - v_curr, v_next = loop_indices[i], loop_indices[(i + 1) % num_v] - new_faces.append([v_curr, v_next, c_idx]) + if new_faces: + v = torch.cat([v, torch.stack(new_verts)], dim=0) + f = torch.cat([f, torch.tensor(new_faces, device=device, dtype=torch.long)], dim=0) - if len(valid_new_verts) > 0: - added_vertices = torch.stack(valid_new_verts, dim=0) - added_faces = torch.tensor(new_faces, dtype=torch.long, device=device) - vertices = torch.cat([orig_vertices, added_vertices], dim=0) - faces = torch.cat([orig_faces, added_faces], dim=0) - else: - vertices, faces = orig_vertices, orig_faces + return v, f +def merge_duplicate_vertices(vertices, faces, tolerance=1e-5): + is_batched = vertices.ndim == 3 if is_batched: - return vertices.unsqueeze(0), faces.unsqueeze(0) + v_list, f_list = [],[] + for i in range(vertices.shape[0]): + v_i, f_i = merge_duplicate_vertices(vertices[i], faces[i], tolerance) + v_list.append(v_i) + f_list.append(f_i) + return torch.stack(v_list), torch.stack(f_list) + v_min = vertices.min(dim=0, keepdim=True)[0] + v_quant = ((vertices - v_min) / tolerance).round().long() + + unique_quant, inverse_indices = torch.unique(v_quant, dim=0, return_inverse=True) + + new_vertices = torch.zeros((unique_quant.shape[0], 3), dtype=vertices.dtype, device=vertices.device) + new_vertices.index_copy_(0, inverse_indices, vertices) + + new_faces = inverse_indices[faces.long()] + + valid = (new_faces[:, 0] != new_faces[:, 1]) & \ + (new_faces[:, 1] != new_faces[:, 2]) & \ + (new_faces[:, 2] != new_faces[:, 0]) + + return new_vertices, new_faces[valid] + +def fix_normals(vertices, faces): + is_batched = vertices.ndim == 3 + if is_batched: + v_list, f_list = [], [] + for i in range(vertices.shape[0]): + v_i, f_i = fix_normals(vertices[i], faces[i]) + v_list.append(v_i) + f_list.append(f_i) + return torch.stack(v_list), torch.stack(f_list) + + if faces.shape[0] == 0: + return vertices, faces + + center = vertices.mean(0) + v0 = vertices[faces[:, 0].long()] + v1 = vertices[faces[:, 1].long()] + v2 = vertices[faces[:, 2].long()] + + normals = torch.cross(v1 - v0, v2 - v0, dim=1) + + face_centers = (v0 + v1 + v2) / 3.0 + dir_from_center = face_centers - center + + dot = (normals * dir_from_center).sum(1) + flip_mask = dot < 0 + + faces[flip_mask] = faces[flip_mask][:, [0, 2, 1]] return vertices, faces class PostProcessMesh(IO.ComfyNode): @@ -539,36 +572,31 @@ class PostProcessMesh(IO.ComfyNode): category="latent/3d", inputs=[ IO.Mesh.Input("mesh"), - IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000), # max? - IO.Float.Input("fill_holes_perimeter", default=0.003, min=0.0, step=0.0001) + IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000), + IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0, step=0.0001) ], outputs=[ IO.Mesh.Output("output_mesh"), ] ) + @classmethod def execute(cls, mesh, simplify, fill_holes_perimeter): - bar = ProgressBar(2) mesh = copy.deepcopy(mesh) verts, faces = mesh.vertices, mesh.faces - if fill_holes_perimeter != 0.0: - verts, faces = fill_holes_fn(verts, faces, max_hole_perimeter=fill_holes_perimeter) - bar.update(1) - else: - bar.update(1) + verts, faces = merge_duplicate_vertices(verts, faces, tolerance=1e-5) - if simplify != 0: - verts, faces = simplify_fn(verts, faces, simplify) - bar.update(1) - else: - bar.update(1) + if fill_holes_perimeter > 0: + verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter) - # potentially adding laplacian smoothing + if simplify > 0 and faces.shape[0] > simplify: + verts, faces = simplify_fn(verts, faces, target=simplify) + + verts, faces = fix_normals(verts, faces) mesh.vertices = verts mesh.faces = faces - return IO.NodeOutput(mesh) class Trellis2Extension(ComfyExtension):