mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-17 22:12:30 +08:00
working version
This commit is contained in:
parent
7d444a4fcc
commit
44adb27782
@ -46,7 +46,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
|
|
||||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
if var_length:
|
if var_length:
|
||||||
return out.contiguous().transpose(1, 2).values()
|
return out.transpose(1, 2).values()
|
||||||
if not skip_output_reshape:
|
if not skip_output_reshape:
|
||||||
out = (
|
out = (
|
||||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
|||||||
@ -767,8 +767,6 @@ class Trellis2(nn.Module):
|
|||||||
if embeds is None:
|
if embeds is None:
|
||||||
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
|
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
|
||||||
is_1024 = self.img2shape.resolution == 1024
|
is_1024 = self.img2shape.resolution == 1024
|
||||||
if is_1024:
|
|
||||||
context = embeds
|
|
||||||
coords = transformer_options.get("coords", None)
|
coords = transformer_options.get("coords", None)
|
||||||
mode = transformer_options.get("generation_mode", "structure_generation")
|
mode = transformer_options.get("generation_mode", "structure_generation")
|
||||||
if coords is not None:
|
if coords is not None:
|
||||||
@ -777,6 +775,8 @@ class Trellis2(nn.Module):
|
|||||||
else:
|
else:
|
||||||
mode = "structure_generation"
|
mode = "structure_generation"
|
||||||
not_struct_mode = False
|
not_struct_mode = False
|
||||||
|
if is_1024 and mode == "shape_generation":
|
||||||
|
context = embeds
|
||||||
sigmas = transformer_options.get("sigmas")[0].item()
|
sigmas = transformer_options.get("sigmas")[0].item()
|
||||||
if sigmas < 1.00001:
|
if sigmas < 1.00001:
|
||||||
timestep *= 1000.0
|
timestep *= 1000.0
|
||||||
|
|||||||
@ -1,9 +1,8 @@
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
from comfy_api.latest import ComfyExtension, IO, Types
|
from comfy_api.latest import ComfyExtension, IO, Types
|
||||||
from comfy.ldm.trellis2.vae import SparseTensor
|
from comfy.ldm.trellis2.vae import SparseTensor
|
||||||
from comfy.utils import ProgressBar, lanczos
|
|
||||||
import torch.nn.functional as TF
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import logging
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -39,93 +38,6 @@ tex_slat_normalization = {
|
|||||||
])[None]
|
])[None]
|
||||||
}
|
}
|
||||||
|
|
||||||
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
|
||||||
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
|
||||||
|
|
||||||
def smart_crop_square(image, mask, margin_ratio=0.1, bg_color=(128, 128, 128)):
|
|
||||||
nz = torch.nonzero(mask[0] > 0.5)
|
|
||||||
if nz.shape[0] == 0:
|
|
||||||
C, H, W = image.shape
|
|
||||||
side = max(H, W)
|
|
||||||
canvas = torch.full((C, side, side), 0.5, device=image.device) # Gray
|
|
||||||
canvas[:, (side-H)//2:(side-H)//2+H, (side-W)//2:(side-W)//2+W] = image
|
|
||||||
return canvas
|
|
||||||
|
|
||||||
y_min, x_min = nz.min(dim=0)[0]
|
|
||||||
y_max, x_max = nz.max(dim=0)[0]
|
|
||||||
|
|
||||||
obj_w, obj_h = x_max - x_min, y_max - y_min
|
|
||||||
center_x, center_y = (x_min + x_max) / 2, (y_min + y_max) / 2
|
|
||||||
|
|
||||||
side = int(max(obj_w, obj_h) * (1 + margin_ratio * 2))
|
|
||||||
half_side = side / 2
|
|
||||||
|
|
||||||
x1, y1 = int(center_x - half_side), int(center_y - half_side)
|
|
||||||
x2, y2 = x1 + side, y1 + side
|
|
||||||
|
|
||||||
C, H, W = image.shape
|
|
||||||
canvas = torch.ones((C, side, side), device=image.device)
|
|
||||||
for c in range(C):
|
|
||||||
canvas[c] *= (bg_color[c] / 255.0)
|
|
||||||
|
|
||||||
src_x1, src_y1 = max(0, x1), max(0, y1)
|
|
||||||
src_x2, src_y2 = min(W, x2), min(H, y2)
|
|
||||||
|
|
||||||
dst_x1, dst_y1 = max(0, -x1), max(0, -y1)
|
|
||||||
dst_x2 = dst_x1 + (src_x2 - src_x1)
|
|
||||||
dst_y2 = dst_y1 + (src_y2 - src_y1)
|
|
||||||
|
|
||||||
img_crop = image[:, src_y1:src_y2, src_x1:src_x2]
|
|
||||||
mask_crop = mask[0, src_y1:src_y2, src_x1:src_x2]
|
|
||||||
|
|
||||||
bg_val = torch.tensor(bg_color, device=image.device, dtype=image.dtype).view(-1, 1, 1) / 255.0
|
|
||||||
|
|
||||||
masked_crop = img_crop * mask_crop + bg_val * (1.0 - mask_crop)
|
|
||||||
|
|
||||||
canvas[:, dst_y1:dst_y2, dst_x1:dst_x2] = masked_crop
|
|
||||||
|
|
||||||
return canvas
|
|
||||||
|
|
||||||
def run_conditioning(model, image, mask, include_1024 = True, background_color = "black"):
|
|
||||||
model_internal = model.model
|
|
||||||
device = comfy.model_management.intermediate_device()
|
|
||||||
torch_device = comfy.model_management.get_torch_device()
|
|
||||||
|
|
||||||
bg_colors = {"black": (0, 0, 0), "gray": (128, 128, 128), "white": (255, 255, 255)}
|
|
||||||
bg_rgb = bg_colors.get(background_color, (128, 128, 128))
|
|
||||||
|
|
||||||
img_t = image[0].movedim(-1, 0).to(torch_device).float()
|
|
||||||
mask_t = mask[0].to(torch_device).float()
|
|
||||||
if mask_t.ndim == 2:
|
|
||||||
mask_t = mask_t.unsqueeze(0)
|
|
||||||
|
|
||||||
cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb)
|
|
||||||
|
|
||||||
def prepare_tensor(img, size):
|
|
||||||
resized = lanczos(img.unsqueeze(0), size, size)
|
|
||||||
return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
|
||||||
|
|
||||||
model_internal.image_size = 512
|
|
||||||
input_512 = prepare_tensor(cropped_img, 512)
|
|
||||||
cond_512 = model_internal(input_512)[0]
|
|
||||||
|
|
||||||
cond_1024 = None
|
|
||||||
if include_1024:
|
|
||||||
model_internal.image_size = 1024
|
|
||||||
input_1024 = prepare_tensor(cropped_img, 1024)
|
|
||||||
cond_1024 = model_internal(input_1024)[0]
|
|
||||||
|
|
||||||
conditioning = {
|
|
||||||
'cond_512': cond_512.to(device),
|
|
||||||
'neg_cond': torch.zeros_like(cond_512).to(device),
|
|
||||||
}
|
|
||||||
if cond_1024 is not None:
|
|
||||||
conditioning['cond_1024'] = cond_1024.to(device)
|
|
||||||
|
|
||||||
preprocessed_tensor = cropped_img.movedim(0, -1).unsqueeze(0).cpu()
|
|
||||||
|
|
||||||
return conditioning, preprocessed_tensor
|
|
||||||
|
|
||||||
class VaeDecodeShapeTrellis(IO.ComfyNode):
|
class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -245,6 +157,39 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
|||||||
out = Types.VOXEL(decoded.squeeze(1).float())
|
out = Types.VOXEL(decoded.squeeze(1).float())
|
||||||
return IO.NodeOutput(out)
|
return IO.NodeOutput(out)
|
||||||
|
|
||||||
|
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
||||||
|
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
||||||
|
|
||||||
|
def run_conditioning(model, cropped_img_tensor, include_1024=True):
|
||||||
|
model_internal = model.model
|
||||||
|
device = comfy.model_management.intermediate_device()
|
||||||
|
torch_device = comfy.model_management.get_torch_device()
|
||||||
|
|
||||||
|
img_t = cropped_img_tensor.to(torch_device)
|
||||||
|
|
||||||
|
def prepare_tensor(img, size):
|
||||||
|
resized = torch.nn.functional.interpolate(img, size=(size, size), mode='bicubic', align_corners=False).clamp(0.0, 1.0)
|
||||||
|
return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
||||||
|
|
||||||
|
model_internal.image_size = 512
|
||||||
|
input_512 = prepare_tensor(img_t, 512)
|
||||||
|
cond_512 = model_internal(input_512)[0]
|
||||||
|
|
||||||
|
cond_1024 = None
|
||||||
|
if include_1024:
|
||||||
|
model_internal.image_size = 1024
|
||||||
|
input_1024 = prepare_tensor(img_t, 1024)
|
||||||
|
cond_1024 = model_internal(input_1024)[0]
|
||||||
|
|
||||||
|
conditioning = {
|
||||||
|
'cond_512': cond_512.to(device),
|
||||||
|
'neg_cond': torch.zeros_like(cond_512).to(device),
|
||||||
|
}
|
||||||
|
if cond_1024 is not None:
|
||||||
|
conditioning['cond_1024'] = cond_1024.to(device)
|
||||||
|
|
||||||
|
return conditioning
|
||||||
|
|
||||||
class Trellis2Conditioning(IO.ComfyNode):
|
class Trellis2Conditioning(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -268,22 +213,60 @@ class Trellis2Conditioning(IO.ComfyNode):
|
|||||||
|
|
||||||
if image.ndim == 4:
|
if image.ndim == 4:
|
||||||
image = image[0]
|
image = image[0]
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask[0]
|
||||||
|
|
||||||
# TODO
|
img_np = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||||
image = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
mask_np = (mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||||
image = Image.fromarray(image)
|
|
||||||
max_size = max(image.size)
|
|
||||||
scale = min(1, 1024 / max_size)
|
|
||||||
if scale < 1:
|
|
||||||
image = image.resize((int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS)
|
|
||||||
new_h, new_w = int(mask.shape[-2] * scale), int(mask.shape[-1] * scale)
|
|
||||||
mask = TF.interpolate(mask.unsqueeze(0).float(), size=(new_h, new_w), mode='nearest').squeeze(0)
|
|
||||||
|
|
||||||
image = torch.tensor(np.array(image)).unsqueeze(0).float() / 255
|
pil_img = Image.fromarray(img_np)
|
||||||
|
pil_mask = Image.fromarray(mask_np)
|
||||||
|
|
||||||
# could make 1024 an option
|
max_size = max(pil_img.size)
|
||||||
conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color)
|
scale = min(1.0, 1024 / max_size)
|
||||||
embeds = conditioning["cond_1024"] # should add that
|
if scale < 1.0:
|
||||||
|
new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale)
|
||||||
|
pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||||
|
pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST)
|
||||||
|
|
||||||
|
rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8)
|
||||||
|
rgba_np[:, :, :3] = np.array(pil_img)
|
||||||
|
rgba_np[:, :, 3] = np.array(pil_mask)
|
||||||
|
|
||||||
|
alpha = rgba_np[:, :, 3]
|
||||||
|
bbox_coords = np.argwhere(alpha > 0.8 * 255)
|
||||||
|
|
||||||
|
if len(bbox_coords) > 0:
|
||||||
|
y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1])
|
||||||
|
y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1])
|
||||||
|
|
||||||
|
center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0
|
||||||
|
size = max(y_max - y_min, x_max - x_min)
|
||||||
|
|
||||||
|
crop_x1 = int(center_x - size // 2)
|
||||||
|
crop_y1 = int(center_y - size // 2)
|
||||||
|
crop_x2 = int(center_x + size // 2)
|
||||||
|
crop_y2 = int(center_y + size // 2)
|
||||||
|
|
||||||
|
rgba_pil = Image.fromarray(rgba_np, 'RGBA')
|
||||||
|
cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2))
|
||||||
|
cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0
|
||||||
|
else:
|
||||||
|
logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.")
|
||||||
|
cropped_np = rgba_np.astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
bg_colors = {"black": [0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]}
|
||||||
|
bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32)
|
||||||
|
|
||||||
|
fg = cropped_np[:, :, :3]
|
||||||
|
alpha_float = cropped_np[:, :, 3:4]
|
||||||
|
composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float)
|
||||||
|
|
||||||
|
cropped_img_tensor = torch.from_numpy(composite_np).movedim(-1, 0).unsqueeze(0).float()
|
||||||
|
|
||||||
|
conditioning = run_conditioning(clip_vision_model, cropped_img_tensor, include_1024=True)
|
||||||
|
|
||||||
|
embeds = conditioning["cond_1024"]
|
||||||
positive = [[conditioning["cond_512"], {"embeds": embeds}]]
|
positive = [[conditioning["cond_512"], {"embeds": embeds}]]
|
||||||
negative = [[conditioning["neg_cond"], {"embeds": torch.zeros_like(embeds)}]]
|
negative = [[conditioning["neg_cond"], {"embeds": torch.zeros_like(embeds)}]]
|
||||||
return IO.NodeOutput(positive, negative)
|
return IO.NodeOutput(positive, negative)
|
||||||
@ -417,118 +400,168 @@ def simplify_fn(vertices, faces, target=100000):
|
|||||||
|
|
||||||
return final_vertices, final_faces
|
return final_vertices, final_faces
|
||||||
|
|
||||||
def fill_holes_fn(vertices, faces, max_hole_perimeter=3e-2):
|
def fill_holes_fn(vertices, faces, max_perimeter=0.03):
|
||||||
is_batched = vertices.ndim == 3
|
is_batched = vertices.ndim == 3
|
||||||
if is_batched:
|
if is_batched:
|
||||||
batch_size = vertices.shape[0]
|
v_list, f_list = [],[]
|
||||||
if batch_size > 1:
|
for i in range(vertices.shape[0]):
|
||||||
v_out, f_out = [], []
|
v_i, f_i = fill_holes_fn(vertices[i], faces[i], max_perimeter)
|
||||||
for i in range(batch_size):
|
v_list.append(v_i)
|
||||||
v, f = fill_holes_fn(vertices[i], faces[i], max_hole_perimeter)
|
f_list.append(f_i)
|
||||||
v_out.append(v)
|
return torch.stack(v_list), torch.stack(f_list)
|
||||||
f_out.append(f)
|
|
||||||
return torch.stack(v_out), torch.stack(f_out)
|
|
||||||
|
|
||||||
vertices = vertices.squeeze(0)
|
|
||||||
faces = faces.squeeze(0)
|
|
||||||
|
|
||||||
device = vertices.device
|
device = vertices.device
|
||||||
orig_vertices = vertices
|
v = vertices
|
||||||
orig_faces = faces
|
f = faces
|
||||||
|
|
||||||
edges = torch.cat([
|
if f.shape[0] == 0:
|
||||||
faces[:, [0, 1]],
|
return v, f
|
||||||
faces[:, [1, 2]],
|
|
||||||
faces[:, [2, 0]]
|
|
||||||
], dim=0)
|
|
||||||
|
|
||||||
|
edges = torch.cat([f[:, [0, 1]], f[:, [1, 2]], f[:, [2, 0]]], dim=0)
|
||||||
edges_sorted, _ = torch.sort(edges, dim=1)
|
edges_sorted, _ = torch.sort(edges, dim=1)
|
||||||
unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True)
|
|
||||||
|
max_v = v.shape[0]
|
||||||
|
packed_undirected = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long()
|
||||||
|
|
||||||
|
unique_packed, counts = torch.unique(packed_undirected, return_counts=True)
|
||||||
boundary_mask = counts == 1
|
boundary_mask = counts == 1
|
||||||
boundary_edges_sorted = unique_edges[boundary_mask]
|
boundary_packed = unique_packed[boundary_mask]
|
||||||
|
|
||||||
if boundary_edges_sorted.shape[0] == 0:
|
if boundary_packed.numel() == 0:
|
||||||
if is_batched:
|
return v, f
|
||||||
return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0)
|
|
||||||
return orig_vertices, orig_faces
|
|
||||||
|
|
||||||
max_idx = vertices.shape[0]
|
packed_directed_sorted = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long()
|
||||||
|
is_boundary = torch.isin(packed_directed_sorted, boundary_packed)
|
||||||
packed_edges_all = torch.sort(edges, dim=1).values
|
boundary_edges_directed = edges[is_boundary]
|
||||||
packed_edges_all = packed_edges_all[:, 0] * max_idx + packed_edges_all[:, 1]
|
|
||||||
|
|
||||||
packed_boundary = boundary_edges_sorted[:, 0] * max_idx + boundary_edges_sorted[:, 1]
|
|
||||||
|
|
||||||
is_boundary_edge = torch.isin(packed_edges_all, packed_boundary)
|
|
||||||
active_boundary_edges = edges[is_boundary_edge]
|
|
||||||
|
|
||||||
adj = {}
|
adj = {}
|
||||||
edges_np = active_boundary_edges.cpu().numpy()
|
in_deg = {}
|
||||||
for u, v in edges_np:
|
out_deg = {}
|
||||||
adj[u] = v
|
|
||||||
|
edges_list = boundary_edges_directed.tolist()
|
||||||
|
for u, v_idx in edges_list:
|
||||||
|
if u not in adj: adj[u] = []
|
||||||
|
adj[u].append(v_idx)
|
||||||
|
out_deg[u] = out_deg.get(u, 0) + 1
|
||||||
|
in_deg[v_idx] = in_deg.get(v_idx, 0) + 1
|
||||||
|
|
||||||
|
manifold_nodes = set()
|
||||||
|
for node in set(list(in_deg.keys()) + list(out_deg.keys())):
|
||||||
|
if in_deg.get(node, 0) == 1 and out_deg.get(node, 0) == 1:
|
||||||
|
manifold_nodes.add(node)
|
||||||
|
|
||||||
|
loops =[]
|
||||||
|
visited_nodes = set()
|
||||||
|
|
||||||
loops = []
|
|
||||||
visited_edges = set()
|
|
||||||
processed_nodes = set()
|
|
||||||
for start_node in list(adj.keys()):
|
for start_node in list(adj.keys()):
|
||||||
if start_node in processed_nodes:
|
if start_node not in manifold_nodes or start_node in visited_nodes:
|
||||||
continue
|
continue
|
||||||
current_loop, curr = [], start_node
|
|
||||||
while curr in adj:
|
curr = start_node
|
||||||
next_node = adj[curr]
|
current_loop =[]
|
||||||
if (curr, next_node) in visited_edges:
|
|
||||||
break
|
while True:
|
||||||
visited_edges.add((curr, next_node))
|
|
||||||
processed_nodes.add(curr)
|
|
||||||
current_loop.append(curr)
|
current_loop.append(curr)
|
||||||
|
visited_nodes.add(curr)
|
||||||
|
|
||||||
|
next_node = adj[curr][0]
|
||||||
|
|
||||||
|
if next_node == start_node:
|
||||||
|
if len(current_loop) >= 3:
|
||||||
|
loops.append(current_loop)
|
||||||
|
break
|
||||||
|
|
||||||
|
if next_node not in manifold_nodes or next_node in visited_nodes:
|
||||||
|
break
|
||||||
|
|
||||||
curr = next_node
|
curr = next_node
|
||||||
if curr == start_node:
|
|
||||||
loops.append(current_loop)
|
if len(current_loop) > len(edges_list):
|
||||||
break
|
|
||||||
if len(current_loop) > len(edges_np):
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if not loops:
|
new_faces =[]
|
||||||
if is_batched:
|
new_verts = []
|
||||||
return orig_vertices.unsqueeze(0), orig_faces.unsqueeze(0)
|
curr_v_idx = v.shape[0]
|
||||||
return orig_vertices, orig_faces
|
|
||||||
|
|
||||||
new_faces = []
|
for loop in loops:
|
||||||
v_offset = vertices.shape[0]
|
loop_indices = torch.tensor(loop, device=device, dtype=torch.long)
|
||||||
valid_new_verts = []
|
loop_points = v[loop_indices]
|
||||||
|
|
||||||
for loop_indices in loops:
|
# Calculate perimeter
|
||||||
if len(loop_indices) < 3:
|
p1 = loop_points
|
||||||
continue
|
p2 = torch.roll(loop_points, shifts=-1, dims=0)
|
||||||
loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device)
|
perimeter = torch.norm(p1 - p2, dim=1).sum().item()
|
||||||
loop_verts = vertices[loop_tensor]
|
|
||||||
diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0)
|
|
||||||
perimeter = torch.norm(diffs, dim=1).sum()
|
|
||||||
|
|
||||||
if perimeter > max_hole_perimeter:
|
if perimeter <= max_perimeter:
|
||||||
continue
|
centroid = loop_points.mean(dim=0)
|
||||||
|
new_verts.append(centroid)
|
||||||
|
center_idx = curr_v_idx
|
||||||
|
curr_v_idx += 1
|
||||||
|
|
||||||
center = loop_verts.mean(dim=0)
|
for i in range(len(loop)):
|
||||||
valid_new_verts.append(center)
|
u_idx = loop[i]
|
||||||
c_idx = v_offset
|
v_next_idx = loop[(i + 1) % len(loop)]
|
||||||
v_offset += 1
|
new_faces.append([u_idx, v_next_idx, center_idx])
|
||||||
|
|
||||||
num_v = len(loop_indices)
|
if new_faces:
|
||||||
for i in range(num_v):
|
v = torch.cat([v, torch.stack(new_verts)], dim=0)
|
||||||
v_curr, v_next = loop_indices[i], loop_indices[(i + 1) % num_v]
|
f = torch.cat([f, torch.tensor(new_faces, device=device, dtype=torch.long)], dim=0)
|
||||||
new_faces.append([v_curr, v_next, c_idx])
|
|
||||||
|
|
||||||
if len(valid_new_verts) > 0:
|
return v, f
|
||||||
added_vertices = torch.stack(valid_new_verts, dim=0)
|
|
||||||
added_faces = torch.tensor(new_faces, dtype=torch.long, device=device)
|
|
||||||
vertices = torch.cat([orig_vertices, added_vertices], dim=0)
|
|
||||||
faces = torch.cat([orig_faces, added_faces], dim=0)
|
|
||||||
else:
|
|
||||||
vertices, faces = orig_vertices, orig_faces
|
|
||||||
|
|
||||||
|
def merge_duplicate_vertices(vertices, faces, tolerance=1e-5):
|
||||||
|
is_batched = vertices.ndim == 3
|
||||||
if is_batched:
|
if is_batched:
|
||||||
return vertices.unsqueeze(0), faces.unsqueeze(0)
|
v_list, f_list = [],[]
|
||||||
|
for i in range(vertices.shape[0]):
|
||||||
|
v_i, f_i = merge_duplicate_vertices(vertices[i], faces[i], tolerance)
|
||||||
|
v_list.append(v_i)
|
||||||
|
f_list.append(f_i)
|
||||||
|
return torch.stack(v_list), torch.stack(f_list)
|
||||||
|
|
||||||
|
v_min = vertices.min(dim=0, keepdim=True)[0]
|
||||||
|
v_quant = ((vertices - v_min) / tolerance).round().long()
|
||||||
|
|
||||||
|
unique_quant, inverse_indices = torch.unique(v_quant, dim=0, return_inverse=True)
|
||||||
|
|
||||||
|
new_vertices = torch.zeros((unique_quant.shape[0], 3), dtype=vertices.dtype, device=vertices.device)
|
||||||
|
new_vertices.index_copy_(0, inverse_indices, vertices)
|
||||||
|
|
||||||
|
new_faces = inverse_indices[faces.long()]
|
||||||
|
|
||||||
|
valid = (new_faces[:, 0] != new_faces[:, 1]) & \
|
||||||
|
(new_faces[:, 1] != new_faces[:, 2]) & \
|
||||||
|
(new_faces[:, 2] != new_faces[:, 0])
|
||||||
|
|
||||||
|
return new_vertices, new_faces[valid]
|
||||||
|
|
||||||
|
def fix_normals(vertices, faces):
|
||||||
|
is_batched = vertices.ndim == 3
|
||||||
|
if is_batched:
|
||||||
|
v_list, f_list = [], []
|
||||||
|
for i in range(vertices.shape[0]):
|
||||||
|
v_i, f_i = fix_normals(vertices[i], faces[i])
|
||||||
|
v_list.append(v_i)
|
||||||
|
f_list.append(f_i)
|
||||||
|
return torch.stack(v_list), torch.stack(f_list)
|
||||||
|
|
||||||
|
if faces.shape[0] == 0:
|
||||||
|
return vertices, faces
|
||||||
|
|
||||||
|
center = vertices.mean(0)
|
||||||
|
v0 = vertices[faces[:, 0].long()]
|
||||||
|
v1 = vertices[faces[:, 1].long()]
|
||||||
|
v2 = vertices[faces[:, 2].long()]
|
||||||
|
|
||||||
|
normals = torch.cross(v1 - v0, v2 - v0, dim=1)
|
||||||
|
|
||||||
|
face_centers = (v0 + v1 + v2) / 3.0
|
||||||
|
dir_from_center = face_centers - center
|
||||||
|
|
||||||
|
dot = (normals * dir_from_center).sum(1)
|
||||||
|
flip_mask = dot < 0
|
||||||
|
|
||||||
|
faces[flip_mask] = faces[flip_mask][:, [0, 2, 1]]
|
||||||
return vertices, faces
|
return vertices, faces
|
||||||
|
|
||||||
class PostProcessMesh(IO.ComfyNode):
|
class PostProcessMesh(IO.ComfyNode):
|
||||||
@ -539,36 +572,31 @@ class PostProcessMesh(IO.ComfyNode):
|
|||||||
category="latent/3d",
|
category="latent/3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Mesh.Input("mesh"),
|
IO.Mesh.Input("mesh"),
|
||||||
IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000), # max?
|
IO.Int.Input("simplify", default=100_000, min=0, max=50_000_000),
|
||||||
IO.Float.Input("fill_holes_perimeter", default=0.003, min=0.0, step=0.0001)
|
IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0, step=0.0001)
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.Mesh.Output("output_mesh"),
|
IO.Mesh.Output("output_mesh"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, mesh, simplify, fill_holes_perimeter):
|
def execute(cls, mesh, simplify, fill_holes_perimeter):
|
||||||
bar = ProgressBar(2)
|
|
||||||
mesh = copy.deepcopy(mesh)
|
mesh = copy.deepcopy(mesh)
|
||||||
verts, faces = mesh.vertices, mesh.faces
|
verts, faces = mesh.vertices, mesh.faces
|
||||||
|
|
||||||
if fill_holes_perimeter != 0.0:
|
verts, faces = merge_duplicate_vertices(verts, faces, tolerance=1e-5)
|
||||||
verts, faces = fill_holes_fn(verts, faces, max_hole_perimeter=fill_holes_perimeter)
|
|
||||||
bar.update(1)
|
|
||||||
else:
|
|
||||||
bar.update(1)
|
|
||||||
|
|
||||||
if simplify != 0:
|
if fill_holes_perimeter > 0:
|
||||||
verts, faces = simplify_fn(verts, faces, simplify)
|
verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter)
|
||||||
bar.update(1)
|
|
||||||
else:
|
|
||||||
bar.update(1)
|
|
||||||
|
|
||||||
# potentially adding laplacian smoothing
|
if simplify > 0 and faces.shape[0] > simplify:
|
||||||
|
verts, faces = simplify_fn(verts, faces, target=simplify)
|
||||||
|
|
||||||
|
verts, faces = fix_normals(verts, faces)
|
||||||
|
|
||||||
mesh.vertices = verts
|
mesh.vertices = verts
|
||||||
mesh.faces = faces
|
mesh.faces = faces
|
||||||
|
|
||||||
return IO.NodeOutput(mesh)
|
return IO.NodeOutput(mesh)
|
||||||
|
|
||||||
class Trellis2Extension(ComfyExtension):
|
class Trellis2Extension(ComfyExtension):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user