pytorch -> scipy

This commit is contained in:
Yousef Rafat 2026-03-25 17:03:52 +02:00
parent fe25190cae
commit d2c37c222a

View File

@ -6,6 +6,7 @@ import logging
from PIL import Image
import numpy as np
import torch
import scipy
import copy
shape_slat_normalization = {
@ -45,26 +46,30 @@ def shape_norm(shape_latent, coords):
samples = samples * std + mean
return samples
def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution, chunk_size=4096):
def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution):
"""
Generic function to paint a mesh using nearest-neighbor colors from a sparse voxel field.
Keeps chunking internal to prevent OOM crashes on large matrices.
"""
device = voxel_coords.device
device = comfy.model_management.vae_offload_device()
# Map Voxel Grid to Real 3D Space
origin = torch.tensor([-0.5, -0.5, -0.5], device=device)
voxel_size = 1.0 / resolution
voxel_pos = voxel_coords.float() * voxel_size + origin
# map voxels
voxel_pos = voxel_coords.to(device).float() * voxel_size + origin
verts = mesh.vertices.to(device).squeeze(0)
v_colors = torch.zeros((verts.shape[0], 3), device=device)
voxel_colors = voxel_colors.to(device)
for i in range(0, verts.shape[0], chunk_size):
v_chunk = verts[i : i + chunk_size]
dists = torch.cdist(v_chunk, voxel_pos)
nearest_idx = torch.argmin(dists, dim=1)
v_colors[i : i + chunk_size] = voxel_colors[nearest_idx]
voxel_pos_np = voxel_pos.numpy()
verts_np = verts.numpy()
tree = scipy.spatial.cKDTree(voxel_pos_np)
# nearest neighbour k=1
_, nearest_idx_np = tree.query(verts_np, k=1, workers=-1)
nearest_idx = torch.from_numpy(nearest_idx_np).long()
v_colors = voxel_colors[nearest_idx]
final_colors = (v_colors * 0.5 + 0.5).clamp(0, 1).unsqueeze(0)
@ -343,7 +348,11 @@ class Trellis2Conditioning(IO.ComfyNode):
alpha_float = cropped_np[:, :, 3:4]
composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float)
cropped_img_tensor = torch.from_numpy(composite_np).movedim(-1, 0).unsqueeze(0).float()
# to match trellis2 code (quantize -> dequantize)
composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
cropped_img_tensor = torch.from_numpy(composite_uint8).float() / 255.0
cropped_img_tensor = cropped_img_tensor.movedim(-1, 0).unsqueeze(0)
conditioning = run_conditioning(clip_vision_model, cropped_img_tensor, include_1024=True)