mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-03 22:02:51 +08:00
Trellis2: share batched mesh helpers
This commit is contained in:
parent
c297a9f839
commit
40219ab0fc
53
comfy_extras/mesh_batch_utils.py
Normal file
53
comfy_extras/mesh_batch_utils.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import torch
|
||||||
|
from comfy_api.latest import Types
|
||||||
|
|
||||||
|
|
||||||
|
def pack_variable_mesh_batch(vertices, faces, colors=None):
|
||||||
|
batch_size = len(vertices)
|
||||||
|
max_vertices = max(v.shape[0] for v in vertices)
|
||||||
|
max_faces = max(f.shape[0] for f in faces)
|
||||||
|
|
||||||
|
packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1]))
|
||||||
|
packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1]))
|
||||||
|
vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64)
|
||||||
|
face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64)
|
||||||
|
|
||||||
|
for i, (v, f) in enumerate(zip(vertices, faces)):
|
||||||
|
packed_vertices[i, :v.shape[0]] = v
|
||||||
|
packed_faces[i, :f.shape[0]] = f
|
||||||
|
|
||||||
|
mesh = Types.MESH(packed_vertices, packed_faces)
|
||||||
|
mesh.vertex_counts = vertex_counts
|
||||||
|
mesh.face_counts = face_counts
|
||||||
|
|
||||||
|
if colors is not None:
|
||||||
|
max_colors = max(c.shape[0] for c in colors)
|
||||||
|
packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1]))
|
||||||
|
color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64)
|
||||||
|
for i, c in enumerate(colors):
|
||||||
|
packed_colors[i, :c.shape[0]] = c
|
||||||
|
mesh.colors = packed_colors
|
||||||
|
mesh.color_counts = color_counts
|
||||||
|
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
def get_mesh_batch_item(mesh, index):
|
||||||
|
if hasattr(mesh, "vertex_counts"):
|
||||||
|
vertex_count = int(mesh.vertex_counts[index].item())
|
||||||
|
face_count = int(mesh.face_counts[index].item())
|
||||||
|
vertices = mesh.vertices[index, :vertex_count]
|
||||||
|
faces = mesh.faces[index, :face_count]
|
||||||
|
colors = None
|
||||||
|
if hasattr(mesh, "colors") and mesh.colors is not None:
|
||||||
|
if hasattr(mesh, "color_counts"):
|
||||||
|
color_count = int(mesh.color_counts[index].item())
|
||||||
|
colors = mesh.colors[index, :color_count]
|
||||||
|
else:
|
||||||
|
colors = mesh.colors[index, :vertex_count]
|
||||||
|
return vertices, faces, colors
|
||||||
|
|
||||||
|
colors = None
|
||||||
|
if hasattr(mesh, "colors") and mesh.colors is not None:
|
||||||
|
colors = mesh.colors[index]
|
||||||
|
return mesh.vertices[index], mesh.faces[index], colors
|
||||||
@ -10,6 +10,7 @@ from comfy.cli_args import args
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, IO, Types
|
from comfy_api.latest import ComfyExtension, IO, Types
|
||||||
from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa
|
from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa
|
||||||
|
from comfy_extras.mesh_batch_utils import pack_variable_mesh_batch, get_mesh_batch_item
|
||||||
|
|
||||||
|
|
||||||
class EmptyLatentHunyuan3Dv2(IO.ComfyNode):
|
class EmptyLatentHunyuan3Dv2(IO.ComfyNode):
|
||||||
@ -631,58 +632,6 @@ def save_glb(vertices, faces, filepath, metadata=None, colors=None):
|
|||||||
|
|
||||||
return filepath
|
return filepath
|
||||||
|
|
||||||
|
|
||||||
def pack_variable_mesh_batch(vertices, faces, colors=None):
|
|
||||||
batch_size = len(vertices)
|
|
||||||
max_vertices = max(v.shape[0] for v in vertices)
|
|
||||||
max_faces = max(f.shape[0] for f in faces)
|
|
||||||
|
|
||||||
packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1]))
|
|
||||||
packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1]))
|
|
||||||
vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64)
|
|
||||||
face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64)
|
|
||||||
|
|
||||||
for i, (v, f) in enumerate(zip(vertices, faces)):
|
|
||||||
packed_vertices[i, :v.shape[0]] = v
|
|
||||||
packed_faces[i, :f.shape[0]] = f
|
|
||||||
|
|
||||||
mesh = Types.MESH(packed_vertices, packed_faces)
|
|
||||||
mesh.vertex_counts = vertex_counts
|
|
||||||
mesh.face_counts = face_counts
|
|
||||||
|
|
||||||
if colors is not None:
|
|
||||||
max_colors = max(c.shape[0] for c in colors)
|
|
||||||
packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1]))
|
|
||||||
color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64)
|
|
||||||
for i, c in enumerate(colors):
|
|
||||||
packed_colors[i, :c.shape[0]] = c
|
|
||||||
mesh.colors = packed_colors
|
|
||||||
mesh.color_counts = color_counts
|
|
||||||
|
|
||||||
return mesh
|
|
||||||
|
|
||||||
|
|
||||||
def get_mesh_batch_item(mesh, index):
|
|
||||||
if hasattr(mesh, "vertex_counts"):
|
|
||||||
vertex_count = int(mesh.vertex_counts[index].item())
|
|
||||||
face_count = int(mesh.face_counts[index].item())
|
|
||||||
vertices = mesh.vertices[index, :vertex_count]
|
|
||||||
faces = mesh.faces[index, :face_count]
|
|
||||||
colors = None
|
|
||||||
if hasattr(mesh, "colors") and mesh.colors is not None:
|
|
||||||
if hasattr(mesh, "color_counts"):
|
|
||||||
color_count = int(mesh.color_counts[index].item())
|
|
||||||
colors = mesh.colors[index, :color_count]
|
|
||||||
else:
|
|
||||||
colors = mesh.colors[index, :vertex_count]
|
|
||||||
return vertices, faces, colors
|
|
||||||
|
|
||||||
colors = None
|
|
||||||
if hasattr(mesh, "colors") and mesh.colors is not None:
|
|
||||||
colors = mesh.colors[index]
|
|
||||||
return mesh.vertices[index], mesh.faces[index], colors
|
|
||||||
|
|
||||||
|
|
||||||
class SaveGLB(IO.ComfyNode):
|
class SaveGLB(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, IO, Types
|
from comfy_api.latest import ComfyExtension, IO, Types
|
||||||
from comfy.ldm.trellis2.vae import SparseTensor
|
from comfy.ldm.trellis2.vae import SparseTensor
|
||||||
|
from comfy_extras.mesh_batch_utils import pack_variable_mesh_batch, get_mesh_batch_item
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -8,57 +9,6 @@ import torch
|
|||||||
import scipy
|
import scipy
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
|
||||||
def pack_variable_mesh_batch(vertices, faces, colors=None):
|
|
||||||
batch_size = len(vertices)
|
|
||||||
max_vertices = max(v.shape[0] for v in vertices)
|
|
||||||
max_faces = max(f.shape[0] for f in faces)
|
|
||||||
|
|
||||||
packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1]))
|
|
||||||
packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1]))
|
|
||||||
vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64)
|
|
||||||
face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64)
|
|
||||||
|
|
||||||
for i, (v, f) in enumerate(zip(vertices, faces)):
|
|
||||||
packed_vertices[i, :v.shape[0]] = v
|
|
||||||
packed_faces[i, :f.shape[0]] = f
|
|
||||||
|
|
||||||
mesh = Types.MESH(packed_vertices, packed_faces)
|
|
||||||
mesh.vertex_counts = vertex_counts
|
|
||||||
mesh.face_counts = face_counts
|
|
||||||
|
|
||||||
if colors is not None:
|
|
||||||
max_colors = max(c.shape[0] for c in colors)
|
|
||||||
packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1]))
|
|
||||||
color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64)
|
|
||||||
for i, c in enumerate(colors):
|
|
||||||
packed_colors[i, :c.shape[0]] = c
|
|
||||||
mesh.colors = packed_colors
|
|
||||||
mesh.color_counts = color_counts
|
|
||||||
|
|
||||||
return mesh
|
|
||||||
|
|
||||||
|
|
||||||
def get_mesh_batch_item(mesh, index):
|
|
||||||
if hasattr(mesh, "vertex_counts"):
|
|
||||||
vertex_count = int(mesh.vertex_counts[index].item())
|
|
||||||
face_count = int(mesh.face_counts[index].item())
|
|
||||||
vertices = mesh.vertices[index, :vertex_count]
|
|
||||||
faces = mesh.faces[index, :face_count]
|
|
||||||
colors = None
|
|
||||||
if hasattr(mesh, "colors") and mesh.colors is not None:
|
|
||||||
if hasattr(mesh, "color_counts"):
|
|
||||||
color_count = int(mesh.color_counts[index].item())
|
|
||||||
colors = mesh.colors[index, :color_count]
|
|
||||||
else:
|
|
||||||
colors = mesh.colors[index, :vertex_count]
|
|
||||||
return vertices, faces, colors
|
|
||||||
|
|
||||||
colors = None
|
|
||||||
if hasattr(mesh, "colors") and mesh.colors is not None:
|
|
||||||
colors = mesh.colors[index]
|
|
||||||
return mesh.vertices[index], mesh.faces[index], colors
|
|
||||||
|
|
||||||
shape_slat_normalization = {
|
shape_slat_normalization = {
|
||||||
"mean": torch.tensor([
|
"mean": torch.tensor([
|
||||||
0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
|
0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
|
||||||
@ -130,14 +80,14 @@ def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution):
|
|||||||
|
|
||||||
final_colors = linear_colors.unsqueeze(0)
|
final_colors = linear_colors.unsqueeze(0)
|
||||||
|
|
||||||
out_mesh = copy.deepcopy(mesh)
|
out_mesh = copy.copy(mesh)
|
||||||
out_mesh.colors = final_colors
|
out_mesh.colors = final_colors
|
||||||
|
|
||||||
return out_mesh
|
return out_mesh
|
||||||
|
|
||||||
|
|
||||||
def paint_mesh_default_colors(mesh):
|
def paint_mesh_default_colors(mesh):
|
||||||
out_mesh = copy.deepcopy(mesh)
|
out_mesh = copy.copy(mesh)
|
||||||
vertex_count = mesh.vertices.shape[1]
|
vertex_count = mesh.vertices.shape[1]
|
||||||
out_mesh.colors = mesh.vertices.new_zeros((1, vertex_count, 3))
|
out_mesh.colors = mesh.vertices.new_zeros((1, vertex_count, 3))
|
||||||
return out_mesh
|
return out_mesh
|
||||||
@ -400,7 +350,7 @@ class Trellis2Conditioning(IO.ComfyNode):
|
|||||||
mask = mask.unsqueeze(0)
|
mask = mask.unsqueeze(0)
|
||||||
batch_size = image.shape[0]
|
batch_size = image.shape[0]
|
||||||
if mask.shape[0] == 1 and batch_size > 1:
|
if mask.shape[0] == 1 and batch_size > 1:
|
||||||
mask = mask.repeat(batch_size, 1, 1)
|
mask = mask.expand(batch_size, -1, -1)
|
||||||
elif mask.shape[0] != batch_size:
|
elif mask.shape[0] != batch_size:
|
||||||
raise ValueError(f"Trellis2Conditioning mask batch {mask.shape[0]} does not match image batch {batch_size}")
|
raise ValueError(f"Trellis2Conditioning mask batch {mask.shape[0]} does not match image batch {batch_size}")
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user