Miscellanous cleanup

This commit is contained in:
kijai 2026-07-02 16:34:16 +03:00
parent 419e726061
commit d635cc412d
9 changed files with 177 additions and 292 deletions

View File

@ -163,6 +163,7 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
naf_sd = {k[len("naf."):]: sd.pop(k) for k in naf_keys}
naf = NAF().eval()
naf.load_state_dict(naf_sd, strict=False)
naf.to(comfy.model_management.text_encoder_dtype(clip.load_device))
clip.naf = comfy.model_patcher.CoreModelPatcher(naf, load_device=clip.load_device, offload_device=comfy.model_management.text_encoder_offload_device())
return clip

View File

@ -55,10 +55,12 @@ class SparseFeedForwardNet(nn.Module):
def forward(self, x: VarLenTensor) -> VarLenTensor:
return self.mlp(x)
class SparseMultiHeadRMSNorm(nn.Module):
def __init__(self, dim: int, heads: int, device, dtype):
class MultiHeadRMSNorm(nn.Module):
# Per-head qk-norm for both sparse (VarLenTensor) and dense inputs. gamma is [heads, dim]
# (per-head), so it's a broadcast multiply rather than F.rms_norm's 1-D weight
def __init__(self, dim: int, heads: int, device=None, dtype=None):
super().__init__()
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
self.gamma = nn.Parameter(torch.empty(heads, dim, device=device, dtype=dtype))
def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
if isinstance(x, VarLenTensor):
@ -147,8 +149,8 @@ class SparseMultiHeadAttention(nn.Module):
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype)
if self.qk_rms_norm:
self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
@ -307,7 +309,7 @@ class ModulatedSparseTransformerCrossBlock(nn.Module):
operations.Linear(channels, 6 * channels, bias=True, device=device, dtype=dtype)
)
else:
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
self.modulation = nn.Parameter(torch.empty(6 * channels, device=device, dtype=dtype))
def _forward(self, x: SparseTensor, mod: torch.Tensor, context, transformer_options=None) -> SparseTensor:
if self.share_mod:
@ -444,24 +446,6 @@ class FeedForwardNet(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
# class MultiHeadRMSNorm(nn.Module):
# def __init__(self, dim: int, heads: int, device=None, dtype=None):
# super().__init__()
# self.scale = dim ** 0.5
# self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
# def forward(self, x: torch.Tensor) -> torch.Tensor:
# return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
class MultiHeadRMSNorm(nn.Module):
def __init__(self, dim: int, heads: int, device=None, dtype=None):
super().__init__()
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return (F.rms_norm(x.float(), (x.shape[-1],)) * self.gamma).to(x.dtype)
class MultiHeadAttention(nn.Module):
def __init__(
self,
@ -580,7 +564,7 @@ class ModulatedTransformerCrossBlock(nn.Module):
if not share_mod:
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
else:
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
self.modulation = nn.Parameter(torch.empty(6 * channels, device=device, dtype=dtype))
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context,
phases: Optional[torch.Tensor] = None, transformer_options=None) -> torch.Tensor:
@ -631,7 +615,7 @@ class SparseStructureFlowModel(nn.Module):
proj_in_channels: Optional[int] = None,
operations=None,
device = None,
dtype = torch.float32,
dtype = None,
**kwargs
):
super().__init__()

View File

@ -918,17 +918,14 @@ def flexible_dual_grid_to_mesh(
):
device = coords.device
if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset") \
or flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset.device != device:
flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset = torch.tensor([
[[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis
[[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis
[[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis
], dtype=torch.int, device=device).unsqueeze(0)
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1") or flexible_dual_grid_to_mesh.quad_split_1.device != device:
flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2") or flexible_dual_grid_to_mesh.quad_split_2.device != device:
flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
# Small constant index tables — built per call (stateless), not cached on the function.
edge_neighbor_voxel_offset = torch.tensor([
[[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis
[[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis
[[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis
], dtype=torch.int, device=device).unsqueeze(0)
quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device)
quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device)
aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
@ -954,7 +951,7 @@ def flexible_dual_grid_to_mesh(
# Find connected voxels — direct gather instead of materializing the full [N, 3, 4, 3]
n_idx, axis_idx = intersected_flag.nonzero(as_tuple=True) # (M,), (M,)
offsets_per_axis = flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset[0] # (3, 4, 3)
offsets_per_axis = edge_neighbor_voxel_offset[0] # (3, 4, 3)
connected_voxel = coords[n_idx].unsqueeze(1) + offsets_per_axis[axis_idx] # (M, 4, 3)
M = connected_voxel.shape[0]
# flatten connected voxel coords and lookup. In-place to avoid extra memory allocation.
@ -973,12 +970,12 @@ def flexible_dual_grid_to_mesh(
mesh_vertices.add_(dual_vertices).mul_(voxel_size).add_(aabb[0].reshape(1, 3))
if split_weight is None:
# if split 1
atempt_triangles_0 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1]
atempt_triangles_0 = quad_indices[:, quad_split_1]
normals0 = torch.cross(mesh_vertices[atempt_triangles_0[:, 1]] - mesh_vertices[atempt_triangles_0[:, 0]], mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 0]])
normals1 = torch.cross(mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 1]], mesh_vertices[atempt_triangles_0[:, 3]] - mesh_vertices[atempt_triangles_0[:, 1]])
align0 = (normals0 * normals1).sum(dim=1, keepdim=True).abs()
# if split 2
atempt_triangles_1 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2]
atempt_triangles_1 = quad_indices[:, quad_split_2]
normals0 = torch.cross(mesh_vertices[atempt_triangles_1[:, 1]] - mesh_vertices[atempt_triangles_1[:, 0]], mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 0]])
normals1 = torch.cross(mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 1]], mesh_vertices[atempt_triangles_1[:, 3]] - mesh_vertices[atempt_triangles_1[:, 1]])
align1 = (normals0 * normals1).sum(dim=1, keepdim=True).abs()
@ -990,8 +987,8 @@ def flexible_dual_grid_to_mesh(
split_weight_ws_13 = split_weight_ws[:, 1] * split_weight_ws[:, 3]
mesh_triangles = torch.where(
split_weight_ws_02 > split_weight_ws_13,
quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1],
quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2]
quad_indices[:, quad_split_1],
quad_indices[:, quad_split_2]
).reshape(-1, 3)
return mesh_vertices, mesh_triangles

View File

@ -39,7 +39,12 @@ class MESH:
vertex_counts: torch.Tensor | None = None,
face_counts: torch.Tensor | None = None,
unlit: bool = False,
normals: torch.Tensor | None = None):
normals: torch.Tensor | None = None,
tangents: torch.Tensor | None = None,
normal_map: torch.Tensor | None = None,
occlusion_in_mr: bool = False,
material: dict | None = None,
emissive: torch.Tensor | None = None):
assert (vertex_counts is None) == (face_counts is None), \
"vertex_counts and face_counts must be provided together (both or neither)"
@ -59,6 +64,13 @@ class MESH:
self.face_counts = face_counts
# Render flat / emissive (no scene lighting) when saved, e.g. for gaussian-splat-derived meshes.
self.unlit = unlit
# Extra maps / material overrides attached by bake, normal/AO, and SetMeshMaterial nodes;
# consumed by SaveGLB. Declared here (with defaults) so consumers read them directly.
self.tangents = tangents # (B, N, 4) per-vertex tangents for normal mapping
self.normal_map = normal_map # tangent-space normal map: (B, H, W, 3)
self.occlusion_in_mr = occlusion_in_mr # True = R channel of metallic_roughness holds AO (ORM)
self.material = material # SetMeshMaterial scalar/factor overrides
self.emissive = emissive # emissive map: (B, H, W, 3)
class File3D:

View File

@ -1282,25 +1282,17 @@ def qem_simplify(
iteration = 0
total_collapses = 0
# progress bars (tqdm + optional comfy ProgressBar), best-effort
# progress bars (tqdm + comfy ProgressBar)
_start_faces = num_faces
_prog_total = max(1, _start_faces - int(target_faces))
try:
_qtq = _tqdm(total=100, desc="QEM simplify", leave=False)
except Exception:
_qtq = None
try:
_qpbar = _comfy_utils.ProgressBar(100)
except Exception:
_qpbar = None
_qtq = _tqdm(total=100, desc="QEM simplify", leave=False)
_qpbar = _comfy_utils.ProgressBar(100)
def _qreport():
pct = min(100, max(0, int(100 * (_start_faces - py_n_faces) / _prog_total)))
if _qtq is not None:
_qtq.n = pct
_qtq.refresh()
if _qpbar is not None:
_qpbar.update_absolute(pct, 100)
_qtq.n = pct
_qtq.refresh()
_qpbar.update_absolute(pct, 100)
while True:
if py_n_faces <= target_faces:
@ -1523,8 +1515,7 @@ def qem_simplify(
break
_qreport()
if _qtq is not None:
_qtq.close()
_qtq.close()
# finalize: compact verts and faces
final_v = verts[v_alive]

View File

@ -522,7 +522,6 @@ def _build_mdc_lut() -> Tuple[torch.Tensor, torch.Tensor]:
return K, group
@functools.lru_cache(maxsize=None)
def _mdc_lut(device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
K, g = _build_mdc_lut()
return K.to(device), g.to(device)
@ -967,15 +966,11 @@ def remesh_narrow_band_dc(
n_levels += 1
_total_ticks = n_levels + 3 + int(smooth_iters)
_pbar = comfy.utils.ProgressBar(_total_ticks)
try:
_tq = _tqdm(total=_total_ticks, desc="Remesh DC", leave=False)
except Exception:
_tq = None
_tq = _tqdm(total=_total_ticks, desc="Remesh DC", leave=False)
def tick():
_pbar.update(1)
if _tq is not None:
_tq.update(1)
_tq.update(1)
# Step 1: sparse narrow-band voxel grid (coarse-to-fine)
voxel_coords, _band_tree = _build_narrow_band_voxels(

View File

@ -366,8 +366,8 @@ def lscm_chart(
x_free = solve_least_squares(A_free, b)
if not np.all(np.isfinite(x_free)):
fallback_to_ortho = True
except Exception:
fallback_to_ortho = True
except (sp.linalg.MatrixRankWarning, RuntimeError):
fallback_to_ortho = True # singular / under-constrained system
if fallback_to_ortho:
if pin_positions is not None and pin_positions.shape == (Vc, 2):

View File

@ -74,31 +74,22 @@ def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=Non
)
packed_tangents[i, :tn.shape[0]] = tn
out = Types.MESH(packed_vertices, packed_faces,
uvs=packed_uvs, vertex_colors=packed_colors, texture=texture,
metallic_roughness=metallic_roughness,
vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit,
normals=packed_normals)
if packed_tangents is not None:
out.tangents = packed_tangents
if normal_map is not None:
out.normal_map = normal_map
if occlusion_in_mr:
out.occlusion_in_mr = True
if material is not None:
out.material = material
if emissive is not None:
out.emissive = emissive
return out
return Types.MESH(packed_vertices, packed_faces,
uvs=packed_uvs, vertex_colors=packed_colors, texture=texture,
metallic_roughness=metallic_roughness,
vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit,
normals=packed_normals, tangents=packed_tangents,
normal_map=normal_map, occlusion_in_mr=occlusion_in_mr,
material=material, emissive=emissive)
def get_mesh_batch_item(mesh, index):
# Returns (vertices, faces, colors, uvs) for batch index, slicing to real lengths
# if the mesh carries per-item counts (variable-size batch).
v_colors = getattr(mesh, "vertex_colors", None)
v_uvs = getattr(mesh, "uvs", None)
v_normals = getattr(mesh, "normals", None)
if getattr(mesh, "vertex_counts", None) is not None:
v_colors = mesh.vertex_colors
v_uvs = mesh.uvs
v_normals = mesh.normals
if mesh.vertex_counts is not None:
vertex_count = int(mesh.vertex_counts[index].item())
face_count = int(mesh.face_counts[index].item())
vertices = mesh.vertices[index, :vertex_count]
@ -482,6 +473,7 @@ def save_glb(vertices, faces, filepath=None, metadata=None,
if (texture_png_bytes is not None and has_uv) or "COLOR_0" in primitive_attributes:
pbr["baseColorFactor"] = [1.0, 1.0, 1.0, 1.0]
pbr["roughnessFactor"] = 1.0
if mr_png_bytes is not None and has_uv:
mr_texture_index = add_image_texture(mr_byte_offset, len(mr_buffer))
@ -599,7 +591,7 @@ def mesh_item_to_glb_bytes(mesh, index, metadata=None):
assert a.ndim == 3 and a.shape[-1] == 3, f"{attr} must be (B, H, W, 3), got {tuple(t.shape)}"
return Image.fromarray(a, mode="RGB")
tangents_b = getattr(mesh, "tangents", None)
tangents_b = mesh.tangents
tangents_i = tangents_b[index, :vertices_i.shape[0]] if tangents_b is not None else None
return save_glb(
vertices_i, faces_i, None, metadata,
@ -607,12 +599,12 @@ def mesh_item_to_glb_bytes(mesh, index, metadata=None):
vertex_colors=v_colors,
texture_image=_img("texture"),
metallic_roughness_image=_img("metallic_roughness"),
unlit=getattr(mesh, "unlit", False),
unlit=mesh.unlit,
normals=normals_i,
normal_map_image=_img("normal_map"),
tangents=tangents_i,
occlusion_in_mr=getattr(mesh, "occlusion_in_mr", False),
material=getattr(mesh, "material", None),
occlusion_in_mr=mesh.occlusion_in_mr,
material=mesh.material,
emissive_image=_img("emissive"),
)
@ -810,7 +802,7 @@ class RotateMesh(IO.ComfyNode):
else:
out.vertices = rotate(mesh.vertices)
# Normals are directions; rotate them too (R is orthogonal) so they stay valid.
nrm = getattr(mesh, "normals", None)
nrm = mesh.normals
if nrm is not None:
out.normals = [rotate(n) for n in nrm] if isinstance(nrm, list) else rotate(nrm)
return IO.NodeOutput(out)
@ -860,7 +852,7 @@ class MeshSmoothNormals(IO.ComfyNode):
return IO.NodeOutput(out)
# Crease split changes per-item vertex counts -> rebuild as a variable-size batch.
tangents_b = getattr(mesh, "tangents", None)
tangents_b = mesh.tangents
v_list, f_list, n_list = [], [], []
c_list = [] if mesh.vertex_colors is not None else None
u_list = [] if mesh.uvs is not None else None
@ -889,11 +881,11 @@ class MeshSmoothNormals(IO.ComfyNode):
return IO.NodeOutput(mesh)
out = pack_variable_mesh_batch(
v_list, f_list, colors=c_list, uvs=u_list,
texture=mesh.texture, unlit=getattr(mesh, "unlit", False),
normals=n_list, metallic_roughness=getattr(mesh, "metallic_roughness", None),
tangents=t_list, normal_map=getattr(mesh, "normal_map", None),
occlusion_in_mr=getattr(mesh, "occlusion_in_mr", False),
material=getattr(mesh, "material", None), emissive=getattr(mesh, "emissive", None))
texture=mesh.texture, unlit=mesh.unlit,
normals=n_list, metallic_roughness=mesh.metallic_roughness,
tangents=t_list, normal_map=mesh.normal_map,
occlusion_in_mr=mesh.occlusion_in_mr,
material=mesh.material, emissive=mesh.emissive)
return IO.NodeOutput(out)

View File

@ -8,9 +8,7 @@ from server import PromptServer
import comfy.latent_formats
import comfy.model_management
import comfy.utils
from PIL import Image
import logging
import numpy as np
import math
import torch
@ -425,7 +423,6 @@ def _dinov3_encode(model, image_bchw, image_size, want_patches=False):
mean = torch.tensor(model.image_mean or [0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
std = torch.tensor(model.image_std or [0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
img_t = (img_t - mean) / std
model_internal.image_size = image_size
tokens = model_internal(img_t, skip_norm_elementwise=True)[0]
if not want_patches:
return tokens
@ -434,20 +431,6 @@ def _dinov3_encode(model, image_bchw, image_size, want_patches=False):
return {"tokens": tokens[:, :1 + n_reg], "patches_2d": _dinov3_patches_to_2d(tokens, image_size)}
def run_conditioning(model, cropped_pil_img):
device = comfy.model_management.intermediate_device()
img_np = np.array(cropped_pil_img).astype(np.float32) / 255.0
image_bchw = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).contiguous()
cond_512 = _dinov3_encode(model, image_bchw, 512)
cond_1024 = _dinov3_encode(model, image_bchw, 1024)
return {
"cond_512": cond_512.to(device),
"neg_cond": torch.zeros_like(cond_512).to(device),
"cond_1024": cond_1024.to(device),
}
class Trellis2Conditioning(IO.ComfyNode):
@classmethod
def define_schema(cls):
@ -467,124 +450,10 @@ class Trellis2Conditioning(IO.ComfyNode):
@classmethod
def execute(cls, clip_vision_model, image, mask) -> IO.NodeOutput:
# Normalize to batched form so per-image conditioning loop below is uniform.
if image.ndim == 3:
image = image.unsqueeze(0)
elif image.ndim == 4:
if image.shape[1] in [1, 3, 4] and image.shape[-1] not in [1, 3, 4]:
image = image.permute(0, 2, 3, 1)
# normalize mask to standard [B, H, W] (handling 2D, 3D, and 4D variants)
if mask.ndim == 4:
if mask.shape[1] == 1:
mask = mask.squeeze(1)
elif mask.shape[-1] == 1:
mask = mask.squeeze(-1)
else:
mask = mask[:, :, :, 0] # take first channel as fallback
if mask.ndim == 3:
if mask.shape[-1] == 1:
mask = mask.squeeze(-1).unsqueeze(0)
elif mask.ndim == 2:
mask = mask.unsqueeze(0)
batch_size = image.shape[0]
if mask.shape[0] == 1 and batch_size > 1:
mask = mask.expand(batch_size, -1, -1)
elif mask.shape[0] != batch_size:
raise ValueError(f"Trellis2Conditioning mask batch {mask.shape[0]} does not match image batch {batch_size}")
cond_512_list = []
cond_1024_list = []
for b in range(batch_size):
item_image = image[b]
item_mask = mask[b] if mask.size(0) > 1 else mask[0]
img_np = (item_image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
mask_np = (item_mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
# Ensure img_np is either 2D (grayscale) or 3D (RGB/RGBA)
if img_np.ndim == 3 and img_np.shape[-1] == 1:
img_np = img_np.squeeze(-1)
mask_np = mask_np.squeeze()
# detect inverted mask
border_pixels = np.concatenate([
mask_np[0, :], mask_np[-1, :], mask_np[:, 0], mask_np[:, -1]
])
if np.mean(border_pixels) > 127:
mask_np = 255 - mask_np
mask_np[mask_np < 35] = 0
border_shave = 4
mask_np[:border_shave, :] = 0
mask_np[-border_shave:, :] = 0
mask_np[:, :border_shave] = 0
mask_np[:, -border_shave:] = 0
pil_img = Image.fromarray(img_np)
pil_mask = Image.fromarray(mask_np)
max_size = max(pil_img.size)
scale = min(1.0, 1024 / max_size)
if scale < 1.0:
new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale)
pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS)
pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST)
rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8)
rgba_np[:, :, :3] = np.array(pil_img.convert("RGB"))
rgba_np[:, :, 3] = np.array(pil_mask)
alpha = rgba_np[:, :, 3]
bbox_coords = np.argwhere(alpha > 0.8 * 255)
if len(bbox_coords) > 0:
y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1])
y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1])
center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0
size = max(y_max - y_min, x_max - x_min)
crop_x1 = int(center_x - size // 2)
crop_y1 = int(center_y - size // 2)
crop_x2 = int(center_x + size // 2)
crop_y2 = int(center_y + size // 2)
rgba_pil = Image.fromarray(rgba_np)
cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2))
cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0
else:
logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.")
cropped_np = rgba_np.astype(np.float32) / 255.0
bg_rgb = np.array([0.0, 0.0, 0.0], dtype=np.float32)
fg = cropped_np[:, :, :3]
alpha_float = cropped_np[:, :, 3:4]
composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float)
# Keep the image as 4-channel RGBA to force TRELLIS to bypass its internal background remover
rgb_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
alpha_uint8 = (alpha_float.squeeze(-1) * 255.0).round().clip(0, 255).astype(np.uint8)
rgba_composite = np.zeros((cropped_np.shape[0], cropped_np.shape[1], 4), dtype=np.uint8)
rgba_composite[:, :, :3] = rgb_uint8
rgba_composite[:, :, 3] = alpha_uint8
cropped_pil = Image.fromarray(rgba_composite, mode="RGBA")
# Convert to RGB to ensure the CLIP/DINO model receives a 3-channel image
item_conditioning = run_conditioning(clip_vision_model, cropped_pil.convert("RGB"))
cond_512_list.append(item_conditioning["cond_512"])
cond_1024_list.append(item_conditioning["cond_1024"])
cond_512_batched = torch.cat(cond_512_list, dim=0)
cond_1024_batched = torch.cat(cond_1024_list, dim=0)
out_device = comfy.model_management.intermediate_device()
cond = _dino_condition_batch(clip_vision_model, image, mask, out_device,
pad_factor=1.0, mask_threshold=35.0 / 255.0, border_shave=4)
cond_512_batched, cond_1024_batched = cond["global_512"], cond["global_1024"]
neg_cond_batched = torch.zeros_like(cond_512_batched)
neg_embeds_batched = torch.zeros_like(cond_1024_batched)
@ -781,12 +650,11 @@ def _dinov3_patches_to_2d(tokens, image_size, patch_size=16):
return patches.transpose(1, 2).reshape(tokens.shape[0], -1, h_p, w_p).contiguous()
def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
img = item_image.permute(2, 0, 1).unsqueeze(0).cpu().float()
mask = item_mask.unsqueeze(0).unsqueeze(0).cpu().float()
# Upstream went float→PIL uint8 implicitly; match that to keep composite bit-exact.
img = (img.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0
mask = (mask.clamp(0, 1) * 255.0).to(torch.uint8).float() / 255.0
def _crop_image_with_mask(item_image, item_mask, max_image_size=1024, pad_factor=1.1,
mask_threshold=0.0, border_shave=0):
img = item_image[..., :3] if item_image.shape[-1] >= 3 else item_image[..., :1].repeat(1, 1, 3)
img = img.permute(2, 0, 1).unsqueeze(0).cpu().float().clamp(0, 1)
mask = item_mask.unsqueeze(0).unsqueeze(0).cpu().float().clamp(0, 1)
# Detect & correct an inverted mask
m2d = mask[0, 0]
@ -794,6 +662,15 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
if float(border.mean()) > 0.5:
mask = 1.0 - mask
if mask_threshold > 0.0:
mask = torch.where(mask < mask_threshold, torch.zeros_like(mask), mask)
if border_shave > 0:
bs = border_shave
mask[..., :bs, :] = 0
mask[..., -bs:, :] = 0
mask[..., :, :bs] = 0
mask[..., :, -bs:] = 0
H, W = img.shape[-2:]
if max(H, W) > max_image_size:
scale = max_image_size / max(H, W)
@ -809,14 +686,14 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
y_min, x_min = fg_pixels.min(dim=0).values.tolist()
y_max, x_max = fg_pixels.max(dim=0).values.tolist()
center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0
size = int(max(y_max - y_min, x_max - x_min) * 1.1)
size = int(max(y_max - y_min, x_max - x_min) * pad_factor)
half = size // 2
crop_x1 = int(center_x - half)
crop_y1 = int(center_y - half)
crop_x2 = crop_x1 + 2 * half
crop_y2 = crop_y1 + 2 * half
else:
logging.warning("Mask for the image is empty. Pixal3D requires a clean foreground mask.")
logging.warning("Mask for the image is empty; a clean foreground mask is required for best quality.")
crop_x1, crop_y1, crop_x2, crop_y2 = 0, 0, W, H
crop_bbox = (crop_x1, crop_y1, crop_x2, crop_y2)
@ -836,9 +713,77 @@ def _crop_image_with_mask(item_image, item_mask, max_image_size=1024):
cropped_mask = mask[..., crop_y1:crop_y2, crop_x1:crop_x2]
composite = (cropped_img * cropped_mask).clamp(0, 1)
composite = (composite * 255.0).round().clamp(0, 255).to(torch.uint8).float() / 255.0
return composite, crop_bbox, scene_size
def _dino_condition_batch(clip_vision_model, image, mask, out_device, *,
pad_factor, mask_threshold=0.0, border_shave=0, want_patches=False):
"""Normalize image/mask to a batch, then per item: masked square crop + DINOv3 encode at
512 and 1024. Returns batched global tokens; with want_patches also the 2D patch grids and
the per-item composites / crop bboxes / scene sizes that the Pixal3D NAF+projection path needs."""
# Normalize to batched form so the per-image loop is uniform.
if image.ndim == 3:
image = image.unsqueeze(0)
elif image.ndim == 4:
if image.shape[1] in [1, 3, 4] and image.shape[-1] not in [1, 3, 4]:
image = image.permute(0, 2, 3, 1)
if mask.ndim == 4:
if mask.shape[1] == 1:
mask = mask.squeeze(1)
elif mask.shape[-1] == 1:
mask = mask.squeeze(-1)
else:
mask = mask[:, :, :, 0] # take first channel as fallback
if mask.ndim == 3:
if mask.shape[-1] == 1:
mask = mask.squeeze(-1).unsqueeze(0)
elif mask.ndim == 2:
mask = mask.unsqueeze(0)
batch_size = image.shape[0]
if mask.shape[0] == 1 and batch_size > 1:
mask = mask.expand(batch_size, -1, -1)
elif mask.shape[0] != batch_size:
raise ValueError(f"Conditioning mask batch {mask.shape[0]} does not match image batch {batch_size}")
cond_512_list, cond_1024_list = [], []
patches_512_list, patches_1024_list = [], []
composite_list, crop_bbox_list, scene_size_list = [], [], []
for b in range(batch_size):
item_image = image[b]
item_mask = mask[b] if mask.size(0) > 1 else mask[0]
composite, crop_bbox, scene_size = _crop_image_with_mask(
item_image, item_mask, max_image_size=1024, pad_factor=pad_factor,
mask_threshold=mask_threshold, border_shave=border_shave)
c512 = _dinov3_encode(clip_vision_model, composite, 512, want_patches=want_patches)
c1024 = _dinov3_encode(clip_vision_model, composite, 1024, want_patches=want_patches)
if want_patches:
cond_512_list.append(c512["tokens"].to(out_device))
cond_1024_list.append(c1024["tokens"].to(out_device))
patches_512_list.append(c512["patches_2d"].to(out_device))
patches_1024_list.append(c1024["patches_2d"].to(out_device))
composite_list.append(composite)
crop_bbox_list.append(crop_bbox)
scene_size_list.append(scene_size)
else:
cond_512_list.append(c512.to(out_device))
cond_1024_list.append(c1024.to(out_device))
out = {
"batch_size": batch_size,
"global_512": torch.cat(cond_512_list, dim=0),
"global_1024": torch.cat(cond_1024_list, dim=0),
}
if want_patches:
out["patches_512"] = torch.cat(patches_512_list, dim=0)
out["patches_1024"] = torch.cat(patches_1024_list, dim=0)
out["composites"] = composite_list
out["crop_bboxes"] = crop_bbox_list
out["scene_sizes"] = scene_size_list
return out
class Pixal3DConditioning(IO.ComfyNode):
@classmethod
@ -867,45 +812,15 @@ class Pixal3DConditioning(IO.ComfyNode):
@classmethod
def execute(cls, clip_vision_model, image, mask, camera_angle_x) -> IO.NodeOutput:
naf_model = clip_vision_model.naf
if image.ndim == 3:
image = image.unsqueeze(0)
if mask.ndim == 2:
mask = mask.unsqueeze(0)
batch_size = image.shape[0]
if mask.shape[0] == 1 and batch_size > 1:
mask = mask.expand(batch_size, -1, -1)
elif mask.shape[0] != batch_size:
raise ValueError(f"Pixal3DConditioning mask batch {mask.shape[0]} != image batch {batch_size}")
out_device = comfy.model_management.intermediate_device()
compute_device = comfy.model_management.get_torch_device()
device = comfy.model_management.intermediate_device()
cond_512_list, cond_1024_list = [], []
patches_512_list, patches_1024_list = [], []
composite_list = []
crop_bbox_list, scene_size_list = [], []
torch_device = comfy.model_management.get_torch_device()
for b in range(batch_size):
item_image = image[b]
item_mask = mask[b] if mask.size(0) > 1 else mask[0]
composite, crop_bbox, scene_size = _crop_image_with_mask(
item_image, item_mask, max_image_size=1024)
crop_bbox_list.append(crop_bbox)
scene_size_list.append(scene_size)
composite_list.append(composite)
cond_512 = _dinov3_encode(clip_vision_model, composite, 512, want_patches=True)
cond_1024 = _dinov3_encode(clip_vision_model, composite, 1024, want_patches=True)
cond_512_list.append(cond_512["tokens"].to(device))
cond_1024_list.append(cond_1024["tokens"].to(device))
patches_512_list.append(cond_512["patches_2d"].to(device))
patches_1024_list.append(cond_1024["patches_2d"].to(device))
global_512 = torch.cat(cond_512_list, dim=0)
global_1024 = torch.cat(cond_1024_list, dim=0)
fm_512_dino = torch.cat(patches_512_list, dim=0)
fm_1024_dino = torch.cat(patches_1024_list, dim=0)
cond = _dino_condition_batch(clip_vision_model, image, mask, out_device, pad_factor=1.1, want_patches=True)
batch_size = cond["batch_size"]
global_512, global_1024 = cond["global_512"], cond["global_1024"]
fm_512_dino, fm_1024_dino = cond["patches_512"], cond["patches_1024"]
composite_list = cond["composites"]
crop_bbox_list, scene_size_list = cond["crop_bboxes"], cond["scene_sizes"]
# The LR DINO grid AND the NAF HR grid are sampled separately
# NAF targets per stage: shape_512=512, shape_1024=512, tex_1024=1024.
@ -914,15 +829,13 @@ class Pixal3DConditioning(IO.ComfyNode):
return None
comfy.model_management.load_model_gpu(naf_model)
inner = naf_model.model
target_dtype = comfy.model_management.text_encoder_dtype(torch_device)
if next(inner.parameters()).dtype != target_dtype:
inner.to(dtype=target_dtype)
model_dtype = next(inner.parameters()).dtype # set at load time (see clip_vision NAF)
hrs = []
for i, c in enumerate(composites):
img_i = comfy.utils.common_upscale(c, image_size, image_size, "lanczos", "disabled")\
.to(torch_device).to(target_dtype)
lr_i = lr_feat[i:i + 1].to(torch_device).to(target_dtype)
hr_i = inner(img_i, lr_i, naf_target, output_device=device)
.to(compute_device).to(model_dtype)
lr_i = lr_feat[i:i + 1].to(compute_device).to(model_dtype)
hr_i = inner(img_i, lr_i, naf_target, output_device=out_device)
hrs.append(hr_i)
return torch.cat(hrs, dim=0)
@ -934,10 +847,10 @@ class Pixal3DConditioning(IO.ComfyNode):
# FOV widget is in degrees for UX; trig + downstream projection expect radians.
camera_angle_x = math.radians(float(camera_angle_x))
distance = 0.5 / math.tan(camera_angle_x / 2.0)
cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=device, dtype=torch.float32)
dist_t = torch.tensor([distance] * batch_size, device=device, dtype=torch.float32)
scale_t = torch.ones(batch_size, device=device, dtype=torch.float32)
T = build_proj_transform_matrix(dist_t, batch_size, device=device, dtype=torch.float32)
cam_angle_t = torch.tensor([camera_angle_x] * batch_size, device=out_device, dtype=torch.float32)
dist_t = torch.tensor([distance] * batch_size, device=out_device, dtype=torch.float32)
scale_t = torch.ones(batch_size, device=out_device, dtype=torch.float32)
T = build_proj_transform_matrix(dist_t, batch_size, device=out_device, dtype=torch.float32)
proj_pack = {
"stages": {
@ -958,7 +871,7 @@ class Pixal3DConditioning(IO.ComfyNode):
# global_512 → SS/shape_512 cross-attn; global_1024 → shape_1024/tex_1024.
ss_proj_feats = compute_stage_proj_feats(
proj_pack, "ss", dense_grid_resolution=16, batch_size=batch_size,
device=torch_device,
device=compute_device,
)
neg_global = torch.zeros_like(global_512)
neg_embeds = torch.zeros_like(global_1024)