PBR baking

This commit is contained in:
kijai 2026-06-10 10:30:41 +03:00
parent a4c8b5064b
commit 1697da460b
4 changed files with 1398 additions and 26 deletions

View File

@ -17,6 +17,7 @@ class MESH:
uvs: torch.Tensor | None = None,
vertex_colors: torch.Tensor | None = None,
texture: torch.Tensor | None = None,
metallic_roughness: torch.Tensor | None = None,
vertex_counts: torch.Tensor | None = None,
face_counts: torch.Tensor | None = None):
@ -26,7 +27,9 @@ class MESH:
self.faces = faces # faces: (B, M, 3)
self.uvs = uvs # uvs: (B, N, 2)
self.vertex_colors = vertex_colors # vertex_colors: (B, N, 3 or 4)
self.texture = texture # texture: (B, H, W, 3)
self.texture = texture # texture (baseColor): (B, H, W, 3)
# glTF metallicRoughness texture: (B, H, W, 3), R unused, G=roughness, B=metallic
self.metallic_roughness = metallic_roughness
# When vertices/faces are zero-padded to a common N/M across the batch (variable-size mesh batch),
# these hold the real per-item lengths (B,). None means rows are uniform and no slicing is needed.
self.vertex_counts = vertex_counts

File diff suppressed because it is too large Load Diff

View File

@ -80,7 +80,8 @@ def get_mesh_batch_item(mesh, index):
def save_glb(vertices, faces, filepath, metadata=None,
uvs=None, vertex_colors=None, texture_image=None):
uvs=None, vertex_colors=None, texture_image=None,
metallic_roughness_image=None):
"""
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
@ -92,6 +93,8 @@ def save_glb(vertices, faces, filepath, metadata=None,
uvs: torch.Tensor of shape (N, 2) - Optional per-vertex texture coordinates
vertex_colors: torch.Tensor of shape (N, 3) or (N, 4) - Optional per-vertex colors in [0, 1]
texture_image: PIL.Image - Optional baseColor texture, embedded as PNG
metallic_roughness_image: PIL.Image - Optional glTF metallicRoughness texture
(R unused, G=roughness, B=metallic), embedded as PNG
"""
# Convert tensors to numpy arrays
@ -126,12 +129,18 @@ def save_glb(vertices, faces, filepath, metadata=None,
buf = BytesIO()
texture_image.save(buf, format="PNG")
texture_png_bytes = buf.getvalue()
mr_png_bytes = None
if metallic_roughness_image is not None:
buf = BytesIO()
metallic_roughness_image.save(buf, format="PNG")
mr_png_bytes = buf.getvalue()
vertices_buffer = vertices_np.tobytes()
indices_buffer = faces_np.tobytes()
uvs_buffer = uvs_np.tobytes() if uvs_np is not None else b""
colors_buffer = colors_np.tobytes() if colors_np is not None else b""
texture_buffer = texture_png_bytes if texture_png_bytes is not None else b""
mr_buffer = mr_png_bytes if mr_png_bytes is not None else b""
def pad_to_4_bytes(buffer):
padding_length = (4 - (len(buffer) % 4)) % 4
@ -142,6 +151,7 @@ def save_glb(vertices, faces, filepath, metadata=None,
uvs_buffer_padded = pad_to_4_bytes(uvs_buffer)
colors_buffer_padded = pad_to_4_bytes(colors_buffer)
texture_buffer_padded = pad_to_4_bytes(texture_buffer)
mr_buffer_padded = pad_to_4_bytes(mr_buffer)
buffer_data = b"".join([
vertices_buffer_padded,
@ -149,6 +159,7 @@ def save_glb(vertices, faces, filepath, metadata=None,
uvs_buffer_padded,
colors_buffer_padded,
texture_buffer_padded,
mr_buffer_padded,
])
vertices_byte_length = len(vertices_buffer)
@ -158,6 +169,7 @@ def save_glb(vertices, faces, filepath, metadata=None,
uvs_byte_offset = indices_byte_offset + len(indices_buffer_padded)
colors_byte_offset = uvs_byte_offset + len(uvs_buffer_padded)
texture_byte_offset = colors_byte_offset + len(colors_buffer_padded)
mr_byte_offset = texture_byte_offset + len(texture_buffer_padded)
buffer_views = [
{
@ -251,8 +263,24 @@ def save_glb(vertices, faces, filepath, metadata=None,
})
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
textures.append({"source": 0, "sampler": 0})
pbr["baseColorTexture"] = {"index": 0, "texCoord": 0}
textures.append({"source": len(images) - 1, "sampler": 0})
pbr["baseColorTexture"] = {"index": len(textures) - 1, "texCoord": 0}
if mr_png_bytes is not None and "TEXCOORD_0" in primitive_attributes:
buffer_views.append({
"buffer": 0,
"byteOffset": mr_byte_offset,
"byteLength": len(mr_buffer),
})
images.append({"bufferView": len(buffer_views) - 1, "mimeType": "image/png"})
if not samplers:
samplers.append({"magFilter": 9729, "minFilter": 9729, "wrapS": 33071, "wrapT": 33071})
textures.append({"source": len(images) - 1, "sampler": 0})
pbr["metallicRoughnessTexture"] = {"index": len(textures) - 1, "texCoord": 0}
# When a metallicRoughness texture is present, the factors scale it; use 1.0
# so the texture values pass through unchanged (glTF convention).
pbr["metallicFactor"] = 1.0
pbr["roughnessFactor"] = 1.0
materials.append({
"pbrMetallicRoughness": pbr,
@ -373,12 +401,20 @@ class SaveGLB(IO.ComfyNode):
assert texture_np.ndim == 4 and texture_np.shape[-1] == 3, (
f"texture must be (B, H, W, 3) RGB, got shape {tuple(texture_np.shape)}"
)
mr_b = getattr(mesh, "metallic_roughness", None)
mr_np = None
if mr_b is not None:
mr_np = (mr_b.clamp(0.0, 1.0).cpu().numpy() * 255).astype(np.uint8)
assert mr_np.ndim == 4 and mr_np.shape[-1] == 3, (
f"metallic_roughness must be (B, H, W, 3), got shape {tuple(mr_np.shape)}"
)
for i in range(mesh.vertices.shape[0]):
vertices_i, faces_i, v_colors, uvs_i = get_mesh_batch_item(mesh, i)
if vertices_i.shape[0] == 0 or faces_i.shape[0] == 0:
logging.warning(f"SaveGLB: skipping empty mesh at batch index {i}")
continue
tex_img = Image.fromarray(texture_np[i], mode="RGB") if texture_np is not None else None
mr_img = Image.fromarray(mr_np[i], mode="RGB") if mr_np is not None else None
f = f"{filename}_{counter:05}_.glb"
save_glb(
vertices_i, faces_i,
@ -387,6 +423,7 @@ class SaveGLB(IO.ComfyNode):
uvs=uvs_i,
vertex_colors=v_colors,
texture_image=tex_img,
metallic_roughness_image=mr_img,
)
results.append({
"filename": f,

View File

@ -9,8 +9,10 @@ from comfy_extras.nodes_mesh_postprocess import pack_variable_mesh_batch
import comfy.model_management
import comfy.utils
import folder_paths
from comfy.ldm.trellis2 import sampling_preview
from PIL import Image
import logging
import os
import numpy as np
import math
import torch
@ -19,6 +21,89 @@ ShapeSubdivides = io.Custom("SHAPE_SUBDIVIDES")
NAFModel = io.Custom("NAF_MODEL")
# Texture latent -> base-color calibration for the per-step preview
def _tex_rgb_factors_path():
return os.path.join(folder_paths.get_folder_paths("vae_approx")[0], "trellis2_tex_rgb_factors.pt")
def _pool_albedo_to_input(in_coords, out_coords, out_colors):
in_sp = in_coords[:, 1:4].long()
out_sp = out_coords[:, 1:4].long()
in_b = in_coords[:, 0].long()
out_b = out_coords[:, 0].long()
in_res = int(in_sp.max().item()) + 1
out_res = int(out_sp.max().item()) + 1
parent = torch.floor(out_sp.float() * in_res / out_res).long().clamp(0, in_res - 1)
R = in_res
in_flat = ((in_b * R + in_sp[:, 0]) * R + in_sp[:, 1]) * R + in_sp[:, 2]
par_flat = ((out_b * R + parent[:, 0]) * R + parent[:, 1]) * R + parent[:, 2]
order = torch.argsort(in_flat)
in_sorted = in_flat[order]
pos = torch.searchsorted(in_sorted, par_flat).clamp(max=in_sorted.numel() - 1)
matched = in_sorted[pos] == par_flat
in_idx = order[pos][matched]
cols = out_colors[matched].float()
N = in_coords.shape[0]
csum = cols.new_zeros((N, 3))
ccount = cols.new_zeros((N, 1))
csum.index_add_(0, in_idx, cols)
ccount.index_add_(0, in_idx, torch.ones((in_idx.shape[0], 1), device=cols.device, dtype=cols.dtype))
valid = ccount[:, 0] > 0
albedo = torch.zeros_like(csum)
albedo[valid] = csum[valid] / ccount[valid]
return albedo, valid
def _calibrate_tex_rgb(in_latent, in_coords, out_colors, out_coords):
"""Accumulate one decode's (latent -> albedo) evidence, re-solve, persist, publish."""
try:
dev = out_colors.device
in_latent = in_latent.to(dev)
in_coords = in_coords.to(dev)
out_coords = out_coords.to(dev)
albedo, valid = _pool_albedo_to_input(in_coords, out_coords, out_colors)
X = in_latent[valid].float().cpu()
Y = albedo[valid].float().cpu()
if X.shape[0] < 64:
return
Xaug = torch.cat([X, torch.ones(X.shape[0], 1)], dim=1) # [K, C+1]
A_run = Xaug.transpose(0, 1) @ Xaug # [C+1, C+1]
B_run = Xaug.transpose(0, 1) @ Y # [C+1, 3]
path = _tex_rgb_factors_path()
if os.path.exists(path):
try:
prev = torch.load(path, map_location="cpu")
A_run = A_run + prev["A"]
B_run = B_run + prev["B"]
except Exception:
pass
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({"A": A_run, "B": B_run}, path)
eye = torch.eye(A_run.shape[0])
WB = torch.linalg.solve(A_run + 1e-3 * eye, B_run) # [C+1, 3]
W, b = WB[:-1].contiguous(), WB[-1].contiguous()
sampling_preview.set_tex_rgb(W, b)
except Exception as e:
logging.debug(f"Trellis2 tex-rgb calibration skipped: {e}")
def _load_tex_rgb_factors():
try:
path = _tex_rgb_factors_path()
if os.path.exists(path):
d = torch.load(path, map_location="cpu")
eye = torch.eye(d["A"].shape[0])
WB = torch.linalg.solve(d["A"] + 1e-3 * eye, d["B"])
sampling_preview.set_tex_rgb(WB[:-1].contiguous(), WB[-1].contiguous())
except Exception as e:
logging.debug(f"Trellis2 tex-rgb factor load skipped: {e}")
_load_tex_rgb_factors()
def prepare_trellis_vae_for_decode(vae, sample_shape):
memory_required = vae.memory_used_decode(sample_shape, vae.vae_dtype)
if len(sample_shape) == 5:
@ -174,6 +259,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
resolution = int(coord_resolution) * 16
else:
resolution = int(vae.first_stage_model.resolution.item())
model_frame = samples.get("model_frame", "y_up")
sample_tensor = samples["samples"]
device = comfy.model_management.get_torch_device()
coords = samples["coords"]
@ -205,8 +291,13 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
coords_list = [stage_tensor.coords for stage_tensor in stage_tensors]
subs.append(SparseTensor.from_tensor_list(feats_list, coords_list))
vert_list = [v.float() for v, f in mesh]
face_list = [f.int() for v, f in mesh]
# Rotate Z-up (Trellis2 training frame) vertices to glTF Y-up. Pixal3D outputs are already Y-up.
if model_frame == "z_up":
vert_list = [torch.stack([v[..., 0], v[..., 2], -v[..., 1]], dim=-1).float().cpu()
for v, _ in mesh]
else:
vert_list = [v.float().cpu() for v, _ in mesh]
face_list = [f.int().cpu() for _, f in mesh]
if all(v.shape == vert_list[0].shape for v in vert_list) and all(f.shape == face_list[0].shape for f in face_list):
mesh = Types.MESH(vertices=torch.stack(vert_list), faces=torch.stack(face_list))
else:
@ -241,19 +332,32 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
prepare_trellis_vae_for_decode(vae, sample_tensor.shape)
trellis_vae = vae.first_stage_model
coord_counts = samples.get("coord_counts")
model_frame = samples.get("model_frame", "y_up")
samples = samples["samples"]
samples, coords = flatten_batched_sparse_latent(samples, coords, coord_counts)
samples = samples.to(device)
cal_in_latent = samples # [N, C] pre-denorm latent, for tex-rgb preview calibration
cal_in_coords = coords
std = tex_slat_normalization["std"].to(samples)
mean = tex_slat_normalization["mean"].to(samples)
samples = SparseTensor(feats = samples, coords=coords.to(device))
samples = samples * std + mean
voxel = trellis_vae.decode_tex_slat(samples, shape_subdivides)
color_feats = voxel.feats[:, :3]
# Keep all decoded channels. The texture VAE emits 6: base_color (0:3),
# metallic (3), roughness (4), alpha (5) — all in [0, 1]. Vertex-color
# consumers (PaintMesh) slice [:3]; BakeTextureFromVoxel uses the full
# PBR set. Older 3-channel checkpoints pass through unchanged.
color_feats = voxel.feats
voxel_coords = voxel.coords
# Calibrate the latent->base_color map for the per-step texture preview.
# Done here while input coords and voxel_coords share the model frame
# (before the z_up remap below) and on the real decoded albedo.
if color_feats.shape[0] > 0 and color_feats.shape[-1] >= 3:
_calibrate_tex_rgb(cal_in_latent, cal_in_coords, color_feats[:, :3], voxel_coords)
if voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3:
spatial = voxel_coords[:, -3:] if voxel_coords.shape[-1] == 4 else voxel_coords
max_idx = int(spatial.max().item()) + 1
@ -261,6 +365,24 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
else:
tex_resolution = 1024
# Remap Z-up voxel coords to Y-up: (x, y, z) -> (x, z, R-1-y), matching the
# R_x(-90°) applied to mesh vertices in VaeDecodeShapeTrellis. Keeps PaintMesh's
# NN lookup correctly aligned without it needing to know the source frame.
if model_frame == "z_up" and voxel_coords.numel() > 0 and voxel_coords.shape[-1] >= 3:
R = tex_resolution
if voxel_coords.shape[-1] == 4:
batch_col = voxel_coords[:, :1]
spatial = voxel_coords[:, 1:]
spatial_yup = torch.stack(
[spatial[:, 0], spatial[:, 2], (R - 1) - spatial[:, 1]], dim=-1
)
voxel_coords = torch.cat([batch_col, spatial_yup], dim=-1)
else:
voxel_coords = torch.stack(
[voxel_coords[:, 0], voxel_coords[:, 2], (R - 1) - voxel_coords[:, 1]],
dim=-1,
)
voxel = Types.VOXEL(voxel_coords, color_feats, tex_resolution)
return IO.NodeOutput(voxel)
@ -425,7 +547,9 @@ class Trellis2UpsampleStage(IO.ComfyNode):
positive_out = _conditioning_set_extras(positive, extras)
negative_out = _conditioning_set_extras(negative, extras)
out_latent = {"samples": latent, "coords": coords, "coord_counts": counts,
"coord_resolution": coord_resolution, "type": "trellis2"}
"coord_resolution": coord_resolution, "type": "trellis2",
"model_frame": shape_latent.get("model_frame",
"y_up" if proj_pack is not None else "z_up")}
return IO.NodeOutput(positive_out, negative_out, out_latent)
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
@ -694,7 +818,8 @@ class Trellis2ShapeStage(IO.ComfyNode):
positive_out = _conditioning_set_extras(positive, extras)
negative_out = _conditioning_set_extras(negative, extras)
out_latent = {"samples": latent, "coords": coords, "coord_counts": counts,
"coord_resolution": coord_resolution, "type": "trellis2"}
"coord_resolution": coord_resolution, "type": "trellis2",
"model_frame": "y_up" if proj_pack is not None else "z_up"}
return IO.NodeOutput(positive_out, negative_out, out_latent)
class Trellis2TextureStage(IO.ComfyNode):
@ -747,7 +872,9 @@ class Trellis2TextureStage(IO.ComfyNode):
positive_out = _conditioning_set_extras(positive, extras)
negative_out = _conditioning_set_extras(negative, extras)
out_latent = {"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts}
out_latent = {"samples": latent, "type": "trellis2", "coords": coords, "coord_counts": counts,
"model_frame": shape_latent.get("model_frame",
"y_up" if proj_pack is not None else "z_up")}
if coord_resolution is not None:
out_latent["coord_resolution"] = coord_resolution
return IO.NodeOutput(positive_out, negative_out, out_latent)
@ -1018,6 +1145,8 @@ def _project_vertices_to_image_uv(vertices_world, transform_matrix, camera_angle
points = vertices_world.unsqueeze(0).float()
T = transform_matrix.unsqueeze(0).float() if transform_matrix.ndim == 2 else transform_matrix.float()
cam = camera_angle_x.unsqueeze(0) if camera_angle_x.ndim == 0 else camera_angle_x
T = T.to(points.device)
cam = cam.to(points.device)
uv_pix, depth, valid = _project_points_to_image(points, T, cam.float(), image_resolution)
uv = uv_pix.squeeze(0) / image_resolution
return uv, depth.squeeze(0), valid.squeeze(0)
@ -1108,13 +1237,25 @@ class Pixal3DAlignObject(IO.ComfyNode):
scene_pixels = _crop_uv_to_scene_pixels(uv_crop, crop_bbox, (scene_W, scene_H))
in_scene = ((scene_pixels[:, 0] >= 0) & (scene_pixels[:, 0] < scene_W) &
(scene_pixels[:, 1] >= 0) & (scene_pixels[:, 1] < scene_H))
# MoGe geometry and object_mask can land on CPU after passing between nodes;
# match the indexed tensor's device for sy/sx so the gather works on either.
moge_points = moge_points.to(scene_pixels.device)
moge_mask = moge_mask.to(scene_pixels.device)
sx = scene_pixels[:, 0].long().clamp(0, scene_W - 1)
sy = scene_pixels[:, 1].long().clamp(0, scene_H - 1)
moge_per_vertex = moge_points[batch_index, sy, sx]
# MoGe's perspective output is (X right, Y down, Z forward). Convert to glTF
# Y-up (X right, Y up, Z back) so the scale/translate fit runs in the same
# frame as vertices_one (Pixal3D model frame = glTF Y-up). Mirrors the
# `verts * [1, -1, -1]` step in MoGePointMapToMesh.
moge_per_vertex = moge_per_vertex * torch.tensor(
[1.0, -1.0, -1.0], dtype=moge_per_vertex.dtype, device=moge_per_vertex.device
)
moge_mask_per_vertex = moge_mask[batch_index, sy, sx]
keep = valid & in_scene & moge_mask_per_vertex
if object_mask is not None:
om = object_mask if object_mask.ndim == 2 else object_mask[batch_index]
om = om.to(sy.device)
keep = keep & (om[sy, sx] > 0.5)
finite = torch.isfinite(moge_per_vertex).all(dim=-1)
@ -1131,25 +1272,23 @@ class Pixal3DAlignObject(IO.ComfyNode):
q_mean = Q.mean(dim=0, keepdim=True)
P_c = P - p_mean
Q_c = Q - q_mean
num = (P_c * Q_c).sum()
den = (P_c * P_c).sum().clamp(min=1e-8)
scale = float((num / den).item())
if not (scale > 0):
# Negative scale would mirror the mesh; treat as a camera-convention mismatch.
logging.warning(
f"Pixal3DAlignObject: computed scale={scale:.4f} <= 0; "
"refusing to apply mirroring. Check camera convention alignment.")
scale = 1.0
aligned = vertices_one
else:
t = q_mean - scale * p_mean
aligned = scale * vertices_one + t
# Rotation-invariant scale: ratio of RMS spreads. MoGe geometry is
# noisy and Pixal3D's mesh frame can be yawed relative to MoGe (paper
# acknowledges this), so the L2-optimal scalar (P_c · Q_c)/(P_c · P_c)
# gets multiplied by cos(yaw) and shrinks the object. Using
# sqrt(||Q_c||² / ||P_c||²) recovers the right size regardless of
# rotation; translation still positions the mesh at MoGe's centroid.
p_var = (P_c * P_c).sum().clamp(min=1e-8)
q_var = (Q_c * Q_c).sum()
scale = float(torch.sqrt(q_var / p_var).item())
t = q_mean - scale * p_mean
aligned = scale * vertices_one + t
if vertices.ndim == 3:
aligned = aligned.unsqueeze(0)
out_mesh = Types.MESH(vertices=aligned, faces=faces)
out_mesh = Types.MESH(vertices=aligned.cpu(), faces=faces.cpu())
else:
out_mesh = Types.MESH(vertices=aligned, faces=faces_one)
out_mesh = Types.MESH(vertices=aligned.cpu(), faces=faces_one.cpu())
return IO.NodeOutput(out_mesh, float(scale))