mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 23:42:36 +08:00
Fix per frame bbox detection
This commit is contained in:
parent
0c7b5addba
commit
322b3467d8
@ -116,26 +116,32 @@ class SAM3_Detect(io.ComfyNode):
|
||||
|
||||
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled")
|
||||
|
||||
# Convert bboxes to normalized cxcywh format [1, N, 4]
|
||||
# BoundingBox type can be: single dict, list of dicts, or list of lists of dicts (per-frame)
|
||||
boxes_tensor = None
|
||||
# Convert bboxes to normalized cxcywh format, per-frame list of [1, N, 4] tensors.
|
||||
# Supports: single dict (all frames), list[dict] (all frames), list[list[dict]] (per-frame).
|
||||
def _boxes_to_tensor(box_list):
|
||||
coords = []
|
||||
for d in box_list:
|
||||
cx = (d["x"] + d["width"] / 2) / W
|
||||
cy = (d["y"] + d["height"] / 2) / H
|
||||
coords.append([cx, cy, d["width"] / W, d["height"] / H])
|
||||
return torch.tensor([coords], dtype=torch.float32) # [1, N, 4]
|
||||
|
||||
per_frame_boxes = None
|
||||
if bboxes is not None:
|
||||
# Flatten to list of dicts
|
||||
if isinstance(bboxes, dict):
|
||||
flat_boxes = [bboxes]
|
||||
# Single box → same for all frames
|
||||
shared = _boxes_to_tensor([bboxes])
|
||||
per_frame_boxes = [shared] * B
|
||||
elif isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list):
|
||||
flat_boxes = [d for frame in bboxes for d in frame] # per-frame list of lists
|
||||
elif isinstance(bboxes, list):
|
||||
flat_boxes = bboxes
|
||||
else:
|
||||
flat_boxes = []
|
||||
if flat_boxes:
|
||||
coords = []
|
||||
for d in flat_boxes:
|
||||
cx = (d["x"] + d["width"] / 2) / W
|
||||
cy = (d["y"] + d["height"] / 2) / H
|
||||
coords.append([cx, cy, d["width"] / W, d["height"] / H])
|
||||
boxes_tensor = torch.tensor([coords], dtype=torch.float32) # [1, N, 4]
|
||||
# list[list[dict]] → per-frame boxes
|
||||
per_frame_boxes = [_boxes_to_tensor(frame_boxes) if frame_boxes else None for frame_boxes in bboxes]
|
||||
# Pad to B if fewer frames provided
|
||||
while len(per_frame_boxes) < B:
|
||||
per_frame_boxes.append(per_frame_boxes[-1] if per_frame_boxes else None)
|
||||
elif isinstance(bboxes, list) and len(bboxes) > 0:
|
||||
# list[dict] → same boxes for all frames
|
||||
shared = _boxes_to_tensor(bboxes)
|
||||
per_frame_boxes = [shared] * B
|
||||
|
||||
# Parse point prompts from JSON (KJNodes PointsEditor format: [{"x": int, "y": int}, ...])
|
||||
pos_pts = json.loads(positive_coords) if positive_coords else []
|
||||
@ -165,10 +171,12 @@ class SAM3_Detect(io.ComfyNode):
|
||||
all_bbox_dicts = []
|
||||
all_masks = []
|
||||
pbar = comfy.utils.ProgressBar(B)
|
||||
b_boxes_tensor = boxes_tensor.to(device=device, dtype=dtype) if boxes_tensor is not None else None
|
||||
|
||||
for b in range(B):
|
||||
frame = image_in[b:b+1].to(device=device, dtype=dtype)
|
||||
b_boxes = None
|
||||
if per_frame_boxes is not None and per_frame_boxes[b] is not None:
|
||||
b_boxes = per_frame_boxes[b].to(device=device, dtype=dtype)
|
||||
|
||||
frame_bbox_dicts = []
|
||||
frame_masks = []
|
||||
@ -182,8 +190,8 @@ class SAM3_Detect(io.ComfyNode):
|
||||
frame_masks.append((mask[0] > 0).float())
|
||||
|
||||
# Box prompts: SAM decoder path (segment inside each box)
|
||||
if b_boxes_tensor is not None and not has_text:
|
||||
for box_cxcywh in b_boxes_tensor[0]:
|
||||
if b_boxes is not None and not has_text:
|
||||
for box_cxcywh in b_boxes[0]:
|
||||
cx, cy, bw, bh = box_cxcywh.tolist()
|
||||
# Convert cxcywh normalized → xyxy in 1008 space → [1, 2, 2] corners
|
||||
sam_box = torch.tensor([[[(cx - bw/2) * 1008, (cy - bh/2) * 1008],
|
||||
@ -199,7 +207,7 @@ class SAM3_Detect(io.ComfyNode):
|
||||
for text_embeddings, text_mask, max_det in cond_list:
|
||||
results = sam3_model(
|
||||
frame, text_embeddings=text_embeddings, text_mask=text_mask,
|
||||
boxes=b_boxes_tensor, threshold=threshold, orig_size=(H, W))
|
||||
boxes=b_boxes, threshold=threshold, orig_size=(H, W))
|
||||
|
||||
pred_boxes = results["boxes"][0]
|
||||
scores = results["scores"][0]
|
||||
@ -234,7 +242,10 @@ class SAM3_Detect(io.ComfyNode):
|
||||
else:
|
||||
all_masks.append((combined > 0).any(dim=0).float())
|
||||
else:
|
||||
all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device()))
|
||||
if individual_masks:
|
||||
all_masks.append(torch.zeros(0, H, W, device=comfy.model_management.intermediate_device()))
|
||||
else:
|
||||
all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device()))
|
||||
pbar.update(1)
|
||||
|
||||
mask_out = torch.cat(all_masks, dim=0) if individual_masks else torch.stack(all_masks)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user