mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-05 22:21:31 +08:00
PBR baking
This commit is contained in:
parent
a4c8b5064b
commit
1697da460b
@ -17,6 +17,7 @@ class MESH:
|
|||||||
uvs: torch.Tensor | None = None,
|
uvs: torch.Tensor | None = None,
|
||||||
vertex_colors: torch.Tensor | None = None,
|
vertex_colors: torch.Tensor | None = None,
|
||||||
texture: torch.Tensor | None = None,
|
texture: torch.Tensor | None = None,
|
||||||
|
metallic_roughness: torch.Tensor | None = None,
|
||||||
vertex_counts: torch.Tensor | None = None,
|
vertex_counts: torch.Tensor | None = None,
|
||||||
face_counts: torch.Tensor | None = None):
|
face_counts: torch.Tensor | None = None):
|
||||||
|
|
||||||
@ -26,7 +27,9 @@ class MESH:
|
|||||||
self.faces = faces # faces: (B, M, 3)
|
self.faces = faces # faces: (B, M, 3)
|
||||||
self.uvs = uvs # uvs: (B, N, 2)
|
self.uvs = uvs # uvs: (B, N, 2)
|
||||||
self.vertex_colors = vertex_colors # vertex_colors: (B, N, 3 or 4)
|
self.vertex_colors = vertex_colors # vertex_colors: (B, N, 3 or 4)
|
||||||
self.texture = texture # texture: (B, H, W, 3)
|
self.texture = texture # texture (baseColor): (B, H, W, 3)
|
||||||
|
# glTF metallicRoughness texture: (B, H, W, 3), R unused, G=roughness, B=metallic
|
||||||
|
self.metallic_roughness = metallic_roughness
|
||||||
# When vertices/faces are zero-padded to a common N/M across the batch (variable-size mesh batch),
|
# When vertices/faces are zero-padded to a common N/M across the batch (variable-size mesh batch),
|
||||||
# these hold the real per-item lengths (B,). None means rows are uniform and no slicing is needed.
|
# these hold the real per-item lengths (B,). None means rows are uniform and no slicing is needed.
|
||||||
self.vertex_counts = vertex_counts
|
self.vertex_counts = vertex_counts
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@ -80,7 +80,8 @@ def get_mesh_batch_item(mesh, index):
|
|||||||
|
|
||||||
|
|
||||||
def save_glb(vertices, faces, filepath, metadata=None,
|
def save_glb(vertices, faces, filepath, metadata=None,
|
||||||
uvs=None, vertex_colors=None, texture_image=None):
|
uvs=None, vertex_colors=None, texture_image=None,
|
||||||
|
metallic_roughness_image=None):
|
||||||
"""
|
"""
|
||||||
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
|
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
|
||||||
|
|
||||||
@ -92,6 +93,8 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
|||||||
uvs: torch.Tensor of shape (N, 2) - Optional per-vertex texture coordinates
|
uvs: torch.Tensor of shape (N, 2) - Optional per-vertex texture coordinates
|
||||||
vertex_colors: torch.Tensor of shape (N, 3) or (N, 4) - Optional per-vertex colors in [0, 1]
|
vertex_colors: torch.Tensor of shape (N, 3) or (N, 4) - Optional per-vertex colors in [0, 1]
|
||||||
texture_image: PIL.Image - Optional baseColor texture, embedded as PNG
|
texture_image: PIL.Image - Optional baseColor texture, embedded as PNG
|
||||||
|
metallic_roughness_image: PIL.Image - Optional glTF metallicRoughness texture
|
||||||
|
(R unused, G=roughness, B=metallic), embedded as PNG
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Convert tensors to numpy arrays
|
# Convert tensors to numpy arrays
|
||||||
@ -126,12 +129,18 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
|||||||
buf = BytesIO()
|
buf = BytesIO()
|
||||||
texture_image.save(buf, format="PNG")
|
texture_image.save(buf, format="PNG")
|
||||||
texture_png_bytes = buf.getvalue()
|
texture_png_bytes = buf.getvalue()
|
||||||
|
mr_png_bytes = None
|
||||||
|
if metallic_roughness_image is not None:
|
||||||
|
buf = BytesIO()
|
||||||
|
metallic_roughness_image.save(buf, format="PNG")
|
||||||
|
mr_png_bytes = buf.getvalue()
|
||||||
|
|
||||||
vertices_buffer = vertices_np.tobytes()
|
vertices_buffer = vertices_np.tobytes()
|
||||||
indices_buffer = faces_np.tobytes()
|
indices_buffer = faces_np.tobytes()
|
||||||
uvs_buffer = uvs_np.tobytes() if uvs_np is not None else b""
|
uvs_buffer = uvs_np.tobytes() if uvs_np is not None else b""
|
||||||
colors_buffer = colors_np.tobytes() if colors_np is not None else b""
|
colors_buffer = colors_np.tobytes() if colors_np is not None else b""
|
||||||
texture_buffer = texture_png_bytes if texture_png_bytes is not None else b""
|
texture_buffer = texture_png_bytes if texture_png_bytes is not None else b""
|
||||||
|
mr_buffer = mr_png_bytes if mr_png_bytes is not None else b""
|
||||||
|
|
||||||
def pad_to_4_bytes(buffer):
|
def pad_to_4_bytes(buffer):
|
||||||
padding_length = (4 - (len(buffer) % 4)) % 4
|
padding_length = (4 - (len(buffer) % 4)) % 4
|
||||||
@ -142,6 +151,7 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
|||||||
uvs_buffer_padded = pad_to_4_bytes(uvs_buffer)
|
uvs_buffer_padded = pad_to_4_bytes(uvs_buffer)
|
||||||
colors_buffer_padded = pad_to_4_bytes(colors_buffer)
|
colors_buffer_padded = pad_to_4_bytes(colors_buffer)
|
||||||
texture_buffer_padded = pad_to_4_bytes(texture_buffer)
|
texture_buffer_padded = pad_to_4_bytes(texture_buffer)
|
||||||
|
mr_buffer_padded = pad_to_4_bytes(mr_buffer)
|
||||||
|
|
||||||
buffer_data = b"".join([
|
buffer_data = b"".join([
|
||||||
vertices_buffer_padded,
|
vertices_buffer_padded,
|
||||||
@ -149,6 +159,7 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
|||||||
uvs_buffer_padded,
|
uvs_buffer_padded,
|
||||||
colors_buffer_padded,
|
colors_buffer_padded,
|
||||||
texture_buffer_padded,
|
texture_buffer_padded,
|
||||||
|
mr_buffer_padded,
|
||||||
])
|
])
|
||||||
|
|
||||||
vertices_byte_length = len(vertices_buffer)
|
vertices_byte_length = len(vertices_buffer)
|
||||||
@ -158,6 +169,7 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
|||||||
uvs_byte_offset = indices_byte_offset + len(indices_buffer_padded)
|
uvs_byte_offset = indices_byte_offset + len(indices_buffer_padded)
|
||||||
colors_byte_offset = uvs_byte_offset + len(uvs_buffer_padded)
|
colors_byte_offset = uvs_byte_offset + len(uvs_buffer_padded)
|
||||||
texture_byte_offset = colors_byte_offset + len(colors_buffer_padded)
|
texture_byte_offset = colors_byte_offset + len(colors_buffer_padded)
|
||||||
|
mr_byte_offset = texture_byte_offset + len(texture_buffer_padded)
|
||||||
|
|
||||||
buffer_views = [
|
buffer_views = [
|
||||||
{
|
{
|
||||||
@ -251,8 +263,24 @@ def save_glb(vertices, faces, filepath, metadata=None,
|
|||||||
})
|
})
|
||||||
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
|
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
|
||||||
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
|
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
|
||||||
textures.append({"source": 0, "sampler": 0})
|
textures.append({"source": len(images) - 1, "sampler": 0})
|
||||||
pbr["baseColorTexture"] = {"index": 0, "texCoord": 0}
|
pbr["baseColorTexture"] = {"index": len(textures) - 1, "texCoord": 0}
|
||||||
|
|
||||||
|
if mr_png_bytes is not None and "TEXCOORD_0" in primitive_attributes:
|
||||||
|
buffer_views.append({
|
||||||
|
"buffer": 0,
|
||||||
|
"byteOffset": mr_byte_offset,
|
||||||
|
"byteLength": len(mr_buffer),
|
||||||
|
})
|
||||||
|
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
|
||||||
|
if not samplers:
|
||||||
|
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
|
||||||
|
textures.append({"source": len(images) - 1, "sampler": 0})
|
||||||
|
pbr["metallicRoughnessTexture"] = {"index": len(textures) - 1, "texCoord": 0}
|
||||||
|
# When a metallicRoughness texture is present, the factors scale it; use 1.0
|
||||||
|
# so the texture values pass through unchanged (glTF convention).
|
||||||
|
pbr["metallicFactor"] = 1.0
|
||||||
|
pbr["roughnessFactor"] = 1.0
|
||||||
|
|
||||||
materials.append({
|
materials.append({
|
||||||
"pbrMetallicRoughness": pbr,
|
"pbrMetallicRoughness": pbr,
|
||||||
@ -373,12 +401,20 @@ class SaveGLB(IO.ComfyNode):
|
|||||||
assert texture_np.ndim == 4 and texture_np.shape[-1] == 3, (
|
assert texture_np.ndim == 4 and texture_np.shape[-1] == 3, (
|
||||||
f"texture must be (B, H, W, 3) RGB, got shape {tuple(texture_np.shape)}"
|
f"texture must be (B, H, W, 3) RGB, got shape {tuple(texture_np.shape)}"
|
||||||
)
|
)
|
||||||
|
mr_b = getattr(mesh, "metallic_roughness", None)
|
||||||
|
mr_np = None
|
||||||
|
if mr_b is not None:
|
||||||
|
mr_np = (mr_b.clamp(0.0, 1.0).cpu().numpy() * 255).astype(np.uint8)
|
||||||
|
assert mr_np.ndim == 4 and mr_np.shape[-1] == 3, (
|
||||||
|
f"metallic_roughness must be (B, H, W, 3), got shape {tuple(mr_np.shape)}"
|
||||||
|
)
|
||||||
for i in range(mesh.vertices.shape[0]):
|
for i in range(mesh.vertices.shape[0]):
|
||||||
vertices_i, faces_i, v_colors, uvs_i = get_mesh_batch_item(mesh, i)
|
vertices_i, faces_i, v_colors, uvs_i = get_mesh_batch_item(mesh, i)
|
||||||
if vertices_i.shape[0] == 0 or faces_i.shape[0] == 0:
|
if vertices_i.shape[0] == 0 or faces_i.shape[0] == 0:
|
||||||
logging.warning(f"SaveGLB: skipping empty mesh at batch index {i}")
|
logging.warning(f"SaveGLB: skipping empty mesh at batch index {i}")
|
||||||
continue
|
continue
|
||||||
tex_img = Image.fromarray(texture_np[i], mode="RGB") if texture_np is not None else None
|
tex_img = Image.fromarray(texture_np[i], mode="RGB") if texture_np is not None else None
|
||||||
|
mr_img = Image.fromarray(mr_np[i], mode="RGB") if mr_np is not None else None
|
||||||
f = f"{filename}_{counter:05}_.glb"
|
f = f"{filename}_{counter:05}_.glb"
|
||||||
save_glb(
|
save_glb(
|
||||||
vertices_i, faces_i,
|
vertices_i, faces_i,
|
||||||
@ -387,6 +423,7 @@ class SaveGLB(IO.ComfyNode):
|
|||||||
uvs=uvs_i,
|
uvs=uvs_i,
|
||||||
vertex_colors=v_colors,
|
vertex_colors=v_colors,
|
||||||
texture_image=tex_img,
|
texture_image=tex_img,
|
||||||
|
metallic_roughness_image=mr_img,
|
||||||
)
|
)
|
||||||
results.append({
|
results.append({
|
||||||
"filename": f,
|
"filename": f,
|
||||||
|
|||||||
@ -9,8 +9,10 @@ from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import folder_paths
|
import folder_paths
|
||||||
|
from comfy.ldm.trellis2 import sampling_preview
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
@ -19,6 +21,89 @@ ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
|
|||||||
NAFModel = io.Custom("NAF_MODEL")
|
NAFModel = io.Custom("NAF_MODEL")
|
||||||
|
|
||||||
|
|
||||||
|
# Texture latent -> base-color calibration for the per-step preview
|
||||||
|
def _tex_rgb_factors_path():
|
||||||
|
return os.path.join(folder_paths.get_folder_paths("vae_approx")[0], "trellis2_tex_rgb_factors.pt")
|
||||||
|
|
||||||
|
|
||||||
|
def _pool_albedo_to_input(in_coords, out_coords, out_colors):
|
||||||
|
in_sp = in_coords[:, 1:4].long()
|
||||||
|
out_sp = out_coords[:, 1:4].long()
|
||||||
|
in_b = in_coords[:, 0].long()
|
||||||
|
out_b = out_coords[:, 0].long()
|
||||||
|
in_res = int(in_sp.max().item()) + 1
|
||||||
|
out_res = int(out_sp.max().item()) + 1
|
||||||
|
parent = torch.floor(out_sp.float() * in_res / out_res).long().clamp(0, in_res - 1)
|
||||||
|
R = in_res
|
||||||
|
in_flat = ((in_b * R + in_sp[:, 0]) * R + in_sp[:, 1]) * R + in_sp[:, 2]
|
||||||
|
par_flat = ((out_b * R + parent[:, 0]) * R + parent[:, 1]) * R + parent[:, 2]
|
||||||
|
order = torch.argsort(in_flat)
|
||||||
|
in_sorted = in_flat[order]
|
||||||
|
pos = torch.searchsorted(in_sorted, par_flat).clamp(max=in_sorted.numel() - 1)
|
||||||
|
matched = in_sorted[pos] == par_flat
|
||||||
|
in_idx = order[pos][matched]
|
||||||
|
cols = out_colors[matched].float()
|
||||||
|
N = in_coords.shape[0]
|
||||||
|
csum = cols.new_zeros((N, 3))
|
||||||
|
ccount = cols.new_zeros((N, 1))
|
||||||
|
csum.index_add_(0, in_idx, cols)
|
||||||
|
ccount.index_add_(0, in_idx, torch.ones((in_idx.shape[0], 1), device=cols.device, dtype=cols.dtype))
|
||||||
|
valid = ccount[:, 0] > 0
|
||||||
|
albedo = torch.zeros_like(csum)
|
||||||
|
albedo[valid] = csum[valid] / ccount[valid]
|
||||||
|
return albedo, valid
|
||||||
|
|
||||||
|
|
||||||
|
def _calibrate_tex_rgb(in_latent, in_coords, out_colors, out_coords):
|
||||||
|
"""Accumulate one decode's (latent -> albedo) evidence, re-solve, persist, publish."""
|
||||||
|
try:
|
||||||
|
dev = out_colors.device
|
||||||
|
in_latent = in_latent.to(dev)
|
||||||
|
in_coords = in_coords.to(dev)
|
||||||
|
out_coords = out_coords.to(dev)
|
||||||
|
albedo, valid = _pool_albedo_to_input(in_coords, out_coords, out_colors)
|
||||||
|
X = in_latent[valid].float().cpu()
|
||||||
|
Y = albedo[valid].float().cpu()
|
||||||
|
if X.shape[0] < 64:
|
||||||
|
return
|
||||||
|
Xaug = torch.cat([X, torch.ones(X.shape[0], 1)], dim=1) # [K, C+1]
|
||||||
|
A_run = Xaug.transpose(0, 1) @ Xaug # [C+1, C+1]
|
||||||
|
B_run = Xaug.transpose(0, 1) @ Y # [C+1, 3]
|
||||||
|
|
||||||
|
path = _tex_rgb_factors_path()
|
||||||
|
if os.path.exists(path):
|
||||||
|
try:
|
||||||
|
prev = torch.load(path, map_location="cpu")
|
||||||
|
A_run = A_run + prev["A"]
|
||||||
|
B_run = B_run + prev["B"]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
|
torch.save({"A": A_run, "B": B_run}, path)
|
||||||
|
|
||||||
|
eye = torch.eye(A_run.shape[0])
|
||||||
|
WB = torch.linalg.solve(A_run + 1e-3 * eye, B_run) # [C+1, 3]
|
||||||
|
W, b = WB[:-1].contiguous(), WB[-1].contiguous()
|
||||||
|
sampling_preview.set_tex_rgb(W, b)
|
||||||
|
except Exception as e:
|
||||||
|
logging.debug(f"Trellis2 tex-rgb calibration skipped: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_tex_rgb_factors():
|
||||||
|
try:
|
||||||
|
path = _tex_rgb_factors_path()
|
||||||
|
if os.path.exists(path):
|
||||||
|
d = torch.load(path, map_location="cpu")
|
||||||
|
eye = torch.eye(d["A"].shape[0])
|
||||||
|
WB = torch.linalg.solve(d["A"] + 1e-3 * eye, d["B"])
|
||||||
|
sampling_preview.set_tex_rgb(WB[:-1].contiguous(), WB[-1].contiguous())
|
||||||
|
except Exception as e:
|
||||||
|
logging.debug(f"Trellis2 tex-rgb factor load skipped: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
_load_tex_rgb_factors()
|
||||||
|
|
||||||
|
|
||||||
def prepare_trellis_vae_for_decode(vae, sample_shape):
|
def prepare_trellis_vae_for_decode(vae, sample_shape):
|
||||||
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
|
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
|
||||||
if len(sample_shape) == 5:
|
if len(sample_shape) == 5:
|
||||||
@ -174,6 +259,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
|||||||
resolution = int(coord_resolution) * 16
|
resolution = int(coord_resolution) * 16
|
||||||
else:
|
else:
|
||||||
resolution = int(vae.first_stage_model.resolution.item())
|
resolution = int(vae.first_stage_model.resolution.item())
|
||||||
|
model_frame = samples.get("model_frame", "y_up")
|
||||||
sample_tensor = samples["samples"]
|
sample_tensor = samples["samples"]
|
||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
coords = samples["coords"]
|
coords = samples["coords"]
|
||||||
@ -205,8 +291,13 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
|||||||
coords_list = [stage_tensor.coords for stage_tensor in stage_tensors]
|
coords_list = [stage_tensor.coords for stage_tensor in stage_tensors]
|
||||||
subs.append(SparseTensor.from_tensor_list(feats_list, coords_list))
|
subs.append(SparseTensor.from_tensor_list(feats_list, coords_list))
|
||||||
|
|
||||||
vert_list = [v.float() for v, f in mesh]
|
# Rotate Z-up (Trellis2 training frame) vertices to glTF Y-up. Pixal3D outputs are already Y-up.
|
||||||
face_list = [f.int() for v, f in mesh]
|
if model_frame == "z_up":
|
||||||
|
vert_list = [torch.stack([v[..., 0], v[..., 2], -v[..., 1]], dim=-1).float().cpu()
|
||||||
|
for v, _ in mesh]
|
||||||
|
else:
|
||||||
|
vert_list = [v.float().cpu() for v, _ in mesh]
|
||||||
|
face_list = [f.int().cpu() for _, f in mesh]
|
||||||
if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list):
|
if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list):
|
||||||
mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list))
|
mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list))
|
||||||
else:
|
else:
|
||||||
@ -241,19 +332,32 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
|||||||
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
|
||||||
trellis_vae = vae.first_stage_model
|
trellis_vae = vae.first_stage_model
|
||||||
coord_counts = samples.get("coord_counts")
|
coord_counts = samples.get("coord_counts")
|
||||||
|
model_frame = samples.get("model_frame", "y_up")
|
||||||
|
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
|
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
|
||||||
samples = samples.to(device)
|
samples = samples.to(device)
|
||||||
|
cal_in_latent = samples # [N, C] pre-denorm latent, for tex-rgb preview calibration
|
||||||
|
cal_in_coords = coords
|
||||||
std = tex_slat_normalization["std"].to(samples)
|
std = tex_slat_normalization["std"].to(samples)
|
||||||
mean = tex_slat_normalization["mean"].to(samples)
|
mean = tex_slat_normalization["mean"].to(samples)
|
||||||
samples = SparseTensor(feats = samples, coords=coords.to(device))
|
samples = SparseTensor(feats = samples, coords=coords.to(device))
|
||||||
samples = samples * std + mean
|
samples = samples * std + mean
|
||||||
|
|
||||||
voxel = trellis_vae.decode_tex_slat(samples, shape_subdivides)
|
voxel = trellis_vae.decode_tex_slat(samples, shape_subdivides)
|
||||||
color_feats = voxel.feats[:, :3]
|
# Keep all decoded channels. The texture VAE emits 6: base_color (0:3),
|
||||||
|
# metallic (3), roughness (4), alpha (5) — all in [0, 1]. Vertex-color
|
||||||
|
# consumers (PaintMesh) slice [:3]; BakeTextureFromVoxel uses the full
|
||||||
|
# PBR set. Older 3-channel checkpoints pass through unchanged.
|
||||||
|
color_feats = voxel.feats
|
||||||
voxel_coords = voxel.coords
|
voxel_coords = voxel.coords
|
||||||
|
|
||||||
|
# Calibrate the latent->base_color map for the per-step texture preview.
|
||||||
|
# Done here while input coords and voxel_coords share the model frame
|
||||||
|
# (before the z_up remap below) and on the real decoded albedo.
|
||||||
|
if color_feats.shape[0] > 0 and color_feats.shape[-1] >= 3:
|
||||||
|
_calibrate_tex_rgb(cal_in_latent, cal_in_coords, color_feats[:, :3], voxel_coords)
|
||||||
|
|
||||||
if voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3:
|
if voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3:
|
||||||
spatial = voxel_coords[:, -3:] if voxel_coords.shape[-1] == 4 else voxel_coords
|
spatial = voxel_coords[:, -3:] if voxel_coords.shape[-1] == 4 else voxel_coords
|
||||||
max_idx = int(spatial.max().item()) + 1
|
max_idx = int(spatial.max().item()) + 1
|
||||||
@ -261,6 +365,24 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
|
|||||||
else:
|
else:
|
||||||
tex_resolution = 1024
|
tex_resolution = 1024
|
||||||
|
|
||||||
|
# Remap Z-up voxel coords to Y-up: (x, y, z) -> (x, z, R-1-y), matching the
|
||||||
|
# R_x(-90°) applied to mesh vertices in VaeDecodeShapeTrellis. Keeps PaintMesh's
|
||||||
|
# NN lookup correctly aligned without it needing to know the source frame.
|
||||||
|
if model_frame == "z_up" and voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3:
|
||||||
|
R = tex_resolution
|
||||||
|
if voxel_coords.shape[-1] == 4:
|
||||||
|
batch_col = voxel_coords[:, :1]
|
||||||
|
spatial = voxel_coords[:, 1:]
|
||||||
|
spatial_yup = torch.stack(
|
||||||
|
[spatial[:, 0], spatial[:, 2], (R - 1) - spatial[:, 1]], dim=-1
|
||||||
|
)
|
||||||
|
voxel_coords = torch.cat([batch_col, spatial_yup], dim=-1)
|
||||||
|
else:
|
||||||
|
voxel_coords = torch.stack(
|
||||||
|
[voxel_coords[:, 0], voxel_coords[:, 2], (R - 1) - voxel_coords[:, 1]],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
|
||||||
voxel = Types.VOXEL(voxel_coords, color_feats, tex_resolution)
|
voxel = Types.VOXEL(voxel_coords, color_feats, tex_resolution)
|
||||||
return IO.NodeOutput(voxel)
|
return IO.NodeOutput(voxel)
|
||||||
|
|
||||||
@ -425,7 +547,9 @@ class Trellis2UpsampleStage(IO.ComfyNode):
|
|||||||
positive_out = _conditioning_set_extras(positive, extras)
|
positive_out = _conditioning_set_extras(positive, extras)
|
||||||
negative_out = _conditioning_set_extras(negative, extras)
|
negative_out = _conditioning_set_extras(negative, extras)
|
||||||
out_latent = {"samples": latent, "coords": coords, "coord_counts": counts,
|
out_latent = {"samples": latent, "coords": coords, "coord_counts": counts,
|
||||||
"coord_resolution": coord_resolution, "type": "trellis2"}
|
"coord_resolution": coord_resolution, "type": "trellis2",
|
||||||
|
"model_frame": shape_latent.get("model_frame",
|
||||||
|
"y_up" if proj_pack is not None else "z_up")}
|
||||||
return IO.NodeOutput(positive_out, negative_out, out_latent)
|
return IO.NodeOutput(positive_out, negative_out, out_latent)
|
||||||
|
|
||||||
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
||||||
@ -694,7 +818,8 @@ class Trellis2ShapeStage(IO.ComfyNode):
|
|||||||
positive_out = _conditioning_set_extras(positive, extras)
|
positive_out = _conditioning_set_extras(positive, extras)
|
||||||
negative_out = _conditioning_set_extras(negative, extras)
|
negative_out = _conditioning_set_extras(negative, extras)
|
||||||
out_latent = {"samples": latent, "coords": coords, "coord_counts": counts,
|
out_latent = {"samples": latent, "coords": coords, "coord_counts": counts,
|
||||||
"coord_resolution": coord_resolution, "type": "trellis2"}
|
"coord_resolution": coord_resolution, "type": "trellis2",
|
||||||
|
"model_frame": "y_up" if proj_pack is not None else "z_up"}
|
||||||
return IO.NodeOutput(positive_out, negative_out, out_latent)
|
return IO.NodeOutput(positive_out, negative_out, out_latent)
|
||||||
|
|
||||||
class Trellis2TextureStage(IO.ComfyNode):
|
class Trellis2TextureStage(IO.ComfyNode):
|
||||||
@ -747,7 +872,9 @@ class Trellis2TextureStage(IO.ComfyNode):
|
|||||||
|
|
||||||
positive_out = _conditioning_set_extras(positive, extras)
|
positive_out = _conditioning_set_extras(positive, extras)
|
||||||
negative_out = _conditioning_set_extras(negative, extras)
|
negative_out = _conditioning_set_extras(negative, extras)
|
||||||
out_latent = {"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts}
|
out_latent = {"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts,
|
||||||
|
"model_frame": shape_latent.get("model_frame",
|
||||||
|
"y_up" if proj_pack is not None else "z_up")}
|
||||||
if coord_resolution is not None:
|
if coord_resolution is not None:
|
||||||
out_latent["coord_resolution"] = coord_resolution
|
out_latent["coord_resolution"] = coord_resolution
|
||||||
return IO.NodeOutput(positive_out, negative_out, out_latent)
|
return IO.NodeOutput(positive_out, negative_out, out_latent)
|
||||||
@ -1018,6 +1145,8 @@ def _project_vertices_to_image_uv(vertices_world, transform_matrix, camera_angle
|
|||||||
points = vertices_world.unsqueeze(0).float()
|
points = vertices_world.unsqueeze(0).float()
|
||||||
T = transform_matrix.unsqueeze(0).float() if transform_matrix.ndim == 2 else transform_matrix.float()
|
T = transform_matrix.unsqueeze(0).float() if transform_matrix.ndim == 2 else transform_matrix.float()
|
||||||
cam = camera_angle_x.unsqueeze(0) if camera_angle_x.ndim == 0 else camera_angle_x
|
cam = camera_angle_x.unsqueeze(0) if camera_angle_x.ndim == 0 else camera_angle_x
|
||||||
|
T = T.to(points.device)
|
||||||
|
cam = cam.to(points.device)
|
||||||
uv_pix, depth, valid = _project_points_to_image(points, T, cam.float(), image_resolution)
|
uv_pix, depth, valid = _project_points_to_image(points, T, cam.float(), image_resolution)
|
||||||
uv = uv_pix.squeeze(0) / image_resolution
|
uv = uv_pix.squeeze(0) / image_resolution
|
||||||
return uv, depth.squeeze(0), valid.squeeze(0)
|
return uv, depth.squeeze(0), valid.squeeze(0)
|
||||||
@ -1108,13 +1237,25 @@ class Pixal3DAlignObject(IO.ComfyNode):
|
|||||||
scene_pixels = _crop_uv_to_scene_pixels(uv_crop, crop_bbox, (scene_W, scene_H))
|
scene_pixels = _crop_uv_to_scene_pixels(uv_crop, crop_bbox, (scene_W, scene_H))
|
||||||
in_scene = ((scene_pixels[:, 0] >= 0) & (scene_pixels[:, 0] < scene_W) &
|
in_scene = ((scene_pixels[:, 0] >= 0) & (scene_pixels[:, 0] < scene_W) &
|
||||||
(scene_pixels[:, 1] >= 0) & (scene_pixels[:, 1] < scene_H))
|
(scene_pixels[:, 1] >= 0) & (scene_pixels[:, 1] < scene_H))
|
||||||
|
# MoGe geometry and object_mask can land on CPU after passing between nodes;
|
||||||
|
# match the indexed tensor's device for sy/sx so the gather works on either.
|
||||||
|
moge_points = moge_points.to(scene_pixels.device)
|
||||||
|
moge_mask = moge_mask.to(scene_pixels.device)
|
||||||
sx = scene_pixels[:, 0].long().clamp(0, scene_W - 1)
|
sx = scene_pixels[:, 0].long().clamp(0, scene_W - 1)
|
||||||
sy = scene_pixels[:, 1].long().clamp(0, scene_H - 1)
|
sy = scene_pixels[:, 1].long().clamp(0, scene_H - 1)
|
||||||
moge_per_vertex = moge_points[batch_index, sy, sx]
|
moge_per_vertex = moge_points[batch_index, sy, sx]
|
||||||
|
# MoGe's perspective output is (X right, Y down, Z forward). Convert to glTF
|
||||||
|
# Y-up (X right, Y up, Z back) so the scale/translate fit runs in the same
|
||||||
|
# frame as vertices_one (Pixal3D model frame = glTF Y-up). Mirrors the
|
||||||
|
# `verts * [1, -1, -1]` step in MoGePointMapToMesh.
|
||||||
|
moge_per_vertex = moge_per_vertex * torch.tensor(
|
||||||
|
[1.0, -1.0, -1.0], dtype=moge_per_vertex.dtype, device=moge_per_vertex.device
|
||||||
|
)
|
||||||
moge_mask_per_vertex = moge_mask[batch_index, sy, sx]
|
moge_mask_per_vertex = moge_mask[batch_index, sy, sx]
|
||||||
keep = valid & in_scene & moge_mask_per_vertex
|
keep = valid & in_scene & moge_mask_per_vertex
|
||||||
if object_mask is not None:
|
if object_mask is not None:
|
||||||
om = object_mask if object_mask.ndim == 2 else object_mask[batch_index]
|
om = object_mask if object_mask.ndim == 2 else object_mask[batch_index]
|
||||||
|
om = om.to(sy.device)
|
||||||
keep = keep & (om[sy, sx] > 0.5)
|
keep = keep & (om[sy, sx] > 0.5)
|
||||||
|
|
||||||
finite = torch.isfinite(moge_per_vertex).all(dim=-1)
|
finite = torch.isfinite(moge_per_vertex).all(dim=-1)
|
||||||
@ -1131,25 +1272,23 @@ class Pixal3DAlignObject(IO.ComfyNode):
|
|||||||
q_mean = Q.mean(dim=0, keepdim=True)
|
q_mean = Q.mean(dim=0, keepdim=True)
|
||||||
P_c = P - p_mean
|
P_c = P - p_mean
|
||||||
Q_c = Q - q_mean
|
Q_c = Q - q_mean
|
||||||
num = (P_c * Q_c).sum()
|
# Rotation-invariant scale: ratio of RMS spreads. MoGe geometry is
|
||||||
den = (P_c * P_c).sum().clamp(min=1e-8)
|
# noisy and Pixal3D's mesh frame can be yawed relative to MoGe (paper
|
||||||
scale = float((num / den).item())
|
# acknowledges this), so the L2-optimal scalar (P_c · Q_c)/(P_c · P_c)
|
||||||
if not (scale > 0):
|
# gets multiplied by cos(yaw) and shrinks the object. Using
|
||||||
# Negative scale would mirror the mesh; treat as a camera-convention mismatch.
|
# sqrt(||Q_c||² / ||P_c||²) recovers the right size regardless of
|
||||||
logging.warning(
|
# rotation; translation still positions the mesh at MoGe's centroid.
|
||||||
f"Pixal3DAlignObject: computed scale={scale:.4f} <= 0; "
|
p_var = (P_c * P_c).sum().clamp(min=1e-8)
|
||||||
"refusing to apply mirroring. Check camera convention alignment.")
|
q_var = (Q_c * Q_c).sum()
|
||||||
scale = 1.0
|
scale = float(torch.sqrt(q_var / p_var).item())
|
||||||
aligned = vertices_one
|
t = q_mean - scale * p_mean
|
||||||
else:
|
aligned = scale * vertices_one + t
|
||||||
t = q_mean - scale * p_mean
|
|
||||||
aligned = scale * vertices_one + t
|
|
||||||
|
|
||||||
if vertices.ndim == 3:
|
if vertices.ndim == 3:
|
||||||
aligned = aligned.unsqueeze(0)
|
aligned = aligned.unsqueeze(0)
|
||||||
out_mesh = Types.MESH(vertices=aligned, faces=faces)
|
out_mesh = Types.MESH(vertices=aligned.cpu(), faces=faces.cpu())
|
||||||
else:
|
else:
|
||||||
out_mesh = Types.MESH(vertices=aligned, faces=faces_one)
|
out_mesh = Types.MESH(vertices=aligned.cpu(), faces=faces_one.cpu())
|
||||||
return IO.NodeOutput(out_mesh, float(scale))
|
return IO.NodeOutput(out_mesh, float(scale))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user