From 2cced8971cefc882bf099e085fd8eeb6017ea960 Mon Sep 17 00:00:00 2001 From: kijai Date: Wed, 1 Jul 2026 00:58:24 +0300 Subject: [PATCH] Remove Pixal3DAlignObject, out of scope --- comfy_extras/nodes_moge.py | 74 +---------------- comfy_extras/nodes_trellis2.py | 147 +-------------------------------- 2 files changed, 2 insertions(+), 219 deletions(-) diff --git a/comfy_extras/nodes_moge.py b/comfy_extras/nodes_moge.py index d7b56e527..819421534 100644 --- a/comfy_extras/nodes_moge.py +++ b/comfy_extras/nodes_moge.py @@ -439,82 +439,10 @@ class MoGeGeometryToFOV(io.ComfyNode): return io.NodeOutput(fov, focal_pixels) -class MoGeMaskOut(io.ComfyNode): - """Mark masked pixels as invalid in a MoGe geometry. MoGePointMapToMesh's - finite-check then drops them during triangulation, so the scene mesh has a - hole where the object lives. Use to cut out an object's footprint from a - MoGe scene mesh before merging in a Pixal3DAlignObject output.""" - - @classmethod - def define_schema(cls): - return io.Schema( - node_id="MoGeMaskOut", - display_name="MoGe Mask Out", - search_aliases=["moge", "mask", "exclude", "cut out"], - category="image/geometry_estimation", - description=( - "Set the masked pixels to invalid in a MoGe geometry. Downstream " - "MoGePointMapToMesh skips invalid pixels, so the resulting scene " - "mesh has a hole there. Use to cut an object out of a MoGe scene " - "before merging with a Pixal3D reconstruction of that object." - ), - inputs=[ - MoGeGeometry.Input("moge_geometry"), - io.Mask.Input("mask", - tooltip="Pixels to exclude (1 = drop, 0 = keep). " - "Auto-resized to MoGe's image resolution via nearest sampling."), - io.Int.Input("batch_index", default=0, min=0, max=4096, - tooltip="Which image of a batched MoGe geometry to mask."), - io.Int.Input("dilate", default=0, min=0, max=128, - tooltip="Optional dilation (max-pool radius) applied to the mask before " - "masking, in MoGe-resolution pixels. Use to overshoot the object " - "silhouette and avoid leftover halos around the Pixal3D mesh."), - ], - outputs=[MoGeGeometry.Output("moge_geometry")], - ) - - @classmethod - def execute(cls, moge_geometry, mask, batch_index, dilate) -> io.NodeOutput: - if "points" not in moge_geometry: - raise ValueError("MoGeMaskOut: moge_geometry has no points; nothing to mask.") - out = dict(moge_geometry) - points = out["points"].clone() - scene_H, scene_W = points.shape[1], points.shape[2] - - m = mask[batch_index] if mask.ndim == 3 else mask - m = m.float() - if m.shape != (scene_H, scene_W): - m = torch.nn.functional.interpolate( - m[None, None], size=(scene_H, scene_W), mode="nearest" - ).squeeze(0).squeeze(0) - if dilate > 0: - k = 2 * dilate + 1 - m = torch.nn.functional.max_pool2d( - m[None, None], kernel_size=k, stride=1, padding=dilate - ).squeeze(0).squeeze(0) - drop = (m > 0.5).to(points.device) - - # inf in points makes triangulate_grid_mesh's finite check exclude the pixel. - points[batch_index][drop] = float("inf") - out["points"] = points - - if "mask" in out and out["mask"] is not None: - mask_field = out["mask"].clone() - mask_field[batch_index] = mask_field[batch_index] & ~drop.to(mask_field.device) - out["mask"] = mask_field - - if "depth" in out and out["depth"] is not None: - depth = out["depth"].clone() - depth[batch_index][drop.to(depth.device)] = float("inf") - out["depth"] = depth - - return io.NodeOutput(out) - - class MoGeExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: - return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh, MoGeGeometryToFOV, MoGeMaskOut] + return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh, MoGeGeometryToFOV] async def comfy_entrypoint() -> MoGeExtension: diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index fe5ab8e88..9e496e4da 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,7 +1,7 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types, UI, io from comfy.ldm.trellis2.vae import SparseTensor -from comfy.ldm.trellis2.model import build_proj_transform_matrix, _project_points_to_image, compute_stage_proj_feats +from comfy.ldm.trellis2.model import build_proj_transform_matrix, compute_stage_proj_feats from comfy.ldm.trellis2.naf.model import NAF from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch @@ -987,150 +987,6 @@ class Pixal3DConditioning(IO.ComfyNode): return IO.NodeOutput(positive, negative) -def _project_vertices_to_image_uv(vertices_world, transform_matrix, camera_angle_x, image_resolution): - points = vertices_world.unsqueeze(0).float() - T = transform_matrix.unsqueeze(0).float() if transform_matrix.ndim == 2 else transform_matrix.float() - cam = camera_angle_x.unsqueeze(0) if camera_angle_x.ndim == 0 else camera_angle_x - T = T.to(points.device) - cam = cam.to(points.device) - uv_pix, depth, valid = _project_points_to_image(points, T, cam.float(), image_resolution) - uv = uv_pix.squeeze(0) / image_resolution - return uv, depth.squeeze(0), valid.squeeze(0) - - -def _crop_uv_to_scene_pixels(uv_crop, crop_bbox, scene_image_size): - crop_x1, crop_y1, crop_x2, crop_y2 = crop_bbox - crop_w = max(1, crop_x2 - crop_x1) - crop_h = max(1, crop_y2 - crop_y1) - px = uv_crop[:, 0] * crop_w + crop_x1 - py = uv_crop[:, 1] * crop_h + crop_y1 - W, H = scene_image_size - return torch.stack([px.clamp(0, W - 1), py.clamp(0, H - 1)], dim=-1) - - -class Pixal3DAlignObject(IO.ComfyNode): - """Pixal3D paper §3.3 Global Alignment for a single object. - - Solves (scale, translation) aligning the mesh to MoGe's per-pixel point map. Requires - MoGe to have been computed on the same resized scene image as Pixal3DConditioning.""" - - @classmethod - def define_schema(cls): - return IO.Schema( - node_id="Pixal3DAlignObject", - category="latent/3d", - inputs=[ - IO.Mesh.Input("mesh"), - IO.Conditioning.Input("positive", tooltip="The positive conditioning from Pixal3DConditioning for this object — Pixal3DAlignObject reads transform_matrix / camera_angle_x / mesh_scale / crop_bboxes out of its proj_feat_pack."), - io.Custom("MOGE_GEOMETRY").Input("moge_geometry", tooltip="MoGe geometry computed on the original scene image."), - IO.Mask.Input( - "object_mask", - optional=True, - tooltip="Optional per-object scene-space mask. If connected, only vertices whose projected pixel falls inside the mask contribute to the alignment solve.", - ), - IO.Int.Input( - "batch_index", - default=0, min=0, max=1024, - tooltip="Which batch slot of the proj_feat_pack/MoGe geometry corresponds to this object.", - ), - ], - outputs=[ - IO.Mesh.Output("aligned_mesh"), - IO.Float.Output(display_name="scale"), - ], - ) - - @classmethod - def execute(cls, mesh, positive, moge_geometry, object_mask=None, batch_index=0) -> IO.NodeOutput: - proj_feat_pack = _proj_pack_from_conditioning(positive) - if proj_feat_pack is None: - raise ValueError("Pixal3DAlignObject: positive conditioning has no proj_feat_pack — connect a Pixal3DConditioning output.") - vertices = mesh.vertices - faces = mesh.faces - if vertices.ndim == 3: - vertices_one = vertices[0] - faces_one = faces[0] - else: - vertices_one = vertices - faces_one = faces - - T = proj_feat_pack["transform_matrix"][batch_index:batch_index + 1] - cam_angle = proj_feat_pack["camera_angle_x"][batch_index:batch_index + 1] - mesh_scale = proj_feat_pack["mesh_scale"][batch_index] - image_resolution = int(proj_feat_pack.get("image_resolution", 1024)) - crop_bbox = proj_feat_pack["crop_bboxes"][batch_index] - pack_scene_size = proj_feat_pack.get("scene_sizes", [None] * (batch_index + 1))[batch_index] - moge_points = moge_geometry["points"] - moge_mask = moge_geometry["mask"] - if moge_points.ndim != 4: - raise ValueError(f"MoGe points expected [B, H, W, 3]; got {tuple(moge_points.shape)}") - scene_H, scene_W = moge_points.shape[1], moge_points.shape[2] - if pack_scene_size is not None and pack_scene_size != (scene_W, scene_H): - raise ValueError( - f"Pixal3DAlignObject: MoGe geometry was computed on a {scene_W}x{scene_H} image, " - f"but the proj_feat_pack's bbox lives in a {pack_scene_size[0]}x{pack_scene_size[1]} " - "image. Run MoGe on the same resized scene image Pixal3DConditioning used." - ) - - # Vertices come out of VaeDecodeShapeTrellis in the Pixal3D model frame - # (no un-rotation). Apply _PROJ_GRID_ROTATION = R_x(-90°) to map model - # frame → ProjGrid world: (X, Y, Z) -> (X, -Z, Y). - v = vertices_one.float() - verts_world = torch.stack([v[..., 0], -v[..., 2], v[..., 1]], dim=-1) - verts_world = verts_world / float(mesh_scale.item()) - uv_crop, _depth, valid = _project_vertices_to_image_uv( - verts_world, T[0], cam_angle[0], image_resolution) - scene_pixels = _crop_uv_to_scene_pixels(uv_crop, crop_bbox, (scene_W, scene_H)) - in_scene = ((scene_pixels[:, 0] >= 0) & (scene_pixels[:, 0] < scene_W) & - (scene_pixels[:, 1] >= 0) & (scene_pixels[:, 1] < scene_H)) - # MoGe geometry and object_mask can land on CPU after passing between nodes; - # match the indexed tensor's device for sy/sx so the gather works on either. - moge_points = moge_points.to(scene_pixels.device) - moge_mask = moge_mask.to(scene_pixels.device) - sx = scene_pixels[:, 0].long().clamp(0, scene_W - 1) - sy = scene_pixels[:, 1].long().clamp(0, scene_H - 1) - moge_per_vertex = moge_points[batch_index, sy, sx] - # MoGe's perspective output is (X right, Y down, Z forward). Convert to glTF - # Y-up (X right, Y up, Z back) so the scale/translate fit runs in the same - # frame as vertices_one (Pixal3D model frame = glTF Y-up). - moge_per_vertex = moge_per_vertex * torch.tensor( - [1.0, -1.0, -1.0], dtype=moge_per_vertex.dtype, device=moge_per_vertex.device - ) - moge_mask_per_vertex = moge_mask[batch_index, sy, sx] - keep = valid & in_scene & moge_mask_per_vertex - if object_mask is not None: - om = object_mask if object_mask.ndim == 2 else object_mask[batch_index] - om = om.to(sy.device) - keep = keep & (om[sy, sx] > 0.5) - - finite = torch.isfinite(moge_per_vertex).all(dim=-1) - keep = keep & finite - - kept = int(keep.sum().item()) - if kept < 8: - scale = 1.0 - aligned = vertices_one - else: - P = vertices_one[keep].float() - Q = moge_per_vertex[keep].float() - p_mean = P.mean(dim=0, keepdim=True) - q_mean = Q.mean(dim=0, keepdim=True) - P_c = P - p_mean - Q_c = Q - q_mean - p_var = (P_c * P_c).sum().clamp(min=1e-8) - q_var = (Q_c * Q_c).sum() - scale = float(torch.sqrt(q_var / p_var).item()) - t = q_mean - scale * p_mean - aligned = scale * vertices_one + t - - if vertices.ndim == 3: - aligned = aligned.unsqueeze(0) - out_mesh = Types.MESH(vertices=aligned.cpu(), faces=faces.cpu()) - else: - out_mesh = Types.MESH(vertices=aligned.cpu(), faces=faces_one.cpu()) - return IO.NodeOutput(out_mesh, float(scale)) - - class LoadNAFModel(IO.ComfyNode): @classmethod @@ -1240,7 +1096,6 @@ class Trellis2Extension(ComfyExtension): return [ Trellis2Conditioning, Pixal3DConditioning, - Pixal3DAlignObject, LoadNAFModel, Trellis2ShapeStage, EmptyTrellis2LatentStructure,