mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-05 22:21:31 +08:00
PBR baking
This commit is contained in:
parent
a4c8b5064b
commit
1697da460b
@ -17,6 +17,7 @@ class MESH:
|
||||
uvs: torch.Tensor | None = None,
|
||||
vertex_colors: torch.Tensor | None = None,
|
||||
texture: torch.Tensor | None = None,
|
||||
metallic_roughness: torch.Tensor | None = None,
|
||||
vertex_counts: torch.Tensor | None = None,
|
||||
face_counts: torch.Tensor | None = None):
|
||||
|
||||
@ -26,7 +27,9 @@ class MESH:
|
||||
self.faces = faces # faces: (B, M, 3)
|
||||
self.uvs = uvs # uvs: (B, N, 2)
|
||||
self.vertex_colors = vertex_colors # vertex_colors: (B, N, 3 or 4)
|
||||
self.texture = texture # texture: (B, H, W, 3)
|
||||
self.texture = texture # texture (baseColor): (B, H, W, 3)
|
||||
# glTF metallicRoughness texture: (B, H, W, 3), R unused, G=roughness, B=metallic
|
||||
self.metallic_roughness = metallic_roughness
|
||||
# When vertices/faces are zero-padded to a common N/M across the batch (variable-size mesh batch),
|
||||
# these hold the real per-item lengths (B,). None means rows are uniform and no slicing is needed.
|
||||
self.vertex_counts = vertex_counts
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -80,7 +80,8 @@ def get_mesh_batch_item(mesh, index):
|
||||
|
||||
|
||||
def save_glb(vertices, faces, filepath, metadata=None,
|
||||
uvs=None, vertex_colors=None, texture_image=None):
|
||||
uvs=None, vertex_colors=None, texture_image=None,
|
||||
metallic_roughness_image=None):
|
||||
"""
|
||||
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
|
||||
|
||||
@ -92,6 +93,8 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
||||
uvs: torch.Tensor of shape (N, 2) - Optional per-vertex texture coordinates
|
||||
vertex_colors: torch.Tensor of shape (N, 3) or (N, 4) - Optional per-vertex colors in [0, 1]
|
||||
texture_image: PIL.Image - Optional baseColor texture, embedded as PNG
|
||||
metallic_roughness_image: PIL.Image - Optional glTF metallicRoughness texture
|
||||
(R unused, G=roughness, B=metallic), embedded as PNG
|
||||
"""
|
||||
|
||||
# Convert tensors to numpy arrays
|
||||
@ -126,12 +129,18 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
||||
buf = BytesIO()
|
||||
texture_image.save(buf, format="PNG")
|
||||
texture_png_bytes = buf.getvalue()
|
||||
mr_png_bytes = None
|
||||
if metallic_roughness_image is not None:
|
||||
buf = BytesIO()
|
||||
metallic_roughness_image.save(buf, format="PNG")
|
||||
mr_png_bytes = buf.getvalue()
|
||||
|
||||
vertices_buffer = vertices_np.tobytes()
|
||||
indices_buffer = faces_np.tobytes()
|
||||
uvs_buffer = uvs_np.tobytes() if uvs_np is not None else b""
|
||||
colors_buffer = colors_np.tobytes() if colors_np is not None else b""
|
||||
texture_buffer = texture_png_bytes if texture_png_bytes is not None else b""
|
||||
mr_buffer = mr_png_bytes if mr_png_bytes is not None else b""
|
||||
|
||||
def pad_to_4_bytes(buffer):
|
||||
padding_length = (4 - (len(buffer) % 4)) % 4
|
||||
@ -142,6 +151,7 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
||||
uvs_buffer_padded = pad_to_4_bytes(uvs_buffer)
|
||||
colors_buffer_padded = pad_to_4_bytes(colors_buffer)
|
||||
texture_buffer_padded = pad_to_4_bytes(texture_buffer)
|
||||
mr_buffer_padded = pad_to_4_bytes(mr_buffer)
|
||||
|
||||
buffer_data = b"".join([
|
||||
vertices_buffer_padded,
|
||||
@ -149,6 +159,7 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
||||
uvs_buffer_padded,
|
||||
colors_buffer_padded,
|
||||
texture_buffer_padded,
|
||||
mr_buffer_padded,
|
||||
])
|
||||
|
||||
vertices_byte_length = len(vertices_buffer)
|
||||
@ -158,6 +169,7 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
||||
uvs_byte_offset = indices_byte_offset + len(indices_buffer_padded)
|
||||
colors_byte_offset = uvs_byte_offset + len(uvs_buffer_padded)
|
||||
texture_byte_offset = colors_byte_offset + len(colors_buffer_padded)
|
||||
mr_byte_offset = texture_byte_offset + len(texture_buffer_padded)
|
||||
|
||||
buffer_views = [
|
||||
{
|
||||
@ -251,8 +263,24 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
||||
})
|
||||
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
|
||||
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
|
||||
textures.append({"source": 0, "sampler": 0})
|
||||
pbr["baseColorTexture"] = {"index": 0, "texCoord": 0}
|
||||
textures.append({"source": len(images) - 1, "sampler": 0})
|
||||
pbr["baseColorTexture"] = {"index": len(textures) - 1, "texCoord": 0}
|
||||
|
||||
if mr_png_bytes is not None and "TEXCOORD_0" in primitive_attributes:
|
||||
buffer_views.append({
|
||||
"buffer": 0,
|
||||
"byteOffset": mr_byte_offset,
|
||||
"byteLength": len(mr_buffer),
|
||||
})
|
||||
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
|
||||
if not samplers:
|
||||
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
|
||||
textures.append({"source": len(images) - 1, "sampler": 0})
|
||||
pbr["metallicRoughnessTexture"] = {"index": len(textures) - 1, "texCoord": 0}
|
||||
# When a metallicRoughness texture is present, the factors scale it; use 1.0
|
||||
# so the texture values pass through unchanged (glTF convention).
|
||||
pbr["metallicFactor"] = 1.0
|
||||
pbr["roughnessFactor"] = 1.0
|
||||
|
||||
materials.append({
|
||||
"pbrMetallicRoughness": pbr,
|
||||
@ -373,12 +401,20 @@ class SaveGLB(IO.ComfyNode):
|
||||
assert texture_np.ndim == 4 and texture_np.shape[-1] == 3, (
|
||||
f"texture must be (B, H, W, 3) RGB, got shape {tuple(texture_np.shape)}"
|
||||
)
|
||||
mr_b = getattr(mesh, "metallic_roughness", None)
|
||||
mr_np = None
|
||||
if mr_b is not None:
|
||||
mr_np = (mr_b.clamp(0.0, 1.0).cpu().numpy() * 255).astype(np.uint8)
|
||||
assert mr_np.ndim == 4 and mr_np.shape[-1] == 3, (
|
||||
f"metallic_roughness must be (B, H, W, 3), got shape {tuple(mr_np.shape)}"
|
||||
)
|
||||
for i in range(mesh.vertices.shape[0]):
|
||||
vertices_i, faces_i, v_colors, uvs_i = get_mesh_batch_item(mesh, i)
|
||||
if vertices_i.shape[0] == 0 or faces_i.shape[0] == 0:
|
||||
logging.warning(f"SaveGLB: skipping empty mesh at batch index {i}")
|
||||
continue
|
||||
tex_img = Image.fromarray(texture_np[i], mode="RGB") if texture_np is not None else None
|
||||
mr_img = Image.fromarray(mr_np[i], mode="RGB") if mr_np is not None else None
|
||||
f = f"{filename}_{counter:05}_.glb"
|
||||
save_glb(
|
||||
vertices_i, faces_i,
|
||||
@ -387,6 +423,7 @@ class SaveGLB(IO.ComfyNode):
|
||||
uvs=uvs_i,
|
||||
vertex_colors=v_colors,
|
||||
texture_image=tex_img,
|
||||
metallic_roughness_image=mr_img,
|
||||
)
|
||||
results.append({
|
||||
"filename": f,
|
||||
|
||||
@ -9,8 +9,10 @@ from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import folder_paths
|
||||
from comfy.ldm.trellis2 import sampling_preview
|
||||
from PIL import Image
|
||||
import logging
|
||||
import os
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
@ -19,6 +21,89 @@ ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
|
||||
NAFModel = io.Custom("NAF_MODEL")
|
||||
|
||||
|
||||
# Texture latent -> base-color calibration for the per-step preview
|
||||
def _tex_rgb_factors_path():
|
||||
return os.path.join(folder_paths.get_folder_paths("vae_approx")[0], "trellis2_tex_rgb_factors.pt")
|
||||
|
||||
|
||||
def _pool_albedo_to_input(in_coords, out_coords, out_colors):
|
||||
in_sp = in_coords[:, 1:4].long()
|
||||
out_sp = out_coords[:, 1:4].long()
|
||||
in_b = in_coords[:, 0].long()
|
||||
out_b = out_coords[:, 0].long()
|
||||
in_res = int(in_sp.max().item()) + 1
|
||||
out_res = int(out_sp.max().item()) + 1
|
||||
parent = torch.floor(out_sp.float() * in_res / out_res).long().clamp(0, in_res - 1)
|
||||
R = in_res
|
||||
in_flat = ((in_b * R + in_sp[:, 0]) * R + in_sp[:, 1]) * R + in_sp[:, 2]
|
||||
par_flat = ((out_b * R + parent[:, 0]) * R + parent[:, 1]) * R + parent[:, 2]
|
||||
order = torch.argsort(in_flat)
|
||||
in_sorted = in_flat[order]
|
||||
pos = torch.searchsorted(in_sorted, par_flat).clamp(max=in_sorted.numel() - 1)
|
||||
matched = in_sorted[pos] == par_flat
|
||||
in_idx = order[pos][matched]
|
||||
cols = out_colors[matched].float()
|
||||
N = in_coords.shape[0]
|
||||
csum = cols.new_zeros((N, 3))
|
||||
ccount = cols.new_zeros((N, 1))
|
||||
csum.index_add_(0, in_idx, cols)
|
||||
ccount.index_add_(0, in_idx, torch.ones((in_idx.shape[0], 1), device=cols.device, dtype=cols.dtype))
|
||||
valid = ccount[:, 0] > 0
|
||||
albedo = torch.zeros_like(csum)
|
||||
albedo[valid] = csum[valid] / ccount[valid]
|
||||
return albedo, valid
|
||||
|
||||
|
||||
def _calibrate_tex_rgb(in_latent, in_coords, out_colors, out_coords):
|
||||
"""Accumulate one decode's (latent -> albedo) evidence, re-solve, persist, publish."""
|
||||
try:
|
||||
dev = out_colors.device
|
||||
in_latent = in_latent.to(dev)
|
||||
in_coords = in_coords.to(dev)
|
||||
out_coords = out_coords.to(dev)
|
||||
albedo, valid = _pool_albedo_to_input(in_coords, out_coords, out_colors)
|
||||
X = in_latent[valid].float().cpu()
|
||||
Y = albedo[valid].float().cpu()
|
||||
if X.shape[0] < 64:
|
||||
return
|
||||
Xaug = torch.cat([X, torch.ones(X.shape[0], 1)], dim=1) # [K, C+1]
|
||||
A_run = Xaug.transpose(0, 1) @ Xaug # [C+1, C+1]
|
||||
B_run = Xaug.transpose(0, 1) @ Y # [C+1, 3]
|
||||
|
||||
path = _tex_rgb_factors_path()
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
prev = torch.load(path, map_location="cpu")
|
||||
A_run = A_run + prev["A"]
|
||||
B_run = B_run + prev["B"]
|
||||
except Exception:
|
||||
pass
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
torch.save({"A": A_run, "B": B_run}, path)
|
||||
|
||||
eye = torch.eye(A_run.shape[0])
|
||||
WB = torch.linalg.solve(A_run + 1e-3 * eye, B_run) # [C+1, 3]
|
||||
W, b = WB[:-1].contiguous(), WB[-1].contiguous()
|
||||
sampling_preview.set_tex_rgb(W, b)
|
||||
except Exception as e:
|
||||
logging.debug(f"Trellis2 tex-rgb calibration skipped: {e}")
|
||||
|
||||
|
||||
def _load_tex_rgb_factors():
|
||||
try:
|
||||
path = _tex_rgb_factors_path()
|
||||
if os.path.exists(path):
|
||||
d = torch.load(path, map_location="cpu")
|
||||
eye = torch.eye(d["A"].shape[0])
|
||||
WB = torch.linalg.solve(d["A"] + 1e-3 * eye, d["B"])
|
||||
sampling_preview.set_tex_rgb(WB[:-1].contiguous(), WB[-1].contiguous())
|
||||
except Exception as e:
|
||||
logging.debug(f"Trellis2 tex-rgb factor load skipped: {e}")
|
||||
|
||||
|
||||
_load_tex_rgb_factors()
|
||||
|
||||
|
||||
def prepare_trellis_vae_for_decode(vae, sample_shape):
|
||||
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
|
||||
if len(sample_shape) == 5:
|
||||
@ -174,6 +259,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
resolution = int(coord_resolution) * 16
|
||||
else:
|
||||
resolution = int(vae.first_stage_model.resolution.item())
|
||||
model_frame = samples.get("model_frame", "y_up")
|
||||
sample_tensor = samples["samples"]
|
||||
device = comfy.model_management.get_torch_device()
|
||||
coords = samples["coords"]
|
||||
@ -205,8 +291,13 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
coords_list = [stage_tensor.coords for stage_tensor in stage_tensors]
|
||||
subs.append(SparseTensor.from_tensor_list(feats_list, coords_list))
|
||||
|
||||
vert_list = [v.float() for v, f in mesh]
|
||||
face_list = [f.int() for v, f in mesh]
|
||||
# Rotate Z-up (Trellis2 training frame) vertices to glTF Y-up. Pixal3D outputs are already Y-up.
|
||||
if model_frame == "z_up":
|
||||
vert_list = [torch.stack([v[..., 0], v[..., 2], -v[..., 1]], dim=-1).float().cpu()
|
||||
for v, _ in mesh]
|
||||
else:
|
||||
vert_list = [v.float().cpu() for v, _ in mesh]
|
||||
face_list = [f.int().cpu() for _, f in mesh]
|
||||
if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list):
|
||||
mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list))
|
||||
else:
|
||||
@ -241,19 +332,32 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||
trellis_vae = vae.first_stage_model
|
||||
coord_counts = samples.get("coord_counts")
|
||||
model_frame = samples.get("model_frame", "y_up")
|
||||
|
||||
samples = samples["samples"]
|
||||
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
|
||||
samples = samples.to(device)
|
||||
cal_in_latent = samples # [N, C] pre-denorm latent, for tex-rgb preview calibration
|
||||
cal_in_coords = coords
|
||||
std = tex_slat_normalization["std"].to(samples)
|
||||
mean = tex_slat_normalization["mean"].to(samples)
|
||||
samples = SparseTensor(feats = samples, coords=coords.to(device))
|
||||
samples = samples * std + mean
|
||||
|
||||
voxel = trellis_vae.decode_tex_slat(samples, shape_subdivides)
|
||||
color_feats = voxel.feats[:, :3]
|
||||
# Keep all decoded channels. The texture VAE emits 6: base_color (0:3),
|
||||
# metallic (3), roughness (4), alpha (5) — all in [0, 1]. Vertex-color
|
||||
# consumers (PaintMesh) slice [:3]; BakeTextureFromVoxel uses the full
|
||||
# PBR set. Older 3-channel checkpoints pass through unchanged.
|
||||
color_feats = voxel.feats
|
||||
voxel_coords = voxel.coords
|
||||
|
||||
# Calibrate the latent->base_color map for the per-step texture preview.
|
||||
# Done here while input coords and voxel_coords share the model frame
|
||||
# (before the z_up remap below) and on the real decoded albedo.
|
||||
if color_feats.shape[0] > 0 and color_feats.shape[-1] >= 3:
|
||||
_calibrate_tex_rgb(cal_in_latent, cal_in_coords, color_feats[:, :3], voxel_coords)
|
||||
|
||||
if voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3:
|
||||
spatial = voxel_coords[:, -3:] if voxel_coords.shape[-1] == 4 else voxel_coords
|
||||
max_idx = int(spatial.max().item()) + 1
|
||||
@ -261,6 +365,24 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
else:
|
||||
tex_resolution = 1024
|
||||
|
||||
# Remap Z-up voxel coords to Y-up: (x, y, z) -> (x, z, R-1-y), matching the
|
||||
# R_x(-90°) applied to mesh vertices in VaeDecodeShapeTrellis. Keeps PaintMesh's
|
||||
# NN lookup correctly aligned without it needing to know the source frame.
|
||||
if model_frame == "z_up" and voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3:
|
||||
R = tex_resolution
|
||||
if voxel_coords.shape[-1] == 4:
|
||||
batch_col = voxel_coords[:, :1]
|
||||
spatial = voxel_coords[:, 1:]
|
||||
spatial_yup = torch.stack(
|
||||
[spatial[:, 0], spatial[:, 2], (R - 1) - spatial[:, 1]], dim=-1
|
||||
)
|
||||
voxel_coords = torch.cat([batch_col, spatial_yup], dim=-1)
|
||||
else:
|
||||
voxel_coords = torch.stack(
|
||||
[voxel_coords[:, 0], voxel_coords[:, 2], (R - 1) - voxel_coords[:, 1]],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
voxel = Types.VOXEL(voxel_coords, color_feats, tex_resolution)
|
||||
return IO.NodeOutput(voxel)
|
||||
|
||||
@ -425,7 +547,9 @@ class Trellis2UpsampleStage(IO.ComfyNode):
|
||||
positive_out = _conditioning_set_extras(positive, extras)
|
||||
negative_out = _conditioning_set_extras(negative, extras)
|
||||
out_latent = {"samples": latent, "coords": coords, "coord_counts": counts,
|
||||
"coord_resolution": coord_resolution, "type": "trellis2"}
|
||||
"coord_resolution": coord_resolution, "type": "trellis2",
|
||||
"model_frame": shape_latent.get("model_frame",
|
||||
"y_up" if proj_pack is not None else "z_up")}
|
||||
return IO.NodeOutput(positive_out, negative_out, out_latent)
|
||||
|
||||
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
||||
@ -694,7 +818,8 @@ class Trellis2ShapeStage(IO.ComfyNode):
|
||||
positive_out = _conditioning_set_extras(positive, extras)
|
||||
negative_out = _conditioning_set_extras(negative, extras)
|
||||
out_latent = {"samples": latent, "coords": coords, "coord_counts": counts,
|
||||
"coord_resolution": coord_resolution, "type": "trellis2"}
|
||||
"coord_resolution": coord_resolution, "type": "trellis2",
|
||||
"model_frame": "y_up" if proj_pack is not None else "z_up"}
|
||||
return IO.NodeOutput(positive_out, negative_out, out_latent)
|
||||
|
||||
class Trellis2TextureStage(IO.ComfyNode):
|
||||
@ -747,7 +872,9 @@ class Trellis2TextureStage(IO.ComfyNode):
|
||||
|
||||
positive_out = _conditioning_set_extras(positive, extras)
|
||||
negative_out = _conditioning_set_extras(negative, extras)
|
||||
out_latent = {"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts}
|
||||
out_latent = {"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts,
|
||||
"model_frame": shape_latent.get("model_frame",
|
||||
"y_up" if proj_pack is not None else "z_up")}
|
||||
if coord_resolution is not None:
|
||||
out_latent["coord_resolution"] = coord_resolution
|
||||
return IO.NodeOutput(positive_out, negative_out, out_latent)
|
||||
@ -1018,6 +1145,8 @@ def _project_vertices_to_image_uv(vertices_world, transform_matrix, camera_angle
|
||||
points = vertices_world.unsqueeze(0).float()
|
||||
T = transform_matrix.unsqueeze(0).float() if transform_matrix.ndim == 2 else transform_matrix.float()
|
||||
cam = camera_angle_x.unsqueeze(0) if camera_angle_x.ndim == 0 else camera_angle_x
|
||||
T = T.to(points.device)
|
||||
cam = cam.to(points.device)
|
||||
uv_pix, depth, valid = _project_points_to_image(points, T, cam.float(), image_resolution)
|
||||
uv = uv_pix.squeeze(0) / image_resolution
|
||||
return uv, depth.squeeze(0), valid.squeeze(0)
|
||||
@ -1108,13 +1237,25 @@ class Pixal3DAlignObject(IO.ComfyNode):
|
||||
scene_pixels = _crop_uv_to_scene_pixels(uv_crop, crop_bbox, (scene_W, scene_H))
|
||||
in_scene = ((scene_pixels[:, 0] >= 0) & (scene_pixels[:, 0] < scene_W) &
|
||||
(scene_pixels[:, 1] >= 0) & (scene_pixels[:, 1] < scene_H))
|
||||
# MoGe geometry and object_mask can land on CPU after passing between nodes;
|
||||
# match the indexed tensor's device for sy/sx so the gather works on either.
|
||||
moge_points = moge_points.to(scene_pixels.device)
|
||||
moge_mask = moge_mask.to(scene_pixels.device)
|
||||
sx = scene_pixels[:, 0].long().clamp(0, scene_W - 1)
|
||||
sy = scene_pixels[:, 1].long().clamp(0, scene_H - 1)
|
||||
moge_per_vertex = moge_points[batch_index, sy, sx]
|
||||
# MoGe's perspective output is (X right, Y down, Z forward). Convert to glTF
|
||||
# Y-up (X right, Y up, Z back) so the scale/translate fit runs in the same
|
||||
# frame as vertices_one (Pixal3D model frame = glTF Y-up). Mirrors the
|
||||
# `verts * [1, -1, -1]` step in MoGePointMapToMesh.
|
||||
moge_per_vertex = moge_per_vertex * torch.tensor(
|
||||
[1.0, -1.0, -1.0], dtype=moge_per_vertex.dtype, device=moge_per_vertex.device
|
||||
)
|
||||
moge_mask_per_vertex = moge_mask[batch_index, sy, sx]
|
||||
keep = valid & in_scene & moge_mask_per_vertex
|
||||
if object_mask is not None:
|
||||
om = object_mask if object_mask.ndim == 2 else object_mask[batch_index]
|
||||
om = om.to(sy.device)
|
||||
keep = keep & (om[sy, sx] > 0.5)
|
||||
|
||||
finite = torch.isfinite(moge_per_vertex).all(dim=-1)
|
||||
@ -1131,25 +1272,23 @@ class Pixal3DAlignObject(IO.ComfyNode):
|
||||
q_mean = Q.mean(dim=0, keepdim=True)
|
||||
P_c = P - p_mean
|
||||
Q_c = Q - q_mean
|
||||
num = (P_c * Q_c).sum()
|
||||
den = (P_c * P_c).sum().clamp(min=1e-8)
|
||||
scale = float((num / den).item())
|
||||
if not (scale > 0):
|
||||
# Negative scale would mirror the mesh; treat as a camera-convention mismatch.
|
||||
logging.warning(
|
||||
f"Pixal3DAlignObject: computed scale={scale:.4f} <= 0; "
|
||||
"refusing to apply mirroring. Check camera convention alignment.")
|
||||
scale = 1.0
|
||||
aligned = vertices_one
|
||||
else:
|
||||
t = q_mean - scale * p_mean
|
||||
aligned = scale * vertices_one + t
|
||||
# Rotation-invariant scale: ratio of RMS spreads. MoGe geometry is
|
||||
# noisy and Pixal3D's mesh frame can be yawed relative to MoGe (paper
|
||||
# acknowledges this), so the L2-optimal scalar (P_c · Q_c)/(P_c · P_c)
|
||||
# gets multiplied by cos(yaw) and shrinks the object. Using
|
||||
# sqrt(||Q_c||² / ||P_c||²) recovers the right size regardless of
|
||||
# rotation; translation still positions the mesh at MoGe's centroid.
|
||||
p_var = (P_c * P_c).sum().clamp(min=1e-8)
|
||||
q_var = (Q_c * Q_c).sum()
|
||||
scale = float(torch.sqrt(q_var / p_var).item())
|
||||
t = q_mean - scale * p_mean
|
||||
aligned = scale * vertices_one + t
|
||||
|
||||
if vertices.ndim == 3:
|
||||
aligned = aligned.unsqueeze(0)
|
||||
out_mesh = Types.MESH(vertices=aligned, faces=faces)
|
||||
out_mesh = Types.MESH(vertices=aligned.cpu(), faces=faces.cpu())
|
||||
else:
|
||||
out_mesh = Types.MESH(vertices=aligned, faces=faces_one)
|
||||
out_mesh = Types.MESH(vertices=aligned.cpu(), faces=faces_one.cpu())
|
||||
return IO.NodeOutput(out_mesh, float(scale))
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user