This commit is contained in:
kijai 2026-06-10 10:30:54 +03:00
parent 1697da460b
commit ad94d3bc93

View File

@ -404,10 +404,82 @@ class MoGePointMapToMesh(io.ComfyNode):
return io.NodeOutput(mesh)
class MoGeMaskOut(io.ComfyNode):
"""Mark masked pixels as invalid in a MoGe geometry. MoGePointMapToMesh's
finite-check then drops them during triangulation, so the scene mesh has a
hole where the object lives. Use to cut out an object's footprint from a
MoGe scene mesh before merging in a Pixal3DAlignObject output."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MoGeMaskOut",
display_name="MoGe Mask Out",
search_aliases=["moge", "mask", "exclude", "cut out"],
category="image/geometry_estimation",
description=(
"Set the masked pixels to invalid in a MoGe geometry. Downstream "
"MoGePointMapToMesh skips invalid pixels, so the resulting scene "
"mesh has a hole there. Use to cut an object out of a MoGe scene "
"before merging with a Pixal3D reconstruction of that object."
),
inputs=[
MoGeGeometry.Input("moge_geometry"),
io.Mask.Input("mask",
tooltip="Pixels to exclude (1 = drop, 0 = keep). "
"Auto-resized to MoGe's image resolution via nearest sampling."),
io.Int.Input("batch_index", default=0, min=0, max=4096,
tooltip="Which image of a batched MoGe geometry to mask."),
io.Int.Input("dilate", default=0, min=0, max=128,
tooltip="Optional dilation (max-pool radius) applied to the mask before "
"masking, in MoGe-resolution pixels. Use to overshoot the object "
"silhouette and avoid leftover halos around the Pixal3D mesh."),
],
outputs=[MoGeGeometry.Output("moge_geometry")],
)
@classmethod
def execute(cls, moge_geometry, mask, batch_index, dilate) -> io.NodeOutput:
if "points" not in moge_geometry:
raise ValueError("MoGeMaskOut: moge_geometry has no points; nothing to mask.")
out = dict(moge_geometry)
points = out["points"].clone()
scene_H, scene_W = points.shape[1], points.shape[2]
m = mask[batch_index] if mask.ndim == 3 else mask
m = m.float()
if m.shape != (scene_H, scene_W):
m = torch.nn.functional.interpolate(
m[None, None], size=(scene_H, scene_W), mode="nearest"
).squeeze(0).squeeze(0)
if dilate > 0:
k = 2 * dilate + 1
m = torch.nn.functional.max_pool2d(
m[None, None], kernel_size=k, stride=1, padding=dilate
).squeeze(0).squeeze(0)
drop = (m > 0.5).to(points.device)
# inf in points makes triangulate_grid_mesh's finite check exclude the pixel.
points[batch_index][drop] = float("inf")
out["points"] = points
if "mask" in out and out["mask"] is not None:
mask_field = out["mask"].clone()
mask_field[batch_index] = mask_field[batch_index] & ~drop.to(mask_field.device)
out["mask"] = mask_field
if "depth" in out and out["depth"] is not None:
depth = out["depth"].clone()
depth[batch_index][drop.to(depth.device)] = float("inf")
out["depth"] = depth
return io.NodeOutput(out)
class MoGeExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh]
return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh, MoGeMaskOut]
async def comfy_entrypoint() -> MoGeExtension: