mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
fixes to vae and cumesh impl.
This commit is contained in:
parent
cd0f7ba64e
commit
f2c0320fe8
@ -5,6 +5,10 @@ import torch
|
||||
from typing import Dict, Callable
|
||||
|
||||
NO_TRITION = False
|
||||
try:
|
||||
allow_tf32 = torch.cuda.is_tf32_supported
|
||||
except Exception:
|
||||
allow_tf32 = False
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
@ -102,10 +106,13 @@ try:
|
||||
grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),)
|
||||
sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid](
|
||||
input, weight, bias, neighbor, sorted_idx, output,
|
||||
N, LOGN, Ci, Co, V, #
|
||||
N, LOGN, Ci, Co, V,
|
||||
B1=128,
|
||||
B2=64,
|
||||
BK=32,
|
||||
valid_kernel=valid_kernel,
|
||||
valid_kernel_seg=valid_kernel_seg,
|
||||
allow_tf32=torch.cuda.is_tf32_supported(),
|
||||
allow_tf32=allow_tf32,
|
||||
)
|
||||
return output
|
||||
except:
|
||||
@ -140,16 +147,16 @@ def build_submanifold_neighbor_map(
|
||||
|
||||
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long)
|
||||
|
||||
b = coords[:, 0]
|
||||
x = coords[:, 1]
|
||||
y = coords[:, 2]
|
||||
z = coords[:, 3]
|
||||
b = coords[:, 0].long()
|
||||
x = coords[:, 1].long()
|
||||
y = coords[:, 2].long()
|
||||
z = coords[:, 3].long()
|
||||
|
||||
offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device)
|
||||
|
||||
ox = x[:, None] - (Kw // 2) * Dw
|
||||
oy = y[:, None] - (Kh // 2) * Dh
|
||||
oz = z[:, None] - (Kd // 2) * Dd
|
||||
ox = x - (Kw // 2) * Dw
|
||||
oy = y - (Kh // 2) * Dh
|
||||
oz = z - (Kd // 2) * Dd
|
||||
|
||||
for v in range(half_V):
|
||||
if v == half_V - 1:
|
||||
@ -158,10 +165,11 @@ def build_submanifold_neighbor_map(
|
||||
|
||||
dx, dy, dz = offsets[v]
|
||||
|
||||
kx = ox[:, v] + dx
|
||||
ky = oy[:, v] + dy
|
||||
kz = oz[:, v] + dz
|
||||
kx = ox + dx
|
||||
ky = oy + dy
|
||||
kz = oz + dz
|
||||
|
||||
# Check spatial bounds
|
||||
valid = (
|
||||
(kx >= 0) & (kx < W) &
|
||||
(ky >= 0) & (ky < H) &
|
||||
@ -169,22 +177,22 @@ def build_submanifold_neighbor_map(
|
||||
)
|
||||
|
||||
flat = (
|
||||
b * (W * H * D) +
|
||||
kx * (H * D) +
|
||||
ky * D +
|
||||
kz
|
||||
b[valid] * (W * H * D) +
|
||||
kx[valid] * (H * D) +
|
||||
ky[valid] * D +
|
||||
kz[valid]
|
||||
)
|
||||
|
||||
flat = flat[valid]
|
||||
idx = torch.nonzero(valid, as_tuple=False).squeeze(1)
|
||||
if flat.numel() > 0:
|
||||
found = hashmap.lookup_flat(flat)
|
||||
idx_in_M = torch.where(valid)[0]
|
||||
neighbor[idx_in_M, v] = found
|
||||
|
||||
found = hashmap.lookup_flat(flat)
|
||||
|
||||
neighbor[idx, v] = found
|
||||
|
||||
# symmetric write
|
||||
valid_found = found != INVALID
|
||||
neighbor[found[valid_found], V - 1 - v] = idx[valid_found]
|
||||
valid_found_mask = (found != INVALID)
|
||||
if valid_found_mask.any():
|
||||
src_points = idx_in_M[valid_found_mask]
|
||||
dst_points = found[valid_found_mask]
|
||||
neighbor[dst_points, V - 1 - v] = src_points
|
||||
|
||||
return neighbor
|
||||
|
||||
@ -461,31 +469,118 @@ class Mesh:
|
||||
def cpu(self):
|
||||
return self.to('cpu')
|
||||
|
||||
# TODO could be an option
|
||||
# could make this into a new node
|
||||
def fill_holes(self, max_hole_perimeter=3e-2):
|
||||
import cumesh
|
||||
vertices = self.vertices.cuda()
|
||||
faces = self.faces.cuda()
|
||||
|
||||
mesh = cumesh.CuMesh()
|
||||
mesh.init(vertices, faces)
|
||||
mesh.get_edges()
|
||||
mesh.get_boundary_info()
|
||||
if mesh.num_boundaries == 0:
|
||||
return
|
||||
mesh.get_vertex_edge_adjacency()
|
||||
mesh.get_vertex_boundary_adjacency()
|
||||
mesh.get_manifold_boundary_adjacency()
|
||||
mesh.read_manifold_boundary_adjacency()
|
||||
mesh.get_boundary_connected_components()
|
||||
mesh.get_boundary_loops()
|
||||
if mesh.num_boundary_loops == 0:
|
||||
return
|
||||
mesh.fill_holes(max_hole_perimeter=max_hole_perimeter)
|
||||
new_vertices, new_faces = mesh.read()
|
||||
device = self.vertices.device
|
||||
vertices = self.vertices
|
||||
faces = self.faces
|
||||
|
||||
edges = torch.cat([
|
||||
faces[:, [0, 1]],
|
||||
faces[:, [1, 2]],
|
||||
faces[:, [2, 0]]
|
||||
], dim=0)
|
||||
|
||||
edges_sorted, _ = torch.sort(edges, dim=1)
|
||||
|
||||
unique_edges, counts = torch.unique(edges_sorted, dim=0, return_counts=True)
|
||||
|
||||
boundary_mask = counts == 1
|
||||
boundary_edges_sorted = unique_edges[boundary_mask]
|
||||
|
||||
if boundary_edges_sorted.shape[0] == 0:
|
||||
return
|
||||
max_idx = vertices.shape[0]
|
||||
|
||||
_, inverse_indices, counts_packed = torch.unique(
|
||||
torch.sort(edges, dim=1).values[:, 0] * max_idx + torch.sort(edges, dim=1).values[:, 1],
|
||||
return_inverse=True, return_counts=True
|
||||
)
|
||||
|
||||
boundary_packed_mask = counts_packed == 1
|
||||
is_boundary_edge = boundary_packed_mask[inverse_indices]
|
||||
|
||||
active_boundary_edges = edges[is_boundary_edge]
|
||||
|
||||
adj = {}
|
||||
edges_np = active_boundary_edges.cpu().numpy()
|
||||
for u, v in edges_np:
|
||||
adj[u] = v
|
||||
|
||||
loops = []
|
||||
visited_edges = set()
|
||||
|
||||
possible_starts = list(adj.keys())
|
||||
|
||||
processed_nodes = set()
|
||||
|
||||
for start_node in possible_starts:
|
||||
if start_node in processed_nodes:
|
||||
continue
|
||||
|
||||
current_loop = []
|
||||
curr = start_node
|
||||
|
||||
while curr in adj:
|
||||
next_node = adj[curr]
|
||||
if (curr, next_node) in visited_edges:
|
||||
break
|
||||
|
||||
visited_edges.add((curr, next_node))
|
||||
processed_nodes.add(curr)
|
||||
current_loop.append(curr)
|
||||
|
||||
curr = next_node
|
||||
|
||||
if curr == start_node:
|
||||
loops.append(current_loop)
|
||||
break
|
||||
|
||||
if len(current_loop) > len(edges_np):
|
||||
break
|
||||
|
||||
if not loops:
|
||||
return
|
||||
|
||||
new_faces = []
|
||||
|
||||
v_offset = vertices.shape[0]
|
||||
|
||||
valid_new_verts = []
|
||||
|
||||
for loop_indices in loops:
|
||||
if len(loop_indices) < 3:
|
||||
continue
|
||||
|
||||
loop_tensor = torch.tensor(loop_indices, dtype=torch.long, device=device)
|
||||
loop_verts = vertices[loop_tensor]
|
||||
|
||||
diffs = loop_verts - torch.roll(loop_verts, shifts=-1, dims=0)
|
||||
perimeter = torch.norm(diffs, dim=1).sum()
|
||||
|
||||
if perimeter > max_hole_perimeter:
|
||||
continue
|
||||
|
||||
center = loop_verts.mean(dim=0)
|
||||
valid_new_verts.append(center)
|
||||
|
||||
c_idx = v_offset
|
||||
v_offset += 1
|
||||
|
||||
num_v = len(loop_indices)
|
||||
for i in range(num_v):
|
||||
v_curr = loop_indices[i]
|
||||
v_next = loop_indices[(i + 1) % num_v]
|
||||
new_faces.append([v_curr, v_next, c_idx])
|
||||
|
||||
if len(valid_new_verts) > 0:
|
||||
added_vertices = torch.stack(valid_new_verts, dim=0)
|
||||
added_faces = torch.tensor(new_faces, dtype=torch.long, device=device)
|
||||
|
||||
self.vertices = torch.cat([self.vertices, added_vertices], dim=0)
|
||||
self.faces = torch.cat([self.faces, added_faces], dim=0)
|
||||
|
||||
self.vertices = new_vertices.to(self.device)
|
||||
self.faces = new_faces.to(self.device)
|
||||
|
||||
# TODO could be an option
|
||||
def simplify(self, target=1000000, verbose: bool=False, options: dict={}):
|
||||
|
||||
@ -208,7 +208,7 @@ class SparseResBlockC2S3d(nn.Module):
|
||||
self.to_subdiv = SparseLinear(channels, 8)
|
||||
self.updown = SparseChannel2Spatial(2)
|
||||
|
||||
def _forward(self, x, subdiv = None):
|
||||
def forward(self, x, subdiv = None):
|
||||
if self.pred_subdiv:
|
||||
subdiv = self.to_subdiv(x)
|
||||
h = x.replace(self.norm1(x.feats))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user