mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 07:22:34 +08:00
post-process rewrite + light texture model work
This commit is contained in:
parent
44adb27782
commit
011f624dd5
@ -810,6 +810,10 @@ class Trellis2(nn.Module):
|
|||||||
elif mode == "texture_generation":
|
elif mode == "texture_generation":
|
||||||
if self.shape2txt is None:
|
if self.shape2txt is None:
|
||||||
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
|
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
|
||||||
|
slat = transformer_options.get("shape_slat")
|
||||||
|
if slat is None:
|
||||||
|
raise ValueError("shape_slat can't be None")
|
||||||
|
x = sparse_cat([x, slat])
|
||||||
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
|
||||||
else: # structure
|
else: # structure
|
||||||
#timestep = timestep_reshift(timestep)
|
#timestep = timestep_reshift(timestep)
|
||||||
|
|||||||
@ -38,6 +38,13 @@ tex_slat_normalization = {
|
|||||||
])[None]
|
])[None]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def shape_norm(shape_latent, coords):
|
||||||
|
std = shape_slat_normalization["std"].to(shape_latent)
|
||||||
|
mean = shape_slat_normalization["mean"].to(shape_latent)
|
||||||
|
samples = SparseTensor(feats = shape_latent, coords=coords)
|
||||||
|
samples = samples * std + mean
|
||||||
|
return samples
|
||||||
|
|
||||||
class VaeDecodeShapeTrellis(IO.ComfyNode):
|
class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -70,10 +77,7 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
|
|||||||
|
|
||||||
samples = samples["samples"]
|
samples = samples["samples"]
|
||||||
samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
||||||
std = shape_slat_normalization["std"].to(samples)
|
samples = shape_norm(samples, coords)
|
||||||
mean = shape_slat_normalization["mean"].to(samples)
|
|
||||||
samples = SparseTensor(feats = samples, coords=coords)
|
|
||||||
samples = samples * std + mean
|
|
||||||
|
|
||||||
mesh, subs = vae.decode_shape_slat(samples, resolution)
|
mesh, subs = vae.decode_shape_slat(samples, resolution)
|
||||||
faces = torch.stack([m.faces for m in mesh])
|
faces = torch.stack([m.faces for m in mesh])
|
||||||
@ -313,6 +317,8 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
category="latent/3d",
|
category="latent/3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Voxel.Input("structure_output"),
|
IO.Voxel.Input("structure_output"),
|
||||||
|
IO.Latent.Input("shape_latent"),
|
||||||
|
IO.Model.Input("model")
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Latent.Output(),
|
IO.Latent.Output(),
|
||||||
@ -321,11 +327,15 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, structure_output, model):
|
def execute(cls, structure_output, shape_latent, model):
|
||||||
# TODO
|
# TODO
|
||||||
decoded = structure_output.data.unsqueeze(1)
|
decoded = structure_output.data.unsqueeze(1)
|
||||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||||
in_channels = 32
|
in_channels = 32
|
||||||
|
|
||||||
|
shape_latent = shape_latent["samples"]
|
||||||
|
shape_latent = shape_norm(shape_latent, coords)
|
||||||
|
|
||||||
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])
|
latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])
|
||||||
model = model.clone()
|
model = model.clone()
|
||||||
model.model_options = model.model_options.copy()
|
model.model_options = model.model_options.copy()
|
||||||
@ -336,6 +346,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
|||||||
|
|
||||||
model.model_options["transformer_options"]["coords"] = coords
|
model.model_options["transformer_options"]["coords"] = coords
|
||||||
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
||||||
|
model.model_options["transformer_options"]["shape_slat"] = shape_latent
|
||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model)
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model)
|
||||||
|
|
||||||
|
|
||||||
@ -360,25 +371,34 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
|||||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||||
|
|
||||||
def simplify_fn(vertices, faces, target=100000):
|
def simplify_fn(vertices, faces, target=100000):
|
||||||
|
is_batched = vertices.ndim == 3
|
||||||
|
if is_batched:
|
||||||
|
v_list, f_list = [], []
|
||||||
|
for i in range(vertices.shape[0]):
|
||||||
|
v_i, f_i = simplify_fn(vertices[i], faces[i], target)
|
||||||
|
v_list.append(v_i)
|
||||||
|
f_list.append(f_i)
|
||||||
|
return torch.stack(v_list), torch.stack(f_list)
|
||||||
|
|
||||||
if vertices.shape[0] <= target:
|
if faces.shape[0] <= target:
|
||||||
return vertices, faces
|
return vertices, faces
|
||||||
|
|
||||||
min_feat = vertices.min(dim=0)[0]
|
device = vertices.device
|
||||||
max_feat = vertices.max(dim=0)[0]
|
target_v = target / 2.0
|
||||||
extent = (max_feat - min_feat).max()
|
|
||||||
|
|
||||||
grid_resolution = int(torch.sqrt(torch.tensor(target)).item() * 1.5)
|
min_v = vertices.min(dim=0)[0]
|
||||||
voxel_size = extent / grid_resolution
|
max_v = vertices.max(dim=0)[0]
|
||||||
|
extent = max_v - min_v
|
||||||
|
|
||||||
quantized_coords = ((vertices - min_feat) / voxel_size).long()
|
volume = (extent[0] * extent[1] * extent[2]).clamp(min=1e-8)
|
||||||
|
cell_size = (volume / target_v) ** (1/3.0)
|
||||||
|
|
||||||
unique_coords, inverse_indices = torch.unique(quantized_coords, dim=0, return_inverse=True)
|
quantized = ((vertices - min_v) / cell_size).round().long()
|
||||||
|
unique_coords, inverse_indices = torch.unique(quantized, dim=0, return_inverse=True)
|
||||||
|
num_cells = unique_coords.shape[0]
|
||||||
|
|
||||||
num_new_verts = unique_coords.shape[0]
|
new_vertices = torch.zeros((num_cells, 3), dtype=vertices.dtype, device=device)
|
||||||
new_vertices = torch.zeros((num_new_verts, 3), dtype=vertices.dtype, device=vertices.device)
|
counts = torch.zeros((num_cells, 1), dtype=vertices.dtype, device=device)
|
||||||
|
|
||||||
counts = torch.zeros((num_new_verts, 1), dtype=vertices.dtype, device=vertices.device)
|
|
||||||
|
|
||||||
new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices)
|
new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices)
|
||||||
counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1]))
|
counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1]))
|
||||||
@ -387,11 +407,9 @@ def simplify_fn(vertices, faces, target=100000):
|
|||||||
|
|
||||||
new_faces = inverse_indices[faces]
|
new_faces = inverse_indices[faces]
|
||||||
|
|
||||||
v0 = new_faces[:, 0]
|
valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \
|
||||||
v1 = new_faces[:, 1]
|
(new_faces[:, 1] != new_faces[:, 2]) & \
|
||||||
v2 = new_faces[:, 2]
|
(new_faces[:, 2] != new_faces[:, 0])
|
||||||
|
|
||||||
valid_mask = (v0 != v1) & (v1 != v2) & (v2 != v0)
|
|
||||||
new_faces = new_faces[valid_mask]
|
new_faces = new_faces[valid_mask]
|
||||||
|
|
||||||
unique_face_indices, inv_face = torch.unique(new_faces.reshape(-1), return_inverse=True)
|
unique_face_indices, inv_face = torch.unique(new_faces.reshape(-1), return_inverse=True)
|
||||||
@ -414,7 +432,7 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
|
|||||||
v = vertices
|
v = vertices
|
||||||
f = faces
|
f = faces
|
||||||
|
|
||||||
if f.shape[0] == 0:
|
if f.numel() == 0:
|
||||||
return v, f
|
return v, f
|
||||||
|
|
||||||
edges = torch.cat([f[:, [0, 1]], f[:, [1, 2]], f[:, [2, 0]]], dim=0)
|
edges = torch.cat([f[:, [0, 1]], f[:, [1, 2]], f[:, [2, 0]]], dim=0)
|
||||||
@ -424,145 +442,75 @@ def fill_holes_fn(vertices, faces, max_perimeter=0.03):
|
|||||||
packed_undirected = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long()
|
packed_undirected = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long()
|
||||||
|
|
||||||
unique_packed, counts = torch.unique(packed_undirected, return_counts=True)
|
unique_packed, counts = torch.unique(packed_undirected, return_counts=True)
|
||||||
boundary_mask = counts == 1
|
boundary_packed = unique_packed[counts == 1]
|
||||||
boundary_packed = unique_packed[boundary_mask]
|
|
||||||
|
|
||||||
if boundary_packed.numel() == 0:
|
if boundary_packed.numel() == 0:
|
||||||
return v, f
|
return v, f
|
||||||
|
|
||||||
packed_directed_sorted = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long()
|
packed_directed_sorted = edges[:, 0].min(edges[:, 1]).long() * max_v + edges[:, 0].max(edges[:, 1]).long()
|
||||||
is_boundary = torch.isin(packed_directed_sorted, boundary_packed)
|
is_boundary = torch.isin(packed_directed_sorted, boundary_packed)
|
||||||
boundary_edges_directed = edges[is_boundary]
|
b_edges = edges[is_boundary]
|
||||||
|
|
||||||
adj = {}
|
adj = {u.item(): v_idx.item() for u, v_idx in b_edges}
|
||||||
in_deg = {}
|
|
||||||
out_deg = {}
|
|
||||||
|
|
||||||
edges_list = boundary_edges_directed.tolist()
|
|
||||||
for u, v_idx in edges_list:
|
|
||||||
if u not in adj: adj[u] = []
|
|
||||||
adj[u].append(v_idx)
|
|
||||||
out_deg[u] = out_deg.get(u, 0) + 1
|
|
||||||
in_deg[v_idx] = in_deg.get(v_idx, 0) + 1
|
|
||||||
|
|
||||||
manifold_nodes = set()
|
|
||||||
for node in set(list(in_deg.keys()) + list(out_deg.keys())):
|
|
||||||
if in_deg.get(node, 0) == 1 and out_deg.get(node, 0) == 1:
|
|
||||||
manifold_nodes.add(node)
|
|
||||||
|
|
||||||
loops =[]
|
loops =[]
|
||||||
visited_nodes = set()
|
visited = set()
|
||||||
|
|
||||||
for start_node in list(adj.keys()):
|
for start_node in adj.keys():
|
||||||
if start_node not in manifold_nodes or start_node in visited_nodes:
|
if start_node in visited:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
curr = start_node
|
curr = start_node
|
||||||
current_loop =[]
|
loop = []
|
||||||
|
|
||||||
while True:
|
while curr not in visited:
|
||||||
current_loop.append(curr)
|
visited.add(curr)
|
||||||
visited_nodes.add(curr)
|
loop.append(curr)
|
||||||
|
curr = adj.get(curr, -1)
|
||||||
|
|
||||||
next_node = adj[curr][0]
|
if curr == -1:
|
||||||
|
loop = []
|
||||||
if next_node == start_node:
|
break
|
||||||
if len(current_loop) >= 3:
|
if curr == start_node:
|
||||||
loops.append(current_loop)
|
loops.append(loop)
|
||||||
break
|
break
|
||||||
|
|
||||||
if next_node not in manifold_nodes or next_node in visited_nodes:
|
new_verts =[]
|
||||||
break
|
new_faces = []
|
||||||
|
v_idx = v.shape[0]
|
||||||
curr = next_node
|
|
||||||
|
|
||||||
if len(current_loop) > len(edges_list):
|
|
||||||
break
|
|
||||||
|
|
||||||
new_faces =[]
|
|
||||||
new_verts = []
|
|
||||||
curr_v_idx = v.shape[0]
|
|
||||||
|
|
||||||
for loop in loops:
|
for loop in loops:
|
||||||
loop_indices = torch.tensor(loop, device=device, dtype=torch.long)
|
loop_t = torch.tensor(loop, device=device, dtype=torch.long)
|
||||||
loop_points = v[loop_indices]
|
loop_v = v[loop_t]
|
||||||
|
|
||||||
# Calculate perimeter
|
diffs = loop_v - torch.roll(loop_v, shifts=-1, dims=0)
|
||||||
p1 = loop_points
|
perimeter = torch.norm(diffs, dim=1).sum().item()
|
||||||
p2 = torch.roll(loop_points, shifts=-1, dims=0)
|
|
||||||
perimeter = torch.norm(p1 - p2, dim=1).sum().item()
|
|
||||||
|
|
||||||
if perimeter <= max_perimeter:
|
if perimeter <= max_perimeter:
|
||||||
centroid = loop_points.mean(dim=0)
|
new_verts.append(loop_v.mean(dim=0))
|
||||||
new_verts.append(centroid)
|
|
||||||
center_idx = curr_v_idx
|
|
||||||
curr_v_idx += 1
|
|
||||||
|
|
||||||
for i in range(len(loop)):
|
for i in range(len(loop)):
|
||||||
u_idx = loop[i]
|
new_faces.append([loop[i], loop[(i + 1) % len(loop)], v_idx])
|
||||||
v_next_idx = loop[(i + 1) % len(loop)]
|
v_idx += 1
|
||||||
new_faces.append([u_idx, v_next_idx, center_idx])
|
|
||||||
|
|
||||||
if new_faces:
|
if new_verts:
|
||||||
v = torch.cat([v, torch.stack(new_verts)], dim=0)
|
v = torch.cat([v, torch.stack(new_verts)], dim=0)
|
||||||
f = torch.cat([f, torch.tensor(new_faces, device=device, dtype=torch.long)], dim=0)
|
f = torch.cat([f, torch.tensor(new_faces, device=device, dtype=torch.long)], dim=0)
|
||||||
|
|
||||||
return v, f
|
return v, f
|
||||||
|
|
||||||
def merge_duplicate_vertices(vertices, faces, tolerance=1e-5):
|
def make_double_sided(vertices, faces):
|
||||||
is_batched = vertices.ndim == 3
|
is_batched = vertices.ndim == 3
|
||||||
if is_batched:
|
if is_batched:
|
||||||
v_list, f_list = [],[]
|
f_list =[]
|
||||||
for i in range(vertices.shape[0]):
|
for i in range(faces.shape[0]):
|
||||||
v_i, f_i = merge_duplicate_vertices(vertices[i], faces[i], tolerance)
|
f_inv = faces[i][:,[0, 2, 1]]
|
||||||
v_list.append(v_i)
|
f_list.append(torch.cat([faces[i], f_inv], dim=0))
|
||||||
f_list.append(f_i)
|
return vertices, torch.stack(f_list)
|
||||||
return torch.stack(v_list), torch.stack(f_list)
|
|
||||||
|
|
||||||
v_min = vertices.min(dim=0, keepdim=True)[0]
|
faces_inv = faces[:, [0, 2, 1]]
|
||||||
v_quant = ((vertices - v_min) / tolerance).round().long()
|
faces_double = torch.cat([faces, faces_inv], dim=0)
|
||||||
|
return vertices, faces_double
|
||||||
unique_quant, inverse_indices = torch.unique(v_quant, dim=0, return_inverse=True)
|
|
||||||
|
|
||||||
new_vertices = torch.zeros((unique_quant.shape[0], 3), dtype=vertices.dtype, device=vertices.device)
|
|
||||||
new_vertices.index_copy_(0, inverse_indices, vertices)
|
|
||||||
|
|
||||||
new_faces = inverse_indices[faces.long()]
|
|
||||||
|
|
||||||
valid = (new_faces[:, 0] != new_faces[:, 1]) & \
|
|
||||||
(new_faces[:, 1] != new_faces[:, 2]) & \
|
|
||||||
(new_faces[:, 2] != new_faces[:, 0])
|
|
||||||
|
|
||||||
return new_vertices, new_faces[valid]
|
|
||||||
|
|
||||||
def fix_normals(vertices, faces):
|
|
||||||
is_batched = vertices.ndim == 3
|
|
||||||
if is_batched:
|
|
||||||
v_list, f_list = [], []
|
|
||||||
for i in range(vertices.shape[0]):
|
|
||||||
v_i, f_i = fix_normals(vertices[i], faces[i])
|
|
||||||
v_list.append(v_i)
|
|
||||||
f_list.append(f_i)
|
|
||||||
return torch.stack(v_list), torch.stack(f_list)
|
|
||||||
|
|
||||||
if faces.shape[0] == 0:
|
|
||||||
return vertices, faces
|
|
||||||
|
|
||||||
center = vertices.mean(0)
|
|
||||||
v0 = vertices[faces[:, 0].long()]
|
|
||||||
v1 = vertices[faces[:, 1].long()]
|
|
||||||
v2 = vertices[faces[:, 2].long()]
|
|
||||||
|
|
||||||
normals = torch.cross(v1 - v0, v2 - v0, dim=1)
|
|
||||||
|
|
||||||
face_centers = (v0 + v1 + v2) / 3.0
|
|
||||||
dir_from_center = face_centers - center
|
|
||||||
|
|
||||||
dot = (normals * dir_from_center).sum(1)
|
|
||||||
flip_mask = dot < 0
|
|
||||||
|
|
||||||
faces[flip_mask] = faces[flip_mask][:, [0, 2, 1]]
|
|
||||||
return vertices, faces
|
|
||||||
|
|
||||||
class PostProcessMesh(IO.ComfyNode):
|
class PostProcessMesh(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -572,7 +520,7 @@ class PostProcessMesh(IO.ComfyNode):
|
|||||||
category="latent/3d",
|
category="latent/3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Mesh.Input("mesh"),
|
IO.Mesh.Input("mesh"),
|
||||||
IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000),
|
IO.Int.Input("simplify", default=1_000_000, min=0, max=50_000_000),
|
||||||
IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0, step=0.0001)
|
IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0, step=0.0001)
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
@ -585,15 +533,13 @@ class PostProcessMesh(IO.ComfyNode):
|
|||||||
mesh = copy.deepcopy(mesh)
|
mesh = copy.deepcopy(mesh)
|
||||||
verts, faces = mesh.vertices, mesh.faces
|
verts, faces = mesh.vertices, mesh.faces
|
||||||
|
|
||||||
verts, faces = merge_duplicate_vertices(verts, faces, tolerance=1e-5)
|
|
||||||
|
|
||||||
if fill_holes_perimeter > 0:
|
if fill_holes_perimeter > 0:
|
||||||
verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter)
|
verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter)
|
||||||
|
|
||||||
if simplify > 0 and faces.shape[0] > simplify:
|
if simplify > 0 and faces.shape[0] > simplify:
|
||||||
verts, faces = simplify_fn(verts, faces, target=simplify)
|
verts, faces = simplify_fn(verts, faces, target=simplify)
|
||||||
|
|
||||||
verts, faces = fix_normals(verts, faces)
|
verts, faces = make_double_sided(verts, faces)
|
||||||
|
|
||||||
mesh.vertices = verts
|
mesh.vertices = verts
|
||||||
mesh.faces = faces
|
mesh.faces = faces
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user