From ad94d3bc93963f8b08a577da9cfe819113ac65dc Mon Sep 17 00:00:00 2001 From: kijai Date: Wed, 10 Jun 2026 10:30:54 +0300 Subject: [PATCH] mogemask --- comfy_extras/nodes_moge.py | 74 +++++++++++++++++++++++++++++++++++++- 1 file changed, 73 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_moge.py b/comfy_extras/nodes_moge.py index 3508781a0..7968c6cda 100644 --- a/comfy_extras/nodes_moge.py +++ b/comfy_extras/nodes_moge.py @@ -404,10 +404,82 @@ class MoGePointMapToMesh(io.ComfyNode): return io.NodeOutput(mesh) +class MoGeMaskOut(io.ComfyNode): + """Mark masked pixels as invalid in a MoGe geometry. MoGePointMapToMesh's + finite-check then drops them during triangulation, so the scene mesh has a + hole where the object lives. Use to cut out an object's footprint from a + MoGe scene mesh before merging in a Pixal3DAlignObject output.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MoGeMaskOut", + display_name="MoGe Mask Out", + search_aliases=["moge", "mask", "exclude", "cut out"], + category="image/geometry_estimation", + description=( + "Set the masked pixels to invalid in a MoGe geometry. Downstream " + "MoGePointMapToMesh skips invalid pixels, so the resulting scene " + "mesh has a hole there. Use to cut an object out of a MoGe scene " + "before merging with a Pixal3D reconstruction of that object." + ), + inputs=[ + MoGeGeometry.Input("moge_geometry"), + io.Mask.Input("mask", + tooltip="Pixels to exclude (1 = drop, 0 = keep). " + "Auto-resized to MoGe's image resolution via nearest sampling."), + io.Int.Input("batch_index", default=0, min=0, max=4096, + tooltip="Which image of a batched MoGe geometry to mask."), + io.Int.Input("dilate", default=0, min=0, max=128, + tooltip="Optional dilation (max-pool radius) applied to the mask before " + "masking, in MoGe-resolution pixels. Use to overshoot the object " + "silhouette and avoid leftover halos around the Pixal3D mesh."), + ], + outputs=[MoGeGeometry.Output("moge_geometry")], + ) + + @classmethod + def execute(cls, moge_geometry, mask, batch_index, dilate) -> io.NodeOutput: + if "points" not in moge_geometry: + raise ValueError("MoGeMaskOut: moge_geometry has no points; nothing to mask.") + out = dict(moge_geometry) + points = out["points"].clone() + scene_H, scene_W = points.shape[1], points.shape[2] + + m = mask[batch_index] if mask.ndim == 3 else mask + m = m.float() + if m.shape != (scene_H, scene_W): + m = torch.nn.functional.interpolate( + m[None, None], size=(scene_H, scene_W), mode="nearest" + ).squeeze(0).squeeze(0) + if dilate > 0: + k = 2 * dilate + 1 + m = torch.nn.functional.max_pool2d( + m[None, None], kernel_size=k, stride=1, padding=dilate + ).squeeze(0).squeeze(0) + drop = (m > 0.5).to(points.device) + + # inf in points makes triangulate_grid_mesh's finite check exclude the pixel. + points[batch_index][drop] = float("inf") + out["points"] = points + + if "mask" in out and out["mask"] is not None: + mask_field = out["mask"].clone() + mask_field[batch_index] = mask_field[batch_index] & ~drop.to(mask_field.device) + out["mask"] = mask_field + + if "depth" in out and out["depth"] is not None: + depth = out["depth"].clone() + depth[batch_index][drop.to(depth.device)] = float("inf") + out["depth"] = depth + + return io.NodeOutput(out) + + class MoGeExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: - return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh] + return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh, MoGeMaskOut] async def comfy_entrypoint() -> MoGeExtension: