PBR baking

This commit is contained in:
kijai 2026-06-10 10:30:41 +03:00
parent a4c8b5064b
commit 1697da460b
4 changed files with 1398 additions and 26 deletions

View File

@ -17,6 +17,7 @@ class MESH:
uvs: torch.Tensor | None = None, uvs: torch.Tensor | None = None,
vertex_colors: torch.Tensor | None = None, vertex_colors: torch.Tensor | None = None,
texture: torch.Tensor | None = None, texture: torch.Tensor | None = None,
metallic_roughness: torch.Tensor | None = None,
vertex_counts: torch.Tensor | None = None, vertex_counts: torch.Tensor | None = None,
face_counts: torch.Tensor | None = None): face_counts: torch.Tensor | None = None):
@ -26,7 +27,9 @@ class MESH:
self.faces = faces # faces: (B, M, 3) self.faces = faces # faces: (B, M, 3)
self.uvs = uvs # uvs: (B, N, 2) self.uvs = uvs # uvs: (B, N, 2)
self.vertex_colors = vertex_colors # vertex_colors: (B, N, 3 or 4) self.vertex_colors = vertex_colors # vertex_colors: (B, N, 3 or 4)
self.texture = texture # texture: (B, H, W, 3) self.texture = texture # texture (baseColor): (B, H, W, 3)
# glTF metallicRoughness texture: (B, H, W, 3), R unused, G=roughness, B=metallic
self.metallic_roughness = metallic_roughness
# When vertices/faces are zero-padded to a common N/M across the batch (variable-size mesh batch), # When vertices/faces are zero-padded to a common N/M across the batch (variable-size mesh batch),
# these hold the real per-item lengths (B,). None means rows are uniform and no slicing is needed. # these hold the real per-item lengths (B,). None means rows are uniform and no slicing is needed.
self.vertex_counts = vertex_counts self.vertex_counts = vertex_counts

File diff suppressed because it is too large Load Diff

View File

@ -80,7 +80,8 @@ def get_mesh_batch_item(mesh, index):
def save_glb(vertices, faces, filepath, metadata=None, def save_glb(vertices, faces, filepath, metadata=None,
uvs=None, vertex_colors=None, texture_image=None): uvs=None, vertex_colors=None, texture_image=None,
metallic_roughness_image=None):
""" """
Save PyTorch tensor vertices and faces as a GLB file without external dependencies. Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
@ -92,6 +93,8 @@ def save_glb(vertices, faces, filepath, metadata=None,
uvs: torch.Tensor of shape (N, 2) - Optional per-vertex texture coordinates uvs: torch.Tensor of shape (N, 2) - Optional per-vertex texture coordinates
vertex_colors: torch.Tensor of shape (N, 3) or (N, 4) - Optional per-vertex colors in [0, 1] vertex_colors: torch.Tensor of shape (N, 3) or (N, 4) - Optional per-vertex colors in [0, 1]
texture_image: PIL.Image - Optional baseColor texture, embedded as PNG texture_image: PIL.Image - Optional baseColor texture, embedded as PNG
metallic_roughness_image: PIL.Image - Optional glTF metallicRoughness texture
(R unused, G=roughness, B=metallic), embedded as PNG
""" """
# Convert tensors to numpy arrays # Convert tensors to numpy arrays
@ -126,12 +129,18 @@ def save_glb(vertices, faces, filepath, metadata=None,
buf = BytesIO() buf = BytesIO()
texture_image.save(buf, format="PNG") texture_image.save(buf, format="PNG")
texture_png_bytes = buf.getvalue() texture_png_bytes = buf.getvalue()
mr_png_bytes = None
if metallic_roughness_image is not None:
buf = BytesIO()
metallic_roughness_image.save(buf, format="PNG")
mr_png_bytes = buf.getvalue()
vertices_buffer = vertices_np.tobytes() vertices_buffer = vertices_np.tobytes()
indices_buffer = faces_np.tobytes() indices_buffer = faces_np.tobytes()
uvs_buffer = uvs_np.tobytes() if uvs_np is not None else b"" uvs_buffer = uvs_np.tobytes() if uvs_np is not None else b""
colors_buffer = colors_np.tobytes() if colors_np is not None else b"" colors_buffer = colors_np.tobytes() if colors_np is not None else b""
texture_buffer = texture_png_bytes if texture_png_bytes is not None else b"" texture_buffer = texture_png_bytes if texture_png_bytes is not None else b""
mr_buffer = mr_png_bytes if mr_png_bytes is not None else b""
def pad_to_4_bytes(buffer): def pad_to_4_bytes(buffer):
padding_length = (4 - (len(buffer) % 4)) % 4 padding_length = (4 - (len(buffer) % 4)) % 4
@ -142,6 +151,7 @@ def save_glb(vertices, faces, filepath, metadata=None,
uvs_buffer_padded = pad_to_4_bytes(uvs_buffer) uvs_buffer_padded = pad_to_4_bytes(uvs_buffer)
colors_buffer_padded = pad_to_4_bytes(colors_buffer) colors_buffer_padded = pad_to_4_bytes(colors_buffer)
texture_buffer_padded = pad_to_4_bytes(texture_buffer) texture_buffer_padded = pad_to_4_bytes(texture_buffer)
mr_buffer_padded = pad_to_4_bytes(mr_buffer)
buffer_data = b"".join([ buffer_data = b"".join([
vertices_buffer_padded, vertices_buffer_padded,
@ -149,6 +159,7 @@ def save_glb(vertices, faces, filepath, metadata=None,
uvs_buffer_padded, uvs_buffer_padded,
colors_buffer_padded, colors_buffer_padded,
texture_buffer_padded, texture_buffer_padded,
mr_buffer_padded,
]) ])
vertices_byte_length = len(vertices_buffer) vertices_byte_length = len(vertices_buffer)
@ -158,6 +169,7 @@ def save_glb(vertices, faces, filepath, metadata=None,
uvs_byte_offset = indices_byte_offset + len(indices_buffer_padded) uvs_byte_offset = indices_byte_offset + len(indices_buffer_padded)
colors_byte_offset = uvs_byte_offset + len(uvs_buffer_padded) colors_byte_offset = uvs_byte_offset + len(uvs_buffer_padded)
texture_byte_offset = colors_byte_offset + len(colors_buffer_padded) texture_byte_offset = colors_byte_offset + len(colors_buffer_padded)
mr_byte_offset = texture_byte_offset + len(texture_buffer_padded)
buffer_views = [ buffer_views = [
{ {
@ -251,8 +263,24 @@ def save_glb(vertices, faces, filepath, metadata=None,
}) })
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"}) images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071}) samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
textures.append({"source": 0, "sampler": 0}) textures.append({"source": len(images) - 1, "sampler": 0})
pbr["baseColorTexture"] = {"index": 0, "texCoord": 0} pbr["baseColorTexture"] = {"index": len(textures) - 1, "texCoord": 0}
if mr_png_bytes is not None and "TEXCOORD_0" in primitive_attributes:
buffer_views.append({
"buffer": 0,
"byteOffset": mr_byte_offset,
"byteLength": len(mr_buffer),
})
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
if not samplers:
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
textures.append({"source": len(images) - 1, "sampler": 0})
pbr["metallicRoughnessTexture"] = {"index": len(textures) - 1, "texCoord": 0}
# When a metallicRoughness texture is present, the factors scale it; use 1.0
# so the texture values pass through unchanged (glTF convention).
pbr["metallicFactor"] = 1.0
pbr["roughnessFactor"] = 1.0
materials.append({ materials.append({
"pbrMetallicRoughness": pbr, "pbrMetallicRoughness": pbr,
@ -373,12 +401,20 @@ class SaveGLB(IO.ComfyNode):
assert texture_np.ndim == 4 and texture_np.shape[-1] == 3, ( assert texture_np.ndim == 4 and texture_np.shape[-1] == 3, (
f"texture must be (B, H, W, 3) RGB, got shape {tuple(texture_np.shape)}" f"texture must be (B, H, W, 3) RGB, got shape {tuple(texture_np.shape)}"
) )
mr_b = getattr(mesh, "metallic_roughness", None)
mr_np = None
if mr_b is not None:
mr_np = (mr_b.clamp(0.0, 1.0).cpu().numpy() * 255).astype(np.uint8)
assert mr_np.ndim == 4 and mr_np.shape[-1] == 3, (
f"metallic_roughness must be (B, H, W, 3), got shape {tuple(mr_np.shape)}"
)
for i in range(mesh.vertices.shape[0]): for i in range(mesh.vertices.shape[0]):
vertices_i, faces_i, v_colors, uvs_i = get_mesh_batch_item(mesh, i) vertices_i, faces_i, v_colors, uvs_i = get_mesh_batch_item(mesh, i)
if vertices_i.shape[0] == 0 or faces_i.shape[0] == 0: if vertices_i.shape[0] == 0 or faces_i.shape[0] == 0:
logging.warning(f"SaveGLB: skipping empty mesh at batch index {i}") logging.warning(f"SaveGLB: skipping empty mesh at batch index {i}")
continue continue
tex_img = Image.fromarray(texture_np[i], mode="RGB") if texture_np is not None else None tex_img = Image.fromarray(texture_np[i], mode="RGB") if texture_np is not None else None
mr_img = Image.fromarray(mr_np[i], mode="RGB") if mr_np is not None else None
f = f"{filename}_{counter:05}_.glb" f = f"{filename}_{counter:05}_.glb"
save_glb( save_glb(
vertices_i, faces_i, vertices_i, faces_i,
@ -387,6 +423,7 @@ class SaveGLB(IO.ComfyNode):
uvs=uvs_i, uvs=uvs_i,
vertex_colors=v_colors, vertex_colors=v_colors,
texture_image=tex_img, texture_image=tex_img,
metallic_roughness_image=mr_img,
) )
results.append({ results.append({
"filename": f, "filename": f,

View File

@ -9,8 +9,10 @@ from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
import comfy.model_management import comfy.model_management
import comfy.utils import comfy.utils
import folder_paths import folder_paths
from comfy.ldm.trellis2 import sampling_preview
from PIL import Image from PIL import Image
import logging import logging
import os
import numpy as np import numpy as np
import math import math
import torch import torch
@ -19,6 +21,89 @@ ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
NAFModel = io.Custom("NAF_MODEL") NAFModel = io.Custom("NAF_MODEL")
# Texture latent -> base-color calibration for the per-step preview
def _tex_rgb_factors_path():
return os.path.join(folder_paths.get_folder_paths("vae_approx")[0], "trellis2_tex_rgb_factors.pt")
def _pool_albedo_to_input(in_coords, out_coords, out_colors):
in_sp = in_coords[:, 1:4].long()
out_sp = out_coords[:, 1:4].long()
in_b = in_coords[:, 0].long()
out_b = out_coords[:, 0].long()
in_res = int(in_sp.max().item()) + 1
out_res = int(out_sp.max().item()) + 1
parent = torch.floor(out_sp.float() * in_res / out_res).long().clamp(0, in_res - 1)
R = in_res
in_flat = ((in_b * R + in_sp[:, 0]) * R + in_sp[:, 1]) * R + in_sp[:, 2]
par_flat = ((out_b * R + parent[:, 0]) * R + parent[:, 1]) * R + parent[:, 2]
order = torch.argsort(in_flat)
in_sorted = in_flat[order]
pos = torch.searchsorted(in_sorted, par_flat).clamp(max=in_sorted.numel() - 1)
matched = in_sorted[pos] == par_flat
in_idx = order[pos][matched]
cols = out_colors[matched].float()
N = in_coords.shape[0]
csum = cols.new_zeros((N, 3))
ccount = cols.new_zeros((N, 1))
csum.index_add_(0, in_idx, cols)
ccount.index_add_(0, in_idx, torch.ones((in_idx.shape[0], 1), device=cols.device, dtype=cols.dtype))
valid = ccount[:, 0] > 0
albedo = torch.zeros_like(csum)
albedo[valid] = csum[valid] / ccount[valid]
return albedo, valid
def _calibrate_tex_rgb(in_latent, in_coords, out_colors, out_coords):
"""Accumulate one decode's (latent -> albedo) evidence, re-solve, persist, publish."""
try:
dev = out_colors.device
in_latent = in_latent.to(dev)
in_coords = in_coords.to(dev)
out_coords = out_coords.to(dev)
albedo, valid = _pool_albedo_to_input(in_coords, out_coords, out_colors)
X = in_latent[valid].float().cpu()
Y = albedo[valid].float().cpu()
if X.shape[0] < 64:
return
Xaug = torch.cat([X, torch.ones(X.shape[0], 1)], dim=1) # [K, C+1]
A_run = Xaug.transpose(0, 1) @ Xaug # [C+1, C+1]
B_run = Xaug.transpose(0, 1) @ Y # [C+1, 3]
path = _tex_rgb_factors_path()
if os.path.exists(path):
try:
prev = torch.load(path, map_location="cpu")
A_run = A_run + prev["A"]
B_run = B_run + prev["B"]
except Exception:
pass
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({"A": A_run, "B": B_run}, path)
eye = torch.eye(A_run.shape[0])
WB = torch.linalg.solve(A_run + 1e-3 * eye, B_run) # [C+1, 3]
W, b = WB[:-1].contiguous(), WB[-1].contiguous()
sampling_preview.set_tex_rgb(W, b)
except Exception as e:
logging.debug(f"Trellis2 tex-rgb calibration skipped: {e}")
def _load_tex_rgb_factors():
try:
path = _tex_rgb_factors_path()
if os.path.exists(path):
d = torch.load(path, map_location="cpu")
eye = torch.eye(d["A"].shape[0])
WB = torch.linalg.solve(d["A"] + 1e-3 * eye, d["B"])
sampling_preview.set_tex_rgb(WB[:-1].contiguous(), WB[-1].contiguous())
except Exception as e:
logging.debug(f"Trellis2 tex-rgb factor load skipped: {e}")
_load_tex_rgb_factors()
def prepare_trellis_vae_for_decode(vae, sample_shape): def prepare_trellis_vae_for_decode(vae, sample_shape):
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype) memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
if len(sample_shape) == 5: if len(sample_shape) == 5:
@ -174,6 +259,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
resolution = int(coord_resolution) * 16 resolution = int(coord_resolution) * 16
else: else:
resolution = int(vae.first_stage_model.resolution.item()) resolution = int(vae.first_stage_model.resolution.item())
model_frame = samples.get("model_frame", "y_up")
sample_tensor = samples["samples"] sample_tensor = samples["samples"]
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
coords = samples["coords"] coords = samples["coords"]
@ -205,8 +291,13 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
coords_list = [stage_tensor.coords for stage_tensor in stage_tensors] coords_list = [stage_tensor.coords for stage_tensor in stage_tensors]
subs.append(SparseTensor.from_tensor_list(feats_list, coords_list)) subs.append(SparseTensor.from_tensor_list(feats_list, coords_list))
vert_list = [v.float() for v, f in mesh] # Rotate Z-up (Trellis2 training frame) vertices to glTF Y-up. Pixal3D outputs are already Y-up.
face_list = [f.int() for v, f in mesh] if model_frame == "z_up":
vert_list = [torch.stack([v[..., 0], v[..., 2], -v[..., 1]], dim=-1).float().cpu()
for v, _ in mesh]
else:
vert_list = [v.float().cpu() for v, _ in mesh]
face_list = [f.int().cpu() for _, f in mesh]
if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list): if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list):
mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list)) mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list))
else: else:
@ -241,19 +332,32 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
prepare_trellis_vae_for_decode(vae, sample_tensor.shape) prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
trellis_vae = vae.first_stage_model trellis_vae = vae.first_stage_model
coord_counts = samples.get("coord_counts") coord_counts = samples.get("coord_counts")
model_frame = samples.get("model_frame", "y_up")
samples = samples["samples"] samples = samples["samples"]
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts) samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
samples = samples.to(device) samples = samples.to(device)
cal_in_latent = samples # [N, C] pre-denorm latent, for tex-rgb preview calibration
cal_in_coords = coords
std = tex_slat_normalization["std"].to(samples) std = tex_slat_normalization["std"].to(samples)
mean = tex_slat_normalization["mean"].to(samples) mean = tex_slat_normalization["mean"].to(samples)
samples = SparseTensor(feats = samples, coords=coords.to(device)) samples = SparseTensor(feats = samples, coords=coords.to(device))
samples = samples * std + mean samples = samples * std + mean
voxel = trellis_vae.decode_tex_slat(samples, shape_subdivides) voxel = trellis_vae.decode_tex_slat(samples, shape_subdivides)
color_feats = voxel.feats[:, :3] # Keep all decoded channels. The texture VAE emits 6: base_color (0:3),
# metallic (3), roughness (4), alpha (5) — all in [0, 1]. Vertex-color
# consumers (PaintMesh) slice [:3]; BakeTextureFromVoxel uses the full
# PBR set. Older 3-channel checkpoints pass through unchanged.
color_feats = voxel.feats
voxel_coords = voxel.coords voxel_coords = voxel.coords
# Calibrate the latent->base_color map for the per-step texture preview.
# Done here while input coords and voxel_coords share the model frame
# (before the z_up remap below) and on the real decoded albedo.
if color_feats.shape[0] > 0 and color_feats.shape[-1] >= 3:
_calibrate_tex_rgb(cal_in_latent, cal_in_coords, color_feats[:, :3], voxel_coords)
if voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3: if voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3:
spatial = voxel_coords[:, -3:] if voxel_coords.shape[-1] == 4 else voxel_coords spatial = voxel_coords[:, -3:] if voxel_coords.shape[-1] == 4 else voxel_coords
max_idx = int(spatial.max().item()) + 1 max_idx = int(spatial.max().item()) + 1
@ -261,6 +365,24 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
else: else:
tex_resolution = 1024 tex_resolution = 1024
# Remap Z-up voxel coords to Y-up: (x, y, z) -> (x, z, R-1-y), matching the
# R_x(-90°) applied to mesh vertices in VaeDecodeShapeTrellis. Keeps PaintMesh's
# NN lookup correctly aligned without it needing to know the source frame.
if model_frame == "z_up" and voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3:
R = tex_resolution
if voxel_coords.shape[-1] == 4:
batch_col = voxel_coords[:, :1]
spatial = voxel_coords[:, 1:]
spatial_yup = torch.stack(
[spatial[:, 0], spatial[:, 2], (R - 1) - spatial[:, 1]], dim=-1
)
voxel_coords = torch.cat([batch_col, spatial_yup], dim=-1)
else:
voxel_coords = torch.stack(
[voxel_coords[:, 0], voxel_coords[:, 2], (R - 1) - voxel_coords[:, 1]],
dim=-1,
)
voxel = Types.VOXEL(voxel_coords, color_feats, tex_resolution) voxel = Types.VOXEL(voxel_coords, color_feats, tex_resolution)
return IO.NodeOutput(voxel) return IO.NodeOutput(voxel)
@ -425,7 +547,9 @@ class Trellis2UpsampleStage(IO.ComfyNode):
positive_out = _conditioning_set_extras(positive, extras) positive_out = _conditioning_set_extras(positive, extras)
negative_out = _conditioning_set_extras(negative, extras) negative_out = _conditioning_set_extras(negative, extras)
out_latent = {"samples": latent, "coords": coords, "coord_counts": counts, out_latent = {"samples": latent, "coords": coords, "coord_counts": counts,
"coord_resolution": coord_resolution, "type": "trellis2"} "coord_resolution": coord_resolution, "type": "trellis2",
"model_frame": shape_latent.get("model_frame",
"y_up" if proj_pack is not None else "z_up")}
return IO.NodeOutput(positive_out, negative_out, out_latent) return IO.NodeOutput(positive_out, negative_out, out_latent)
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
@ -694,7 +818,8 @@ class Trellis2ShapeStage(IO.ComfyNode):
positive_out = _conditioning_set_extras(positive, extras) positive_out = _conditioning_set_extras(positive, extras)
negative_out = _conditioning_set_extras(negative, extras) negative_out = _conditioning_set_extras(negative, extras)
out_latent = {"samples": latent, "coords": coords, "coord_counts": counts, out_latent = {"samples": latent, "coords": coords, "coord_counts": counts,
"coord_resolution": coord_resolution, "type": "trellis2"} "coord_resolution": coord_resolution, "type": "trellis2",
"model_frame": "y_up" if proj_pack is not None else "z_up"}
return IO.NodeOutput(positive_out, negative_out, out_latent) return IO.NodeOutput(positive_out, negative_out, out_latent)
class Trellis2TextureStage(IO.ComfyNode): class Trellis2TextureStage(IO.ComfyNode):
@ -747,7 +872,9 @@ class Trellis2TextureStage(IO.ComfyNode):
positive_out = _conditioning_set_extras(positive, extras) positive_out = _conditioning_set_extras(positive, extras)
negative_out = _conditioning_set_extras(negative, extras) negative_out = _conditioning_set_extras(negative, extras)
out_latent = {"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts} out_latent = {"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts,
"model_frame": shape_latent.get("model_frame",
"y_up" if proj_pack is not None else "z_up")}
if coord_resolution is not None: if coord_resolution is not None:
out_latent["coord_resolution"] = coord_resolution out_latent["coord_resolution"] = coord_resolution
return IO.NodeOutput(positive_out, negative_out, out_latent) return IO.NodeOutput(positive_out, negative_out, out_latent)
@ -1018,6 +1145,8 @@ def _project_vertices_to_image_uv(vertices_world, transform_matrix, camera_angle
points = vertices_world.unsqueeze(0).float() points = vertices_world.unsqueeze(0).float()
T = transform_matrix.unsqueeze(0).float() if transform_matrix.ndim == 2 else transform_matrix.float() T = transform_matrix.unsqueeze(0).float() if transform_matrix.ndim == 2 else transform_matrix.float()
cam = camera_angle_x.unsqueeze(0) if camera_angle_x.ndim == 0 else camera_angle_x cam = camera_angle_x.unsqueeze(0) if camera_angle_x.ndim == 0 else camera_angle_x
T = T.to(points.device)
cam = cam.to(points.device)
uv_pix, depth, valid = _project_points_to_image(points, T, cam.float(), image_resolution) uv_pix, depth, valid = _project_points_to_image(points, T, cam.float(), image_resolution)
uv = uv_pix.squeeze(0) / image_resolution uv = uv_pix.squeeze(0) / image_resolution
return uv, depth.squeeze(0), valid.squeeze(0) return uv, depth.squeeze(0), valid.squeeze(0)
@ -1108,13 +1237,25 @@ class Pixal3DAlignObject(IO.ComfyNode):
scene_pixels = _crop_uv_to_scene_pixels(uv_crop, crop_bbox, (scene_W, scene_H)) scene_pixels = _crop_uv_to_scene_pixels(uv_crop, crop_bbox, (scene_W, scene_H))
in_scene = ((scene_pixels[:, 0] >= 0) & (scene_pixels[:, 0] < scene_W) & in_scene = ((scene_pixels[:, 0] >= 0) & (scene_pixels[:, 0] < scene_W) &
(scene_pixels[:, 1] >= 0) & (scene_pixels[:, 1] < scene_H)) (scene_pixels[:, 1] >= 0) & (scene_pixels[:, 1] < scene_H))
# MoGe geometry and object_mask can land on CPU after passing between nodes;
# match the indexed tensor's device for sy/sx so the gather works on either.
moge_points = moge_points.to(scene_pixels.device)
moge_mask = moge_mask.to(scene_pixels.device)
sx = scene_pixels[:, 0].long().clamp(0, scene_W - 1) sx = scene_pixels[:, 0].long().clamp(0, scene_W - 1)
sy = scene_pixels[:, 1].long().clamp(0, scene_H - 1) sy = scene_pixels[:, 1].long().clamp(0, scene_H - 1)
moge_per_vertex = moge_points[batch_index, sy, sx] moge_per_vertex = moge_points[batch_index, sy, sx]
# MoGe's perspective output is (X right, Y down, Z forward). Convert to glTF
# Y-up (X right, Y up, Z back) so the scale/translate fit runs in the same
# frame as vertices_one (Pixal3D model frame = glTF Y-up). Mirrors the
# `verts * [1, -1, -1]` step in MoGePointMapToMesh.
moge_per_vertex = moge_per_vertex * torch.tensor(
[1.0, -1.0, -1.0], dtype=moge_per_vertex.dtype, device=moge_per_vertex.device
)
moge_mask_per_vertex = moge_mask[batch_index, sy, sx] moge_mask_per_vertex = moge_mask[batch_index, sy, sx]
keep = valid & in_scene & moge_mask_per_vertex keep = valid & in_scene & moge_mask_per_vertex
if object_mask is not None: if object_mask is not None:
om = object_mask if object_mask.ndim == 2 else object_mask[batch_index] om = object_mask if object_mask.ndim == 2 else object_mask[batch_index]
om = om.to(sy.device)
keep = keep & (om[sy, sx] > 0.5) keep = keep & (om[sy, sx] > 0.5)
finite = torch.isfinite(moge_per_vertex).all(dim=-1) finite = torch.isfinite(moge_per_vertex).all(dim=-1)
@ -1131,25 +1272,23 @@ class Pixal3DAlignObject(IO.ComfyNode):
q_mean = Q.mean(dim=0, keepdim=True) q_mean = Q.mean(dim=0, keepdim=True)
P_c = P - p_mean P_c = P - p_mean
Q_c = Q - q_mean Q_c = Q - q_mean
num = (P_c * Q_c).sum() # Rotation-invariant scale: ratio of RMS spreads. MoGe geometry is
den = (P_c * P_c).sum().clamp(min=1e-8) # noisy and Pixal3D's mesh frame can be yawed relative to MoGe (paper
scale = float((num / den).item()) # acknowledges this), so the L2-optimal scalar (P_c · Q_c)/(P_c · P_c)
if not (scale > 0): # gets multiplied by cos(yaw) and shrinks the object. Using
# Negative scale would mirror the mesh; treat as a camera-convention mismatch. # sqrt(||Q_c||² / ||P_c||²) recovers the right size regardless of
logging.warning( # rotation; translation still positions the mesh at MoGe's centroid.
f"Pixal3DAlignObject: computed scale={scale:.4f} <= 0; " p_var = (P_c * P_c).sum().clamp(min=1e-8)
"refusing to apply mirroring. Check camera convention alignment.") q_var = (Q_c * Q_c).sum()
scale = 1.0 scale = float(torch.sqrt(q_var / p_var).item())
aligned = vertices_one
else:
t = q_mean - scale * p_mean t = q_mean - scale * p_mean
aligned = scale * vertices_one + t aligned = scale * vertices_one + t
if vertices.ndim == 3: if vertices.ndim == 3:
aligned = aligned.unsqueeze(0) aligned = aligned.unsqueeze(0)
out_mesh = Types.MESH(vertices=aligned, faces=faces) out_mesh = Types.MESH(vertices=aligned.cpu(), faces=faces.cpu())
else: else:
out_mesh = Types.MESH(vertices=aligned, faces=faces_one) out_mesh = Types.MESH(vertices=aligned.cpu(), faces=faces_one.cpu())
return IO.NodeOutput(out_mesh, float(scale)) return IO.NodeOutput(out_mesh, float(scale))