From 2d514f5f0c1578e7b4acf4793070450ba2d90647 Mon Sep 17 00:00:00 2001 From: kijai <40791699+kijai@users.noreply.github.com> Date: Sat, 6 Jun 2026 15:59:30 +0300 Subject: [PATCH] Use common dinov3, cleanup --- comfy/image_encoders/dino3.py | 15 +- comfy/ldm/sam3d_body/model/dinov3.py | 250 ------------------- comfy/ldm/sam3d_body/model/model.py | 6 +- comfy_extras/nodes_sam3d_body.py | 7 +- comfy_extras/sam3d_body/export/glb_shared.py | 62 ++--- comfy_extras/sam3d_body/rasterizer.py | 2 +- nodes.py | 2 +- 7 files changed, 47 insertions(+), 297 deletions(-) delete mode 100644 comfy/ldm/sam3d_body/model/dinov3.py diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index ad29b06f8..09eb9beab 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -159,7 +159,7 @@ class DINOv3ViTEmbeddings(nn.Module): def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations): super().__init__() self.cls_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) - self.mask_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size, device=device, dtype=dtype)) self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype)) self.patch_embeddings = operations.Conv2d( num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype @@ -240,6 +240,10 @@ class DINOv3ViTModel(nn.Module): for _ in range(num_hidden_layers)]) self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device) + self.patch_size = patch_size + self.embed_dim = self.embed_dims = hidden_size + self.num_prefix_tokens = 1 + num_register_tokens # cls + register + def get_input_embeddings(self): return self.embeddings.patch_embeddings @@ -257,3 +261,12 @@ class DINOv3ViTModel(nn.Module): sequence_output = norm(hidden_states) pooled_output = sequence_output[:, 0, :] return sequence_output, None, pooled_output, None + + def forward_features(self, pixel_values, **kwargs): + """Dense (B, C, H, W) patch-feature grid, CLS + register tokens dropped.""" + sequence_output = self.forward(pixel_values, **kwargs)[0] + b = pixel_values.shape[0] + h = pixel_values.shape[-2] // self.patch_size + w = pixel_values.shape[-1] // self.patch_size + patches = sequence_output[:, self.num_prefix_tokens:, :] + return patches.reshape(b, h, w, self.embed_dim).permute(0, 3, 1, 2).contiguous() diff --git a/comfy/ldm/sam3d_body/model/dinov3.py b/comfy/ldm/sam3d_body/model/dinov3.py deleted file mode 100644 index 97637b2f9..000000000 --- a/comfy/ldm/sam3d_body/model/dinov3.py +++ /dev/null @@ -1,250 +0,0 @@ -# DINOv3 ViT-H+ backbone for SAM 3D Body. -# -# Single-file consolidation of the inference path. SAM 3D Body only ships a -# `dinov3_vith16plus` checkpoint, so the architecture is hardcoded rather -# than reconstructed from Hydra-flavoured configs. -# -# Adapted from facebookresearch/dinov3 (DINOv3 License Agreement). Trimmed -# to what's actually exercised at inference: no multi-crop training path, -# no DINOHead, no causal blocks, no rmsnorm/Mlp variants, no rope shift / -# jitter / rescale (training-time augmentations). - -#TODO: Unify with TRELLIS2 - -import math -from typing import Optional, Tuple - -import torch -import torch.nn.functional as F -from comfy.ldm.modules.attention import optimized_attention -from torch import Tensor, nn - -# DINOv3 ViT-H+ architecture constants. -EMBED_DIM = 1280 -DEPTH = 32 -NUM_HEADS = 20 -FFN_RATIO = 6.0 -PATCH_SIZE = 16 -LAYERSCALE_INIT = 1.0e-5 -N_STORAGE_TOKENS = 4 -LAYERNORM_EPS = 1e-5 # "layernormbf16" preset uses 1e-5 -ROPE_BASE = 100.0 - -# RoPE (axial sin/cos, no learnable weights) - -def _rotate_half(x: Tensor) -> Tensor: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat([-x2, x1], dim=-1) - - -def _apply_rope(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor: - return x * cos + _rotate_half(x) * sin - - -class RopePositionEmbedding(nn.Module): - """Axial RoPE for 2D patch grids; periods buffer is deterministic.""" - - def __init__(self, embed_dim: int, num_heads: int, dtype=torch.float32, device=None): - super().__init__() - assert embed_dim % (4 * num_heads) == 0 - D_head = embed_dim // num_heads - # Periods are persistent so they round-trip through state_dict, but the - # values are deterministic from D_head/base; load_state_dict will - # overwrite this with the saved buffer either way. - periods = ROPE_BASE ** ( - 2 * torch.arange(D_head // 4, dtype=dtype, device=device) / (D_head // 2) - ) - self.register_buffer("periods", periods, persistent=True) - self._dtype = dtype - - def forward(self, H: int, W: int) -> Tuple[Tensor, Tensor]: - device, dtype = self.periods.device, self._dtype - coords_h = torch.arange(0.5, H, device=device, dtype=dtype) / H - coords_w = torch.arange(0.5, W, device=device, dtype=dtype) / W - coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) - coords = 2.0 * coords.flatten(0, 1) - 1.0 # [HW, 2] in [-1, +1] - angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] - angles = angles.flatten(1, 2).tile(2) # [HW, D_head] - return torch.sin(angles), torch.cos(angles) - - -def _apply_rope_to_qk(q: Tensor, k: Tensor, rope: Tuple[Tensor, Tensor]): - """Apply RoPE only to the patch-token slice (skip CLS + storage tokens).""" - sin, cos = rope - rope_dtype = sin.dtype - q_dtype, k_dtype = q.dtype, k.dtype - q = q.to(rope_dtype) - k = k.to(rope_dtype) - prefix = q.shape[-2] - sin.shape[-2] - q_pre, q_rope = q[..., :prefix, :], q[..., prefix:, :] - k_pre, k_rope = k[..., :prefix, :], k[..., prefix:, :] - q = torch.cat([q_pre, _apply_rope(q_rope, sin, cos)], dim=-2) - k = torch.cat([k_pre, _apply_rope(k_rope, sin, cos)], dim=-2) - return q.to(q_dtype), k.to(k_dtype) - -# Layers - -class LayerScale(nn.Module): - def __init__(self, dim: int, init_values: float, device=None, dtype=None): - super().__init__() - self.gamma = nn.Parameter( - torch.full((dim,), init_values, device=device, dtype=dtype) - ) - - def forward(self, x: Tensor) -> Tensor: - return x * self.gamma - - -class SwiGLUFFN(nn.Module): - """w3(silu(w1(x)) * w2(x)).""" - - def __init__(self, in_features: int, hidden_features: int, align_to: int = 8, - device=None, dtype=None, operations=None): - super().__init__() - ops = operations if operations is not None else nn - d = int(hidden_features * 2 / 3) - h = d + (-d % align_to) - self.w1 = ops.Linear(in_features, h, bias=True, device=device, dtype=dtype) - self.w2 = ops.Linear(in_features, h, bias=True, device=device, dtype=dtype) - self.w3 = ops.Linear(h, in_features, bias=True, device=device, dtype=dtype) - - def forward(self, x: Tensor) -> Tensor: - return self.w3(F.silu(self.w1(x)) * self.w2(x)) - - -class SelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int, device=None, dtype=None, operations=None): - super().__init__() - ops = operations if operations is not None else nn - self.num_heads = num_heads - # DINOv3's `mask_k_bias` zeroes the K third of qkv.bias. The mask is - # deterministic from out_features, so the loader applies it in-place - # once after load_state_dict (see `apply_dinov3_qkv_bias_mask`) and the - # forward stays a plain F.linear. - self.qkv = ops.Linear(dim, dim * 3, bias=True, device=device, dtype=dtype) - self.proj = ops.Linear(dim, dim, bias=True, device=device, dtype=dtype) - - def forward(self, x: Tensor, rope: Optional[Tuple[Tensor, Tensor]] = None) -> Tensor: - B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) - q, k, v = qkv.unbind(2) - q, k, v = (t.transpose(1, 2) for t in (q, k, v)) - if rope is not None: - q, k = _apply_rope_to_qk(q, k, rope) - # low_precision_attention=False forces attention_sage (when enabled - # globally in comfy) to fall back to pytorch SDPA. SAM 3D Body's - # regression heads (camera projection, MHR rig math) are sensitive - # to attention output precision; sage's int8/fp8 path drifts the - # keypoints and mesh visibly. - x = optimized_attention( - q, k, v, self.num_heads, skip_reshape=True, - low_precision_attention=False, - ) - return self.proj(x) - - -class Block(nn.Module): - def __init__(self, dim: int, num_heads: int, ffn_ratio: float, - device=None, dtype=None, operations=None): - super().__init__() - ops = operations if operations is not None else nn - self.norm1 = ops.LayerNorm(dim, eps=LAYERNORM_EPS, device=device, dtype=dtype) - self.attn = SelfAttention(dim, num_heads, device=device, dtype=dtype, operations=operations) - self.ls1 = LayerScale(dim, LAYERSCALE_INIT, device=device, dtype=dtype) - self.norm2 = ops.LayerNorm(dim, eps=LAYERNORM_EPS, device=device, dtype=dtype) - self.mlp = SwiGLUFFN(dim, int(dim * ffn_ratio), device=device, dtype=dtype, operations=operations) - self.ls2 = LayerScale(dim, LAYERSCALE_INIT, device=device, dtype=dtype) - - def forward(self, x: Tensor, rope=None) -> Tensor: - x = x + self.ls1(self.attn(self.norm1(x), rope=rope)) - x = x + self.ls2(self.mlp(self.norm2(x))) - return x - - -class PatchEmbed(nn.Module): - def __init__(self, in_chans=3, embed_dim=EMBED_DIM, patch_size=PATCH_SIZE, - device=None, dtype=None, operations=None): - super().__init__() - ops = operations if operations is not None else nn - self.proj = ops.Conv2d( - in_chans, embed_dim, - kernel_size=patch_size, stride=patch_size, - device=device, dtype=dtype, - ) - -# Encoder + wrapper - -class _DinoEncoder(nn.Module): - """Inner ViT module. Held under `Dinov3Backbone.encoder` so state_dict - keys (`backbone.encoder.*`) match the upstream layout.""" - - def __init__(self, device=None, dtype=None, operations=None): - super().__init__() - ops = operations if operations is not None else nn - self.patch_size = PATCH_SIZE - self.embed_dim = EMBED_DIM - - self.patch_embed = PatchEmbed( - embed_dim=EMBED_DIM, patch_size=PATCH_SIZE, - device=device, dtype=dtype, operations=operations, - ) - self.cls_token = nn.Parameter(torch.empty(1, 1, EMBED_DIM, device=device, dtype=dtype)) - self.storage_tokens = nn.Parameter( - torch.empty(1, N_STORAGE_TOKENS, EMBED_DIM, device=device, dtype=dtype) - ) - # The released config sets pos_embed_rope_dtype="fp32"; periods stays - # in fp32 regardless of the backbone weight dtype. - self.rope_embed = RopePositionEmbedding(EMBED_DIM, NUM_HEADS, dtype=torch.float32, device=device) - - self.blocks = nn.ModuleList([ - Block(EMBED_DIM, NUM_HEADS, FFN_RATIO, device=device, dtype=dtype, operations=operations) - for _ in range(DEPTH) - ]) - self.norm = ops.LayerNorm(EMBED_DIM, eps=LAYERNORM_EPS, device=device, dtype=dtype) - - def forward(self, x: Tensor) -> Tensor: - x = self.patch_embed.proj(x) # (B, embed_dim, H, W) - B, _, H, W = x.shape - x = x.flatten(2).transpose(1, 2) # (B, H*W, embed_dim) - - # Prepend CLS + storage tokens. - x = torch.cat([ - self.cls_token.expand(B, -1, -1), - self.storage_tokens.expand(B, -1, -1), - x, - ], dim=1) - - rope = self.rope_embed(H=H, W=W) - for blk in self.blocks: - x = blk(x, rope) - x = self.norm(x) - - # Drop CLS + storage tokens; reshape patch grid to (B, C, H, W). - x = x[:, 1 + N_STORAGE_TOKENS :] - return x.reshape(B, H, W, EMBED_DIM).permute(0, 3, 1, 2).contiguous() - - -class Dinov3Backbone(nn.Module): - """Public backbone interface used by SAM3DBody.""" - - def __init__(self, device=None, dtype=None, operations=None): - super().__init__() - self.encoder = _DinoEncoder(device=device, dtype=dtype, operations=operations) - self.patch_size = PATCH_SIZE - self.embed_dim = self.embed_dims = EMBED_DIM - - def forward(self, x: Tensor) -> Tensor: - return self.encoder(x) - - -def apply_dinov3_qkv_bias_mask(backbone: "Dinov3Backbone") -> None: - """Zero the K third of every block's qkv.bias in-place. - - Implements DINOv3's `mask_k_bias` once at load time so the per-block forward - stays a plain F.linear instead of cloning + slicing the bias every call. - """ - for blk in backbone.encoder.blocks: - qkv = blk.attn.qkv - if qkv.bias is not None: - o = qkv.out_features - qkv.bias.data[o // 3 : 2 * o // 3] = 0 diff --git a/comfy/ldm/sam3d_body/model/model.py b/comfy/ldm/sam3d_body/model/model.py index 7e2fbef4f..9d1a53ec8 100644 --- a/comfy/ldm/sam3d_body/model/model.py +++ b/comfy/ldm/sam3d_body/model/model.py @@ -7,7 +7,7 @@ import torch.nn.functional as F import comfy.model_management from comfy.ldm.sam3.sam import PositionEmbeddingRandom -from .dinov3 import Dinov3Backbone +from comfy.image_encoders.dino3 import DINOV3_VITH_CONFIG, DINOv3ViTModel from .prompt import PromptEncoder, PromptableDecoder from ..mhr.mhr_head import MHRHead from ..mhr.mhr_rig import MHRRig @@ -50,7 +50,7 @@ class SAM3DBody(nn.Module): self.image_size = IMAGE_SIZE - self.backbone = Dinov3Backbone(device=device, dtype=dtype, operations=operations) + self.backbone = DINOv3ViTModel(DINOV3_VITH_CONFIG, dtype=dtype, device=device, operations=ops) embed_dims = self.backbone.embed_dims # MHR rig shared between body + hand pose heads via a non-registered @@ -612,7 +612,7 @@ class SAM3DBody(nn.Module): batch["ray_cond_hand"] = ray_cond[self.hand_batch_idx].clone() ray_cond = None - image_embeddings = self.backbone(x.type(self.backbone_dtype)) + image_embeddings = self.backbone.forward_features(x.type(self.backbone_dtype)) # bf16 mantissa too lossy for the heads — promote back. fp16 survives. if self.backbone_dtype != torch.float16: image_embeddings = image_embeddings.type(x.dtype) diff --git a/comfy_extras/nodes_sam3d_body.py b/comfy_extras/nodes_sam3d_body.py index c779a2f6e..fd966f571 100644 --- a/comfy_extras/nodes_sam3d_body.py +++ b/comfy_extras/nodes_sam3d_body.py @@ -15,7 +15,6 @@ from typing_extensions import override import folder_paths from comfy.ldm.sam3d_body.model.model import SAM3DBody -from comfy.ldm.sam3d_body.model.dinov3 import apply_dinov3_qkv_bias_mask from comfy_extras.sam3d_body.utils import ( apply_camera_override, cam_int_from_fov, @@ -79,8 +78,6 @@ class SAM3DBody_Loader(io.ComfyNode): model = SAM3DBody(dtype=torch_dtype, operations=operations) model.load_state_dict(sd, strict=False) - apply_dinov3_qkv_bias_mask(model.backbone) - model.eval() model.backbone_dtype = torch_dtype model._sam3d_image_size = model.image_size @@ -308,8 +305,8 @@ class SAM3DBody_FaceExpression(io.ComfyNode): @classmethod def execute(cls, mhr_pose_data, sam3d_body_model, image, - strength=1.0, mouth_strength=1.0, eye_strength=1.0, brow_strength=1.0, - input_threshold=0.15, blendshape_smooth_window=7) -> io.NodeOutput: + strength=1.0, mouth_strength=1.0, eye_strength=2.0, brow_strength=2.0, + input_threshold=0.02, blendshape_smooth_window=7) -> io.NodeOutput: comfy.model_management.load_model_gpu(sam3d_body_model) inner: SAM3DBody = sam3d_body_model.model diff --git a/comfy_extras/sam3d_body/export/glb_shared.py b/comfy_extras/sam3d_body/export/glb_shared.py index ba7268985..0935fe7f2 100644 --- a/comfy_extras/sam3d_body/export/glb_shared.py +++ b/comfy_extras/sam3d_body/export/glb_shared.py @@ -94,32 +94,36 @@ def _skel_state_compose_np(s1: np.ndarray, s2: np.ndarray) -> np.ndarray: return np.concatenate([t_res, q_res, s_res], axis=-1) +def _gaussian_smooth_time(arr: np.ndarray, window: int) -> np.ndarray: + """Edge-replicate Gaussian smoothing along axis 0 (time); sigma = window/4. + Endpoints replicate so they aren't pulled toward zero. Returns float64.""" + a = np.asarray(arr, dtype=np.float64) + n = a.shape[0] + half = window // 2 + sigma = max(0.5, window / 4.0) + x = np.arange(-half, half + 1, dtype=np.float64) + kernel = np.exp(-x * x / (2.0 * sigma * sigma)) + kernel = kernel / kernel.sum() + padded = np.concatenate([ + np.broadcast_to(a[:1], (half,) + a.shape[1:]), + a, + np.broadcast_to(a[-1:], (half,) + a.shape[1:]), + ], axis=0) + out = np.zeros_like(a) + for k, w in enumerate(kernel): + out += w * padded[k:k + n] + return out + + def gaussian_smooth_quats(q_seq: np.ndarray, window: int) -> np.ndarray: """Gaussian-smooth a (N, NJ, 4) quaternion sequence along time. Sign-aligns per joint first, convolves per-component, renormalizes. Suppresses multi- frame bone spikes at extreme poses without needing the upstream Smooth node.""" if window <= 1 or q_seq.shape[0] < 2: return q_seq - aligned = quat_sign_fix_per_joint(q_seq).astype(np.float64) - n = q_seq.shape[0] - half = window // 2 - sigma = max(0.5, window / 4.0) - x = np.arange(-half, half + 1, dtype=np.float64) - kernel = np.exp(-x * x / (2.0 * sigma * sigma)) - kernel = kernel / kernel.sum() - # Edge-replicate padding so endpoints don't get pulled toward zero. - pad = half - padded = np.concatenate([ - np.broadcast_to(aligned[:1], (pad,) + aligned.shape[1:]), - aligned, - np.broadcast_to(aligned[-1:], (pad,) + aligned.shape[1:]), - ], axis=0) - out = np.zeros_like(aligned) - for k, w in enumerate(kernel): - out += w * padded[k:k + n] + out = _gaussian_smooth_time(quat_sign_fix_per_joint(q_seq), window) norms = np.linalg.norm(out, axis=-1, keepdims=True) - out = out / np.maximum(norms, 1e-12) - return out.astype(np.float32) + return (out / np.maximum(norms, 1e-12)).astype(np.float32) def gaussian_smooth_positions(seq: np.ndarray, window: int) -> np.ndarray: @@ -128,22 +132,7 @@ def gaussian_smooth_positions(seq: np.ndarray, window: int) -> np.ndarray: derives sphere translations + limb TRS from them.""" if window <= 1 or seq.shape[0] < 2: return seq - s = np.asarray(seq, dtype=np.float64) - n = s.shape[0] - half = window // 2 - sigma = max(0.5, window / 4.0) - x = np.arange(-half, half + 1, dtype=np.float64) - kernel = np.exp(-x * x / (2.0 * sigma * sigma)) - kernel = kernel / kernel.sum() - padded = np.concatenate([ - np.broadcast_to(s[:1], (half,) + s.shape[1:]), - s, - np.broadcast_to(s[-1:], (half,) + s.shape[1:]), - ], axis=0) - out = np.zeros_like(s) - for k, wgt in enumerate(kernel): - out += wgt * padded[k:k + n] - return out.astype(np.float32) + return _gaussian_smooth_time(seq, window).astype(np.float32) def quat_sign_fix_per_joint(q_seq: np.ndarray) -> np.ndarray: @@ -261,7 +250,8 @@ class GLBWriter: self.accessors.append({ "bufferView": view_idx, "componentType": _FLOAT, "count": a.shape[0], "type": "VEC3", - "min": a.min(axis=0).tolist(), "max": a.max(axis=0).tolist(), + "min": a.min(axis=0).tolist() if a.shape[0] else [0.0, 0.0, 0.0], + "max": a.max(axis=0).tolist() if a.shape[0] else [0.0, 0.0, 0.0], }) return len(self.accessors) - 1 diff --git a/comfy_extras/sam3d_body/rasterizer.py b/comfy_extras/sam3d_body/rasterizer.py index 37d32f952..2b2b2259b 100644 --- a/comfy_extras/sam3d_body/rasterizer.py +++ b/comfy_extras/sam3d_body/rasterizer.py @@ -35,7 +35,7 @@ def rainbow_colors_from_canonical( Returns: (N_v, 3) float32 RGB in [0, 1]. """ - key = (id(positions), round(float(tilt_x_deg), 3), round(float(tilt_z_deg), 3)) + key = (hash(positions.tobytes()), round(float(tilt_x_deg), 3), round(float(tilt_z_deg), 3)) cached = _rainbow_cache.get(key) if cached is not None: return cached diff --git a/nodes.py b/nodes.py index 6da982c59..15e5d1b08 100644 --- a/nodes.py +++ b/nodes.py @@ -2457,7 +2457,7 @@ async def init_builtin_extra_nodes(): "nodes_moge.py", "nodes_mediapipe.py", "nodes_gaussian_splat.py", - "nodes_triposplat.py" + "nodes_triposplat.py", "nodes_sam3d_body.py", ]