mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
mogemask
This commit is contained in:
parent
1697da460b
commit
ad94d3bc93
@ -404,10 +404,82 @@ class MoGePointMapToMesh(io.ComfyNode):
|
||||
return io.NodeOutput(mesh)
|
||||
|
||||
|
||||
class MoGeMaskOut(io.ComfyNode):
|
||||
"""Mark masked pixels as invalid in a MoGe geometry. MoGePointMapToMesh's
|
||||
finite-check then drops them during triangulation, so the scene mesh has a
|
||||
hole where the object lives. Use to cut out an object's footprint from a
|
||||
MoGe scene mesh before merging in a Pixal3DAlignObject output."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGeMaskOut",
|
||||
display_name="MoGe Mask Out",
|
||||
search_aliases=["moge", "mask", "exclude", "cut out"],
|
||||
category="image/geometry_estimation",
|
||||
description=(
|
||||
"Set the masked pixels to invalid in a MoGe geometry. Downstream "
|
||||
"MoGePointMapToMesh skips invalid pixels, so the resulting scene "
|
||||
"mesh has a hole there. Use to cut an object out of a MoGe scene "
|
||||
"before merging with a Pixal3D reconstruction of that object."
|
||||
),
|
||||
inputs=[
|
||||
MoGeGeometry.Input("moge_geometry"),
|
||||
io.Mask.Input("mask",
|
||||
tooltip="Pixels to exclude (1 = drop, 0 = keep). "
|
||||
"Auto-resized to MoGe's image resolution via nearest sampling."),
|
||||
io.Int.Input("batch_index", default=0, min=0, max=4096,
|
||||
tooltip="Which image of a batched MoGe geometry to mask."),
|
||||
io.Int.Input("dilate", default=0, min=0, max=128,
|
||||
tooltip="Optional dilation (max-pool radius) applied to the mask before "
|
||||
"masking, in MoGe-resolution pixels. Use to overshoot the object "
|
||||
"silhouette and avoid leftover halos around the Pixal3D mesh."),
|
||||
],
|
||||
outputs=[MoGeGeometry.Output("moge_geometry")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, moge_geometry, mask, batch_index, dilate) -> io.NodeOutput:
|
||||
if "points" not in moge_geometry:
|
||||
raise ValueError("MoGeMaskOut: moge_geometry has no points; nothing to mask.")
|
||||
out = dict(moge_geometry)
|
||||
points = out["points"].clone()
|
||||
scene_H, scene_W = points.shape[1], points.shape[2]
|
||||
|
||||
m = mask[batch_index] if mask.ndim == 3 else mask
|
||||
m = m.float()
|
||||
if m.shape != (scene_H, scene_W):
|
||||
m = torch.nn.functional.interpolate(
|
||||
m[None, None], size=(scene_H, scene_W), mode="nearest"
|
||||
).squeeze(0).squeeze(0)
|
||||
if dilate > 0:
|
||||
k = 2 * dilate + 1
|
||||
m = torch.nn.functional.max_pool2d(
|
||||
m[None, None], kernel_size=k, stride=1, padding=dilate
|
||||
).squeeze(0).squeeze(0)
|
||||
drop = (m > 0.5).to(points.device)
|
||||
|
||||
# inf in points makes triangulate_grid_mesh's finite check exclude the pixel.
|
||||
points[batch_index][drop] = float("inf")
|
||||
out["points"] = points
|
||||
|
||||
if "mask" in out and out["mask"] is not None:
|
||||
mask_field = out["mask"].clone()
|
||||
mask_field[batch_index] = mask_field[batch_index] & ~drop.to(mask_field.device)
|
||||
out["mask"] = mask_field
|
||||
|
||||
if "depth" in out and out["depth"] is not None:
|
||||
depth = out["depth"].clone()
|
||||
depth[batch_index][drop.to(depth.device)] = float("inf")
|
||||
out["depth"] = depth
|
||||
|
||||
return io.NodeOutput(out)
|
||||
|
||||
|
||||
class MoGeExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh]
|
||||
return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh, MoGeMaskOut]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> MoGeExtension:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user