diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py index 09eb9beab..ea04074cf 100644 --- a/comfy/image_encoders/dino3.py +++ b/comfy/image_encoders/dino3.py @@ -156,10 +156,12 @@ class DINOv3ViTRopePositionEmbedding(nn.Module): class DINOv3ViTEmbeddings(nn.Module): - def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations): + def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations, use_mask_token=True): super().__init__() self.cls_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) - self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size, device=device, dtype=dtype)) + # mask_token is a pre-training param, omit it when the checkpoint does not ship it so strict loading stays clean + self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size, device=device, dtype=dtype)) if use_mask_token else None + self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype)) self.patch_embeddings = operations.Conv2d( num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype @@ -212,7 +214,7 @@ class DINOv3ViTLayer(nn.Module): class DINOv3ViTModel(nn.Module): - def __init__(self, config, dtype, device, operations): + def __init__(self, config, dtype, device, operations, use_mask_token=True): super().__init__() num_hidden_layers = config["num_hidden_layers"] hidden_size = config["hidden_size"] @@ -228,7 +230,7 @@ class DINOv3ViTModel(nn.Module): self.embeddings = DINOv3ViTEmbeddings( hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size, - dtype=dtype, device=device, operations=operations + dtype=dtype, device=device, operations=operations, use_mask_token=use_mask_token ) self.rope_embeddings = DINOv3ViTRopePositionEmbedding( rope_theta, hidden_size, num_attention_heads, patch_size=patch_size, dtype=dtype, device=device diff --git a/comfy/ldm/sam3d_body/mhr/mhr_head.py b/comfy/ldm/sam3d_body/mhr/mhr_head.py index d226180a3..fbeac93c5 100644 --- a/comfy/ldm/sam3d_body/mhr/mhr_head.py +++ b/comfy/ldm/sam3d_body/mhr/mhr_head.py @@ -51,17 +51,17 @@ class MHRHead(nn.Module): self.joint_rotation = _p(127, 3, 3) self.scale_mean = _p(68) self.scale_comps = _p(28, 68) - self.faces = _p(36874, 3, dtype=torch.int64) + self.register_buffer("faces", torch.empty(36874, 3, dtype=torch.int64)) self.hand_pose_mean = _p(54) self.hand_pose_comps = nn.Parameter(torch.eye(54), requires_grad=False) - self.hand_joint_idxs_left = _p(27, dtype=torch.int64) - self.hand_joint_idxs_right = _p(27, dtype=torch.int64) + self.register_buffer("hand_joint_idxs_left", torch.empty(27, dtype=torch.int64)) + self.register_buffer("hand_joint_idxs_right", torch.empty(27, dtype=torch.int64)) self.keypoint_mapping = _p(308, 18439 + 127) # Some special buffers for the hand-version self.right_wrist_coords = _p(3) self.root_coords = _p(3) self.local_to_world_wrist = _p(3, 3) - self.nonhand_param_idxs = _p(145, dtype=torch.int64) + self.register_buffer("nonhand_param_idxs", torch.empty(145, dtype=torch.int64)) # Hand-painted per-vertex face region RGB (rainbow_face_semantic shader). self.register_buffer("face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32)) diff --git a/comfy/ldm/sam3d_body/model/model.py b/comfy/ldm/sam3d_body/model/model.py index 2063ce559..db5193366 100644 --- a/comfy/ldm/sam3d_body/model/model.py +++ b/comfy/ldm/sam3d_body/model/model.py @@ -44,8 +44,10 @@ class SAM3DBody(nn.Module): self.register_buffer("image_std", torch.tensor(IMAGE_STD).view(-1, 1, 1), False) self.image_size = IMAGE_SIZE - - self.backbone = DINOv3ViTModel(DINOV3_VITH_CONFIG, dtype=dtype, device=device, operations=operations) + # Populated by the loader once weights are in place. + self.canonical_colors = None + self.hand_vert_mask = None + self.backbone = DINOv3ViTModel(DINOV3_VITH_CONFIG, dtype=dtype, device=device, operations=operations, use_mask_token=False) embed_dims = self.backbone.embed_dims # MHR rig shared between body + hand pose heads via a non-registered @@ -61,8 +63,8 @@ class SAM3DBody(nn.Module): device=device, dtype=dtype, operations=operations, ) self.head_pose = MHRHead(**head_kwargs) - self.head_pose.hand_pose_comps_ori = nn.Parameter( - self.head_pose.hand_pose_comps.clone(), requires_grad=False + self.head_pose.register_buffer( + "hand_pose_comps_ori", self.head_pose.hand_pose_comps.clone(), persistent=False ) self.head_pose.hand_pose_comps.data = ( torch.eye(54).to(self.head_pose.hand_pose_comps.data).float() @@ -70,8 +72,8 @@ class SAM3DBody(nn.Module): self.init_pose = operations.Embedding(1, self.head_pose.npose, device=device, dtype=dtype) self.head_pose_hand = MHRHead(enable_hand_model=True, **head_kwargs) - self.head_pose_hand.hand_pose_comps_ori = nn.Parameter( - self.head_pose_hand.hand_pose_comps.clone(), requires_grad=False + self.head_pose_hand.register_buffer( + "hand_pose_comps_ori", self.head_pose_hand.hand_pose_comps.clone(), persistent=False ) self.head_pose_hand.hand_pose_comps.data = ( torch.eye(54).to(self.head_pose_hand.hand_pose_comps.data).float() diff --git a/comfy_extras/nodes_sam3d_body.py b/comfy_extras/nodes_sam3d_body.py index 865bbce94..a48d1a85d 100644 --- a/comfy_extras/nodes_sam3d_body.py +++ b/comfy_extras/nodes_sam3d_body.py @@ -78,14 +78,16 @@ class SAM3DBody_Loader(io.ComfyNode): operations = comfy.ops.pick_operations(torch_dtype, manual_cast_dtype, load_device=load_device, disable_fast_fp8=True) model = SAM3DBody(dtype=torch_dtype, operations=operations) - model.load_state_dict(sd, strict=False) + sd.pop("hand_cls_embed.weight", None) + sd.pop("hand_cls_embed.bias", None) + missing, unexpected = model.load_state_dict(sd, strict=False) + missing = set(missing) - {"head_pose_hand.face_region_rgb"} + if missing or unexpected: + raise RuntimeError(f"SAM3D-Body checkpoint key mismatch: missing={sorted(missing)}, unexpected={sorted(unexpected)}") - model.eval() model.backbone_dtype = torch_dtype - model._sam3d_image_size = model.image_size - - model._sam3d_canonical_colors = compute_canonical_colors(model) - model._sam3d_hand_vert_mask = compute_hand_vert_mask(model) + model.canonical_colors = compute_canonical_colors(model) + model.hand_vert_mask = compute_hand_vert_mask(model) patcher = comfy.model_patcher.CoreModelPatcher( model, @@ -153,7 +155,7 @@ class SAM3DBody_Predict(io.ComfyNode): ), advanced=True ), io.Int.Input( - "batch_size", #TODO: automate? + "batch_size", default=64, min=1, max=512, step=1, advanced=True, tooltip=( "Max frames to process as a batch. Larger values utilize more VRAM for faster inference." @@ -169,7 +171,7 @@ class SAM3DBody_Predict(io.ComfyNode): inner: SAM3DBody = sam3d_body_model.model B, H, W, _ = image.shape - image_size = getattr(inner, "_sam3d_image_size", (512, 512)) + image_size = inner.image_size # Precedence: SAM3 track (masks + boxes) > detector boxes > full-frame fallback. per_frame_bboxes, per_frame_masks = (None, None) @@ -234,8 +236,8 @@ class SAM3DBody_Predict(io.ComfyNode): "frames": frames_out, "faces": inner.head_pose.faces.cpu().numpy(), "image_size": (int(H), int(W)), - "canonical_colors": getattr(inner, "_sam3d_canonical_colors", None), - "hand_vert_mask": getattr(inner, "_sam3d_hand_vert_mask", None), + "canonical_colors": inner.canonical_colors, + "hand_vert_mask": inner.hand_vert_mask, } return io.NodeOutput(mhr_pose_data) @@ -253,7 +255,7 @@ class SAM3DBody_FaceExpression(io.ComfyNode): return io.Schema( node_id="SAM3DBody_FaceExpression", description="Drive MHR face blendshapes from the core MediaPipe Face Landmarker.", - display_name="Face Expression to SAM3D Body", #TODO: better name? + display_name="Face Expression to SAM3D Body", category="image/detection", inputs=[ SAM3DBodyModel.Input("sam3d_body_model"), diff --git a/comfy_extras/sam3d_body/export/openpose_2d.py b/comfy_extras/sam3d_body/export/openpose_2d.py index 8f5c4ab81..3bf3a62ac 100644 --- a/comfy_extras/sam3d_body/export/openpose_2d.py +++ b/comfy_extras/sam3d_body/export/openpose_2d.py @@ -168,7 +168,7 @@ def render_pose_data_openpose( ) if use_rig_only: face_vert_ids = face_vert_ids[_EYES_MOUTH_IDX] - except Exception as e: + except (ValueError, IndexError) as e: logging.warning(f"[SAM3DBody] face landmarks disabled - {e}") face_vert_ids = None