mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Cleanup
This commit is contained in:
parent
b6b12bd5fc
commit
e5d6bbc2de
@ -156,10 +156,12 @@ class DINOv3ViTRopePositionEmbedding(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DINOv3ViTEmbeddings(nn.Module):
|
class DINOv3ViTEmbeddings(nn.Module):
|
||||||
def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations):
|
def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations, use_mask_token=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cls_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
|
self.cls_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
|
||||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size, device=device, dtype=dtype))
|
# mask_token is a pre-training param, omit it when the checkpoint does not ship it so strict loading stays clean
|
||||||
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size, device=device, dtype=dtype)) if use_mask_token else None
|
||||||
|
|
||||||
self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype))
|
self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype))
|
||||||
self.patch_embeddings = operations.Conv2d(
|
self.patch_embeddings = operations.Conv2d(
|
||||||
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
|
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
|
||||||
@ -212,7 +214,7 @@ class DINOv3ViTLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class DINOv3ViTModel(nn.Module):
|
class DINOv3ViTModel(nn.Module):
|
||||||
def __init__(self, config, dtype, device, operations):
|
def __init__(self, config, dtype, device, operations, use_mask_token=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
num_hidden_layers = config["num_hidden_layers"]
|
num_hidden_layers = config["num_hidden_layers"]
|
||||||
hidden_size = config["hidden_size"]
|
hidden_size = config["hidden_size"]
|
||||||
@ -228,7 +230,7 @@ class DINOv3ViTModel(nn.Module):
|
|||||||
|
|
||||||
self.embeddings = DINOv3ViTEmbeddings(
|
self.embeddings = DINOv3ViTEmbeddings(
|
||||||
hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size,
|
hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size,
|
||||||
dtype=dtype, device=device, operations=operations
|
dtype=dtype, device=device, operations=operations, use_mask_token=use_mask_token
|
||||||
)
|
)
|
||||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(
|
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(
|
||||||
rope_theta, hidden_size, num_attention_heads, patch_size=patch_size, dtype=dtype, device=device
|
rope_theta, hidden_size, num_attention_heads, patch_size=patch_size, dtype=dtype, device=device
|
||||||
|
|||||||
@ -51,17 +51,17 @@ class MHRHead(nn.Module):
|
|||||||
self.joint_rotation = _p(127, 3, 3)
|
self.joint_rotation = _p(127, 3, 3)
|
||||||
self.scale_mean = _p(68)
|
self.scale_mean = _p(68)
|
||||||
self.scale_comps = _p(28, 68)
|
self.scale_comps = _p(28, 68)
|
||||||
self.faces = _p(36874, 3, dtype=torch.int64)
|
self.register_buffer("faces", torch.empty(36874, 3, dtype=torch.int64))
|
||||||
self.hand_pose_mean = _p(54)
|
self.hand_pose_mean = _p(54)
|
||||||
self.hand_pose_comps = nn.Parameter(torch.eye(54), requires_grad=False)
|
self.hand_pose_comps = nn.Parameter(torch.eye(54), requires_grad=False)
|
||||||
self.hand_joint_idxs_left = _p(27, dtype=torch.int64)
|
self.register_buffer("hand_joint_idxs_left", torch.empty(27, dtype=torch.int64))
|
||||||
self.hand_joint_idxs_right = _p(27, dtype=torch.int64)
|
self.register_buffer("hand_joint_idxs_right", torch.empty(27, dtype=torch.int64))
|
||||||
self.keypoint_mapping = _p(308, 18439 + 127)
|
self.keypoint_mapping = _p(308, 18439 + 127)
|
||||||
# Some special buffers for the hand-version
|
# Some special buffers for the hand-version
|
||||||
self.right_wrist_coords = _p(3)
|
self.right_wrist_coords = _p(3)
|
||||||
self.root_coords = _p(3)
|
self.root_coords = _p(3)
|
||||||
self.local_to_world_wrist = _p(3, 3)
|
self.local_to_world_wrist = _p(3, 3)
|
||||||
self.nonhand_param_idxs = _p(145, dtype=torch.int64)
|
self.register_buffer("nonhand_param_idxs", torch.empty(145, dtype=torch.int64))
|
||||||
# Hand-painted per-vertex face region RGB (rainbow_face_semantic shader).
|
# Hand-painted per-vertex face region RGB (rainbow_face_semantic shader).
|
||||||
self.register_buffer("face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32))
|
self.register_buffer("face_region_rgb", torch.zeros(18439, 3, dtype=torch.float32))
|
||||||
|
|
||||||
|
|||||||
@ -44,8 +44,10 @@ class SAM3DBody(nn.Module):
|
|||||||
self.register_buffer("image_std", torch.tensor(IMAGE_STD).view(-1, 1, 1), False)
|
self.register_buffer("image_std", torch.tensor(IMAGE_STD).view(-1, 1, 1), False)
|
||||||
|
|
||||||
self.image_size = IMAGE_SIZE
|
self.image_size = IMAGE_SIZE
|
||||||
|
# Populated by the loader once weights are in place.
|
||||||
self.backbone = DINOv3ViTModel(DINOV3_VITH_CONFIG, dtype=dtype, device=device, operations=operations)
|
self.canonical_colors = None
|
||||||
|
self.hand_vert_mask = None
|
||||||
|
self.backbone = DINOv3ViTModel(DINOV3_VITH_CONFIG, dtype=dtype, device=device, operations=operations, use_mask_token=False)
|
||||||
embed_dims = self.backbone.embed_dims
|
embed_dims = self.backbone.embed_dims
|
||||||
|
|
||||||
# MHR rig shared between body + hand pose heads via a non-registered
|
# MHR rig shared between body + hand pose heads via a non-registered
|
||||||
@ -61,8 +63,8 @@ class SAM3DBody(nn.Module):
|
|||||||
device=device, dtype=dtype, operations=operations,
|
device=device, dtype=dtype, operations=operations,
|
||||||
)
|
)
|
||||||
self.head_pose = MHRHead(**head_kwargs)
|
self.head_pose = MHRHead(**head_kwargs)
|
||||||
self.head_pose.hand_pose_comps_ori = nn.Parameter(
|
self.head_pose.register_buffer(
|
||||||
self.head_pose.hand_pose_comps.clone(), requires_grad=False
|
"hand_pose_comps_ori", self.head_pose.hand_pose_comps.clone(), persistent=False
|
||||||
)
|
)
|
||||||
self.head_pose.hand_pose_comps.data = (
|
self.head_pose.hand_pose_comps.data = (
|
||||||
torch.eye(54).to(self.head_pose.hand_pose_comps.data).float()
|
torch.eye(54).to(self.head_pose.hand_pose_comps.data).float()
|
||||||
@ -70,8 +72,8 @@ class SAM3DBody(nn.Module):
|
|||||||
self.init_pose = operations.Embedding(1, self.head_pose.npose, device=device, dtype=dtype)
|
self.init_pose = operations.Embedding(1, self.head_pose.npose, device=device, dtype=dtype)
|
||||||
|
|
||||||
self.head_pose_hand = MHRHead(enable_hand_model=True, **head_kwargs)
|
self.head_pose_hand = MHRHead(enable_hand_model=True, **head_kwargs)
|
||||||
self.head_pose_hand.hand_pose_comps_ori = nn.Parameter(
|
self.head_pose_hand.register_buffer(
|
||||||
self.head_pose_hand.hand_pose_comps.clone(), requires_grad=False
|
"hand_pose_comps_ori", self.head_pose_hand.hand_pose_comps.clone(), persistent=False
|
||||||
)
|
)
|
||||||
self.head_pose_hand.hand_pose_comps.data = (
|
self.head_pose_hand.hand_pose_comps.data = (
|
||||||
torch.eye(54).to(self.head_pose_hand.hand_pose_comps.data).float()
|
torch.eye(54).to(self.head_pose_hand.hand_pose_comps.data).float()
|
||||||
|
|||||||
@ -78,14 +78,16 @@ class SAM3DBody_Loader(io.ComfyNode):
|
|||||||
operations = comfy.ops.pick_operations(torch_dtype, manual_cast_dtype, load_device=load_device, disable_fast_fp8=True)
|
operations = comfy.ops.pick_operations(torch_dtype, manual_cast_dtype, load_device=load_device, disable_fast_fp8=True)
|
||||||
|
|
||||||
model = SAM3DBody(dtype=torch_dtype, operations=operations)
|
model = SAM3DBody(dtype=torch_dtype, operations=operations)
|
||||||
model.load_state_dict(sd, strict=False)
|
sd.pop("hand_cls_embed.weight", None)
|
||||||
|
sd.pop("hand_cls_embed.bias", None)
|
||||||
|
missing, unexpected = model.load_state_dict(sd, strict=False)
|
||||||
|
missing = set(missing) - {"head_pose_hand.face_region_rgb"}
|
||||||
|
if missing or unexpected:
|
||||||
|
raise RuntimeError(f"SAM3D-Body checkpoint key mismatch: missing={sorted(missing)}, unexpected={sorted(unexpected)}")
|
||||||
|
|
||||||
model.eval()
|
|
||||||
model.backbone_dtype = torch_dtype
|
model.backbone_dtype = torch_dtype
|
||||||
model._sam3d_image_size = model.image_size
|
model.canonical_colors = compute_canonical_colors(model)
|
||||||
|
model.hand_vert_mask = compute_hand_vert_mask(model)
|
||||||
model._sam3d_canonical_colors = compute_canonical_colors(model)
|
|
||||||
model._sam3d_hand_vert_mask = compute_hand_vert_mask(model)
|
|
||||||
|
|
||||||
patcher = comfy.model_patcher.CoreModelPatcher(
|
patcher = comfy.model_patcher.CoreModelPatcher(
|
||||||
model,
|
model,
|
||||||
@ -153,7 +155,7 @@ class SAM3DBody_Predict(io.ComfyNode):
|
|||||||
), advanced=True
|
), advanced=True
|
||||||
),
|
),
|
||||||
io.Int.Input(
|
io.Int.Input(
|
||||||
"batch_size", #TODO: automate?
|
"batch_size",
|
||||||
default=64, min=1, max=512, step=1, advanced=True,
|
default=64, min=1, max=512, step=1, advanced=True,
|
||||||
tooltip=(
|
tooltip=(
|
||||||
"Max frames to process as a batch. Larger values utilize more VRAM for faster inference."
|
"Max frames to process as a batch. Larger values utilize more VRAM for faster inference."
|
||||||
@ -169,7 +171,7 @@ class SAM3DBody_Predict(io.ComfyNode):
|
|||||||
inner: SAM3DBody = sam3d_body_model.model
|
inner: SAM3DBody = sam3d_body_model.model
|
||||||
|
|
||||||
B, H, W, _ = image.shape
|
B, H, W, _ = image.shape
|
||||||
image_size = getattr(inner, "_sam3d_image_size", (512, 512))
|
image_size = inner.image_size
|
||||||
|
|
||||||
# Precedence: SAM3 track (masks + boxes) > detector boxes > full-frame fallback.
|
# Precedence: SAM3 track (masks + boxes) > detector boxes > full-frame fallback.
|
||||||
per_frame_bboxes, per_frame_masks = (None, None)
|
per_frame_bboxes, per_frame_masks = (None, None)
|
||||||
@ -234,8 +236,8 @@ class SAM3DBody_Predict(io.ComfyNode):
|
|||||||
"frames": frames_out,
|
"frames": frames_out,
|
||||||
"faces": inner.head_pose.faces.cpu().numpy(),
|
"faces": inner.head_pose.faces.cpu().numpy(),
|
||||||
"image_size": (int(H), int(W)),
|
"image_size": (int(H), int(W)),
|
||||||
"canonical_colors": getattr(inner, "_sam3d_canonical_colors", None),
|
"canonical_colors": inner.canonical_colors,
|
||||||
"hand_vert_mask": getattr(inner, "_sam3d_hand_vert_mask", None),
|
"hand_vert_mask": inner.hand_vert_mask,
|
||||||
}
|
}
|
||||||
return io.NodeOutput(mhr_pose_data)
|
return io.NodeOutput(mhr_pose_data)
|
||||||
|
|
||||||
@ -253,7 +255,7 @@ class SAM3DBody_FaceExpression(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="SAM3DBody_FaceExpression",
|
node_id="SAM3DBody_FaceExpression",
|
||||||
description="Drive MHR face blendshapes from the core MediaPipe Face Landmarker.",
|
description="Drive MHR face blendshapes from the core MediaPipe Face Landmarker.",
|
||||||
display_name="Face Expression to SAM3D Body", #TODO: better name?
|
display_name="Face Expression to SAM3D Body",
|
||||||
category="image/detection",
|
category="image/detection",
|
||||||
inputs=[
|
inputs=[
|
||||||
SAM3DBodyModel.Input("sam3d_body_model"),
|
SAM3DBodyModel.Input("sam3d_body_model"),
|
||||||
|
|||||||
@ -168,7 +168,7 @@ def render_pose_data_openpose(
|
|||||||
)
|
)
|
||||||
if use_rig_only:
|
if use_rig_only:
|
||||||
face_vert_ids = face_vert_ids[_EYES_MOUTH_IDX]
|
face_vert_ids = face_vert_ids[_EYES_MOUTH_IDX]
|
||||||
except Exception as e:
|
except (ValueError, IndexError) as e:
|
||||||
logging.warning(f"[SAM3DBody] face landmarks disabled - {e}")
|
logging.warning(f"[SAM3DBody] face landmarks disabled - {e}")
|
||||||
face_vert_ids = None
|
face_vert_ids = None
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user