diff --git a/comfy_extras/nodes_sam3.py b/comfy_extras/nodes_sam3.py index 56cbafbc6..5cf92ccb3 100644 --- a/comfy_extras/nodes_sam3.py +++ b/comfy_extras/nodes_sam3.py @@ -61,7 +61,7 @@ def _refine_mask(sam3_model, orig_image_hwc, coarse_mask, box_xyxy, H, W, device if cx2 <= cx1 or cy2 <= cy1: return _coarse_fallback() - crop = orig_image_hwc[cy1:cy2, cx1:cx2] + crop = orig_image_hwc[cy1:cy2, cx1:cx2, :3] crop_1008 = comfy.utils.common_upscale(crop.unsqueeze(0).movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled") crop_frame = crop_1008.to(device=device, dtype=dtype) crop_h, crop_w = cy2 - cy1, cx2 - cx1 @@ -115,8 +115,7 @@ class SAM3_Detect(io.ComfyNode): @classmethod def execute(cls, model, image, conditioning=None, bboxes=None, positive_coords=None, negative_coords=None, threshold=0.5, refine_iterations=2, individual_masks=False) -> io.NodeOutput: B, H, W, C = image.shape - - image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled") + image_in = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled") # Convert bboxes to normalized cxcywh format, per-frame list of [1, N, 4] tensors. # Supports: single dict (all frames), list[dict] (all frames), list[list[dict]] (per-frame). @@ -291,7 +290,7 @@ class SAM3_VideoTrack(io.ComfyNode): dtype = model.model.get_dtype() sam3_model = model.model.diffusion_model - frames = images.movedim(-1, 1) + frames = images[..., :3].movedim(-1, 1) frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype) init_masks = None