mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-05 22:21:31 +08:00
Remove Pixal3DAlignObject, out of scope
This commit is contained in:
parent
eec0692bcb
commit
2cced8971c
@ -439,82 +439,10 @@ class MoGeGeometryToFOV(io.ComfyNode):
|
|||||||
return io.NodeOutput(fov, focal_pixels)
|
return io.NodeOutput(fov, focal_pixels)
|
||||||
|
|
||||||
|
|
||||||
class MoGeMaskOut(io.ComfyNode):
|
|
||||||
"""Mark masked pixels as invalid in a MoGe geometry. MoGePointMapToMesh's
|
|
||||||
finite-check then drops them during triangulation, so the scene mesh has a
|
|
||||||
hole where the object lives. Use to cut out an object's footprint from a
|
|
||||||
MoGe scene mesh before merging in a Pixal3DAlignObject output."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return io.Schema(
|
|
||||||
node_id="MoGeMaskOut",
|
|
||||||
display_name="MoGe Mask Out",
|
|
||||||
search_aliases=["moge", "mask", "exclude", "cut out"],
|
|
||||||
category="image/geometry_estimation",
|
|
||||||
description=(
|
|
||||||
"Set the masked pixels to invalid in a MoGe geometry. Downstream "
|
|
||||||
"MoGePointMapToMesh skips invalid pixels, so the resulting scene "
|
|
||||||
"mesh has a hole there. Use to cut an object out of a MoGe scene "
|
|
||||||
"before merging with a Pixal3D reconstruction of that object."
|
|
||||||
),
|
|
||||||
inputs=[
|
|
||||||
MoGeGeometry.Input("moge_geometry"),
|
|
||||||
io.Mask.Input("mask",
|
|
||||||
tooltip="Pixels to exclude (1 = drop, 0 = keep). "
|
|
||||||
"Auto-resized to MoGe's image resolution via nearest sampling."),
|
|
||||||
io.Int.Input("batch_index", default=0, min=0, max=4096,
|
|
||||||
tooltip="Which image of a batched MoGe geometry to mask."),
|
|
||||||
io.Int.Input("dilate", default=0, min=0, max=128,
|
|
||||||
tooltip="Optional dilation (max-pool radius) applied to the mask before "
|
|
||||||
"masking, in MoGe-resolution pixels. Use to overshoot the object "
|
|
||||||
"silhouette and avoid leftover halos around the Pixal3D mesh."),
|
|
||||||
],
|
|
||||||
outputs=[MoGeGeometry.Output("moge_geometry")],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, moge_geometry, mask, batch_index, dilate) -> io.NodeOutput:
|
|
||||||
if "points" not in moge_geometry:
|
|
||||||
raise ValueError("MoGeMaskOut: moge_geometry has no points; nothing to mask.")
|
|
||||||
out = dict(moge_geometry)
|
|
||||||
points = out["points"].clone()
|
|
||||||
scene_H, scene_W = points.shape[1], points.shape[2]
|
|
||||||
|
|
||||||
m = mask[batch_index] if mask.ndim == 3 else mask
|
|
||||||
m = m.float()
|
|
||||||
if m.shape != (scene_H, scene_W):
|
|
||||||
m = torch.nn.functional.interpolate(
|
|
||||||
m[None, None], size=(scene_H, scene_W), mode="nearest"
|
|
||||||
).squeeze(0).squeeze(0)
|
|
||||||
if dilate > 0:
|
|
||||||
k = 2 * dilate + 1
|
|
||||||
m = torch.nn.functional.max_pool2d(
|
|
||||||
m[None, None], kernel_size=k, stride=1, padding=dilate
|
|
||||||
).squeeze(0).squeeze(0)
|
|
||||||
drop = (m > 0.5).to(points.device)
|
|
||||||
|
|
||||||
# inf in points makes triangulate_grid_mesh's finite check exclude the pixel.
|
|
||||||
points[batch_index][drop] = float("inf")
|
|
||||||
out["points"] = points
|
|
||||||
|
|
||||||
if "mask" in out and out["mask"] is not None:
|
|
||||||
mask_field = out["mask"].clone()
|
|
||||||
mask_field[batch_index] = mask_field[batch_index] & ~drop.to(mask_field.device)
|
|
||||||
out["mask"] = mask_field
|
|
||||||
|
|
||||||
if "depth" in out and out["depth"] is not None:
|
|
||||||
depth = out["depth"].clone()
|
|
||||||
depth[batch_index][drop.to(depth.device)] = float("inf")
|
|
||||||
out["depth"] = depth
|
|
||||||
|
|
||||||
return io.NodeOutput(out)
|
|
||||||
|
|
||||||
|
|
||||||
class MoGeExtension(ComfyExtension):
|
class MoGeExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh, MoGeGeometryToFOV, MoGeMaskOut]
|
return [LoadMoGeModel, MoGeInference, MoGePanoramaInference, MoGeRender, MoGePointMapToMesh, MoGeGeometryToFOV]
|
||||||
|
|
||||||
|
|
||||||
async def comfy_entrypoint() -> MoGeExtension:
|
async def comfy_entrypoint() -> MoGeExtension:
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, IO, Types, UI, io
|
from comfy_api.latest import ComfyExtension, IO, Types, UI, io
|
||||||
from comfy.ldm.trellis2.vae import SparseTensor
|
from comfy.ldm.trellis2.vae import SparseTensor
|
||||||
from comfy.ldm.trellis2.model import build_proj_transform_matrix, _project_points_to_image, compute_stage_proj_feats
|
from comfy.ldm.trellis2.model import build_proj_transform_matrix, compute_stage_proj_feats
|
||||||
from comfy.ldm.trellis2.naf.model import NAF
|
from comfy.ldm.trellis2.naf.model import NAF
|
||||||
|
|
||||||
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
|
from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
|
||||||
@ -987,150 +987,6 @@ class Pixal3DConditioning(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(positive, negative)
|
return IO.NodeOutput(positive, negative)
|
||||||
|
|
||||||
|
|
||||||
def _project_vertices_to_image_uv(vertices_world, transform_matrix, camera_angle_x, image_resolution):
|
|
||||||
points = vertices_world.unsqueeze(0).float()
|
|
||||||
T = transform_matrix.unsqueeze(0).float() if transform_matrix.ndim == 2 else transform_matrix.float()
|
|
||||||
cam = camera_angle_x.unsqueeze(0) if camera_angle_x.ndim == 0 else camera_angle_x
|
|
||||||
T = T.to(points.device)
|
|
||||||
cam = cam.to(points.device)
|
|
||||||
uv_pix, depth, valid = _project_points_to_image(points, T, cam.float(), image_resolution)
|
|
||||||
uv = uv_pix.squeeze(0) / image_resolution
|
|
||||||
return uv, depth.squeeze(0), valid.squeeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
def _crop_uv_to_scene_pixels(uv_crop, crop_bbox, scene_image_size):
|
|
||||||
crop_x1, crop_y1, crop_x2, crop_y2 = crop_bbox
|
|
||||||
crop_w = max(1, crop_x2 - crop_x1)
|
|
||||||
crop_h = max(1, crop_y2 - crop_y1)
|
|
||||||
px = uv_crop[:, 0] * crop_w + crop_x1
|
|
||||||
py = uv_crop[:, 1] * crop_h + crop_y1
|
|
||||||
W, H = scene_image_size
|
|
||||||
return torch.stack([px.clamp(0, W - 1), py.clamp(0, H - 1)], dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
class Pixal3DAlignObject(IO.ComfyNode):
|
|
||||||
"""Pixal3D paper §3.3 Global Alignment for a single object.
|
|
||||||
|
|
||||||
Solves (scale, translation) aligning the mesh to MoGe's per-pixel point map. Requires
|
|
||||||
MoGe to have been computed on the same resized scene image as Pixal3DConditioning."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls):
|
|
||||||
return IO.Schema(
|
|
||||||
node_id="Pixal3DAlignObject",
|
|
||||||
category="latent/3d",
|
|
||||||
inputs=[
|
|
||||||
IO.Mesh.Input("mesh"),
|
|
||||||
IO.Conditioning.Input("positive", tooltip="The positive conditioning from Pixal3DConditioning for this object — Pixal3DAlignObject reads transform_matrix / camera_angle_x / mesh_scale / crop_bboxes out of its proj_feat_pack."),
|
|
||||||
io.Custom("MOGE_GEOMETRY").Input("moge_geometry", tooltip="MoGe geometry computed on the original scene image."),
|
|
||||||
IO.Mask.Input(
|
|
||||||
"object_mask",
|
|
||||||
optional=True,
|
|
||||||
tooltip="Optional per-object scene-space mask. If connected, only vertices whose projected pixel falls inside the mask contribute to the alignment solve.",
|
|
||||||
),
|
|
||||||
IO.Int.Input(
|
|
||||||
"batch_index",
|
|
||||||
default=0, min=0, max=1024,
|
|
||||||
tooltip="Which batch slot of the proj_feat_pack/MoGe geometry corresponds to this object.",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
outputs=[
|
|
||||||
IO.Mesh.Output("aligned_mesh"),
|
|
||||||
IO.Float.Output(display_name="scale"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, mesh, positive, moge_geometry, object_mask=None, batch_index=0) -> IO.NodeOutput:
|
|
||||||
proj_feat_pack = _proj_pack_from_conditioning(positive)
|
|
||||||
if proj_feat_pack is None:
|
|
||||||
raise ValueError("Pixal3DAlignObject: positive conditioning has no proj_feat_pack — connect a Pixal3DConditioning output.")
|
|
||||||
vertices = mesh.vertices
|
|
||||||
faces = mesh.faces
|
|
||||||
if vertices.ndim == 3:
|
|
||||||
vertices_one = vertices[0]
|
|
||||||
faces_one = faces[0]
|
|
||||||
else:
|
|
||||||
vertices_one = vertices
|
|
||||||
faces_one = faces
|
|
||||||
|
|
||||||
T = proj_feat_pack["transform_matrix"][batch_index:batch_index + 1]
|
|
||||||
cam_angle = proj_feat_pack["camera_angle_x"][batch_index:batch_index + 1]
|
|
||||||
mesh_scale = proj_feat_pack["mesh_scale"][batch_index]
|
|
||||||
image_resolution = int(proj_feat_pack.get("image_resolution", 1024))
|
|
||||||
crop_bbox = proj_feat_pack["crop_bboxes"][batch_index]
|
|
||||||
pack_scene_size = proj_feat_pack.get("scene_sizes", [None] * (batch_index + 1))[batch_index]
|
|
||||||
moge_points = moge_geometry["points"]
|
|
||||||
moge_mask = moge_geometry["mask"]
|
|
||||||
if moge_points.ndim != 4:
|
|
||||||
raise ValueError(f"MoGe points expected [B, H, W, 3]; got {tuple(moge_points.shape)}")
|
|
||||||
scene_H, scene_W = moge_points.shape[1], moge_points.shape[2]
|
|
||||||
if pack_scene_size is not None and pack_scene_size != (scene_W, scene_H):
|
|
||||||
raise ValueError(
|
|
||||||
f"Pixal3DAlignObject: MoGe geometry was computed on a {scene_W}x{scene_H} image, "
|
|
||||||
f"but the proj_feat_pack's bbox lives in a {pack_scene_size[0]}x{pack_scene_size[1]} "
|
|
||||||
"image. Run MoGe on the same resized scene image Pixal3DConditioning used."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Vertices come out of VaeDecodeShapeTrellis in the Pixal3D model frame
|
|
||||||
# (no un-rotation). Apply _PROJ_GRID_ROTATION = R_x(-90°) to map model
|
|
||||||
# frame → ProjGrid world: (X, Y, Z) -> (X, -Z, Y).
|
|
||||||
v = vertices_one.float()
|
|
||||||
verts_world = torch.stack([v[..., 0], -v[..., 2], v[..., 1]], dim=-1)
|
|
||||||
verts_world = verts_world / float(mesh_scale.item())
|
|
||||||
uv_crop, _depth, valid = _project_vertices_to_image_uv(
|
|
||||||
verts_world, T[0], cam_angle[0], image_resolution)
|
|
||||||
scene_pixels = _crop_uv_to_scene_pixels(uv_crop, crop_bbox, (scene_W, scene_H))
|
|
||||||
in_scene = ((scene_pixels[:, 0] >= 0) & (scene_pixels[:, 0] < scene_W) &
|
|
||||||
(scene_pixels[:, 1] >= 0) & (scene_pixels[:, 1] < scene_H))
|
|
||||||
# MoGe geometry and object_mask can land on CPU after passing between nodes;
|
|
||||||
# match the indexed tensor's device for sy/sx so the gather works on either.
|
|
||||||
moge_points = moge_points.to(scene_pixels.device)
|
|
||||||
moge_mask = moge_mask.to(scene_pixels.device)
|
|
||||||
sx = scene_pixels[:, 0].long().clamp(0, scene_W - 1)
|
|
||||||
sy = scene_pixels[:, 1].long().clamp(0, scene_H - 1)
|
|
||||||
moge_per_vertex = moge_points[batch_index, sy, sx]
|
|
||||||
# MoGe's perspective output is (X right, Y down, Z forward). Convert to glTF
|
|
||||||
# Y-up (X right, Y up, Z back) so the scale/translate fit runs in the same
|
|
||||||
# frame as vertices_one (Pixal3D model frame = glTF Y-up).
|
|
||||||
moge_per_vertex = moge_per_vertex * torch.tensor(
|
|
||||||
[1.0, -1.0, -1.0], dtype=moge_per_vertex.dtype, device=moge_per_vertex.device
|
|
||||||
)
|
|
||||||
moge_mask_per_vertex = moge_mask[batch_index, sy, sx]
|
|
||||||
keep = valid & in_scene & moge_mask_per_vertex
|
|
||||||
if object_mask is not None:
|
|
||||||
om = object_mask if object_mask.ndim == 2 else object_mask[batch_index]
|
|
||||||
om = om.to(sy.device)
|
|
||||||
keep = keep & (om[sy, sx] > 0.5)
|
|
||||||
|
|
||||||
finite = torch.isfinite(moge_per_vertex).all(dim=-1)
|
|
||||||
keep = keep & finite
|
|
||||||
|
|
||||||
kept = int(keep.sum().item())
|
|
||||||
if kept < 8:
|
|
||||||
scale = 1.0
|
|
||||||
aligned = vertices_one
|
|
||||||
else:
|
|
||||||
P = vertices_one[keep].float()
|
|
||||||
Q = moge_per_vertex[keep].float()
|
|
||||||
p_mean = P.mean(dim=0, keepdim=True)
|
|
||||||
q_mean = Q.mean(dim=0, keepdim=True)
|
|
||||||
P_c = P - p_mean
|
|
||||||
Q_c = Q - q_mean
|
|
||||||
p_var = (P_c * P_c).sum().clamp(min=1e-8)
|
|
||||||
q_var = (Q_c * Q_c).sum()
|
|
||||||
scale = float(torch.sqrt(q_var / p_var).item())
|
|
||||||
t = q_mean - scale * p_mean
|
|
||||||
aligned = scale * vertices_one + t
|
|
||||||
|
|
||||||
if vertices.ndim == 3:
|
|
||||||
aligned = aligned.unsqueeze(0)
|
|
||||||
out_mesh = Types.MESH(vertices=aligned.cpu(), faces=faces.cpu())
|
|
||||||
else:
|
|
||||||
out_mesh = Types.MESH(vertices=aligned.cpu(), faces=faces_one.cpu())
|
|
||||||
return IO.NodeOutput(out_mesh, float(scale))
|
|
||||||
|
|
||||||
|
|
||||||
class LoadNAFModel(IO.ComfyNode):
|
class LoadNAFModel(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1240,7 +1096,6 @@ class Trellis2Extension(ComfyExtension):
|
|||||||
return [
|
return [
|
||||||
Trellis2Conditioning,
|
Trellis2Conditioning,
|
||||||
Pixal3DConditioning,
|
Pixal3DConditioning,
|
||||||
Pixal3DAlignObject,
|
|
||||||
LoadNAFModel,
|
LoadNAFModel,
|
||||||
Trellis2ShapeStage,
|
Trellis2ShapeStage,
|
||||||
EmptyTrellis2LatentStructure,
|
EmptyTrellis2LatentStructure,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user