This commit is contained in:
kijai 2026-07-02 09:51:27 +03:00
parent b6b12bd5fc
commit e5d6bbc2de
5 changed files with 32 additions and 26 deletions

View File

@ -156,10 +156,12 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
class DINOv3ViTEmbeddings(nn.Module):
def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations):
def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations, use_mask_token=True):
super().__init__()
self.cls_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size, device=device, dtype=dtype))
# mask_token is a pre-training param, omit it when the checkpoint does not ship it so strict loading stays clean
self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size, device=device, dtype=dtype)) if use_mask_token else None
self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype))
self.patch_embeddings = operations.Conv2d(
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
@ -212,7 +214,7 @@ class DINOv3ViTLayer(nn.Module):
class DINOv3ViTModel(nn.Module):
def __init__(self, config, dtype, device, operations):
def __init__(self, config, dtype, device, operations, use_mask_token=True):
super().__init__()
num_hidden_layers = config["num_hidden_layers"]
hidden_size = config["hidden_size"]
@ -228,7 +230,7 @@ class DINOv3ViTModel(nn.Module):
self.embeddings = DINOv3ViTEmbeddings(
hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size,
dtype=dtype, device=device, operations=operations
dtype=dtype, device=device, operations=operations, use_mask_token=use_mask_token
)
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(
rope_theta, hidden_size, num_attention_heads, patch_size=patch_size, dtype=dtype, device=device

View File

@ -51,17 +51,17 @@ class MHRHead(nn.Module):
self.joint_rotation = _p(127, 3, 3)
self.scale_mean = _p(68)
self.scale_comps = _p(28, 68)
self.faces = _p(36874, 3, dtype=torch.int64)
self.register_buffer("faces", torch.empty(36874, 3, dtype=torch.int64))
self.hand_pose_mean = _p(54)
self.hand_pose_comps = nn.Parameter(torch.eye(54), requires_grad=False)
self.hand_joint_idxs_left = _p(27, dtype=torch.int64)
self.hand_joint_idxs_right = _p(27, dtype=torch.int64)
self.register_buffer("hand_joint_idxs_left", torch.empty(27, dtype=torch.int64))
self.register_buffer("hand_joint_idxs_right", torch.empty(27, dtype=torch.int64))
self.keypoint_mapping = _p(308, 18439 + 127)
# Some special buffers for the hand-version
self.right_wrist_coords = _p(3)
self.root_coords = _p(3)
self.local_to_world_wrist = _p(3, 3)
self.nonhand_param_idxs = _p(145, dtype=torch.int64)
self.register_buffer("nonhand_param_idxs", torch.empty(145, dtype=torch.int64))
# Hand-painted per-vertex face region RGB (rainbow_face_semantic shader).
self.register_buffer("face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32))

View File

@ -44,8 +44,10 @@ class SAM3DBody(nn.Module):
self.register_buffer("image_std", torch.tensor(IMAGE_STD).view(-1, 1, 1), False)
self.image_size = IMAGE_SIZE
self.backbone = DINOv3ViTModel(DINOV3_VITH_CONFIG, dtype=dtype, device=device, operations=operations)
# Populated by the loader once weights are in place.
self.canonical_colors = None
self.hand_vert_mask = None
self.backbone = DINOv3ViTModel(DINOV3_VITH_CONFIG, dtype=dtype, device=device, operations=operations, use_mask_token=False)
embed_dims = self.backbone.embed_dims
# MHR rig shared between body + hand pose heads via a non-registered
@ -61,8 +63,8 @@ class SAM3DBody(nn.Module):
device=device, dtype=dtype, operations=operations,
)
self.head_pose = MHRHead(**head_kwargs)
self.head_pose.hand_pose_comps_ori = nn.Parameter(
self.head_pose.hand_pose_comps.clone(), requires_grad=False
self.head_pose.register_buffer(
"hand_pose_comps_ori", self.head_pose.hand_pose_comps.clone(), persistent=False
)
self.head_pose.hand_pose_comps.data = (
torch.eye(54).to(self.head_pose.hand_pose_comps.data).float()
@ -70,8 +72,8 @@ class SAM3DBody(nn.Module):
self.init_pose = operations.Embedding(1, self.head_pose.npose, device=device, dtype=dtype)
self.head_pose_hand = MHRHead(enable_hand_model=True, **head_kwargs)
self.head_pose_hand.hand_pose_comps_ori = nn.Parameter(
self.head_pose_hand.hand_pose_comps.clone(), requires_grad=False
self.head_pose_hand.register_buffer(
"hand_pose_comps_ori", self.head_pose_hand.hand_pose_comps.clone(), persistent=False
)
self.head_pose_hand.hand_pose_comps.data = (
torch.eye(54).to(self.head_pose_hand.hand_pose_comps.data).float()

View File

@ -78,14 +78,16 @@ class SAM3DBody_Loader(io.ComfyNode):
operations = comfy.ops.pick_operations(torch_dtype, manual_cast_dtype, load_device=load_device, disable_fast_fp8=True)
model = SAM3DBody(dtype=torch_dtype, operations=operations)
model.load_state_dict(sd, strict=False)
sd.pop("hand_cls_embed.weight", None)
sd.pop("hand_cls_embed.bias", None)
missing, unexpected = model.load_state_dict(sd, strict=False)
missing = set(missing) - {"head_pose_hand.face_region_rgb"}
if missing or unexpected:
raise RuntimeError(f"SAM3D-Body checkpoint key mismatch: missing={sorted(missing)}, unexpected={sorted(unexpected)}")
model.eval()
model.backbone_dtype = torch_dtype
model._sam3d_image_size = model.image_size
model._sam3d_canonical_colors = compute_canonical_colors(model)
model._sam3d_hand_vert_mask = compute_hand_vert_mask(model)
model.canonical_colors = compute_canonical_colors(model)
model.hand_vert_mask = compute_hand_vert_mask(model)
patcher = comfy.model_patcher.CoreModelPatcher(
model,
@ -153,7 +155,7 @@ class SAM3DBody_Predict(io.ComfyNode):
), advanced=True
),
io.Int.Input(
"batch_size", #TODO: automate?
"batch_size",
default=64, min=1, max=512, step=1, advanced=True,
tooltip=(
"Max frames to process as a batch. Larger values utilize more VRAM for faster inference."
@ -169,7 +171,7 @@ class SAM3DBody_Predict(io.ComfyNode):
inner: SAM3DBody = sam3d_body_model.model
B, H, W, _ = image.shape
image_size = getattr(inner, "_sam3d_image_size", (512, 512))
image_size = inner.image_size
# Precedence: SAM3 track (masks + boxes) > detector boxes > full-frame fallback.
per_frame_bboxes, per_frame_masks = (None, None)
@ -234,8 +236,8 @@ class SAM3DBody_Predict(io.ComfyNode):
"frames": frames_out,
"faces": inner.head_pose.faces.cpu().numpy(),
"image_size": (int(H), int(W)),
"canonical_colors": getattr(inner, "_sam3d_canonical_colors", None),
"hand_vert_mask": getattr(inner, "_sam3d_hand_vert_mask", None),
"canonical_colors": inner.canonical_colors,
"hand_vert_mask": inner.hand_vert_mask,
}
return io.NodeOutput(mhr_pose_data)
@ -253,7 +255,7 @@ class SAM3DBody_FaceExpression(io.ComfyNode):
return io.Schema(
node_id="SAM3DBody_FaceExpression",
description="Drive MHR face blendshapes from the core MediaPipe Face Landmarker.",
display_name="Face Expression to SAM3D Body", #TODO: better name?
display_name="Face Expression to SAM3D Body",
category="image/detection",
inputs=[
SAM3DBodyModel.Input("sam3d_body_model"),

View File

@ -168,7 +168,7 @@ def render_pose_data_openpose(
)
if use_rig_only:
face_vert_ids = face_vert_ids[_EYES_MOUTH_IDX]
except Exception as e:
except (ValueError, IndexError) as e:
logging.warning(f"[SAM3DBody] face landmarks disabled - {e}")
face_vert_ids = None