Remove Pixal3DAlignObject, out of scope

This commit is contained in:
kijai 2026-07-01 00:58:24 +03:00
parent eec0692bcb
commit 2cced8971c
2 changed files with 2 additions and 219 deletions

View File

@ -439,82 +439,10 @@ class MoGeGeometryToFOV(io.ComfyNode):
return io.NodeOutput(fov, focal_pixels)
class MoGeMaskOut(io.ComfyNode):
"""Mark masked pixels as invalid in a MoGe geometry. MoGePointMapToMesh's
finite-check then drops them during triangulation, so the scene mesh has a
hole where the object lives. Use to cut out an object's footprint from a
MoGe scene mesh before merging in a Pixal3DAlignObject output."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="MoGeMaskOut",
display_name="MoGe Mask Out",
search_aliases=["moge", "mask", "exclude", "cut out"],
category="image/geometry_estimation",
description=(
"Set the masked pixels to invalid in a MoGe geometry. Downstream "
"MoGePointMapToMesh skips invalid pixels, so the resulting scene "
"mesh has a hole there. Use to cut an object out of a MoGe scene "
"before merging with a Pixal3D reconstruction of that object."
),
inputs=[
MoGeGeometry.Input("moge_geometry"),
io.Mask.Input("mask",
tooltip="Pixels to exclude (1 = drop, 0 = keep). "
"Auto-resized to MoGe's image resolution via nearest sampling."),
io.Int.Input("batch_index", default=0, min=0, max=4096,
tooltip="Which image of a batched MoGe geometry to mask."),
io.Int.Input("dilate", default=0, min=0, max=128,
tooltip="Optional dilation (max-pool radius) applied to the mask before "
"masking, in MoGe-resolution pixels. Use to overshoot the object "
"silhouette and avoid leftover halos around the Pixal3D mesh."),
],
outputs=[MoGeGeometry.Output("moge_geometry")],
)
@classmethod
def execute(cls, moge_geometry, mask, batch_index, dilate) -> io.NodeOutput:
if "points" not in moge_geometry:
raise ValueError("MoGeMaskOut: moge_geometry has no points; nothing to mask.")
out = dict(moge_geometry)
points = out["points"].clone()
scene_H, scene_W = points.shape[1], points.shape[2]
m = mask[batch_index] if mask.ndim == 3 else mask
m = m.float()
if m.shape != (scene_H, scene_W):
m = torch.nn.functional.interpolate(
m[None, None], size=(scene_H, scene_W), mode="nearest"
).squeeze(0).squeeze(0)
if dilate > 0:
k = 2 * dilate + 1
m = torch.nn.functional.max_pool2d(
m[None, None], kernel_size=k, stride=1, padding=dilate
).squeeze(0).squeeze(0)
drop = (m > 0.5).to(points.device)
# inf in points makes triangulate_grid_mesh's finite check exclude the pixel.
points[batch_index][drop] = float("inf")
out["points"] = points
if "mask" in out and out["mask"] is not None:
mask_field = out["mask"].clone()
mask_field[batch_index] = mask_field[batch_index] & ~drop.to(mask_field.device)
out["mask"] = mask_field
if "depth" in out and out["depth"] is not None:
depth = out["depth"].clone()
depth[batch_index][drop.to(depth.device)] = float("inf")
out["depth"] = depth
return io.NodeOutput(out)
class MoGeExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh, MoGeGeometryToFOV, MoGeMaskOut]
return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh, MoGeGeometryToFOV]
async def comfy_entrypoint() -> MoGeExtension:

View File

@ -1,7 +1,7 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types, UI, io
from comfy.ldm.trellis2.vae import SparseTensor
from comfy.ldm.trellis2.model import build_proj_transform_matrix, _project_points_to_image, compute_stage_proj_feats
from comfy.ldm.trellis2.model import build_proj_transform_matrix, compute_stage_proj_feats
from comfy.ldm.trellis2.naf.model import NAF
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
@ -987,150 +987,6 @@ class Pixal3DConditioning(IO.ComfyNode):
return IO.NodeOutput(positive, negative)
def _project_vertices_to_image_uv(vertices_world, transform_matrix, camera_angle_x, image_resolution):
points = vertices_world.unsqueeze(0).float()
T = transform_matrix.unsqueeze(0).float() if transform_matrix.ndim == 2 else transform_matrix.float()
cam = camera_angle_x.unsqueeze(0) if camera_angle_x.ndim == 0 else camera_angle_x
T = T.to(points.device)
cam = cam.to(points.device)
uv_pix, depth, valid = _project_points_to_image(points, T, cam.float(), image_resolution)
uv = uv_pix.squeeze(0) / image_resolution
return uv, depth.squeeze(0), valid.squeeze(0)
def _crop_uv_to_scene_pixels(uv_crop, crop_bbox, scene_image_size):
crop_x1, crop_y1, crop_x2, crop_y2 = crop_bbox
crop_w = max(1, crop_x2 - crop_x1)
crop_h = max(1, crop_y2 - crop_y1)
px = uv_crop[:, 0] * crop_w + crop_x1
py = uv_crop[:, 1] * crop_h + crop_y1
W, H = scene_image_size
return torch.stack([px.clamp(0, W - 1), py.clamp(0, H - 1)], dim=-1)
class Pixal3DAlignObject(IO.ComfyNode):
"""Pixal3D paper §3.3 Global Alignment for a single object.
Solves (scale, translation) aligning the mesh to MoGe's per-pixel point map. Requires
MoGe to have been computed on the same resized scene image as Pixal3DConditioning."""
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Pixal3DAlignObject",
category="latent/3d",
inputs=[
IO.Mesh.Input("mesh"),
IO.Conditioning.Input("positive", tooltip="The positive conditioning from Pixal3DConditioning for this object — Pixal3DAlignObject reads transform_matrix / camera_angle_x / mesh_scale / crop_bboxes out of its proj_feat_pack."),
io.Custom("MOGE_GEOMETRY").Input("moge_geometry", tooltip="MoGe geometry computed on the original scene image."),
IO.Mask.Input(
"object_mask",
optional=True,
tooltip="Optional per-object scene-space mask. If connected, only vertices whose projected pixel falls inside the mask contribute to the alignment solve.",
),
IO.Int.Input(
"batch_index",
default=0, min=0, max=1024,
tooltip="Which batch slot of the proj_feat_pack/MoGe geometry corresponds to this object.",
),
],
outputs=[
IO.Mesh.Output("aligned_mesh"),
IO.Float.Output(display_name="scale"),
],
)
@classmethod
def execute(cls, mesh, positive, moge_geometry, object_mask=None, batch_index=0) -> IO.NodeOutput:
proj_feat_pack = _proj_pack_from_conditioning(positive)
if proj_feat_pack is None:
raise ValueError("Pixal3DAlignObject: positive conditioning has no proj_feat_pack — connect a Pixal3DConditioning output.")
vertices = mesh.vertices
faces = mesh.faces
if vertices.ndim == 3:
vertices_one = vertices[0]
faces_one = faces[0]
else:
vertices_one = vertices
faces_one = faces
T = proj_feat_pack["transform_matrix"][batch_index:batch_index + 1]
cam_angle = proj_feat_pack["camera_angle_x"][batch_index:batch_index + 1]
mesh_scale = proj_feat_pack["mesh_scale"][batch_index]
image_resolution = int(proj_feat_pack.get("image_resolution", 1024))
crop_bbox = proj_feat_pack["crop_bboxes"][batch_index]
pack_scene_size = proj_feat_pack.get("scene_sizes", [None] * (batch_index + 1))[batch_index]
moge_points = moge_geometry["points"]
moge_mask = moge_geometry["mask"]
if moge_points.ndim != 4:
raise ValueError(f"MoGe points expected [B, H, W, 3]; got {tuple(moge_points.shape)}")
scene_H, scene_W = moge_points.shape[1], moge_points.shape[2]
if pack_scene_size is not None and pack_scene_size != (scene_W, scene_H):
raise ValueError(
f"Pixal3DAlignObject: MoGe geometry was computed on a {scene_W}x{scene_H} image, "
f"but the proj_feat_pack's bbox lives in a {pack_scene_size[0]}x{pack_scene_size[1]} "
"image. Run MoGe on the same resized scene image Pixal3DConditioning used."
)
# Vertices come out of VaeDecodeShapeTrellis in the Pixal3D model frame
# (no un-rotation). Apply _PROJ_GRID_ROTATION = R_x(-90°) to map model
# frame → ProjGrid world: (X, Y, Z) -> (X, -Z, Y).
v = vertices_one.float()
verts_world = torch.stack([v[..., 0], -v[..., 2], v[..., 1]], dim=-1)
verts_world = verts_world / float(mesh_scale.item())
uv_crop, _depth, valid = _project_vertices_to_image_uv(
verts_world, T[0], cam_angle[0], image_resolution)
scene_pixels = _crop_uv_to_scene_pixels(uv_crop, crop_bbox, (scene_W, scene_H))
in_scene = ((scene_pixels[:, 0] >= 0) & (scene_pixels[:, 0] < scene_W) &
(scene_pixels[:, 1] >= 0) & (scene_pixels[:, 1] < scene_H))
# MoGe geometry and object_mask can land on CPU after passing between nodes;
# match the indexed tensor's device for sy/sx so the gather works on either.
moge_points = moge_points.to(scene_pixels.device)
moge_mask = moge_mask.to(scene_pixels.device)
sx = scene_pixels[:, 0].long().clamp(0, scene_W - 1)
sy = scene_pixels[:, 1].long().clamp(0, scene_H - 1)
moge_per_vertex = moge_points[batch_index, sy, sx]
# MoGe's perspective output is (X right, Y down, Z forward). Convert to glTF
# Y-up (X right, Y up, Z back) so the scale/translate fit runs in the same
# frame as vertices_one (Pixal3D model frame = glTF Y-up).
moge_per_vertex = moge_per_vertex * torch.tensor(
[1.0, -1.0, -1.0], dtype=moge_per_vertex.dtype, device=moge_per_vertex.device
)
moge_mask_per_vertex = moge_mask[batch_index, sy, sx]
keep = valid & in_scene & moge_mask_per_vertex
if object_mask is not None:
om = object_mask if object_mask.ndim == 2 else object_mask[batch_index]
om = om.to(sy.device)
keep = keep & (om[sy, sx] > 0.5)
finite = torch.isfinite(moge_per_vertex).all(dim=-1)
keep = keep & finite
kept = int(keep.sum().item())
if kept < 8:
scale = 1.0
aligned = vertices_one
else:
P = vertices_one[keep].float()
Q = moge_per_vertex[keep].float()
p_mean = P.mean(dim=0, keepdim=True)
q_mean = Q.mean(dim=0, keepdim=True)
P_c = P - p_mean
Q_c = Q - q_mean
p_var = (P_c * P_c).sum().clamp(min=1e-8)
q_var = (Q_c * Q_c).sum()
scale = float(torch.sqrt(q_var / p_var).item())
t = q_mean - scale * p_mean
aligned = scale * vertices_one + t
if vertices.ndim == 3:
aligned = aligned.unsqueeze(0)
out_mesh = Types.MESH(vertices=aligned.cpu(), faces=faces.cpu())
else:
out_mesh = Types.MESH(vertices=aligned.cpu(), faces=faces_one.cpu())
return IO.NodeOutput(out_mesh, float(scale))
class LoadNAFModel(IO.ComfyNode):
@classmethod
@ -1240,7 +1096,6 @@ class Trellis2Extension(ComfyExtension):
return [
Trellis2Conditioning,
Pixal3DConditioning,
Pixal3DAlignObject,
LoadNAFModel,
Trellis2ShapeStage,
EmptyTrellis2LatentStructure,