mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-04 21:51:00 +08:00
Remove Pixal3DAlignObject, out of scope
This commit is contained in:
parent
eec0692bcb
commit
2cced8971c
@ -439,82 +439,10 @@ class MoGeGeometryToFOV(io.ComfyNode):
|
||||
return io.NodeOutput(fov, focal_pixels)
|
||||
|
||||
|
||||
class MoGeMaskOut(io.ComfyNode):
|
||||
"""Mark masked pixels as invalid in a MoGe geometry. MoGePointMapToMesh's
|
||||
finite-check then drops them during triangulation, so the scene mesh has a
|
||||
hole where the object lives. Use to cut out an object's footprint from a
|
||||
MoGe scene mesh before merging in a Pixal3DAlignObject output."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="MoGeMaskOut",
|
||||
display_name="MoGe Mask Out",
|
||||
search_aliases=["moge", "mask", "exclude", "cut out"],
|
||||
category="image/geometry_estimation",
|
||||
description=(
|
||||
"Set the masked pixels to invalid in a MoGe geometry. Downstream "
|
||||
"MoGePointMapToMesh skips invalid pixels, so the resulting scene "
|
||||
"mesh has a hole there. Use to cut an object out of a MoGe scene "
|
||||
"before merging with a Pixal3D reconstruction of that object."
|
||||
),
|
||||
inputs=[
|
||||
MoGeGeometry.Input("moge_geometry"),
|
||||
io.Mask.Input("mask",
|
||||
tooltip="Pixels to exclude (1 = drop, 0 = keep). "
|
||||
"Auto-resized to MoGe's image resolution via nearest sampling."),
|
||||
io.Int.Input("batch_index", default=0, min=0, max=4096,
|
||||
tooltip="Which image of a batched MoGe geometry to mask."),
|
||||
io.Int.Input("dilate", default=0, min=0, max=128,
|
||||
tooltip="Optional dilation (max-pool radius) applied to the mask before "
|
||||
"masking, in MoGe-resolution pixels. Use to overshoot the object "
|
||||
"silhouette and avoid leftover halos around the Pixal3D mesh."),
|
||||
],
|
||||
outputs=[MoGeGeometry.Output("moge_geometry")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, moge_geometry, mask, batch_index, dilate) -> io.NodeOutput:
|
||||
if "points" not in moge_geometry:
|
||||
raise ValueError("MoGeMaskOut: moge_geometry has no points; nothing to mask.")
|
||||
out = dict(moge_geometry)
|
||||
points = out["points"].clone()
|
||||
scene_H, scene_W = points.shape[1], points.shape[2]
|
||||
|
||||
m = mask[batch_index] if mask.ndim == 3 else mask
|
||||
m = m.float()
|
||||
if m.shape != (scene_H, scene_W):
|
||||
m = torch.nn.functional.interpolate(
|
||||
m[None, None], size=(scene_H, scene_W), mode="nearest"
|
||||
).squeeze(0).squeeze(0)
|
||||
if dilate > 0:
|
||||
k = 2 * dilate + 1
|
||||
m = torch.nn.functional.max_pool2d(
|
||||
m[None, None], kernel_size=k, stride=1, padding=dilate
|
||||
).squeeze(0).squeeze(0)
|
||||
drop = (m > 0.5).to(points.device)
|
||||
|
||||
# inf in points makes triangulate_grid_mesh's finite check exclude the pixel.
|
||||
points[batch_index][drop] = float("inf")
|
||||
out["points"] = points
|
||||
|
||||
if "mask" in out and out["mask"] is not None:
|
||||
mask_field = out["mask"].clone()
|
||||
mask_field[batch_index] = mask_field[batch_index] & ~drop.to(mask_field.device)
|
||||
out["mask"] = mask_field
|
||||
|
||||
if "depth" in out and out["depth"] is not None:
|
||||
depth = out["depth"].clone()
|
||||
depth[batch_index][drop.to(depth.device)] = float("inf")
|
||||
out["depth"] = depth
|
||||
|
||||
return io.NodeOutput(out)
|
||||
|
||||
|
||||
class MoGeExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh, MoGeGeometryToFOV, MoGeMaskOut]
|
||||
return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh, MoGeGeometryToFOV]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> MoGeExtension:
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, Types, UI, io
|
||||
from comfy.ldm.trellis2.vae import SparseTensor
|
||||
from comfy.ldm.trellis2.model import build_proj_transform_matrix, _project_points_to_image, compute_stage_proj_feats
|
||||
from comfy.ldm.trellis2.model import build_proj_transform_matrix, compute_stage_proj_feats
|
||||
from comfy.ldm.trellis2.naf.model import NAF
|
||||
|
||||
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
|
||||
@ -987,150 +987,6 @@ class Pixal3DConditioning(IO.ComfyNode):
|
||||
return IO.NodeOutput(positive, negative)
|
||||
|
||||
|
||||
def _project_vertices_to_image_uv(vertices_world, transform_matrix, camera_angle_x, image_resolution):
|
||||
points = vertices_world.unsqueeze(0).float()
|
||||
T = transform_matrix.unsqueeze(0).float() if transform_matrix.ndim == 2 else transform_matrix.float()
|
||||
cam = camera_angle_x.unsqueeze(0) if camera_angle_x.ndim == 0 else camera_angle_x
|
||||
T = T.to(points.device)
|
||||
cam = cam.to(points.device)
|
||||
uv_pix, depth, valid = _project_points_to_image(points, T, cam.float(), image_resolution)
|
||||
uv = uv_pix.squeeze(0) / image_resolution
|
||||
return uv, depth.squeeze(0), valid.squeeze(0)
|
||||
|
||||
|
||||
def _crop_uv_to_scene_pixels(uv_crop, crop_bbox, scene_image_size):
|
||||
crop_x1, crop_y1, crop_x2, crop_y2 = crop_bbox
|
||||
crop_w = max(1, crop_x2 - crop_x1)
|
||||
crop_h = max(1, crop_y2 - crop_y1)
|
||||
px = uv_crop[:, 0] * crop_w + crop_x1
|
||||
py = uv_crop[:, 1] * crop_h + crop_y1
|
||||
W, H = scene_image_size
|
||||
return torch.stack([px.clamp(0, W - 1), py.clamp(0, H - 1)], dim=-1)
|
||||
|
||||
|
||||
class Pixal3DAlignObject(IO.ComfyNode):
|
||||
"""Pixal3D paper §3.3 Global Alignment for a single object.
|
||||
|
||||
Solves (scale, translation) aligning the mesh to MoGe's per-pixel point map. Requires
|
||||
MoGe to have been computed on the same resized scene image as Pixal3DConditioning."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Pixal3DAlignObject",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.Conditioning.Input("positive", tooltip="The positive conditioning from Pixal3DConditioning for this object — Pixal3DAlignObject reads transform_matrix / camera_angle_x / mesh_scale / crop_bboxes out of its proj_feat_pack."),
|
||||
io.Custom("MOGE_GEOMETRY").Input("moge_geometry", tooltip="MoGe geometry computed on the original scene image."),
|
||||
IO.Mask.Input(
|
||||
"object_mask",
|
||||
optional=True,
|
||||
tooltip="Optional per-object scene-space mask. If connected, only vertices whose projected pixel falls inside the mask contribute to the alignment solve.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"batch_index",
|
||||
default=0, min=0, max=1024,
|
||||
tooltip="Which batch slot of the proj_feat_pack/MoGe geometry corresponds to this object.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output("aligned_mesh"),
|
||||
IO.Float.Output(display_name="scale"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, positive, moge_geometry, object_mask=None, batch_index=0) -> IO.NodeOutput:
|
||||
proj_feat_pack = _proj_pack_from_conditioning(positive)
|
||||
if proj_feat_pack is None:
|
||||
raise ValueError("Pixal3DAlignObject: positive conditioning has no proj_feat_pack — connect a Pixal3DConditioning output.")
|
||||
vertices = mesh.vertices
|
||||
faces = mesh.faces
|
||||
if vertices.ndim == 3:
|
||||
vertices_one = vertices[0]
|
||||
faces_one = faces[0]
|
||||
else:
|
||||
vertices_one = vertices
|
||||
faces_one = faces
|
||||
|
||||
T = proj_feat_pack["transform_matrix"][batch_index:batch_index + 1]
|
||||
cam_angle = proj_feat_pack["camera_angle_x"][batch_index:batch_index + 1]
|
||||
mesh_scale = proj_feat_pack["mesh_scale"][batch_index]
|
||||
image_resolution = int(proj_feat_pack.get("image_resolution", 1024))
|
||||
crop_bbox = proj_feat_pack["crop_bboxes"][batch_index]
|
||||
pack_scene_size = proj_feat_pack.get("scene_sizes", [None] * (batch_index + 1))[batch_index]
|
||||
moge_points = moge_geometry["points"]
|
||||
moge_mask = moge_geometry["mask"]
|
||||
if moge_points.ndim != 4:
|
||||
raise ValueError(f"MoGe points expected [B, H, W, 3]; got {tuple(moge_points.shape)}")
|
||||
scene_H, scene_W = moge_points.shape[1], moge_points.shape[2]
|
||||
if pack_scene_size is not None and pack_scene_size != (scene_W, scene_H):
|
||||
raise ValueError(
|
||||
f"Pixal3DAlignObject: MoGe geometry was computed on a {scene_W}x{scene_H} image, "
|
||||
f"but the proj_feat_pack's bbox lives in a {pack_scene_size[0]}x{pack_scene_size[1]} "
|
||||
"image. Run MoGe on the same resized scene image Pixal3DConditioning used."
|
||||
)
|
||||
|
||||
# Vertices come out of VaeDecodeShapeTrellis in the Pixal3D model frame
|
||||
# (no un-rotation). Apply _PROJ_GRID_ROTATION = R_x(-90°) to map model
|
||||
# frame → ProjGrid world: (X, Y, Z) -> (X, -Z, Y).
|
||||
v = vertices_one.float()
|
||||
verts_world = torch.stack([v[..., 0], -v[..., 2], v[..., 1]], dim=-1)
|
||||
verts_world = verts_world / float(mesh_scale.item())
|
||||
uv_crop, _depth, valid = _project_vertices_to_image_uv(
|
||||
verts_world, T[0], cam_angle[0], image_resolution)
|
||||
scene_pixels = _crop_uv_to_scene_pixels(uv_crop, crop_bbox, (scene_W, scene_H))
|
||||
in_scene = ((scene_pixels[:, 0] >= 0) & (scene_pixels[:, 0] < scene_W) &
|
||||
(scene_pixels[:, 1] >= 0) & (scene_pixels[:, 1] < scene_H))
|
||||
# MoGe geometry and object_mask can land on CPU after passing between nodes;
|
||||
# match the indexed tensor's device for sy/sx so the gather works on either.
|
||||
moge_points = moge_points.to(scene_pixels.device)
|
||||
moge_mask = moge_mask.to(scene_pixels.device)
|
||||
sx = scene_pixels[:, 0].long().clamp(0, scene_W - 1)
|
||||
sy = scene_pixels[:, 1].long().clamp(0, scene_H - 1)
|
||||
moge_per_vertex = moge_points[batch_index, sy, sx]
|
||||
# MoGe's perspective output is (X right, Y down, Z forward). Convert to glTF
|
||||
# Y-up (X right, Y up, Z back) so the scale/translate fit runs in the same
|
||||
# frame as vertices_one (Pixal3D model frame = glTF Y-up).
|
||||
moge_per_vertex = moge_per_vertex * torch.tensor(
|
||||
[1.0, -1.0, -1.0], dtype=moge_per_vertex.dtype, device=moge_per_vertex.device
|
||||
)
|
||||
moge_mask_per_vertex = moge_mask[batch_index, sy, sx]
|
||||
keep = valid & in_scene & moge_mask_per_vertex
|
||||
if object_mask is not None:
|
||||
om = object_mask if object_mask.ndim == 2 else object_mask[batch_index]
|
||||
om = om.to(sy.device)
|
||||
keep = keep & (om[sy, sx] > 0.5)
|
||||
|
||||
finite = torch.isfinite(moge_per_vertex).all(dim=-1)
|
||||
keep = keep & finite
|
||||
|
||||
kept = int(keep.sum().item())
|
||||
if kept < 8:
|
||||
scale = 1.0
|
||||
aligned = vertices_one
|
||||
else:
|
||||
P = vertices_one[keep].float()
|
||||
Q = moge_per_vertex[keep].float()
|
||||
p_mean = P.mean(dim=0, keepdim=True)
|
||||
q_mean = Q.mean(dim=0, keepdim=True)
|
||||
P_c = P - p_mean
|
||||
Q_c = Q - q_mean
|
||||
p_var = (P_c * P_c).sum().clamp(min=1e-8)
|
||||
q_var = (Q_c * Q_c).sum()
|
||||
scale = float(torch.sqrt(q_var / p_var).item())
|
||||
t = q_mean - scale * p_mean
|
||||
aligned = scale * vertices_one + t
|
||||
|
||||
if vertices.ndim == 3:
|
||||
aligned = aligned.unsqueeze(0)
|
||||
out_mesh = Types.MESH(vertices=aligned.cpu(), faces=faces.cpu())
|
||||
else:
|
||||
out_mesh = Types.MESH(vertices=aligned.cpu(), faces=faces_one.cpu())
|
||||
return IO.NodeOutput(out_mesh, float(scale))
|
||||
|
||||
|
||||
class LoadNAFModel(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -1240,7 +1096,6 @@ class Trellis2Extension(ComfyExtension):
|
||||
return [
|
||||
Trellis2Conditioning,
|
||||
Pixal3DConditioning,
|
||||
Pixal3DAlignObject,
|
||||
LoadNAFModel,
|
||||
Trellis2ShapeStage,
|
||||
EmptyTrellis2LatentStructure,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user