diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 4126fb536..469b460eb 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -6,6 +6,7 @@ import logging from PIL import Image import numpy as np import torch +import scipy import copy shape_slat_normalization = { @@ -45,26 +46,30 @@ def shape_norm(shape_latent, coords): samples = samples * std + mean return samples -def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution, chunk_size=4096): +def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): """ Generic function to paint a mesh using nearest-neighbor colors from a sparse voxel field. - Keeps chunking internal to prevent OOM crashes on large matrices. """ - device = voxel_coords.device + device = comfy.model_management.vae_offload_device() - # Map Voxel Grid to Real 3D Space origin = torch.tensor([-0.5, -0.5, -0.5], device=device) voxel_size = 1.0 / resolution - voxel_pos = voxel_coords.float() * voxel_size + origin + # map voxels + voxel_pos = voxel_coords.to(device).float() * voxel_size + origin verts = mesh.vertices.to(device).squeeze(0) - v_colors = torch.zeros((verts.shape[0], 3), device=device) + voxel_colors = voxel_colors.to(device) - for i in range(0, verts.shape[0], chunk_size): - v_chunk = verts[i : i + chunk_size] - dists = torch.cdist(v_chunk, voxel_pos) - nearest_idx = torch.argmin(dists, dim=1) - v_colors[i : i + chunk_size] = voxel_colors[nearest_idx] + voxel_pos_np = voxel_pos.numpy() + verts_np = verts.numpy() + + tree = scipy.spatial.cKDTree(voxel_pos_np) + + # nearest neighbour k=1 + _, nearest_idx_np = tree.query(verts_np, k=1, workers=-1) + + nearest_idx = torch.from_numpy(nearest_idx_np).long() + v_colors = voxel_colors[nearest_idx] final_colors = (v_colors * 0.5 + 0.5).clamp(0, 1).unsqueeze(0) @@ -343,7 +348,11 @@ class Trellis2Conditioning(IO.ComfyNode): alpha_float = cropped_np[:, :, 3:4] composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float) - cropped_img_tensor = torch.from_numpy(composite_np).movedim(-1, 0).unsqueeze(0).float() + # to match trellis2 code (quantize -> dequantize) + composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) + + cropped_img_tensor = torch.from_numpy(composite_uint8).float() / 255.0 + cropped_img_tensor = cropped_img_tensor.movedim(-1, 0).unsqueeze(0) conditioning = run_conditioning(clip_vision_model, cropped_img_tensor, include_1024=True)