mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-27 09:27:24 +08:00
Vertex Clustering, Mask Fix, Normal Fix (#14035)
* Vertex Clustering, Mask Fix, Normal Fix * detects inverted mask * update the decimate mesh
This commit is contained in:
parent
7547897b6f
commit
e90bde2f82
@ -625,7 +625,6 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, normals_in, target_faces,
|
||||
# Finalize
|
||||
final_v = verts[v_alive]
|
||||
final_c = colors[v_alive] if colors is not None else None
|
||||
final_n = normals[v_alive] if normals is not None else None
|
||||
|
||||
remap = torch.full((num_verts,), -1, dtype=torch.int64, device=device)
|
||||
remap[v_alive] = torch.arange(int(v_alive.sum().item()), device=device)
|
||||
@ -640,26 +639,9 @@ def _qem_simplify_fast(vertices, faces_in, colors_in, normals_in, target_faces,
|
||||
if final_f.numel() > 0:
|
||||
final_f = torch.unique(torch.sort(final_f, dim=1)[0], dim=0)
|
||||
|
||||
if final_n is not None and final_f.numel() > 0:
|
||||
v0, v1, v2 = final_v[final_f[:, 0]], final_v[final_f[:, 1]], final_v[final_f[:, 2]]
|
||||
|
||||
# calculate the actual normal of the simplified faces
|
||||
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
||||
|
||||
# Get the average reference normal for each face
|
||||
n0, n1, n2 = final_n[final_f[:, 0]], final_n[final_f[:, 1]], final_n[final_f[:, 2]]
|
||||
ref_face_normals = (n0 + n1 + n2) / 3.0
|
||||
|
||||
# Dot product to check if they point in the same direction
|
||||
dot_products = (face_normals * ref_face_normals).sum(dim=-1)
|
||||
|
||||
# Flip the indices of ONLY the incorrect faces (swap vertex 1 and 2)
|
||||
wrong_way_mask = dot_products < 0
|
||||
final_f[wrong_way_mask] = final_f[wrong_way_mask][:, [0, 2, 1]]
|
||||
|
||||
final_v, final_f = _cleanup_mesh(final_v, final_f, min_angle_deg=0.5, max_aspect=100.0)
|
||||
|
||||
return final_v, final_f, final_c, final_n
|
||||
return final_v, final_f, final_c, None
|
||||
|
||||
|
||||
def simplify_fn_fast(vertices, faces, colors=None, normals=None, target=100000, max_edge_length=None):
|
||||
@ -709,6 +691,79 @@ def simplify_fn_fast(vertices, faces, colors=None, normals=None, target=100000,
|
||||
)
|
||||
return final_v, final_f, final_c, final_n
|
||||
|
||||
def simplify_fn_vertex(vertices, faces, colors=None, target=100000):
|
||||
if vertices.ndim == 3:
|
||||
v_list, f_list, c_list = [], [], []
|
||||
for i in range(vertices.shape[0]):
|
||||
c_in = colors[i] if colors is not None else None
|
||||
v_i, f_i, c_i = simplify_fn_vertex(vertices[i], faces[i], c_in, target)
|
||||
v_list.append(v_i)
|
||||
f_list.append(f_i)
|
||||
if c_i is not None:
|
||||
c_list.append(c_i)
|
||||
|
||||
c_out = torch.stack(c_list) if len(c_list) > 0 else None
|
||||
return torch.stack(v_list), torch.stack(f_list), c_out
|
||||
|
||||
if faces.shape[0] <= target:
|
||||
return vertices, faces, colors
|
||||
|
||||
device = vertices.device
|
||||
target_v = max(target / 4.0, 1.0)
|
||||
|
||||
min_v = vertices.min(dim=0)[0]
|
||||
max_v = vertices.max(dim=0)[0]
|
||||
extent = max_v - min_v
|
||||
|
||||
volume = (extent[0] * extent[1] * extent[2]).clamp(min=1e-8)
|
||||
cell_size = (volume / target_v) ** (1/3.0)
|
||||
|
||||
# Use CPU-side ordered reductions here so repeated runs produce identical
|
||||
# simplified meshes instead of relying on GPU scatter-add accumulation order.
|
||||
vertices_np = vertices.detach().cpu().numpy()
|
||||
faces_np = faces.detach().cpu().numpy()
|
||||
colors_np = colors.detach().cpu().numpy() if colors is not None else None
|
||||
min_v_np = min_v.detach().cpu().numpy()
|
||||
cell_size_value = float(cell_size.detach().cpu())
|
||||
|
||||
quantized = np.rint((vertices_np - min_v_np) / cell_size_value).astype(np.int64)
|
||||
unique_coords, inverse_indices = np.unique(quantized, axis=0, return_inverse=True)
|
||||
num_cells = unique_coords.shape[0]
|
||||
|
||||
new_vertices_np = np.zeros((num_cells, 3), dtype=vertices_np.dtype)
|
||||
np.add.at(new_vertices_np, inverse_indices, vertices_np)
|
||||
|
||||
counts_np = np.bincount(inverse_indices, minlength=num_cells).astype(vertices_np.dtype).reshape(-1, 1)
|
||||
new_vertices_np = new_vertices_np / np.clip(counts_np, 1, None)
|
||||
|
||||
new_colors = None
|
||||
if colors_np is not None:
|
||||
new_colors_np = np.zeros((num_cells, colors_np.shape[1]), dtype=colors_np.dtype)
|
||||
np.add.at(new_colors_np, inverse_indices, colors_np)
|
||||
new_colors = new_colors_np / np.clip(counts_np, 1, None)
|
||||
|
||||
new_faces = inverse_indices[faces_np]
|
||||
valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \
|
||||
(new_faces[:, 1] != new_faces[:, 2]) & \
|
||||
(new_faces[:, 2] != new_faces[:, 0])
|
||||
new_faces = new_faces[valid_mask]
|
||||
|
||||
if new_faces.size == 0:
|
||||
final_vertices_np = new_vertices_np[:0]
|
||||
final_faces_np = np.empty((0, 3), dtype=np.int64)
|
||||
final_colors_np = new_colors[:0] if new_colors is not None else None
|
||||
else:
|
||||
unique_face_indices, inv_face = np.unique(new_faces.reshape(-1), return_inverse=True)
|
||||
final_vertices_np = new_vertices_np[unique_face_indices]
|
||||
final_faces_np = inv_face.reshape(-1, 3).astype(np.int64)
|
||||
final_colors_np = new_colors[unique_face_indices] if new_colors is not None else None
|
||||
|
||||
final_vertices = torch.from_numpy(final_vertices_np).to(device=device, dtype=vertices.dtype)
|
||||
final_faces = torch.from_numpy(final_faces_np).to(device=device, dtype=faces.dtype)
|
||||
final_colors = torch.from_numpy(final_colors_np).to(device=device, dtype=colors.dtype) if final_colors_np is not None else None
|
||||
|
||||
return final_vertices, final_faces, final_colors
|
||||
|
||||
def compute_vertex_normals(verts, faces):
|
||||
"""Computes area-weighted vertex normals."""
|
||||
# QUICK FIX: Ensure indices are int64 for scatter_add_
|
||||
@ -781,6 +836,235 @@ def _process_mesh_batch(mesh, per_item_fn):
|
||||
return IO.NodeOutput(mesh)
|
||||
|
||||
|
||||
def fix_face_orientation(vertices, faces, reference_normals=None):
|
||||
num_faces = faces.shape[0]
|
||||
if num_faces == 0:
|
||||
return faces
|
||||
|
||||
device = faces.device
|
||||
corrected = faces.clone()
|
||||
|
||||
idx = torch.tensor([[0, 1], [1, 2], [2, 0]], dtype=torch.int64, device=device)
|
||||
edges = corrected[:, idx] # (num_faces, 3, 2)
|
||||
|
||||
edges_canon = torch.sort(edges, dim=2)[0]
|
||||
edges_flat = edges_canon.view(-1, 2)
|
||||
|
||||
max_vert = vertices.shape[0]
|
||||
edge_hash = edges_flat[:, 0] * max_vert + edges_flat[:, 1]
|
||||
|
||||
hash_sorted, sort_idx = torch.sort(edge_hash)
|
||||
|
||||
hash_diff = hash_sorted[1:] != hash_sorted[:-1]
|
||||
hash_diff = torch.cat([torch.tensor([True], device=device), hash_diff])
|
||||
unique_starts = torch.nonzero(hash_diff, as_tuple=True)[0]
|
||||
unique_ends = torch.cat([unique_starts[1:], torch.tensor([len(hash_sorted)], device=device)])
|
||||
run_lengths = unique_ends - unique_starts
|
||||
|
||||
manifold_mask = run_lengths == 2
|
||||
manifold_starts = unique_starts[manifold_mask]
|
||||
|
||||
component_id_np = np.full(num_faces, -1, dtype=np.int64)
|
||||
|
||||
if manifold_starts.numel() > 0:
|
||||
# Replaces slow, nested element-wise matching with direct index mapping
|
||||
f_a = sort_idx[manifold_starts] // 3
|
||||
f_b = sort_idx[manifold_starts + 1] // 3
|
||||
local_edge_a = sort_idx[manifold_starts] % 3
|
||||
local_edge_b = sort_idx[manifold_starts + 1] % 3
|
||||
|
||||
dir_edge_a = edges[f_a, local_edge_a]
|
||||
dir_edge_b = edges[f_b, local_edge_b]
|
||||
|
||||
opposite = (dir_edge_a == dir_edge_b.flip(dims=[1])).all(dim=1)
|
||||
needs_flip_rel = ~opposite
|
||||
|
||||
adj_faces = torch.cat([f_a, f_b])
|
||||
adj_neighbors = torch.cat([f_b, f_a])
|
||||
adj_flip = torch.cat([needs_flip_rel, needs_flip_rel])
|
||||
|
||||
adj_order = torch.argsort(adj_faces)
|
||||
adj_faces_np = adj_faces[adj_order].cpu().numpy()
|
||||
adj_neighbors_np = adj_neighbors[adj_order].cpu().numpy()
|
||||
adj_flip_np = adj_flip[adj_order].cpu().numpy()
|
||||
|
||||
# Build CSR-style adjacency on CPU using NumPy
|
||||
adj_ptr_np = np.zeros(num_faces + 1, dtype=np.int64)
|
||||
counts_np = np.bincount(adj_faces_np, minlength=num_faces)
|
||||
adj_ptr_np[1:] = np.cumsum(counts_np)
|
||||
|
||||
visited_np = np.zeros(num_faces, dtype=bool)
|
||||
flip_state_np = np.zeros(num_faces, dtype=bool)
|
||||
comp_counter = 0
|
||||
|
||||
queue_np = np.empty(num_faces, dtype=np.int64)
|
||||
|
||||
for seed in range(num_faces):
|
||||
if visited_np[seed]:
|
||||
continue
|
||||
|
||||
visited_np[seed] = True
|
||||
component_id_np[seed] = comp_counter
|
||||
q_head = 0
|
||||
q_tail = 1
|
||||
queue_np[0] = seed
|
||||
|
||||
while q_head < q_tail:
|
||||
current = queue_np[q_head]
|
||||
q_head += 1
|
||||
|
||||
start = adj_ptr_np[current]
|
||||
end = adj_ptr_np[current + 1]
|
||||
if start == end:
|
||||
continue
|
||||
|
||||
nbrs = adj_neighbors_np[start:end]
|
||||
flips = adj_flip_np[start:end]
|
||||
src_flip = flip_state_np[current]
|
||||
|
||||
unvisited_mask = ~visited_np[nbrs]
|
||||
if not np.any(unvisited_mask):
|
||||
continue
|
||||
|
||||
nbrs_new = nbrs[unvisited_mask]
|
||||
flips_new = flips[unvisited_mask]
|
||||
|
||||
visited_np[nbrs_new] = True
|
||||
component_id_np[nbrs_new] = comp_counter
|
||||
|
||||
# NumPy bitwise XOR is fast and direct
|
||||
flip_state_np[nbrs_new] = flips_new ^ src_flip
|
||||
|
||||
n_new = len(nbrs_new)
|
||||
queue_np[q_tail:q_tail + n_new] = nbrs_new
|
||||
q_tail += n_new
|
||||
|
||||
comp_counter += 1
|
||||
|
||||
flip_state = torch.from_numpy(flip_state_np).to(device=device)
|
||||
component_id = torch.from_numpy(component_id_np).to(device=device)
|
||||
|
||||
if flip_state.any():
|
||||
corrected[flip_state] = corrected[flip_state][:, [0, 2, 1]]
|
||||
else:
|
||||
component_id = torch.arange(num_faces, device=device)
|
||||
|
||||
v0 = vertices[corrected[:, 0]]
|
||||
v1 = vertices[corrected[:, 1]]
|
||||
v2 = vertices[corrected[:, 2]]
|
||||
|
||||
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
||||
face_normals = face_normals / (torch.norm(face_normals, dim=-1, keepdim=True) + 1e-8)
|
||||
|
||||
num_components = int(component_id.max().item()) + 1 if component_id.numel() > 0 else 0
|
||||
|
||||
if reference_normals is not None:
|
||||
n0 = reference_normals[corrected[:, 0]]
|
||||
n1 = reference_normals[corrected[:, 1]]
|
||||
n2 = reference_normals[corrected[:, 2]]
|
||||
ref_normals = (n0 + n1 + n2) / 3.0
|
||||
ref_normals = ref_normals / (torch.norm(ref_normals, dim=-1, keepdim=True) + 1e-8)
|
||||
|
||||
votes = (face_normals * ref_normals).sum(dim=-1)
|
||||
|
||||
outward_votes_comp = torch.zeros(num_components, dtype=torch.int64, device=device)
|
||||
inward_votes_comp = torch.zeros(num_components, dtype=torch.int64, device=device)
|
||||
|
||||
outward_votes_comp.scatter_add_(0, component_id, (votes > 0).to(torch.int64))
|
||||
inward_votes_comp.scatter_add_(0, component_id, (votes < 0).to(torch.int64))
|
||||
|
||||
n_faces_comp_int = torch.zeros(num_components, dtype=torch.int64, device=device)
|
||||
n_faces_comp_int.scatter_add_(0, component_id, torch.ones(num_faces, dtype=torch.int64, device=device))
|
||||
|
||||
thresholds = torch.maximum(torch.ones_like(n_faces_comp_int), n_faces_comp_int // 10)
|
||||
should_flip_comp = inward_votes_comp > outward_votes_comp + thresholds
|
||||
else:
|
||||
# Vectorized 3-Axis Extreme Majority Vote (Geometrically Infallible)
|
||||
face_centroids = (v0 + v1 + v2) / 3.0
|
||||
|
||||
votes_by_axis = []
|
||||
for axis in range(3):
|
||||
coords = face_centroids[:, axis]
|
||||
|
||||
# Double stable sort acts as a vectorized lexsort on (coords, component_id)
|
||||
sort_idx = torch.argsort(coords, stable=True)
|
||||
sort_idx = sort_idx[torch.argsort(component_id[sort_idx], stable=True)]
|
||||
|
||||
# Find group boundaries to get the extreme outer face along this axis per component
|
||||
comp_id_sorted = component_id[sort_idx]
|
||||
group_ends = torch.nonzero(comp_id_sorted[1:] != comp_id_sorted[:-1], as_tuple=True)[0]
|
||||
group_ends = torch.cat([group_ends, torch.tensor([len(comp_id_sorted) - 1], device=device)])
|
||||
|
||||
extreme_face_indices = sort_idx[group_ends]
|
||||
extreme_normals = face_normals[extreme_face_indices]
|
||||
|
||||
# Normal's component along the respective axis should be positive
|
||||
votes_by_axis.append(extreme_normals[:, axis] > 0)
|
||||
|
||||
stacked_votes = torch.stack(votes_by_axis, dim=0)
|
||||
should_flip_comp = stacked_votes.sum(dim=0) < 2 # False if at least 2 axes agree outward
|
||||
|
||||
should_flip_face = should_flip_comp[component_id]
|
||||
if should_flip_face.any():
|
||||
corrected[should_flip_face] = corrected[should_flip_face][:, [0, 2, 1]]
|
||||
|
||||
return corrected
|
||||
|
||||
|
||||
def unweld_and_offset_mesh(vertices, faces, colors=None, z_offset=1e-4):
|
||||
is_batched = vertices.ndim == 3
|
||||
device = vertices.device
|
||||
|
||||
if is_batched:
|
||||
B = vertices.shape[0]
|
||||
F = faces.shape[1]
|
||||
|
||||
# 1. Advanced index broadcast to pull all faces in parallel without any Python loops
|
||||
batch_idx = torch.arange(B, device=device).view(-1, 1, 1)
|
||||
v_faces = vertices[batch_idx, faces] # shape (B, F, 3, 3)
|
||||
|
||||
v0, v1, v2 = v_faces[:, :, 0], v_faces[:, :, 1], v_faces[:, :, 2]
|
||||
|
||||
# 2. Compute face normals
|
||||
fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
||||
fn = fn / (torch.norm(fn, dim=-1, keepdim=True) + 1e-8)
|
||||
|
||||
# 3. Translate directly along the face normals in parallel
|
||||
offset_verts = v_faces + fn.unsqueeze(2) * z_offset
|
||||
out_v = offset_verts.reshape(B, -1, 3)
|
||||
|
||||
# 4. Generate identical faces for all batches using constant expansion (O(1))
|
||||
f_single = torch.arange(F * 3, device=device).reshape(-1, 3)
|
||||
out_f = f_single.unsqueeze(0).expand(B, -1, -1)
|
||||
|
||||
if colors is not None:
|
||||
c_faces = colors[batch_idx, faces]
|
||||
out_c = c_faces.reshape(B, -1, colors.shape[-1])
|
||||
return out_v, out_f, out_c
|
||||
return out_v, out_f
|
||||
|
||||
# --- Unbatched (Single Mesh) ---
|
||||
v_faces = vertices[faces] # shape (F, 3, 3)
|
||||
v0, v1, v2 = v_faces[:, 0], v_faces[:, 1], v_faces[:, 2]
|
||||
|
||||
# Compute face normals
|
||||
fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
||||
fn = fn / (torch.norm(fn, dim=-1, keepdim=True) + 1e-8)
|
||||
|
||||
# Offset each face's private vertices along its face normal
|
||||
offset_verts = v_faces + fn.unsqueeze(1) * z_offset
|
||||
offset_verts = offset_verts.reshape(-1, 3)
|
||||
|
||||
# Generate sequential face indices for the unwelded vertices
|
||||
f_unwelded = torch.arange(faces.shape[0] * 3, device=vertices.device).reshape(-1, 3)
|
||||
|
||||
if colors is not None:
|
||||
c_faces = colors[faces]
|
||||
c_unwelded = c_faces.reshape(-1, colors.shape[-1])
|
||||
return offset_verts, f_unwelded, c_unwelded
|
||||
|
||||
return offset_verts, f_unwelded, None
|
||||
|
||||
class DecimateMesh(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -801,8 +1085,23 @@ class DecimateMesh(IO.ComfyNode):
|
||||
def execute(cls, mesh, target_face_count):
|
||||
def _fn(v, f, c):
|
||||
if target_face_count > 0 and f.shape[0] > target_face_count:
|
||||
n = compute_vertex_normals(v, f)
|
||||
v, f, c, _ = simplify_fn_fast(v, f, colors=c, normals=n, target=target_face_count)
|
||||
try:
|
||||
v0, v1, v2 = v[f[:, 0]], v[f[:, 1]], v[f[:, 2]]
|
||||
fn = torch.cross(v1 - v0, v2 - v0, dim=-1)
|
||||
fn = fn / (torch.norm(fn, dim=-1, keepdim=True) + 1e-8)
|
||||
|
||||
n = torch.zeros_like(v)
|
||||
n.index_add_(0, f[:, 0], fn)
|
||||
n.index_add_(0, f[:, 1], fn)
|
||||
n.index_add_(0, f[:, 2], fn)
|
||||
n = n / (torch.norm(n, dim=-1, keepdim=True) + 1e-8)
|
||||
|
||||
v, f, c, _ = simplify_fn_fast(v, f, colors=c, normals=n, target=target_face_count)
|
||||
f = fix_face_orientation(v, f)
|
||||
v, f, c = unweld_and_offset_mesh(v, f, colors=c, z_offset=1e-4)
|
||||
except Exception as e:
|
||||
logging.warning("Ran into an error while QEM Simplifying, falling back to vertex clustering:\n" + str(e))
|
||||
v, f, c = simplify_fn_vertex(v, f, c, target_face_count)
|
||||
return v, f, c
|
||||
return _process_mesh_batch(mesh, _fn)
|
||||
|
||||
|
||||
@ -445,13 +445,31 @@ class Trellis2Conditioning(IO.ComfyNode):
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@classmethod
|
||||
def execute(cls, clip_vision_model, image, mask) -> IO.NodeOutput:
|
||||
# Normalize to batched form so per-image conditioning loop below is uniform.
|
||||
if image.ndim == 3:
|
||||
image = image.unsqueeze(0)
|
||||
if mask.ndim == 2:
|
||||
elif image.ndim == 4:
|
||||
if image.shape[1] in [1, 3, 4] and image.shape[-1] not in [1, 3, 4]:
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
|
||||
# normalize mask to standard [B, H, W] (handling 2D, 3D, and 4D variants)
|
||||
if mask.ndim == 4:
|
||||
if mask.shape[1] == 1:
|
||||
mask = mask.squeeze(1)
|
||||
elif mask.shape[-1] == 1:
|
||||
mask = mask.squeeze(-1)
|
||||
else:
|
||||
mask = mask[:, :, :, 0] # take first channel as fallback
|
||||
|
||||
if mask.ndim == 3:
|
||||
if mask.shape[-1] == 1:
|
||||
mask = mask.squeeze(-1).unsqueeze(0)
|
||||
elif mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
batch_size = image.shape[0]
|
||||
if mask.shape[0] == 1 and batch_size > 1:
|
||||
mask = mask.expand(batch_size, -1, -1)
|
||||
@ -468,6 +486,27 @@ class Trellis2Conditioning(IO.ComfyNode):
|
||||
img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||
mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||
|
||||
# Ensure img_np is either 2D (grayscale) or 3D (RGB/RGBA)
|
||||
if img_np.ndim == 3 and img_np.shape[-1] == 1:
|
||||
img_np = img_np.squeeze(-1)
|
||||
|
||||
mask_np = mask_np.squeeze()
|
||||
|
||||
# detect inverted mask
|
||||
border_pixels = np.concatenate([
|
||||
mask_np[0, :], mask_np[-1, :], mask_np[:, 0], mask_np[:, -1]
|
||||
])
|
||||
if np.mean(border_pixels) > 127:
|
||||
mask_np = 255 - mask_np
|
||||
|
||||
mask_np[mask_np < 35] = 0
|
||||
|
||||
border_shave = 4
|
||||
mask_np[:border_shave, :] = 0
|
||||
mask_np[-border_shave:, :] = 0
|
||||
mask_np[:, :border_shave] = 0
|
||||
mask_np[:, -border_shave:] = 0
|
||||
|
||||
pil_img = Image.fromarray(img_np)
|
||||
pil_mask = Image.fromarray(mask_np)
|
||||
|
||||
@ -479,7 +518,7 @@ class Trellis2Conditioning(IO.ComfyNode):
|
||||
pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST)
|
||||
|
||||
rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8)
|
||||
rgba_np[:, :, :3] = np.array(pil_img)
|
||||
rgba_np[:, :, :3] = np.array(pil_img.convert("RGB"))
|
||||
rgba_np[:, :, 3] = np.array(pil_mask)
|
||||
|
||||
alpha = rgba_np[:, :, 3]
|
||||
@ -511,12 +550,18 @@ class Trellis2Conditioning(IO.ComfyNode):
|
||||
alpha_float = cropped_np[:, :, 3:4]
|
||||
composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float)
|
||||
|
||||
# to match trellis2 code (quantize -> dequantize)
|
||||
composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
|
||||
# Keep the image as 4-channel RGBA to force TRELLIS to bypass its internal background remover
|
||||
rgb_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
|
||||
alpha_uint8 = (alpha_float.squeeze(-1) * 255.0).round().clip(0, 255).astype(np.uint8)
|
||||
|
||||
cropped_pil = Image.fromarray(composite_uint8)
|
||||
rgba_composite = np.zeros((cropped_np.shape[0], cropped_np.shape[1], 4), dtype=np.uint8)
|
||||
rgba_composite[:, :, :3] = rgb_uint8
|
||||
rgba_composite[:, :, 3] = alpha_uint8
|
||||
|
||||
item_conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True)
|
||||
cropped_pil = Image.fromarray(rgba_composite, mode="RGBA")
|
||||
|
||||
# Convert to RGB to ensure the CLIP/DINO model receives a 3-channel image
|
||||
item_conditioning = run_conditioning(clip_vision_model, cropped_pil.convert("RGB"), include_1024=True)
|
||||
cond_512_list.append(item_conditioning["cond_512"])
|
||||
cond_1024_list.append(item_conditioning["cond_1024"])
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user