mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-05 22:21:31 +08:00
Miscellanous cleanup
This commit is contained in:
parent
419e726061
commit
d635cc412d
@ -163,6 +163,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
naf_sd = {k[len("naf."):]: sd.pop(k) for k in naf_keys}
|
naf_sd = {k[len("naf."):]: sd.pop(k) for k in naf_keys}
|
||||||
naf = NAF().eval()
|
naf = NAF().eval()
|
||||||
naf.load_state_dict(naf_sd, strict=False)
|
naf.load_state_dict(naf_sd, strict=False)
|
||||||
|
naf.to(comfy.model_management.text_encoder_dtype(clip.load_device))
|
||||||
clip.naf = comfy.model_patcher.CoreModelPatcher(naf, load_device=clip.load_device, offload_device=comfy.model_management.text_encoder_offload_device())
|
clip.naf = comfy.model_patcher.CoreModelPatcher(naf, load_device=clip.load_device, offload_device=comfy.model_management.text_encoder_offload_device())
|
||||||
return clip
|
return clip
|
||||||
|
|
||||||
|
|||||||
@ -55,10 +55,12 @@ class SparseFeedForwardNet(nn.Module):
|
|||||||
def forward(self, x: VarLenTensor) -> VarLenTensor:
|
def forward(self, x: VarLenTensor) -> VarLenTensor:
|
||||||
return self.mlp(x)
|
return self.mlp(x)
|
||||||
|
|
||||||
class SparseMultiHeadRMSNorm(nn.Module):
|
class MultiHeadRMSNorm(nn.Module):
|
||||||
def __init__(self, dim: int, heads: int, device, dtype):
|
# Per-head qk-norm for both sparse (VarLenTensor) and dense inputs. gamma is [heads, dim]
|
||||||
|
# (per-head), so it's a broadcast multiply rather than F.rms_norm's 1-D weight
|
||||||
|
def __init__(self, dim: int, heads: int, device=None, dtype=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
|
self.gamma = nn.Parameter(torch.empty(heads, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
|
def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
|
||||||
if isinstance(x, VarLenTensor):
|
if isinstance(x, VarLenTensor):
|
||||||
@ -147,8 +149,8 @@ class SparseMultiHeadAttention(nn.Module):
|
|||||||
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype)
|
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||||
|
|
||||||
if self.qk_rms_norm:
|
if self.qk_rms_norm:
|
||||||
self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||||
self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||||
|
|
||||||
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
|
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
|
||||||
|
|
||||||
@ -307,7 +309,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
|
|||||||
operations.Linear(channels, 6 * channels, bias=True, device=device, dtype=dtype)
|
operations.Linear(channels, 6 * channels, bias=True, device=device, dtype=dtype)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
|
self.modulation = nn.Parameter(torch.empty(6 * channels, device=device, dtype=dtype))
|
||||||
|
|
||||||
def _forward(self, x: SparseTensor, mod: torch.Tensor, context, transformer_options=None) -> SparseTensor:
|
def _forward(self, x: SparseTensor, mod: torch.Tensor, context, transformer_options=None) -> SparseTensor:
|
||||||
if self.share_mod:
|
if self.share_mod:
|
||||||
@ -444,24 +446,6 @@ class FeedForwardNet(nn.Module):
|
|||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
return self.mlp(x)
|
return self.mlp(x)
|
||||||
|
|
||||||
# class MultiHeadRMSNorm(nn.Module):
|
|
||||||
# def __init__(self, dim: int, heads: int, device=None, dtype=None):
|
|
||||||
# super().__init__()
|
|
||||||
# self.scale = dim ** 0.5
|
|
||||||
# self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
|
|
||||||
|
|
||||||
# def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
# return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
|
|
||||||
|
|
||||||
class MultiHeadRMSNorm(nn.Module):
|
|
||||||
def __init__(self, dim: int, heads: int, device=None, dtype=None):
|
|
||||||
super().__init__()
|
|
||||||
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
return (F.rms_norm(x.float(), (x.shape[-1],)) * self.gamma).to(x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention(nn.Module):
|
class MultiHeadAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -580,7 +564,7 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
|||||||
if not share_mod:
|
if not share_mod:
|
||||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
|
||||||
else:
|
else:
|
||||||
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
|
self.modulation = nn.Parameter(torch.empty(6 * channels, device=device, dtype=dtype))
|
||||||
|
|
||||||
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context,
|
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context,
|
||||||
phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor:
|
phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor:
|
||||||
@ -631,7 +615,7 @@ class SparseStructureFlowModel(nn.Module):
|
|||||||
proj_in_channels: Optional[int] = None,
|
proj_in_channels: Optional[int] = None,
|
||||||
operations=None,
|
operations=None,
|
||||||
device = None,
|
device = None,
|
||||||
dtype = torch.float32,
|
dtype = None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -918,17 +918,14 @@ def flexible_dual_grid_to_mesh(
|
|||||||
):
|
):
|
||||||
|
|
||||||
device = coords.device
|
device = coords.device
|
||||||
if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset") \
|
# Small constant index tables — built per call (stateless), not cached on the function.
|
||||||
or flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset.device != device:
|
edge_neighbor_voxel_offset = torch.tensor([
|
||||||
flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset = torch.tensor([
|
[[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis
|
||||||
[[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis
|
[[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis
|
||||||
[[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis
|
[[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis
|
||||||
[[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis
|
], dtype=torch.int, device=device).unsqueeze(0)
|
||||||
], dtype=torch.int, device=device).unsqueeze(0)
|
quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device)
|
||||||
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1") or flexible_dual_grid_to_mesh.quad_split_1.device != device:
|
quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device)
|
||||||
flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
|
|
||||||
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2") or flexible_dual_grid_to_mesh.quad_split_2.device != device:
|
|
||||||
flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
|
|
||||||
|
|
||||||
aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
|
aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
@ -954,7 +951,7 @@ def flexible_dual_grid_to_mesh(
|
|||||||
|
|
||||||
# Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3]
|
# Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3]
|
||||||
n_idx, axis_idx = intersected_flag.nonzero(as_tuple=True) # (M,), (M,)
|
n_idx, axis_idx = intersected_flag.nonzero(as_tuple=True) # (M,), (M,)
|
||||||
offsets_per_axis = flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset[0] # (3, 4, 3)
|
offsets_per_axis = edge_neighbor_voxel_offset[0] # (3, 4, 3)
|
||||||
connected_voxel = coords[n_idx].unsqueeze(1) + offsets_per_axis[axis_idx] # (M, 4, 3)
|
connected_voxel = coords[n_idx].unsqueeze(1) + offsets_per_axis[axis_idx] # (M, 4, 3)
|
||||||
M = connected_voxel.shape[0]
|
M = connected_voxel.shape[0]
|
||||||
# flatten connected voxel coords and lookup. In-place to avoid extra memory allocation.
|
# flatten connected voxel coords and lookup. In-place to avoid extra memory allocation.
|
||||||
@ -973,12 +970,12 @@ def flexible_dual_grid_to_mesh(
|
|||||||
mesh_vertices.add_(dual_vertices).mul_(voxel_size).add_(aabb[0].reshape(1, 3))
|
mesh_vertices.add_(dual_vertices).mul_(voxel_size).add_(aabb[0].reshape(1, 3))
|
||||||
if split_weight is None:
|
if split_weight is None:
|
||||||
# if split 1
|
# if split 1
|
||||||
atempt_triangles_0 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1]
|
atempt_triangles_0 = quad_indices[:, quad_split_1]
|
||||||
normals0 = torch.cross(mesh_vertices[atempt_triangles_0[:, 1]] - mesh_vertices[atempt_triangles_0[:, 0]], mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 0]])
|
normals0 = torch.cross(mesh_vertices[atempt_triangles_0[:, 1]] - mesh_vertices[atempt_triangles_0[:, 0]], mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 0]])
|
||||||
normals1 = torch.cross(mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 1]], mesh_vertices[atempt_triangles_0[:, 3]] - mesh_vertices[atempt_triangles_0[:, 1]])
|
normals1 = torch.cross(mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 1]], mesh_vertices[atempt_triangles_0[:, 3]] - mesh_vertices[atempt_triangles_0[:, 1]])
|
||||||
align0 = (normals0 * normals1).sum(dim=1, keepdim=True).abs()
|
align0 = (normals0 * normals1).sum(dim=1, keepdim=True).abs()
|
||||||
# if split 2
|
# if split 2
|
||||||
atempt_triangles_1 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2]
|
atempt_triangles_1 = quad_indices[:, quad_split_2]
|
||||||
normals0 = torch.cross(mesh_vertices[atempt_triangles_1[:, 1]] - mesh_vertices[atempt_triangles_1[:, 0]], mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 0]])
|
normals0 = torch.cross(mesh_vertices[atempt_triangles_1[:, 1]] - mesh_vertices[atempt_triangles_1[:, 0]], mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 0]])
|
||||||
normals1 = torch.cross(mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 1]], mesh_vertices[atempt_triangles_1[:, 3]] - mesh_vertices[atempt_triangles_1[:, 1]])
|
normals1 = torch.cross(mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 1]], mesh_vertices[atempt_triangles_1[:, 3]] - mesh_vertices[atempt_triangles_1[:, 1]])
|
||||||
align1 = (normals0 * normals1).sum(dim=1, keepdim=True).abs()
|
align1 = (normals0 * normals1).sum(dim=1, keepdim=True).abs()
|
||||||
@ -990,8 +987,8 @@ def flexible_dual_grid_to_mesh(
|
|||||||
split_weight_ws_13 = split_weight_ws[:, 1] * split_weight_ws[:, 3]
|
split_weight_ws_13 = split_weight_ws[:, 1] * split_weight_ws[:, 3]
|
||||||
mesh_triangles = torch.where(
|
mesh_triangles = torch.where(
|
||||||
split_weight_ws_02 > split_weight_ws_13,
|
split_weight_ws_02 > split_weight_ws_13,
|
||||||
quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1],
|
quad_indices[:, quad_split_1],
|
||||||
quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2]
|
quad_indices[:, quad_split_2]
|
||||||
).reshape(-1, 3)
|
).reshape(-1, 3)
|
||||||
|
|
||||||
return mesh_vertices, mesh_triangles
|
return mesh_vertices, mesh_triangles
|
||||||
|
|||||||
@ -39,7 +39,12 @@ class MESH:
|
|||||||
vertex_counts: torch.Tensor | None = None,
|
vertex_counts: torch.Tensor | None = None,
|
||||||
face_counts: torch.Tensor | None = None,
|
face_counts: torch.Tensor | None = None,
|
||||||
unlit: bool = False,
|
unlit: bool = False,
|
||||||
normals: torch.Tensor | None = None):
|
normals: torch.Tensor | None = None,
|
||||||
|
tangents: torch.Tensor | None = None,
|
||||||
|
normal_map: torch.Tensor | None = None,
|
||||||
|
occlusion_in_mr: bool = False,
|
||||||
|
material: dict | None = None,
|
||||||
|
emissive: torch.Tensor | None = None):
|
||||||
|
|
||||||
assert (vertex_counts is None) == (face_counts is None), \
|
assert (vertex_counts is None) == (face_counts is None), \
|
||||||
"vertex_counts and face_counts must be provided together (both or neither)"
|
"vertex_counts and face_counts must be provided together (both or neither)"
|
||||||
@ -59,6 +64,13 @@ class MESH:
|
|||||||
self.face_counts = face_counts
|
self.face_counts = face_counts
|
||||||
# Render flat / emissive (no scene lighting) when saved, e.g. for gaussian-splat-derived meshes.
|
# Render flat / emissive (no scene lighting) when saved, e.g. for gaussian-splat-derived meshes.
|
||||||
self.unlit = unlit
|
self.unlit = unlit
|
||||||
|
# Extra maps / material overrides attached by bake, normal/AO, and SetMeshMaterial nodes;
|
||||||
|
# consumed by SaveGLB. Declared here (with defaults) so consumers read them directly.
|
||||||
|
self.tangents = tangents # (B, N, 4) per-vertex tangents for normal mapping
|
||||||
|
self.normal_map = normal_map # tangent-space normal map: (B, H, W, 3)
|
||||||
|
self.occlusion_in_mr = occlusion_in_mr # True = R channel of metallic_roughness holds AO (ORM)
|
||||||
|
self.material = material # SetMeshMaterial scalar/factor overrides
|
||||||
|
self.emissive = emissive # emissive map: (B, H, W, 3)
|
||||||
|
|
||||||
|
|
||||||
class File3D:
|
class File3D:
|
||||||
|
|||||||
@ -1282,25 +1282,17 @@ def qem_simplify(
|
|||||||
iteration = 0
|
iteration = 0
|
||||||
total_collapses = 0
|
total_collapses = 0
|
||||||
|
|
||||||
# progress bars (tqdm + optional comfy ProgressBar), best-effort
|
# progress bars (tqdm + comfy ProgressBar)
|
||||||
_start_faces = num_faces
|
_start_faces = num_faces
|
||||||
_prog_total = max(1, _start_faces - int(target_faces))
|
_prog_total = max(1, _start_faces - int(target_faces))
|
||||||
try:
|
_qtq = _tqdm(total=100, desc="QEM simplify", leave=False)
|
||||||
_qtq = _tqdm(total=100, desc="QEM simplify", leave=False)
|
_qpbar = _comfy_utils.ProgressBar(100)
|
||||||
except Exception:
|
|
||||||
_qtq = None
|
|
||||||
try:
|
|
||||||
_qpbar = _comfy_utils.ProgressBar(100)
|
|
||||||
except Exception:
|
|
||||||
_qpbar = None
|
|
||||||
|
|
||||||
def _qreport():
|
def _qreport():
|
||||||
pct = min(100, max(0, int(100 * (_start_faces - py_n_faces) / _prog_total)))
|
pct = min(100, max(0, int(100 * (_start_faces - py_n_faces) / _prog_total)))
|
||||||
if _qtq is not None:
|
_qtq.n = pct
|
||||||
_qtq.n = pct
|
_qtq.refresh()
|
||||||
_qtq.refresh()
|
_qpbar.update_absolute(pct, 100)
|
||||||
if _qpbar is not None:
|
|
||||||
_qpbar.update_absolute(pct, 100)
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if py_n_faces <= target_faces:
|
if py_n_faces <= target_faces:
|
||||||
@ -1523,8 +1515,7 @@ def qem_simplify(
|
|||||||
break
|
break
|
||||||
|
|
||||||
_qreport()
|
_qreport()
|
||||||
if _qtq is not None:
|
_qtq.close()
|
||||||
_qtq.close()
|
|
||||||
|
|
||||||
# finalize: compact verts and faces
|
# finalize: compact verts and faces
|
||||||
final_v = verts[v_alive]
|
final_v = verts[v_alive]
|
||||||
|
|||||||
@ -522,7 +522,6 @@ def _build_mdc_lut() -> Tuple[torch.Tensor, torch.Tensor]:
|
|||||||
return K, group
|
return K, group
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(maxsize=None)
|
|
||||||
def _mdc_lut(device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
def _mdc_lut(device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
K, g = _build_mdc_lut()
|
K, g = _build_mdc_lut()
|
||||||
return K.to(device), g.to(device)
|
return K.to(device), g.to(device)
|
||||||
@ -967,15 +966,11 @@ def remesh_narrow_band_dc(
|
|||||||
n_levels += 1
|
n_levels += 1
|
||||||
_total_ticks = n_levels + 3 + int(smooth_iters)
|
_total_ticks = n_levels + 3 + int(smooth_iters)
|
||||||
_pbar = comfy.utils.ProgressBar(_total_ticks)
|
_pbar = comfy.utils.ProgressBar(_total_ticks)
|
||||||
try:
|
_tq = _tqdm(total=_total_ticks, desc="Remesh DC", leave=False)
|
||||||
_tq = _tqdm(total=_total_ticks, desc="Remesh DC", leave=False)
|
|
||||||
except Exception:
|
|
||||||
_tq = None
|
|
||||||
|
|
||||||
def tick():
|
def tick():
|
||||||
_pbar.update(1)
|
_pbar.update(1)
|
||||||
if _tq is not None:
|
_tq.update(1)
|
||||||
_tq.update(1)
|
|
||||||
|
|
||||||
# Step 1: sparse narrow-band voxel grid (coarse-to-fine)
|
# Step 1: sparse narrow-band voxel grid (coarse-to-fine)
|
||||||
voxel_coords, _band_tree = _build_narrow_band_voxels(
|
voxel_coords, _band_tree = _build_narrow_band_voxels(
|
||||||
|
|||||||
@ -366,8 +366,8 @@ def lscm_chart(
|
|||||||
x_free = solve_least_squares(A_free, b)
|
x_free = solve_least_squares(A_free, b)
|
||||||
if not np.all(np.isfinite(x_free)):
|
if not np.all(np.isfinite(x_free)):
|
||||||
fallback_to_ortho = True
|
fallback_to_ortho = True
|
||||||
except Exception:
|
except (sp.linalg.MatrixRankWarning, RuntimeError):
|
||||||
fallback_to_ortho = True
|
fallback_to_ortho = True # singular / under-constrained system
|
||||||
|
|
||||||
if fallback_to_ortho:
|
if fallback_to_ortho:
|
||||||
if pin_positions is not None and pin_positions.shape == (Vc, 2):
|
if pin_positions is not None and pin_positions.shape == (Vc, 2):
|
||||||
|
|||||||
@ -74,31 +74,22 @@ def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=Non
|
|||||||
)
|
)
|
||||||
packed_tangents[i, :tn.shape[0]] = tn
|
packed_tangents[i, :tn.shape[0]] = tn
|
||||||
|
|
||||||
out = Types.MESH(packed_vertices, packed_faces,
|
return Types.MESH(packed_vertices, packed_faces,
|
||||||
uvs=packed_uvs, vertex_colors=packed_colors, texture=texture,
|
uvs=packed_uvs, vertex_colors=packed_colors, texture=texture,
|
||||||
metallic_roughness=metallic_roughness,
|
metallic_roughness=metallic_roughness,
|
||||||
vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit,
|
vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit,
|
||||||
normals=packed_normals)
|
normals=packed_normals, tangents=packed_tangents,
|
||||||
if packed_tangents is not None:
|
normal_map=normal_map, occlusion_in_mr=occlusion_in_mr,
|
||||||
out.tangents = packed_tangents
|
material=material, emissive=emissive)
|
||||||
if normal_map is not None:
|
|
||||||
out.normal_map = normal_map
|
|
||||||
if occlusion_in_mr:
|
|
||||||
out.occlusion_in_mr = True
|
|
||||||
if material is not None:
|
|
||||||
out.material = material
|
|
||||||
if emissive is not None:
|
|
||||||
out.emissive = emissive
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def get_mesh_batch_item(mesh, index):
|
def get_mesh_batch_item(mesh, index):
|
||||||
# Returns (vertices, faces, colors, uvs) for batch index, slicing to real lengths
|
# Returns (vertices, faces, colors, uvs) for batch index, slicing to real lengths
|
||||||
# if the mesh carries per-item counts (variable-size batch).
|
# if the mesh carries per-item counts (variable-size batch).
|
||||||
v_colors = getattr(mesh, "vertex_colors", None)
|
v_colors = mesh.vertex_colors
|
||||||
v_uvs = getattr(mesh, "uvs", None)
|
v_uvs = mesh.uvs
|
||||||
v_normals = getattr(mesh, "normals", None)
|
v_normals = mesh.normals
|
||||||
if getattr(mesh, "vertex_counts", None) is not None:
|
if mesh.vertex_counts is not None:
|
||||||
vertex_count = int(mesh.vertex_counts[index].item())
|
vertex_count = int(mesh.vertex_counts[index].item())
|
||||||
face_count = int(mesh.face_counts[index].item())
|
face_count = int(mesh.face_counts[index].item())
|
||||||
vertices = mesh.vertices[index, :vertex_count]
|
vertices = mesh.vertices[index, :vertex_count]
|
||||||
@ -482,6 +473,7 @@ def save_glb(vertices, faces, filepath=None, metadata=None,
|
|||||||
|
|
||||||
if (texture_png_bytes is not None and has_uv) or "COLOR_0" in primitive_attributes:
|
if (texture_png_bytes is not None and has_uv) or "COLOR_0" in primitive_attributes:
|
||||||
pbr["baseColorFactor"] = [1.0, 1.0, 1.0, 1.0]
|
pbr["baseColorFactor"] = [1.0, 1.0, 1.0, 1.0]
|
||||||
|
pbr["roughnessFactor"] = 1.0
|
||||||
|
|
||||||
if mr_png_bytes is not None and has_uv:
|
if mr_png_bytes is not None and has_uv:
|
||||||
mr_texture_index = add_image_texture(mr_byte_offset, len(mr_buffer))
|
mr_texture_index = add_image_texture(mr_byte_offset, len(mr_buffer))
|
||||||
@ -599,7 +591,7 @@ def mesh_item_to_glb_bytes(mesh, index, metadata=None):
|
|||||||
assert a.ndim == 3 and a.shape[-1] == 3, f"{attr} must be (B, H, W, 3), got {tuple(t.shape)}"
|
assert a.ndim == 3 and a.shape[-1] == 3, f"{attr} must be (B, H, W, 3), got {tuple(t.shape)}"
|
||||||
return Image.fromarray(a, mode="RGB")
|
return Image.fromarray(a, mode="RGB")
|
||||||
|
|
||||||
tangents_b = getattr(mesh, "tangents", None)
|
tangents_b = mesh.tangents
|
||||||
tangents_i = tangents_b[index, :vertices_i.shape[0]] if tangents_b is not None else None
|
tangents_i = tangents_b[index, :vertices_i.shape[0]] if tangents_b is not None else None
|
||||||
return save_glb(
|
return save_glb(
|
||||||
vertices_i, faces_i, None, metadata,
|
vertices_i, faces_i, None, metadata,
|
||||||
@ -607,12 +599,12 @@ def mesh_item_to_glb_bytes(mesh, index, metadata=None):
|
|||||||
vertex_colors=v_colors,
|
vertex_colors=v_colors,
|
||||||
texture_image=_img("texture"),
|
texture_image=_img("texture"),
|
||||||
metallic_roughness_image=_img("metallic_roughness"),
|
metallic_roughness_image=_img("metallic_roughness"),
|
||||||
unlit=getattr(mesh, "unlit", False),
|
unlit=mesh.unlit,
|
||||||
normals=normals_i,
|
normals=normals_i,
|
||||||
normal_map_image=_img("normal_map"),
|
normal_map_image=_img("normal_map"),
|
||||||
tangents=tangents_i,
|
tangents=tangents_i,
|
||||||
occlusion_in_mr=getattr(mesh, "occlusion_in_mr", False),
|
occlusion_in_mr=mesh.occlusion_in_mr,
|
||||||
material=getattr(mesh, "material", None),
|
material=mesh.material,
|
||||||
emissive_image=_img("emissive"),
|
emissive_image=_img("emissive"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -810,7 +802,7 @@ class RotateMesh(IO.ComfyNode):
|
|||||||
else:
|
else:
|
||||||
out.vertices = rotate(mesh.vertices)
|
out.vertices = rotate(mesh.vertices)
|
||||||
# Normals are directions; rotate them too (R is orthogonal) so they stay valid.
|
# Normals are directions; rotate them too (R is orthogonal) so they stay valid.
|
||||||
nrm = getattr(mesh, "normals", None)
|
nrm = mesh.normals
|
||||||
if nrm is not None:
|
if nrm is not None:
|
||||||
out.normals = [rotate(n) for n in nrm] if isinstance(nrm, list) else rotate(nrm)
|
out.normals = [rotate(n) for n in nrm] if isinstance(nrm, list) else rotate(nrm)
|
||||||
return IO.NodeOutput(out)
|
return IO.NodeOutput(out)
|
||||||
@ -860,7 +852,7 @@ class MeshSmoothNormals(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(out)
|
return IO.NodeOutput(out)
|
||||||
|
|
||||||
# Crease split changes per-item vertex counts -> rebuild as a variable-size batch.
|
# Crease split changes per-item vertex counts -> rebuild as a variable-size batch.
|
||||||
tangents_b = getattr(mesh, "tangents", None)
|
tangents_b = mesh.tangents
|
||||||
v_list, f_list, n_list = [], [], []
|
v_list, f_list, n_list = [], [], []
|
||||||
c_list = [] if mesh.vertex_colors is not None else None
|
c_list = [] if mesh.vertex_colors is not None else None
|
||||||
u_list = [] if mesh.uvs is not None else None
|
u_list = [] if mesh.uvs is not None else None
|
||||||
@ -889,11 +881,11 @@ class MeshSmoothNormals(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(mesh)
|
return IO.NodeOutput(mesh)
|
||||||
out = pack_variable_mesh_batch(
|
out = pack_variable_mesh_batch(
|
||||||
v_list, f_list, colors=c_list, uvs=u_list,
|
v_list, f_list, colors=c_list, uvs=u_list,
|
||||||
texture=mesh.texture, unlit=getattr(mesh, "unlit", False),
|
texture=mesh.texture, unlit=mesh.unlit,
|
||||||
normals=n_list, metallic_roughness=getattr(mesh, "metallic_roughness", None),
|
normals=n_list, metallic_roughness=mesh.metallic_roughness,
|
||||||
tangents=t_list, normal_map=getattr(mesh, "normal_map", None),
|
tangents=t_list, normal_map=mesh.normal_map,
|
||||||
occlusion_in_mr=getattr(mesh, "occlusion_in_mr", False),
|
occlusion_in_mr=mesh.occlusion_in_mr,
|
||||||
material=getattr(mesh, "material", None), emissive=getattr(mesh, "emissive", None))
|
material=mesh.material, emissive=mesh.emissive)
|
||||||
return IO.NodeOutput(out)
|
return IO.NodeOutput(out)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -8,9 +8,7 @@ from server import PromptServer
|
|||||||
import comfy.latent_formats
|
import comfy.latent_formats
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from PIL import Image
|
|
||||||
import logging
|
import logging
|
||||||
import numpy as np
|
|
||||||
import math
|
import math
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -425,7 +423,6 @@ def _dinov3_encode(model, image_bchw, image_size, want_patches=False):
|
|||||||
mean = torch.tensor(model.image_mean or [0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
|
mean = torch.tensor(model.image_mean or [0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
|
||||||
std = torch.tensor(model.image_std or [0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
|
std = torch.tensor(model.image_std or [0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
|
||||||
img_t = (img_t - mean) / std
|
img_t = (img_t - mean) / std
|
||||||
model_internal.image_size = image_size
|
|
||||||
tokens = model_internal(img_t, skip_norm_elementwise=True)[0]
|
tokens = model_internal(img_t, skip_norm_elementwise=True)[0]
|
||||||
if not want_patches:
|
if not want_patches:
|
||||||
return tokens
|
return tokens
|
||||||
@ -434,20 +431,6 @@ def _dinov3_encode(model, image_bchw, image_size, want_patches=False):
|
|||||||
return {"tokens": tokens[:, :1 + n_reg], "patches_2d": _dinov3_patches_to_2d(tokens, image_size)}
|
return {"tokens": tokens[:, :1 + n_reg], "patches_2d": _dinov3_patches_to_2d(tokens, image_size)}
|
||||||
|
|
||||||
|
|
||||||
def run_conditioning(model, cropped_pil_img):
|
|
||||||
device = comfy.model_management.intermediate_device()
|
|
||||||
|
|
||||||
img_np = np.array(cropped_pil_img).astype(np.float32) / 255.0
|
|
||||||
image_bchw = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).contiguous()
|
|
||||||
|
|
||||||
cond_512 = _dinov3_encode(model, image_bchw, 512)
|
|
||||||
cond_1024 = _dinov3_encode(model, image_bchw, 1024)
|
|
||||||
return {
|
|
||||||
"cond_512": cond_512.to(device),
|
|
||||||
"neg_cond": torch.zeros_like(cond_512).to(device),
|
|
||||||
"cond_1024": cond_1024.to(device),
|
|
||||||
}
|
|
||||||
|
|
||||||
class Trellis2Conditioning(IO.ComfyNode):
|
class Trellis2Conditioning(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -467,124 +450,10 @@ class Trellis2Conditioning(IO.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip_vision_model, image, mask) -> IO.NodeOutput:
|
def execute(cls, clip_vision_model, image, mask) -> IO.NodeOutput:
|
||||||
# Normalize to batched form so per-image conditioning loop below is uniform.
|
out_device = comfy.model_management.intermediate_device()
|
||||||
if image.ndim == 3:
|
cond = _dino_condition_batch(clip_vision_model, image, mask, out_device,
|
||||||
image = image.unsqueeze(0)
|
pad_factor=1.0, mask_threshold=35.0 / 255.0, border_shave=4)
|
||||||
elif image.ndim == 4:
|
cond_512_batched, cond_1024_batched = cond["global_512"], cond["global_1024"]
|
||||||
if image.shape[1] in [1, 3, 4] and image.shape[-1] not in [1, 3, 4]:
|
|
||||||
image = image.permute(0, 2, 3, 1)
|
|
||||||
|
|
||||||
# normalize mask to standard [B, H, W] (handling 2D, 3D, and 4D variants)
|
|
||||||
if mask.ndim == 4:
|
|
||||||
if mask.shape[1] == 1:
|
|
||||||
mask = mask.squeeze(1)
|
|
||||||
elif mask.shape[-1] == 1:
|
|
||||||
mask = mask.squeeze(-1)
|
|
||||||
else:
|
|
||||||
mask = mask[:, :, :, 0] # take first channel as fallback
|
|
||||||
|
|
||||||
if mask.ndim == 3:
|
|
||||||
if mask.shape[-1] == 1:
|
|
||||||
mask = mask.squeeze(-1).unsqueeze(0)
|
|
||||||
elif mask.ndim == 2:
|
|
||||||
mask = mask.unsqueeze(0)
|
|
||||||
|
|
||||||
batch_size = image.shape[0]
|
|
||||||
if mask.shape[0] == 1 and batch_size > 1:
|
|
||||||
mask = mask.expand(batch_size, -1, -1)
|
|
||||||
elif mask.shape[0] != batch_size:
|
|
||||||
raise ValueError(f"Trellis2Conditioning mask batch {mask.shape[0]} does not match image batch {batch_size}")
|
|
||||||
|
|
||||||
cond_512_list = []
|
|
||||||
cond_1024_list = []
|
|
||||||
|
|
||||||
for b in range(batch_size):
|
|
||||||
item_image = image[b]
|
|
||||||
item_mask = mask[b] if mask.size(0) > 1 else mask[0]
|
|
||||||
|
|
||||||
img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
|
||||||
mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
|
||||||
|
|
||||||
# Ensure img_np is either 2D (grayscale) or 3D (RGB/RGBA)
|
|
||||||
if img_np.ndim == 3 and img_np.shape[-1] == 1:
|
|
||||||
img_np = img_np.squeeze(-1)
|
|
||||||
|
|
||||||
mask_np = mask_np.squeeze()
|
|
||||||
|
|
||||||
# detect inverted mask
|
|
||||||
border_pixels = np.concatenate([
|
|
||||||
mask_np[0, :], mask_np[-1, :], mask_np[:, 0], mask_np[:, -1]
|
|
||||||
])
|
|
||||||
if np.mean(border_pixels) > 127:
|
|
||||||
mask_np = 255 - mask_np
|
|
||||||
|
|
||||||
mask_np[mask_np < 35] = 0
|
|
||||||
|
|
||||||
border_shave = 4
|
|
||||||
mask_np[:border_shave, :] = 0
|
|
||||||
mask_np[-border_shave:, :] = 0
|
|
||||||
mask_np[:, :border_shave] = 0
|
|
||||||
mask_np[:, -border_shave:] = 0
|
|
||||||
|
|
||||||
pil_img = Image.fromarray(img_np)
|
|
||||||
pil_mask = Image.fromarray(mask_np)
|
|
||||||
|
|
||||||
max_size = max(pil_img.size)
|
|
||||||
scale = min(1.0, 1024 / max_size)
|
|
||||||
if scale < 1.0:
|
|
||||||
new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale)
|
|
||||||
pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
|
||||||
pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST)
|
|
||||||
|
|
||||||
rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8)
|
|
||||||
rgba_np[:, :, :3] = np.array(pil_img.convert("RGB"))
|
|
||||||
rgba_np[:, :, 3] = np.array(pil_mask)
|
|
||||||
|
|
||||||
alpha = rgba_np[:, :, 3]
|
|
||||||
bbox_coords = np.argwhere(alpha > 0.8 * 255)
|
|
||||||
|
|
||||||
if len(bbox_coords) > 0:
|
|
||||||
y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1])
|
|
||||||
y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1])
|
|
||||||
|
|
||||||
center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0
|
|
||||||
size = max(y_max - y_min, x_max - x_min)
|
|
||||||
|
|
||||||
crop_x1 = int(center_x - size // 2)
|
|
||||||
crop_y1 = int(center_y - size // 2)
|
|
||||||
crop_x2 = int(center_x + size // 2)
|
|
||||||
crop_y2 = int(center_y + size // 2)
|
|
||||||
|
|
||||||
rgba_pil = Image.fromarray(rgba_np)
|
|
||||||
cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2))
|
|
||||||
cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0
|
|
||||||
else:
|
|
||||||
logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.")
|
|
||||||
cropped_np = rgba_np.astype(np.float32) / 255.0
|
|
||||||
|
|
||||||
bg_rgb = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
|
||||||
|
|
||||||
fg = cropped_np[:, :, :3]
|
|
||||||
alpha_float = cropped_np[:, :, 3:4]
|
|
||||||
composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float)
|
|
||||||
|
|
||||||
# Keep the image as 4-channel RGBA to force TRELLIS to bypass its internal background remover
|
|
||||||
rgb_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
|
|
||||||
alpha_uint8 = (alpha_float.squeeze(-1) * 255.0).round().clip(0, 255).astype(np.uint8)
|
|
||||||
|
|
||||||
rgba_composite = np.zeros((cropped_np.shape[0], cropped_np.shape[1], 4), dtype=np.uint8)
|
|
||||||
rgba_composite[:, :, :3] = rgb_uint8
|
|
||||||
rgba_composite[:, :, 3] = alpha_uint8
|
|
||||||
|
|
||||||
cropped_pil = Image.fromarray(rgba_composite, mode="RGBA")
|
|
||||||
|
|
||||||
# Convert to RGB to ensure the CLIP/DINO model receives a 3-channel image
|
|
||||||
item_conditioning = run_conditioning(clip_vision_model, cropped_pil.convert("RGB"))
|
|
||||||
cond_512_list.append(item_conditioning["cond_512"])
|
|
||||||
cond_1024_list.append(item_conditioning["cond_1024"])
|
|
||||||
|
|
||||||
cond_512_batched = torch.cat(cond_512_list, dim=0)
|
|
||||||
cond_1024_batched = torch.cat(cond_1024_list, dim=0)
|
|
||||||
neg_cond_batched = torch.zeros_like(cond_512_batched)
|
neg_cond_batched = torch.zeros_like(cond_512_batched)
|
||||||
neg_embeds_batched = torch.zeros_like(cond_1024_batched)
|
neg_embeds_batched = torch.zeros_like(cond_1024_batched)
|
||||||
|
|
||||||
@ -781,12 +650,11 @@ def _dinov3_patches_to_2d(tokens, image_size, patch_size=16):
|
|||||||
return patches.transpose(1, 2).reshape(tokens.shape[0], -1, h_p, w_p).contiguous()
|
return patches.transpose(1, 2).reshape(tokens.shape[0], -1, h_p, w_p).contiguous()
|
||||||
|
|
||||||
|
|
||||||
def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
|
def _crop_image_with_mask(item_image, item_mask, max_image_size=1024, pad_factor=1.1,
|
||||||
img = item_image.permute(2, 0, 1).unsqueeze(0).cpu().float()
|
mask_threshold=0.0, border_shave=0):
|
||||||
mask = item_mask.unsqueeze(0).unsqueeze(0).cpu().float()
|
img = item_image[..., :3] if item_image.shape[-1] >= 3 else item_image[..., :1].repeat(1, 1, 3)
|
||||||
# Upstream went float→PIL uint8 implicitly; match that to keep composite bit-exact.
|
img = img.permute(2, 0, 1).unsqueeze(0).cpu().float().clamp(0, 1)
|
||||||
img = (img.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0
|
mask = item_mask.unsqueeze(0).unsqueeze(0).cpu().float().clamp(0, 1)
|
||||||
mask = (mask.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0
|
|
||||||
|
|
||||||
# Detect & correct an inverted mask
|
# Detect & correct an inverted mask
|
||||||
m2d = mask[0, 0]
|
m2d = mask[0, 0]
|
||||||
@ -794,6 +662,15 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
|
|||||||
if float(border.mean()) > 0.5:
|
if float(border.mean()) > 0.5:
|
||||||
mask = 1.0 - mask
|
mask = 1.0 - mask
|
||||||
|
|
||||||
|
if mask_threshold > 0.0:
|
||||||
|
mask = torch.where(mask < mask_threshold, torch.zeros_like(mask), mask)
|
||||||
|
if border_shave > 0:
|
||||||
|
bs = border_shave
|
||||||
|
mask[..., :bs, :] = 0
|
||||||
|
mask[..., -bs:, :] = 0
|
||||||
|
mask[..., :, :bs] = 0
|
||||||
|
mask[..., :, -bs:] = 0
|
||||||
|
|
||||||
H, W = img.shape[-2:]
|
H, W = img.shape[-2:]
|
||||||
if max(H, W) > max_image_size:
|
if max(H, W) > max_image_size:
|
||||||
scale = max_image_size / max(H, W)
|
scale = max_image_size / max(H, W)
|
||||||
@ -809,14 +686,14 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
|
|||||||
y_min, x_min = fg_pixels.min(dim=0).values.tolist()
|
y_min, x_min = fg_pixels.min(dim=0).values.tolist()
|
||||||
y_max, x_max = fg_pixels.max(dim=0).values.tolist()
|
y_max, x_max = fg_pixels.max(dim=0).values.tolist()
|
||||||
center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0
|
center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0
|
||||||
size = int(max(y_max - y_min, x_max - x_min) * 1.1)
|
size = int(max(y_max - y_min, x_max - x_min) * pad_factor)
|
||||||
half = size // 2
|
half = size // 2
|
||||||
crop_x1 = int(center_x - half)
|
crop_x1 = int(center_x - half)
|
||||||
crop_y1 = int(center_y - half)
|
crop_y1 = int(center_y - half)
|
||||||
crop_x2 = crop_x1 + 2 * half
|
crop_x2 = crop_x1 + 2 * half
|
||||||
crop_y2 = crop_y1 + 2 * half
|
crop_y2 = crop_y1 + 2 * half
|
||||||
else:
|
else:
|
||||||
logging.warning("Mask for the image is empty. Pixal3D requires a clean foreground mask.")
|
logging.warning("Mask for the image is empty; a clean foreground mask is required for best quality.")
|
||||||
crop_x1, crop_y1, crop_x2, crop_y2 = 0, 0, W, H
|
crop_x1, crop_y1, crop_x2, crop_y2 = 0, 0, W, H
|
||||||
crop_bbox = (crop_x1, crop_y1, crop_x2, crop_y2)
|
crop_bbox = (crop_x1, crop_y1, crop_x2, crop_y2)
|
||||||
|
|
||||||
@ -836,9 +713,77 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
|
|||||||
cropped_mask = mask[..., crop_y1:crop_y2, crop_x1:crop_x2]
|
cropped_mask = mask[..., crop_y1:crop_y2, crop_x1:crop_x2]
|
||||||
|
|
||||||
composite = (cropped_img * cropped_mask).clamp(0, 1)
|
composite = (cropped_img * cropped_mask).clamp(0, 1)
|
||||||
composite = (composite * 255.0).round().clamp(0, 255).to(torch.uint8).float() / 255.0
|
|
||||||
return composite, crop_bbox, scene_size
|
return composite, crop_bbox, scene_size
|
||||||
|
|
||||||
|
|
||||||
|
def _dino_condition_batch(clip_vision_model, image, mask, out_device, *,
|
||||||
|
pad_factor, mask_threshold=0.0, border_shave=0, want_patches=False):
|
||||||
|
"""Normalize image/mask to a batch, then per item: masked square crop + DINOv3 encode at
|
||||||
|
512 and 1024. Returns batched global tokens; with want_patches also the 2D patch grids and
|
||||||
|
the per-item composites / crop bboxes / scene sizes that the Pixal3D NAF+projection path needs."""
|
||||||
|
# Normalize to batched form so the per-image loop is uniform.
|
||||||
|
if image.ndim == 3:
|
||||||
|
image = image.unsqueeze(0)
|
||||||
|
elif image.ndim == 4:
|
||||||
|
if image.shape[1] in [1, 3, 4] and image.shape[-1] not in [1, 3, 4]:
|
||||||
|
image = image.permute(0, 2, 3, 1)
|
||||||
|
|
||||||
|
if mask.ndim == 4:
|
||||||
|
if mask.shape[1] == 1:
|
||||||
|
mask = mask.squeeze(1)
|
||||||
|
elif mask.shape[-1] == 1:
|
||||||
|
mask = mask.squeeze(-1)
|
||||||
|
else:
|
||||||
|
mask = mask[:, :, :, 0] # take first channel as fallback
|
||||||
|
if mask.ndim == 3:
|
||||||
|
if mask.shape[-1] == 1:
|
||||||
|
mask = mask.squeeze(-1).unsqueeze(0)
|
||||||
|
elif mask.ndim == 2:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
|
||||||
|
batch_size = image.shape[0]
|
||||||
|
if mask.shape[0] == 1 and batch_size > 1:
|
||||||
|
mask = mask.expand(batch_size, -1, -1)
|
||||||
|
elif mask.shape[0] != batch_size:
|
||||||
|
raise ValueError(f"Conditioning mask batch {mask.shape[0]} does not match image batch {batch_size}")
|
||||||
|
|
||||||
|
cond_512_list, cond_1024_list = [], []
|
||||||
|
patches_512_list, patches_1024_list = [], []
|
||||||
|
composite_list, crop_bbox_list, scene_size_list = [], [], []
|
||||||
|
for b in range(batch_size):
|
||||||
|
item_image = image[b]
|
||||||
|
item_mask = mask[b] if mask.size(0) > 1 else mask[0]
|
||||||
|
composite, crop_bbox, scene_size = _crop_image_with_mask(
|
||||||
|
item_image, item_mask, max_image_size=1024, pad_factor=pad_factor,
|
||||||
|
mask_threshold=mask_threshold, border_shave=border_shave)
|
||||||
|
c512 = _dinov3_encode(clip_vision_model, composite, 512, want_patches=want_patches)
|
||||||
|
c1024 = _dinov3_encode(clip_vision_model, composite, 1024, want_patches=want_patches)
|
||||||
|
if want_patches:
|
||||||
|
cond_512_list.append(c512["tokens"].to(out_device))
|
||||||
|
cond_1024_list.append(c1024["tokens"].to(out_device))
|
||||||
|
patches_512_list.append(c512["patches_2d"].to(out_device))
|
||||||
|
patches_1024_list.append(c1024["patches_2d"].to(out_device))
|
||||||
|
composite_list.append(composite)
|
||||||
|
crop_bbox_list.append(crop_bbox)
|
||||||
|
scene_size_list.append(scene_size)
|
||||||
|
else:
|
||||||
|
cond_512_list.append(c512.to(out_device))
|
||||||
|
cond_1024_list.append(c1024.to(out_device))
|
||||||
|
|
||||||
|
out = {
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"global_512": torch.cat(cond_512_list, dim=0),
|
||||||
|
"global_1024": torch.cat(cond_1024_list, dim=0),
|
||||||
|
}
|
||||||
|
if want_patches:
|
||||||
|
out["patches_512"] = torch.cat(patches_512_list, dim=0)
|
||||||
|
out["patches_1024"] = torch.cat(patches_1024_list, dim=0)
|
||||||
|
out["composites"] = composite_list
|
||||||
|
out["crop_bboxes"] = crop_bbox_list
|
||||||
|
out["scene_sizes"] = scene_size_list
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Pixal3DConditioning(IO.ComfyNode):
|
class Pixal3DConditioning(IO.ComfyNode):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -867,45 +812,15 @@ class Pixal3DConditioning(IO.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, clip_vision_model, image, mask, camera_angle_x) -> IO.NodeOutput:
|
def execute(cls, clip_vision_model, image, mask, camera_angle_x) -> IO.NodeOutput:
|
||||||
naf_model = clip_vision_model.naf
|
naf_model = clip_vision_model.naf
|
||||||
if image.ndim == 3:
|
out_device = comfy.model_management.intermediate_device()
|
||||||
image = image.unsqueeze(0)
|
compute_device = comfy.model_management.get_torch_device()
|
||||||
if mask.ndim == 2:
|
|
||||||
mask = mask.unsqueeze(0)
|
|
||||||
batch_size = image.shape[0]
|
|
||||||
if mask.shape[0] == 1 and batch_size > 1:
|
|
||||||
mask = mask.expand(batch_size, -1, -1)
|
|
||||||
elif mask.shape[0] != batch_size:
|
|
||||||
raise ValueError(f"Pixal3DConditioning mask batch {mask.shape[0]} != image batch {batch_size}")
|
|
||||||
|
|
||||||
device = comfy.model_management.intermediate_device()
|
cond = _dino_condition_batch(clip_vision_model, image, mask, out_device, pad_factor=1.1, want_patches=True)
|
||||||
|
batch_size = cond["batch_size"]
|
||||||
cond_512_list, cond_1024_list = [], []
|
global_512, global_1024 = cond["global_512"], cond["global_1024"]
|
||||||
patches_512_list, patches_1024_list = [], []
|
fm_512_dino, fm_1024_dino = cond["patches_512"], cond["patches_1024"]
|
||||||
composite_list = []
|
composite_list = cond["composites"]
|
||||||
crop_bbox_list, scene_size_list = [], []
|
crop_bbox_list, scene_size_list = cond["crop_bboxes"], cond["scene_sizes"]
|
||||||
|
|
||||||
torch_device = comfy.model_management.get_torch_device()
|
|
||||||
for b in range(batch_size):
|
|
||||||
item_image = image[b]
|
|
||||||
item_mask = mask[b] if mask.size(0) > 1 else mask[0]
|
|
||||||
composite, crop_bbox, scene_size = _crop_image_with_mask(
|
|
||||||
item_image, item_mask, max_image_size=1024)
|
|
||||||
crop_bbox_list.append(crop_bbox)
|
|
||||||
scene_size_list.append(scene_size)
|
|
||||||
composite_list.append(composite)
|
|
||||||
|
|
||||||
cond_512 = _dinov3_encode(clip_vision_model, composite, 512, want_patches=True)
|
|
||||||
cond_1024 = _dinov3_encode(clip_vision_model, composite, 1024, want_patches=True)
|
|
||||||
cond_512_list.append(cond_512["tokens"].to(device))
|
|
||||||
cond_1024_list.append(cond_1024["tokens"].to(device))
|
|
||||||
patches_512_list.append(cond_512["patches_2d"].to(device))
|
|
||||||
patches_1024_list.append(cond_1024["patches_2d"].to(device))
|
|
||||||
|
|
||||||
global_512 = torch.cat(cond_512_list, dim=0)
|
|
||||||
global_1024 = torch.cat(cond_1024_list, dim=0)
|
|
||||||
|
|
||||||
fm_512_dino = torch.cat(patches_512_list, dim=0)
|
|
||||||
fm_1024_dino = torch.cat(patches_1024_list, dim=0)
|
|
||||||
|
|
||||||
# The LR DINO grid AND the NAF HR grid are sampled separately
|
# The LR DINO grid AND the NAF HR grid are sampled separately
|
||||||
# NAF targets per stage: shape_512=512, shape_1024=512, tex_1024=1024.
|
# NAF targets per stage: shape_512=512, shape_1024=512, tex_1024=1024.
|
||||||
@ -914,15 +829,13 @@ class Pixal3DConditioning(IO.ComfyNode):
|
|||||||
return None
|
return None
|
||||||
comfy.model_management.load_model_gpu(naf_model)
|
comfy.model_management.load_model_gpu(naf_model)
|
||||||
inner = naf_model.model
|
inner = naf_model.model
|
||||||
target_dtype = comfy.model_management.text_encoder_dtype(torch_device)
|
model_dtype = next(inner.parameters()).dtype # set at load time (see clip_vision NAF)
|
||||||
if next(inner.parameters()).dtype != target_dtype:
|
|
||||||
inner.to(dtype=target_dtype)
|
|
||||||
hrs = []
|
hrs = []
|
||||||
for i, c in enumerate(composites):
|
for i, c in enumerate(composites):
|
||||||
img_i = comfy.utils.common_upscale(c, image_size, image_size, "lanczos", "disabled")\
|
img_i = comfy.utils.common_upscale(c, image_size, image_size, "lanczos", "disabled")\
|
||||||
.to(torch_device).to(target_dtype)
|
.to(compute_device).to(model_dtype)
|
||||||
lr_i = lr_feat[i:i + 1].to(torch_device).to(target_dtype)
|
lr_i = lr_feat[i:i + 1].to(compute_device).to(model_dtype)
|
||||||
hr_i = inner(img_i, lr_i, naf_target, output_device=device)
|
hr_i = inner(img_i, lr_i, naf_target, output_device=out_device)
|
||||||
hrs.append(hr_i)
|
hrs.append(hr_i)
|
||||||
return torch.cat(hrs, dim=0)
|
return torch.cat(hrs, dim=0)
|
||||||
|
|
||||||
@ -934,10 +847,10 @@ class Pixal3DConditioning(IO.ComfyNode):
|
|||||||
# FOV widget is in degrees for UX; trig + downstream projection expect radians.
|
# FOV widget is in degrees for UX; trig + downstream projection expect radians.
|
||||||
camera_angle_x = math.radians(float(camera_angle_x))
|
camera_angle_x = math.radians(float(camera_angle_x))
|
||||||
distance = 0.5 / math.tan(camera_angle_x / 2.0)
|
distance = 0.5 / math.tan(camera_angle_x / 2.0)
|
||||||
cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=device, dtype=torch.float32)
|
cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=out_device, dtype=torch.float32)
|
||||||
dist_t = torch.tensor([distance] * batch_size, device=device, dtype=torch.float32)
|
dist_t = torch.tensor([distance] * batch_size, device=out_device, dtype=torch.float32)
|
||||||
scale_t = torch.ones(batch_size, device=device, dtype=torch.float32)
|
scale_t = torch.ones(batch_size, device=out_device, dtype=torch.float32)
|
||||||
T = build_proj_transform_matrix(dist_t, batch_size, device=device, dtype=torch.float32)
|
T = build_proj_transform_matrix(dist_t, batch_size, device=out_device, dtype=torch.float32)
|
||||||
|
|
||||||
proj_pack = {
|
proj_pack = {
|
||||||
"stages": {
|
"stages": {
|
||||||
@ -958,7 +871,7 @@ class Pixal3DConditioning(IO.ComfyNode):
|
|||||||
# global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024.
|
# global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024.
|
||||||
ss_proj_feats = compute_stage_proj_feats(
|
ss_proj_feats = compute_stage_proj_feats(
|
||||||
proj_pack, "ss", dense_grid_resolution=16, batch_size=batch_size,
|
proj_pack, "ss", dense_grid_resolution=16, batch_size=batch_size,
|
||||||
device=torch_device,
|
device=compute_device,
|
||||||
)
|
)
|
||||||
neg_global = torch.zeros_like(global_512)
|
neg_global = torch.zeros_like(global_512)
|
||||||
neg_embeds = torch.zeros_like(global_1024)
|
neg_embeds = torch.zeros_like(global_1024)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user