Fix per frame bbox detection

This commit is contained in:
kijai 2026-04-15 01:06:23 +03:00
parent 0c7b5addba
commit 322b3467d8

View File

@ -116,26 +116,32 @@ class SAM3_Detect(io.ComfyNode):
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled")
# Convert bboxes to normalized cxcywh format [1, N, 4]
# BoundingBox type can be: single dict, list of dicts, or list of lists of dicts (per-frame)
boxes_tensor = None
# Convert bboxes to normalized cxcywh format, per-frame list of [1, N, 4] tensors.
# Supports: single dict (all frames), list[dict] (all frames), list[list[dict]] (per-frame).
def _boxes_to_tensor(box_list):
coords = []
for d in box_list:
cx = (d["x"] + d["width"] / 2) / W
cy = (d["y"] + d["height"] / 2) / H
coords.append([cx, cy, d["width"] / W, d["height"] / H])
return torch.tensor([coords], dtype=torch.float32) # [1, N, 4]
per_frame_boxes = None
if bboxes is not None:
# Flatten to list of dicts
if isinstance(bboxes, dict):
flat_boxes = [bboxes]
# Single box → same for all frames
shared = _boxes_to_tensor([bboxes])
per_frame_boxes = [shared] * B
elif isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list):
flat_boxes = [d for frame in bboxes for d in frame] # per-frame list of lists
elif isinstance(bboxes, list):
flat_boxes = bboxes
else:
flat_boxes = []
if flat_boxes:
coords = []
for d in flat_boxes:
cx = (d["x"] + d["width"] / 2) / W
cy = (d["y"] + d["height"] / 2) / H
coords.append([cx, cy, d["width"] / W, d["height"] / H])
boxes_tensor = torch.tensor([coords], dtype=torch.float32) # [1, N, 4]
# list[list[dict]] → per-frame boxes
per_frame_boxes = [_boxes_to_tensor(frame_boxes) if frame_boxes else None for frame_boxes in bboxes]
# Pad to B if fewer frames provided
while len(per_frame_boxes) < B:
per_frame_boxes.append(per_frame_boxes[-1] if per_frame_boxes else None)
elif isinstance(bboxes, list) and len(bboxes) > 0:
# list[dict] → same boxes for all frames
shared = _boxes_to_tensor(bboxes)
per_frame_boxes = [shared] * B
# Parse point prompts from JSON (KJNodes PointsEditor format: [{"x": int, "y": int}, ...])
pos_pts = json.loads(positive_coords) if positive_coords else []
@ -165,10 +171,12 @@ class SAM3_Detect(io.ComfyNode):
all_bbox_dicts = []
all_masks = []
pbar = comfy.utils.ProgressBar(B)
b_boxes_tensor = boxes_tensor.to(device=device, dtype=dtype) if boxes_tensor is not None else None
for b in range(B):
frame = image_in[b:b+1].to(device=device, dtype=dtype)
b_boxes = None
if per_frame_boxes is not None and per_frame_boxes[b] is not None:
b_boxes = per_frame_boxes[b].to(device=device, dtype=dtype)
frame_bbox_dicts = []
frame_masks = []
@ -182,8 +190,8 @@ class SAM3_Detect(io.ComfyNode):
frame_masks.append((mask[0] > 0).float())
# Box prompts: SAM decoder path (segment inside each box)
if b_boxes_tensor is not None and not has_text:
for box_cxcywh in b_boxes_tensor[0]:
if b_boxes is not None and not has_text:
for box_cxcywh in b_boxes[0]:
cx, cy, bw, bh = box_cxcywh.tolist()
# Convert cxcywh normalized → xyxy in 1008 space → [1, 2, 2] corners
sam_box = torch.tensor([[[(cx - bw/2) * 1008, (cy - bh/2) * 1008],
@ -199,7 +207,7 @@ class SAM3_Detect(io.ComfyNode):
for text_embeddings, text_mask, max_det in cond_list:
results = sam3_model(
frame, text_embeddings=text_embeddings, text_mask=text_mask,
boxes=b_boxes_tensor, threshold=threshold, orig_size=(H, W))
boxes=b_boxes, threshold=threshold, orig_size=(H, W))
pred_boxes = results["boxes"][0]
scores = results["scores"][0]
@ -234,7 +242,10 @@ class SAM3_Detect(io.ComfyNode):
else:
all_masks.append((combined > 0).any(dim=0).float())
else:
all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device()))
if individual_masks:
all_masks.append(torch.zeros(0, H, W, device=comfy.model_management.intermediate_device()))
else:
all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device()))
pbar.update(1)
mask_out = torch.cat(all_masks, dim=0) if individual_masks else torch.stack(all_masks)