mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Miscellanous cleanup
This commit is contained in:
parent
419e726061
commit
d635cc412d
@ -163,6 +163,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
naf_sd = {k[len("naf."):]: sd.pop(k) for k in naf_keys}
|
||||
naf = NAF().eval()
|
||||
naf.load_state_dict(naf_sd, strict=False)
|
||||
naf.to(comfy.model_management.text_encoder_dtype(clip.load_device))
|
||||
clip.naf = comfy.model_patcher.CoreModelPatcher(naf, load_device=clip.load_device, offload_device=comfy.model_management.text_encoder_offload_device())
|
||||
return clip
|
||||
|
||||
|
||||
@ -55,10 +55,12 @@ class SparseFeedForwardNet(nn.Module):
|
||||
def forward(self, x: VarLenTensor) -> VarLenTensor:
|
||||
return self.mlp(x)
|
||||
|
||||
class SparseMultiHeadRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, heads: int, device, dtype):
|
||||
class MultiHeadRMSNorm(nn.Module):
|
||||
# Per-head qk-norm for both sparse (VarLenTensor) and dense inputs. gamma is [heads, dim]
|
||||
# (per-head), so it's a broadcast multiply rather than F.rms_norm's 1-D weight
|
||||
def __init__(self, dim: int, heads: int, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
|
||||
self.gamma = nn.Parameter(torch.empty(heads, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
|
||||
if isinstance(x, VarLenTensor):
|
||||
@ -147,8 +149,8 @@ class SparseMultiHeadAttention(nn.Module):
|
||||
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||
|
||||
if self.qk_rms_norm:
|
||||
self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||
self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||
|
||||
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
|
||||
|
||||
@ -307,7 +309,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
|
||||
operations.Linear(channels, 6 * channels, bias=True, device=device, dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
|
||||
self.modulation = nn.Parameter(torch.empty(6 * channels, device=device, dtype=dtype))
|
||||
|
||||
def _forward(self, x: SparseTensor, mod: torch.Tensor, context, transformer_options=None) -> SparseTensor:
|
||||
if self.share_mod:
|
||||
@ -444,24 +446,6 @@ class FeedForwardNet(nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.mlp(x)
|
||||
|
||||
# class MultiHeadRMSNorm(nn.Module):
|
||||
# def __init__(self, dim: int, heads: int, device=None, dtype=None):
|
||||
# super().__init__()
|
||||
# self.scale = dim ** 0.5
|
||||
# self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
|
||||
|
||||
# def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
|
||||
|
||||
class MultiHeadRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, heads: int, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (F.rms_norm(x.float(), (x.shape[-1],)) * self.gamma).to(x.dtype)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -580,7 +564,7 @@ class ModulatedTransformerCrossBlock(nn.Module):
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
|
||||
else:
|
||||
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
|
||||
self.modulation = nn.Parameter(torch.empty(6 * channels, device=device, dtype=dtype))
|
||||
|
||||
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context,
|
||||
phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor:
|
||||
@ -631,7 +615,7 @@ class SparseStructureFlowModel(nn.Module):
|
||||
proj_in_channels: Optional[int] = None,
|
||||
operations=None,
|
||||
device = None,
|
||||
dtype = torch.float32,
|
||||
dtype = None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -918,17 +918,14 @@ def flexible_dual_grid_to_mesh(
|
||||
):
|
||||
|
||||
device = coords.device
|
||||
if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset") \
|
||||
or flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset.device != device:
|
||||
flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset = torch.tensor([
|
||||
[[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis
|
||||
[[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis
|
||||
[[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis
|
||||
], dtype=torch.int, device=device).unsqueeze(0)
|
||||
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1") or flexible_dual_grid_to_mesh.quad_split_1.device != device:
|
||||
flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
|
||||
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2") or flexible_dual_grid_to_mesh.quad_split_2.device != device:
|
||||
flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
|
||||
# Small constant index tables — built per call (stateless), not cached on the function.
|
||||
edge_neighbor_voxel_offset = torch.tensor([
|
||||
[[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis
|
||||
[[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis
|
||||
[[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis
|
||||
], dtype=torch.int, device=device).unsqueeze(0)
|
||||
quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device)
|
||||
quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device)
|
||||
|
||||
aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
|
||||
|
||||
@ -954,7 +951,7 @@ def flexible_dual_grid_to_mesh(
|
||||
|
||||
# Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3]
|
||||
n_idx, axis_idx = intersected_flag.nonzero(as_tuple=True) # (M,), (M,)
|
||||
offsets_per_axis = flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset[0] # (3, 4, 3)
|
||||
offsets_per_axis = edge_neighbor_voxel_offset[0] # (3, 4, 3)
|
||||
connected_voxel = coords[n_idx].unsqueeze(1) + offsets_per_axis[axis_idx] # (M, 4, 3)
|
||||
M = connected_voxel.shape[0]
|
||||
# flatten connected voxel coords and lookup. In-place to avoid extra memory allocation.
|
||||
@ -973,12 +970,12 @@ def flexible_dual_grid_to_mesh(
|
||||
mesh_vertices.add_(dual_vertices).mul_(voxel_size).add_(aabb[0].reshape(1, 3))
|
||||
if split_weight is None:
|
||||
# if split 1
|
||||
atempt_triangles_0 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1]
|
||||
atempt_triangles_0 = quad_indices[:, quad_split_1]
|
||||
normals0 = torch.cross(mesh_vertices[atempt_triangles_0[:, 1]] - mesh_vertices[atempt_triangles_0[:, 0]], mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 0]])
|
||||
normals1 = torch.cross(mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 1]], mesh_vertices[atempt_triangles_0[:, 3]] - mesh_vertices[atempt_triangles_0[:, 1]])
|
||||
align0 = (normals0 * normals1).sum(dim=1, keepdim=True).abs()
|
||||
# if split 2
|
||||
atempt_triangles_1 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2]
|
||||
atempt_triangles_1 = quad_indices[:, quad_split_2]
|
||||
normals0 = torch.cross(mesh_vertices[atempt_triangles_1[:, 1]] - mesh_vertices[atempt_triangles_1[:, 0]], mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 0]])
|
||||
normals1 = torch.cross(mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 1]], mesh_vertices[atempt_triangles_1[:, 3]] - mesh_vertices[atempt_triangles_1[:, 1]])
|
||||
align1 = (normals0 * normals1).sum(dim=1, keepdim=True).abs()
|
||||
@ -990,8 +987,8 @@ def flexible_dual_grid_to_mesh(
|
||||
split_weight_ws_13 = split_weight_ws[:, 1] * split_weight_ws[:, 3]
|
||||
mesh_triangles = torch.where(
|
||||
split_weight_ws_02 > split_weight_ws_13,
|
||||
quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1],
|
||||
quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2]
|
||||
quad_indices[:, quad_split_1],
|
||||
quad_indices[:, quad_split_2]
|
||||
).reshape(-1, 3)
|
||||
|
||||
return mesh_vertices, mesh_triangles
|
||||
|
||||
@ -39,7 +39,12 @@ class MESH:
|
||||
vertex_counts: torch.Tensor | None = None,
|
||||
face_counts: torch.Tensor | None = None,
|
||||
unlit: bool = False,
|
||||
normals: torch.Tensor | None = None):
|
||||
normals: torch.Tensor | None = None,
|
||||
tangents: torch.Tensor | None = None,
|
||||
normal_map: torch.Tensor | None = None,
|
||||
occlusion_in_mr: bool = False,
|
||||
material: dict | None = None,
|
||||
emissive: torch.Tensor | None = None):
|
||||
|
||||
assert (vertex_counts is None) == (face_counts is None), \
|
||||
"vertex_counts and face_counts must be provided together (both or neither)"
|
||||
@ -59,6 +64,13 @@ class MESH:
|
||||
self.face_counts = face_counts
|
||||
# Render flat / emissive (no scene lighting) when saved, e.g. for gaussian-splat-derived meshes.
|
||||
self.unlit = unlit
|
||||
# Extra maps / material overrides attached by bake, normal/AO, and SetMeshMaterial nodes;
|
||||
# consumed by SaveGLB. Declared here (with defaults) so consumers read them directly.
|
||||
self.tangents = tangents # (B, N, 4) per-vertex tangents for normal mapping
|
||||
self.normal_map = normal_map # tangent-space normal map: (B, H, W, 3)
|
||||
self.occlusion_in_mr = occlusion_in_mr # True = R channel of metallic_roughness holds AO (ORM)
|
||||
self.material = material # SetMeshMaterial scalar/factor overrides
|
||||
self.emissive = emissive # emissive map: (B, H, W, 3)
|
||||
|
||||
|
||||
class File3D:
|
||||
|
||||
@ -1282,25 +1282,17 @@ def qem_simplify(
|
||||
iteration = 0
|
||||
total_collapses = 0
|
||||
|
||||
# progress bars (tqdm + optional comfy ProgressBar), best-effort
|
||||
# progress bars (tqdm + comfy ProgressBar)
|
||||
_start_faces = num_faces
|
||||
_prog_total = max(1, _start_faces - int(target_faces))
|
||||
try:
|
||||
_qtq = _tqdm(total=100, desc="QEM simplify", leave=False)
|
||||
except Exception:
|
||||
_qtq = None
|
||||
try:
|
||||
_qpbar = _comfy_utils.ProgressBar(100)
|
||||
except Exception:
|
||||
_qpbar = None
|
||||
_qtq = _tqdm(total=100, desc="QEM simplify", leave=False)
|
||||
_qpbar = _comfy_utils.ProgressBar(100)
|
||||
|
||||
def _qreport():
|
||||
pct = min(100, max(0, int(100 * (_start_faces - py_n_faces) / _prog_total)))
|
||||
if _qtq is not None:
|
||||
_qtq.n = pct
|
||||
_qtq.refresh()
|
||||
if _qpbar is not None:
|
||||
_qpbar.update_absolute(pct, 100)
|
||||
_qtq.n = pct
|
||||
_qtq.refresh()
|
||||
_qpbar.update_absolute(pct, 100)
|
||||
|
||||
while True:
|
||||
if py_n_faces <= target_faces:
|
||||
@ -1523,8 +1515,7 @@ def qem_simplify(
|
||||
break
|
||||
|
||||
_qreport()
|
||||
if _qtq is not None:
|
||||
_qtq.close()
|
||||
_qtq.close()
|
||||
|
||||
# finalize: compact verts and faces
|
||||
final_v = verts[v_alive]
|
||||
|
||||
@ -522,7 +522,6 @@ def _build_mdc_lut() -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return K, group
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _mdc_lut(device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
K, g = _build_mdc_lut()
|
||||
return K.to(device), g.to(device)
|
||||
@ -967,15 +966,11 @@ def remesh_narrow_band_dc(
|
||||
n_levels += 1
|
||||
_total_ticks = n_levels + 3 + int(smooth_iters)
|
||||
_pbar = comfy.utils.ProgressBar(_total_ticks)
|
||||
try:
|
||||
_tq = _tqdm(total=_total_ticks, desc="Remesh DC", leave=False)
|
||||
except Exception:
|
||||
_tq = None
|
||||
_tq = _tqdm(total=_total_ticks, desc="Remesh DC", leave=False)
|
||||
|
||||
def tick():
|
||||
_pbar.update(1)
|
||||
if _tq is not None:
|
||||
_tq.update(1)
|
||||
_tq.update(1)
|
||||
|
||||
# Step 1: sparse narrow-band voxel grid (coarse-to-fine)
|
||||
voxel_coords, _band_tree = _build_narrow_band_voxels(
|
||||
|
||||
@ -366,8 +366,8 @@ def lscm_chart(
|
||||
x_free = solve_least_squares(A_free, b)
|
||||
if not np.all(np.isfinite(x_free)):
|
||||
fallback_to_ortho = True
|
||||
except Exception:
|
||||
fallback_to_ortho = True
|
||||
except (sp.linalg.MatrixRankWarning, RuntimeError):
|
||||
fallback_to_ortho = True # singular / under-constrained system
|
||||
|
||||
if fallback_to_ortho:
|
||||
if pin_positions is not None and pin_positions.shape == (Vc, 2):
|
||||
|
||||
@ -74,31 +74,22 @@ def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=Non
|
||||
)
|
||||
packed_tangents[i, :tn.shape[0]] = tn
|
||||
|
||||
out = Types.MESH(packed_vertices, packed_faces,
|
||||
uvs=packed_uvs, vertex_colors=packed_colors, texture=texture,
|
||||
metallic_roughness=metallic_roughness,
|
||||
vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit,
|
||||
normals=packed_normals)
|
||||
if packed_tangents is not None:
|
||||
out.tangents = packed_tangents
|
||||
if normal_map is not None:
|
||||
out.normal_map = normal_map
|
||||
if occlusion_in_mr:
|
||||
out.occlusion_in_mr = True
|
||||
if material is not None:
|
||||
out.material = material
|
||||
if emissive is not None:
|
||||
out.emissive = emissive
|
||||
return out
|
||||
return Types.MESH(packed_vertices, packed_faces,
|
||||
uvs=packed_uvs, vertex_colors=packed_colors, texture=texture,
|
||||
metallic_roughness=metallic_roughness,
|
||||
vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit,
|
||||
normals=packed_normals, tangents=packed_tangents,
|
||||
normal_map=normal_map, occlusion_in_mr=occlusion_in_mr,
|
||||
material=material, emissive=emissive)
|
||||
|
||||
|
||||
def get_mesh_batch_item(mesh, index):
|
||||
# Returns (vertices, faces, colors, uvs) for batch index, slicing to real lengths
|
||||
# if the mesh carries per-item counts (variable-size batch).
|
||||
v_colors = getattr(mesh, "vertex_colors", None)
|
||||
v_uvs = getattr(mesh, "uvs", None)
|
||||
v_normals = getattr(mesh, "normals", None)
|
||||
if getattr(mesh, "vertex_counts", None) is not None:
|
||||
v_colors = mesh.vertex_colors
|
||||
v_uvs = mesh.uvs
|
||||
v_normals = mesh.normals
|
||||
if mesh.vertex_counts is not None:
|
||||
vertex_count = int(mesh.vertex_counts[index].item())
|
||||
face_count = int(mesh.face_counts[index].item())
|
||||
vertices = mesh.vertices[index, :vertex_count]
|
||||
@ -482,6 +473,7 @@ def save_glb(vertices, faces, filepath=None, metadata=None,
|
||||
|
||||
if (texture_png_bytes is not None and has_uv) or "COLOR_0" in primitive_attributes:
|
||||
pbr["baseColorFactor"] = [1.0, 1.0, 1.0, 1.0]
|
||||
pbr["roughnessFactor"] = 1.0
|
||||
|
||||
if mr_png_bytes is not None and has_uv:
|
||||
mr_texture_index = add_image_texture(mr_byte_offset, len(mr_buffer))
|
||||
@ -599,7 +591,7 @@ def mesh_item_to_glb_bytes(mesh, index, metadata=None):
|
||||
assert a.ndim == 3 and a.shape[-1] == 3, f"{attr} must be (B, H, W, 3), got {tuple(t.shape)}"
|
||||
return Image.fromarray(a, mode="RGB")
|
||||
|
||||
tangents_b = getattr(mesh, "tangents", None)
|
||||
tangents_b = mesh.tangents
|
||||
tangents_i = tangents_b[index, :vertices_i.shape[0]] if tangents_b is not None else None
|
||||
return save_glb(
|
||||
vertices_i, faces_i, None, metadata,
|
||||
@ -607,12 +599,12 @@ def mesh_item_to_glb_bytes(mesh, index, metadata=None):
|
||||
vertex_colors=v_colors,
|
||||
texture_image=_img("texture"),
|
||||
metallic_roughness_image=_img("metallic_roughness"),
|
||||
unlit=getattr(mesh, "unlit", False),
|
||||
unlit=mesh.unlit,
|
||||
normals=normals_i,
|
||||
normal_map_image=_img("normal_map"),
|
||||
tangents=tangents_i,
|
||||
occlusion_in_mr=getattr(mesh, "occlusion_in_mr", False),
|
||||
material=getattr(mesh, "material", None),
|
||||
occlusion_in_mr=mesh.occlusion_in_mr,
|
||||
material=mesh.material,
|
||||
emissive_image=_img("emissive"),
|
||||
)
|
||||
|
||||
@ -810,7 +802,7 @@ class RotateMesh(IO.ComfyNode):
|
||||
else:
|
||||
out.vertices = rotate(mesh.vertices)
|
||||
# Normals are directions; rotate them too (R is orthogonal) so they stay valid.
|
||||
nrm = getattr(mesh, "normals", None)
|
||||
nrm = mesh.normals
|
||||
if nrm is not None:
|
||||
out.normals = [rotate(n) for n in nrm] if isinstance(nrm, list) else rotate(nrm)
|
||||
return IO.NodeOutput(out)
|
||||
@ -860,7 +852,7 @@ class MeshSmoothNormals(IO.ComfyNode):
|
||||
return IO.NodeOutput(out)
|
||||
|
||||
# Crease split changes per-item vertex counts -> rebuild as a variable-size batch.
|
||||
tangents_b = getattr(mesh, "tangents", None)
|
||||
tangents_b = mesh.tangents
|
||||
v_list, f_list, n_list = [], [], []
|
||||
c_list = [] if mesh.vertex_colors is not None else None
|
||||
u_list = [] if mesh.uvs is not None else None
|
||||
@ -889,11 +881,11 @@ class MeshSmoothNormals(IO.ComfyNode):
|
||||
return IO.NodeOutput(mesh)
|
||||
out = pack_variable_mesh_batch(
|
||||
v_list, f_list, colors=c_list, uvs=u_list,
|
||||
texture=mesh.texture, unlit=getattr(mesh, "unlit", False),
|
||||
normals=n_list, metallic_roughness=getattr(mesh, "metallic_roughness", None),
|
||||
tangents=t_list, normal_map=getattr(mesh, "normal_map", None),
|
||||
occlusion_in_mr=getattr(mesh, "occlusion_in_mr", False),
|
||||
material=getattr(mesh, "material", None), emissive=getattr(mesh, "emissive", None))
|
||||
texture=mesh.texture, unlit=mesh.unlit,
|
||||
normals=n_list, metallic_roughness=mesh.metallic_roughness,
|
||||
tangents=t_list, normal_map=mesh.normal_map,
|
||||
occlusion_in_mr=mesh.occlusion_in_mr,
|
||||
material=mesh.material, emissive=mesh.emissive)
|
||||
return IO.NodeOutput(out)
|
||||
|
||||
|
||||
|
||||
@ -8,9 +8,7 @@ from server import PromptServer
|
||||
import comfy.latent_formats
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
from PIL import Image
|
||||
import logging
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
|
||||
@ -425,7 +423,6 @@ def _dinov3_encode(model, image_bchw, image_size, want_patches=False):
|
||||
mean = torch.tensor(model.image_mean or [0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
|
||||
std = torch.tensor(model.image_std or [0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
|
||||
img_t = (img_t - mean) / std
|
||||
model_internal.image_size = image_size
|
||||
tokens = model_internal(img_t, skip_norm_elementwise=True)[0]
|
||||
if not want_patches:
|
||||
return tokens
|
||||
@ -434,20 +431,6 @@ def _dinov3_encode(model, image_bchw, image_size, want_patches=False):
|
||||
return {"tokens": tokens[:, :1 + n_reg], "patches_2d": _dinov3_patches_to_2d(tokens, image_size)}
|
||||
|
||||
|
||||
def run_conditioning(model, cropped_pil_img):
|
||||
device = comfy.model_management.intermediate_device()
|
||||
|
||||
img_np = np.array(cropped_pil_img).astype(np.float32) / 255.0
|
||||
image_bchw = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).contiguous()
|
||||
|
||||
cond_512 = _dinov3_encode(model, image_bchw, 512)
|
||||
cond_1024 = _dinov3_encode(model, image_bchw, 1024)
|
||||
return {
|
||||
"cond_512": cond_512.to(device),
|
||||
"neg_cond": torch.zeros_like(cond_512).to(device),
|
||||
"cond_1024": cond_1024.to(device),
|
||||
}
|
||||
|
||||
class Trellis2Conditioning(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -467,124 +450,10 @@ class Trellis2Conditioning(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip_vision_model, image, mask) -> IO.NodeOutput:
|
||||
# Normalize to batched form so per-image conditioning loop below is uniform.
|
||||
if image.ndim == 3:
|
||||
image = image.unsqueeze(0)
|
||||
elif image.ndim == 4:
|
||||
if image.shape[1] in [1, 3, 4] and image.shape[-1] not in [1, 3, 4]:
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
|
||||
# normalize mask to standard [B, H, W] (handling 2D, 3D, and 4D variants)
|
||||
if mask.ndim == 4:
|
||||
if mask.shape[1] == 1:
|
||||
mask = mask.squeeze(1)
|
||||
elif mask.shape[-1] == 1:
|
||||
mask = mask.squeeze(-1)
|
||||
else:
|
||||
mask = mask[:, :, :, 0] # take first channel as fallback
|
||||
|
||||
if mask.ndim == 3:
|
||||
if mask.shape[-1] == 1:
|
||||
mask = mask.squeeze(-1).unsqueeze(0)
|
||||
elif mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
batch_size = image.shape[0]
|
||||
if mask.shape[0] == 1 and batch_size > 1:
|
||||
mask = mask.expand(batch_size, -1, -1)
|
||||
elif mask.shape[0] != batch_size:
|
||||
raise ValueError(f"Trellis2Conditioning mask batch {mask.shape[0]} does not match image batch {batch_size}")
|
||||
|
||||
cond_512_list = []
|
||||
cond_1024_list = []
|
||||
|
||||
for b in range(batch_size):
|
||||
item_image = image[b]
|
||||
item_mask = mask[b] if mask.size(0) > 1 else mask[0]
|
||||
|
||||
img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||
mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||
|
||||
# Ensure img_np is either 2D (grayscale) or 3D (RGB/RGBA)
|
||||
if img_np.ndim == 3 and img_np.shape[-1] == 1:
|
||||
img_np = img_np.squeeze(-1)
|
||||
|
||||
mask_np = mask_np.squeeze()
|
||||
|
||||
# detect inverted mask
|
||||
border_pixels = np.concatenate([
|
||||
mask_np[0, :], mask_np[-1, :], mask_np[:, 0], mask_np[:, -1]
|
||||
])
|
||||
if np.mean(border_pixels) > 127:
|
||||
mask_np = 255 - mask_np
|
||||
|
||||
mask_np[mask_np < 35] = 0
|
||||
|
||||
border_shave = 4
|
||||
mask_np[:border_shave, :] = 0
|
||||
mask_np[-border_shave:, :] = 0
|
||||
mask_np[:, :border_shave] = 0
|
||||
mask_np[:, -border_shave:] = 0
|
||||
|
||||
pil_img = Image.fromarray(img_np)
|
||||
pil_mask = Image.fromarray(mask_np)
|
||||
|
||||
max_size = max(pil_img.size)
|
||||
scale = min(1.0, 1024 / max_size)
|
||||
if scale < 1.0:
|
||||
new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale)
|
||||
pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST)
|
||||
|
||||
rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8)
|
||||
rgba_np[:, :, :3] = np.array(pil_img.convert("RGB"))
|
||||
rgba_np[:, :, 3] = np.array(pil_mask)
|
||||
|
||||
alpha = rgba_np[:, :, 3]
|
||||
bbox_coords = np.argwhere(alpha > 0.8 * 255)
|
||||
|
||||
if len(bbox_coords) > 0:
|
||||
y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1])
|
||||
y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1])
|
||||
|
||||
center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0
|
||||
size = max(y_max - y_min, x_max - x_min)
|
||||
|
||||
crop_x1 = int(center_x - size // 2)
|
||||
crop_y1 = int(center_y - size // 2)
|
||||
crop_x2 = int(center_x + size // 2)
|
||||
crop_y2 = int(center_y + size // 2)
|
||||
|
||||
rgba_pil = Image.fromarray(rgba_np)
|
||||
cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2))
|
||||
cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0
|
||||
else:
|
||||
logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.")
|
||||
cropped_np = rgba_np.astype(np.float32) / 255.0
|
||||
|
||||
bg_rgb = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||||
|
||||
fg = cropped_np[:, :, :3]
|
||||
alpha_float = cropped_np[:, :, 3:4]
|
||||
composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float)
|
||||
|
||||
# Keep the image as 4-channel RGBA to force TRELLIS to bypass its internal background remover
|
||||
rgb_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
|
||||
alpha_uint8 = (alpha_float.squeeze(-1) * 255.0).round().clip(0, 255).astype(np.uint8)
|
||||
|
||||
rgba_composite = np.zeros((cropped_np.shape[0], cropped_np.shape[1], 4), dtype=np.uint8)
|
||||
rgba_composite[:, :, :3] = rgb_uint8
|
||||
rgba_composite[:, :, 3] = alpha_uint8
|
||||
|
||||
cropped_pil = Image.fromarray(rgba_composite, mode="RGBA")
|
||||
|
||||
# Convert to RGB to ensure the CLIP/DINO model receives a 3-channel image
|
||||
item_conditioning = run_conditioning(clip_vision_model, cropped_pil.convert("RGB"))
|
||||
cond_512_list.append(item_conditioning["cond_512"])
|
||||
cond_1024_list.append(item_conditioning["cond_1024"])
|
||||
|
||||
cond_512_batched = torch.cat(cond_512_list, dim=0)
|
||||
cond_1024_batched = torch.cat(cond_1024_list, dim=0)
|
||||
out_device = comfy.model_management.intermediate_device()
|
||||
cond = _dino_condition_batch(clip_vision_model, image, mask, out_device,
|
||||
pad_factor=1.0, mask_threshold=35.0 / 255.0, border_shave=4)
|
||||
cond_512_batched, cond_1024_batched = cond["global_512"], cond["global_1024"]
|
||||
neg_cond_batched = torch.zeros_like(cond_512_batched)
|
||||
neg_embeds_batched = torch.zeros_like(cond_1024_batched)
|
||||
|
||||
@ -781,12 +650,11 @@ def _dinov3_patches_to_2d(tokens, image_size, patch_size=16):
|
||||
return patches.transpose(1, 2).reshape(tokens.shape[0], -1, h_p, w_p).contiguous()
|
||||
|
||||
|
||||
def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
|
||||
img = item_image.permute(2, 0, 1).unsqueeze(0).cpu().float()
|
||||
mask = item_mask.unsqueeze(0).unsqueeze(0).cpu().float()
|
||||
# Upstream went float→PIL uint8 implicitly; match that to keep composite bit-exact.
|
||||
img = (img.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0
|
||||
mask = (mask.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0
|
||||
def _crop_image_with_mask(item_image, item_mask, max_image_size=1024, pad_factor=1.1,
|
||||
mask_threshold=0.0, border_shave=0):
|
||||
img = item_image[..., :3] if item_image.shape[-1] >= 3 else item_image[..., :1].repeat(1, 1, 3)
|
||||
img = img.permute(2, 0, 1).unsqueeze(0).cpu().float().clamp(0, 1)
|
||||
mask = item_mask.unsqueeze(0).unsqueeze(0).cpu().float().clamp(0, 1)
|
||||
|
||||
# Detect & correct an inverted mask
|
||||
m2d = mask[0, 0]
|
||||
@ -794,6 +662,15 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
|
||||
if float(border.mean()) > 0.5:
|
||||
mask = 1.0 - mask
|
||||
|
||||
if mask_threshold > 0.0:
|
||||
mask = torch.where(mask < mask_threshold, torch.zeros_like(mask), mask)
|
||||
if border_shave > 0:
|
||||
bs = border_shave
|
||||
mask[..., :bs, :] = 0
|
||||
mask[..., -bs:, :] = 0
|
||||
mask[..., :, :bs] = 0
|
||||
mask[..., :, -bs:] = 0
|
||||
|
||||
H, W = img.shape[-2:]
|
||||
if max(H, W) > max_image_size:
|
||||
scale = max_image_size / max(H, W)
|
||||
@ -809,14 +686,14 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
|
||||
y_min, x_min = fg_pixels.min(dim=0).values.tolist()
|
||||
y_max, x_max = fg_pixels.max(dim=0).values.tolist()
|
||||
center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0
|
||||
size = int(max(y_max - y_min, x_max - x_min) * 1.1)
|
||||
size = int(max(y_max - y_min, x_max - x_min) * pad_factor)
|
||||
half = size // 2
|
||||
crop_x1 = int(center_x - half)
|
||||
crop_y1 = int(center_y - half)
|
||||
crop_x2 = crop_x1 + 2 * half
|
||||
crop_y2 = crop_y1 + 2 * half
|
||||
else:
|
||||
logging.warning("Mask for the image is empty. Pixal3D requires a clean foreground mask.")
|
||||
logging.warning("Mask for the image is empty; a clean foreground mask is required for best quality.")
|
||||
crop_x1, crop_y1, crop_x2, crop_y2 = 0, 0, W, H
|
||||
crop_bbox = (crop_x1, crop_y1, crop_x2, crop_y2)
|
||||
|
||||
@ -836,9 +713,77 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
|
||||
cropped_mask = mask[..., crop_y1:crop_y2, crop_x1:crop_x2]
|
||||
|
||||
composite = (cropped_img * cropped_mask).clamp(0, 1)
|
||||
composite = (composite * 255.0).round().clamp(0, 255).to(torch.uint8).float() / 255.0
|
||||
return composite, crop_bbox, scene_size
|
||||
|
||||
|
||||
def _dino_condition_batch(clip_vision_model, image, mask, out_device, *,
|
||||
pad_factor, mask_threshold=0.0, border_shave=0, want_patches=False):
|
||||
"""Normalize image/mask to a batch, then per item: masked square crop + DINOv3 encode at
|
||||
512 and 1024. Returns batched global tokens; with want_patches also the 2D patch grids and
|
||||
the per-item composites / crop bboxes / scene sizes that the Pixal3D NAF+projection path needs."""
|
||||
# Normalize to batched form so the per-image loop is uniform.
|
||||
if image.ndim == 3:
|
||||
image = image.unsqueeze(0)
|
||||
elif image.ndim == 4:
|
||||
if image.shape[1] in [1, 3, 4] and image.shape[-1] not in [1, 3, 4]:
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
|
||||
if mask.ndim == 4:
|
||||
if mask.shape[1] == 1:
|
||||
mask = mask.squeeze(1)
|
||||
elif mask.shape[-1] == 1:
|
||||
mask = mask.squeeze(-1)
|
||||
else:
|
||||
mask = mask[:, :, :, 0] # take first channel as fallback
|
||||
if mask.ndim == 3:
|
||||
if mask.shape[-1] == 1:
|
||||
mask = mask.squeeze(-1).unsqueeze(0)
|
||||
elif mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
batch_size = image.shape[0]
|
||||
if mask.shape[0] == 1 and batch_size > 1:
|
||||
mask = mask.expand(batch_size, -1, -1)
|
||||
elif mask.shape[0] != batch_size:
|
||||
raise ValueError(f"Conditioning mask batch {mask.shape[0]} does not match image batch {batch_size}")
|
||||
|
||||
cond_512_list, cond_1024_list = [], []
|
||||
patches_512_list, patches_1024_list = [], []
|
||||
composite_list, crop_bbox_list, scene_size_list = [], [], []
|
||||
for b in range(batch_size):
|
||||
item_image = image[b]
|
||||
item_mask = mask[b] if mask.size(0) > 1 else mask[0]
|
||||
composite, crop_bbox, scene_size = _crop_image_with_mask(
|
||||
item_image, item_mask, max_image_size=1024, pad_factor=pad_factor,
|
||||
mask_threshold=mask_threshold, border_shave=border_shave)
|
||||
c512 = _dinov3_encode(clip_vision_model, composite, 512, want_patches=want_patches)
|
||||
c1024 = _dinov3_encode(clip_vision_model, composite, 1024, want_patches=want_patches)
|
||||
if want_patches:
|
||||
cond_512_list.append(c512["tokens"].to(out_device))
|
||||
cond_1024_list.append(c1024["tokens"].to(out_device))
|
||||
patches_512_list.append(c512["patches_2d"].to(out_device))
|
||||
patches_1024_list.append(c1024["patches_2d"].to(out_device))
|
||||
composite_list.append(composite)
|
||||
crop_bbox_list.append(crop_bbox)
|
||||
scene_size_list.append(scene_size)
|
||||
else:
|
||||
cond_512_list.append(c512.to(out_device))
|
||||
cond_1024_list.append(c1024.to(out_device))
|
||||
|
||||
out = {
|
||||
"batch_size": batch_size,
|
||||
"global_512": torch.cat(cond_512_list, dim=0),
|
||||
"global_1024": torch.cat(cond_1024_list, dim=0),
|
||||
}
|
||||
if want_patches:
|
||||
out["patches_512"] = torch.cat(patches_512_list, dim=0)
|
||||
out["patches_1024"] = torch.cat(patches_1024_list, dim=0)
|
||||
out["composites"] = composite_list
|
||||
out["crop_bboxes"] = crop_bbox_list
|
||||
out["scene_sizes"] = scene_size_list
|
||||
return out
|
||||
|
||||
|
||||
class Pixal3DConditioning(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -867,45 +812,15 @@ class Pixal3DConditioning(IO.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, clip_vision_model, image, mask, camera_angle_x) -> IO.NodeOutput:
|
||||
naf_model = clip_vision_model.naf
|
||||
if image.ndim == 3:
|
||||
image = image.unsqueeze(0)
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
batch_size = image.shape[0]
|
||||
if mask.shape[0] == 1 and batch_size > 1:
|
||||
mask = mask.expand(batch_size, -1, -1)
|
||||
elif mask.shape[0] != batch_size:
|
||||
raise ValueError(f"Pixal3DConditioning mask batch {mask.shape[0]} != image batch {batch_size}")
|
||||
out_device = comfy.model_management.intermediate_device()
|
||||
compute_device = comfy.model_management.get_torch_device()
|
||||
|
||||
device = comfy.model_management.intermediate_device()
|
||||
|
||||
cond_512_list, cond_1024_list = [], []
|
||||
patches_512_list, patches_1024_list = [], []
|
||||
composite_list = []
|
||||
crop_bbox_list, scene_size_list = [], []
|
||||
|
||||
torch_device = comfy.model_management.get_torch_device()
|
||||
for b in range(batch_size):
|
||||
item_image = image[b]
|
||||
item_mask = mask[b] if mask.size(0) > 1 else mask[0]
|
||||
composite, crop_bbox, scene_size = _crop_image_with_mask(
|
||||
item_image, item_mask, max_image_size=1024)
|
||||
crop_bbox_list.append(crop_bbox)
|
||||
scene_size_list.append(scene_size)
|
||||
composite_list.append(composite)
|
||||
|
||||
cond_512 = _dinov3_encode(clip_vision_model, composite, 512, want_patches=True)
|
||||
cond_1024 = _dinov3_encode(clip_vision_model, composite, 1024, want_patches=True)
|
||||
cond_512_list.append(cond_512["tokens"].to(device))
|
||||
cond_1024_list.append(cond_1024["tokens"].to(device))
|
||||
patches_512_list.append(cond_512["patches_2d"].to(device))
|
||||
patches_1024_list.append(cond_1024["patches_2d"].to(device))
|
||||
|
||||
global_512 = torch.cat(cond_512_list, dim=0)
|
||||
global_1024 = torch.cat(cond_1024_list, dim=0)
|
||||
|
||||
fm_512_dino = torch.cat(patches_512_list, dim=0)
|
||||
fm_1024_dino = torch.cat(patches_1024_list, dim=0)
|
||||
cond = _dino_condition_batch(clip_vision_model, image, mask, out_device, pad_factor=1.1, want_patches=True)
|
||||
batch_size = cond["batch_size"]
|
||||
global_512, global_1024 = cond["global_512"], cond["global_1024"]
|
||||
fm_512_dino, fm_1024_dino = cond["patches_512"], cond["patches_1024"]
|
||||
composite_list = cond["composites"]
|
||||
crop_bbox_list, scene_size_list = cond["crop_bboxes"], cond["scene_sizes"]
|
||||
|
||||
# The LR DINO grid AND the NAF HR grid are sampled separately
|
||||
# NAF targets per stage: shape_512=512, shape_1024=512, tex_1024=1024.
|
||||
@ -914,15 +829,13 @@ class Pixal3DConditioning(IO.ComfyNode):
|
||||
return None
|
||||
comfy.model_management.load_model_gpu(naf_model)
|
||||
inner = naf_model.model
|
||||
target_dtype = comfy.model_management.text_encoder_dtype(torch_device)
|
||||
if next(inner.parameters()).dtype != target_dtype:
|
||||
inner.to(dtype=target_dtype)
|
||||
model_dtype = next(inner.parameters()).dtype # set at load time (see clip_vision NAF)
|
||||
hrs = []
|
||||
for i, c in enumerate(composites):
|
||||
img_i = comfy.utils.common_upscale(c, image_size, image_size, "lanczos", "disabled")\
|
||||
.to(torch_device).to(target_dtype)
|
||||
lr_i = lr_feat[i:i + 1].to(torch_device).to(target_dtype)
|
||||
hr_i = inner(img_i, lr_i, naf_target, output_device=device)
|
||||
.to(compute_device).to(model_dtype)
|
||||
lr_i = lr_feat[i:i + 1].to(compute_device).to(model_dtype)
|
||||
hr_i = inner(img_i, lr_i, naf_target, output_device=out_device)
|
||||
hrs.append(hr_i)
|
||||
return torch.cat(hrs, dim=0)
|
||||
|
||||
@ -934,10 +847,10 @@ class Pixal3DConditioning(IO.ComfyNode):
|
||||
# FOV widget is in degrees for UX; trig + downstream projection expect radians.
|
||||
camera_angle_x = math.radians(float(camera_angle_x))
|
||||
distance = 0.5 / math.tan(camera_angle_x / 2.0)
|
||||
cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=device, dtype=torch.float32)
|
||||
dist_t = torch.tensor([distance] * batch_size, device=device, dtype=torch.float32)
|
||||
scale_t = torch.ones(batch_size, device=device, dtype=torch.float32)
|
||||
T = build_proj_transform_matrix(dist_t, batch_size, device=device, dtype=torch.float32)
|
||||
cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=out_device, dtype=torch.float32)
|
||||
dist_t = torch.tensor([distance] * batch_size, device=out_device, dtype=torch.float32)
|
||||
scale_t = torch.ones(batch_size, device=out_device, dtype=torch.float32)
|
||||
T = build_proj_transform_matrix(dist_t, batch_size, device=out_device, dtype=torch.float32)
|
||||
|
||||
proj_pack = {
|
||||
"stages": {
|
||||
@ -958,7 +871,7 @@ class Pixal3DConditioning(IO.ComfyNode):
|
||||
# global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024.
|
||||
ss_proj_feats = compute_stage_proj_feats(
|
||||
proj_pack, "ss", dense_grid_resolution=16, batch_size=batch_size,
|
||||
device=torch_device,
|
||||
device=compute_device,
|
||||
)
|
||||
neg_global = torch.zeros_like(global_512)
|
||||
neg_embeds = torch.zeros_like(global_1024)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user