diff --git a/comfy_extras/nodes_sam3.py b/comfy_extras/nodes_sam3.py index 7eee2c66e..0ea5ac1f7 100644 --- a/comfy_extras/nodes_sam3.py +++ b/comfy_extras/nodes_sam3.py @@ -116,26 +116,32 @@ class SAM3_Detect(io.ComfyNode): image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled") - # Convert bboxes to normalized cxcywh format [1, N, 4] - # BoundingBox type can be: single dict, list of dicts, or list of lists of dicts (per-frame) - boxes_tensor = None + # Convert bboxes to normalized cxcywh format, per-frame list of [1, N, 4] tensors. + # Supports: single dict (all frames), list[dict] (all frames), list[list[dict]] (per-frame). + def _boxes_to_tensor(box_list): + coords = [] + for d in box_list: + cx = (d["x"] + d["width"] / 2) / W + cy = (d["y"] + d["height"] / 2) / H + coords.append([cx, cy, d["width"] / W, d["height"] / H]) + return torch.tensor([coords], dtype=torch.float32) # [1, N, 4] + + per_frame_boxes = None if bboxes is not None: - # Flatten to list of dicts if isinstance(bboxes, dict): - flat_boxes = [bboxes] + # Single box → same for all frames + shared = _boxes_to_tensor([bboxes]) + per_frame_boxes = [shared] * B elif isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list): - flat_boxes = [d for frame in bboxes for d in frame] # per-frame list of lists - elif isinstance(bboxes, list): - flat_boxes = bboxes - else: - flat_boxes = [] - if flat_boxes: - coords = [] - for d in flat_boxes: - cx = (d["x"] + d["width"] / 2) / W - cy = (d["y"] + d["height"] / 2) / H - coords.append([cx, cy, d["width"] / W, d["height"] / H]) - boxes_tensor = torch.tensor([coords], dtype=torch.float32) # [1, N, 4] + # list[list[dict]] → per-frame boxes + per_frame_boxes = [_boxes_to_tensor(frame_boxes) if frame_boxes else None for frame_boxes in bboxes] + # Pad to B if fewer frames provided + while len(per_frame_boxes) < B: + per_frame_boxes.append(per_frame_boxes[-1] if per_frame_boxes else None) + elif isinstance(bboxes, list) and len(bboxes) > 0: + # list[dict] → same boxes for all frames + shared = _boxes_to_tensor(bboxes) + per_frame_boxes = [shared] * B # Parse point prompts from JSON (KJNodes PointsEditor format: [{"x": int, "y": int}, ...]) pos_pts = json.loads(positive_coords) if positive_coords else [] @@ -165,10 +171,12 @@ class SAM3_Detect(io.ComfyNode): all_bbox_dicts = [] all_masks = [] pbar = comfy.utils.ProgressBar(B) - b_boxes_tensor = boxes_tensor.to(device=device, dtype=dtype) if boxes_tensor is not None else None for b in range(B): frame = image_in[b:b+1].to(device=device, dtype=dtype) + b_boxes = None + if per_frame_boxes is not None and per_frame_boxes[b] is not None: + b_boxes = per_frame_boxes[b].to(device=device, dtype=dtype) frame_bbox_dicts = [] frame_masks = [] @@ -182,8 +190,8 @@ class SAM3_Detect(io.ComfyNode): frame_masks.append((mask[0] > 0).float()) # Box prompts: SAM decoder path (segment inside each box) - if b_boxes_tensor is not None and not has_text: - for box_cxcywh in b_boxes_tensor[0]: + if b_boxes is not None and not has_text: + for box_cxcywh in b_boxes[0]: cx, cy, bw, bh = box_cxcywh.tolist() # Convert cxcywh normalized → xyxy in 1008 space → [1, 2, 2] corners sam_box = torch.tensor([[[(cx - bw/2) * 1008, (cy - bh/2) * 1008], @@ -199,7 +207,7 @@ class SAM3_Detect(io.ComfyNode): for text_embeddings, text_mask, max_det in cond_list: results = sam3_model( frame, text_embeddings=text_embeddings, text_mask=text_mask, - boxes=b_boxes_tensor, threshold=threshold, orig_size=(H, W)) + boxes=b_boxes, threshold=threshold, orig_size=(H, W)) pred_boxes = results["boxes"][0] scores = results["scores"][0] @@ -234,7 +242,10 @@ class SAM3_Detect(io.ComfyNode): else: all_masks.append((combined > 0).any(dim=0).float()) else: - all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device())) + if individual_masks: + all_masks.append(torch.zeros(0, H, W, device=comfy.model_management.intermediate_device())) + else: + all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device())) pbar.update(1) mask_out = torch.cat(all_masks, dim=0) if individual_masks else torch.stack(all_masks)