mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 23:42:36 +08:00
515 lines
24 KiB
Python
515 lines
24 KiB
Python
"""
|
|
SAM3 (Segment Anything 3) nodes for detection, segmentation, and video tracking.
|
|
"""
|
|
|
|
from typing_extensions import override
|
|
|
|
import json
|
|
import os
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import comfy.model_management
|
|
import comfy.utils
|
|
import folder_paths
|
|
from comfy_api.latest import ComfyExtension, io, ui
|
|
import av
|
|
from fractions import Fraction
|
|
|
|
|
|
def _extract_text_prompts(conditioning, device, dtype):
|
|
"""Extract list of (text_embeddings, text_mask) from conditioning."""
|
|
cond_meta = conditioning[0][1]
|
|
multi = cond_meta.get("sam3_multi_cond")
|
|
prompts = []
|
|
if multi is not None:
|
|
for entry in multi:
|
|
emb = entry["cond"].to(device=device, dtype=dtype)
|
|
mask = entry["attention_mask"].to(device) if entry["attention_mask"] is not None else None
|
|
if mask is None:
|
|
mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device)
|
|
prompts.append((emb, mask, entry.get("max_detections", 1)))
|
|
else:
|
|
emb = conditioning[0][0].to(device=device, dtype=dtype)
|
|
mask = cond_meta.get("attention_mask")
|
|
if mask is not None:
|
|
mask = mask.to(device)
|
|
else:
|
|
mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device)
|
|
prompts.append((emb, mask, 1))
|
|
return prompts
|
|
|
|
|
|
def _refine_mask(sam3_model, orig_image_hwc, coarse_mask, box_xyxy, H, W, device, dtype, iterations):
|
|
"""Refine a coarse detector mask via SAM decoder, cropping to the detection box.
|
|
|
|
Returns: [1, H, W] binary mask
|
|
"""
|
|
def _coarse_fallback():
|
|
return (F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W),
|
|
mode="bilinear", align_corners=False)[0] > 0).float()
|
|
|
|
if iterations <= 0:
|
|
return _coarse_fallback()
|
|
|
|
pad_frac = 0.1
|
|
x1, y1, x2, y2 = box_xyxy.tolist()
|
|
bw, bh = x2 - x1, y2 - y1
|
|
cx1 = max(0, int(x1 - bw * pad_frac))
|
|
cy1 = max(0, int(y1 - bh * pad_frac))
|
|
cx2 = min(W, int(x2 + bw * pad_frac))
|
|
cy2 = min(H, int(y2 + bh * pad_frac))
|
|
if cx2 <= cx1 or cy2 <= cy1:
|
|
return _coarse_fallback()
|
|
|
|
crop = orig_image_hwc[cy1:cy2, cx1:cx2]
|
|
crop_1008 = comfy.utils.common_upscale(crop.unsqueeze(0).movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled")
|
|
crop_frame = crop_1008.to(device=device, dtype=dtype)
|
|
crop_h, crop_w = cy2 - cy1, cx2 - cx1
|
|
|
|
# Crop coarse mask and refine via SAM on the cropped image
|
|
mask_h, mask_w = coarse_mask.shape[-2:]
|
|
mx1, my1 = int(cx1 / W * mask_w), int(cy1 / H * mask_h)
|
|
mx2, my2 = int(cx2 / W * mask_w), int(cy2 / H * mask_h)
|
|
mask_logit = coarse_mask[..., my1:my2, mx1:mx2].unsqueeze(0).unsqueeze(0)
|
|
for _ in range(iterations):
|
|
coarse_input = F.interpolate(mask_logit, size=(1008, 1008), mode="bilinear", align_corners=False)
|
|
mask_logit = sam3_model.forward_segment(crop_frame, mask_inputs=coarse_input)
|
|
|
|
refined_crop = F.interpolate(mask_logit, size=(crop_h, crop_w), mode="bilinear", align_corners=False)
|
|
full_mask = torch.zeros(1, 1, H, W, device=device, dtype=dtype)
|
|
full_mask[:, :, cy1:cy2, cx1:cx2] = refined_crop
|
|
coarse_full = F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W), mode="bilinear", align_corners=False)
|
|
return ((full_mask[0] > 0) | (coarse_full[0] > 0)).float()
|
|
|
|
|
|
|
|
class SAM3_Detect(io.ComfyNode):
|
|
"""Open-vocabulary detection and segmentation using text, box, or point prompts."""
|
|
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SAM3_Detect",
|
|
display_name="SAM3 Detect",
|
|
category="detection/",
|
|
search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"],
|
|
inputs=[
|
|
io.Model.Input("model", display_name="model"),
|
|
io.Image.Input("image", display_name="image"),
|
|
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning from CLIPTextEncode"),
|
|
io.BoundingBox.Input("bboxes", display_name="bboxes", force_input=True, optional=True, tooltip="Bounding boxes to segment within"),
|
|
io.String.Input("positive_coords", display_name="positive_coords", force_input=True, optional=True, tooltip="Positive point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"),
|
|
io.String.Input("negative_coords", display_name="negative_coords", force_input=True, optional=True, tooltip="Negative point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"),
|
|
io.Float.Input("threshold", display_name="threshold", default=0.5, min=0.0, max=1.0, step=0.01),
|
|
io.Int.Input("refine_iterations", display_name="refine_iterations", default=2, min=0, max=5, tooltip="SAM decoder refinement passes (0=use raw detector masks)"),
|
|
io.Boolean.Input("individual_masks", display_name="individual_masks", default=False, tooltip="Output per-object masks instead of union"),
|
|
],
|
|
outputs=[
|
|
io.Mask.Output("masks"),
|
|
io.BoundingBox.Output("bboxes"),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, model, image, conditioning=None, bboxes=None, positive_coords=None, negative_coords=None, threshold=0.5, refine_iterations=2, individual_masks=False) -> io.NodeOutput:
|
|
B, H, W, C = image.shape
|
|
|
|
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled")
|
|
|
|
# Convert bboxes to normalized cxcywh format [1, N, 4]
|
|
# BoundingBox type can be: single dict, list of dicts, or list of lists of dicts (per-frame)
|
|
boxes_tensor = None
|
|
if bboxes is not None:
|
|
# Flatten to list of dicts
|
|
if isinstance(bboxes, dict):
|
|
flat_boxes = [bboxes]
|
|
elif isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list):
|
|
flat_boxes = [d for frame in bboxes for d in frame] # per-frame list of lists
|
|
elif isinstance(bboxes, list):
|
|
flat_boxes = bboxes
|
|
else:
|
|
flat_boxes = []
|
|
if flat_boxes:
|
|
coords = []
|
|
for d in flat_boxes:
|
|
cx = (d["x"] + d["width"] / 2) / W
|
|
cy = (d["y"] + d["height"] / 2) / H
|
|
coords.append([cx, cy, d["width"] / W, d["height"] / H])
|
|
boxes_tensor = torch.tensor([coords], dtype=torch.float32) # [1, N, 4]
|
|
|
|
# Parse point prompts from JSON (KJNodes PointsEditor format: [{"x": int, "y": int}, ...])
|
|
pos_pts = json.loads(positive_coords) if positive_coords else []
|
|
neg_pts = json.loads(negative_coords) if negative_coords else []
|
|
has_points = len(pos_pts) > 0 or len(neg_pts) > 0
|
|
|
|
comfy.model_management.load_model_gpu(model)
|
|
device = comfy.model_management.get_torch_device()
|
|
dtype = model.model.get_dtype()
|
|
sam3_model = model.model.diffusion_model
|
|
|
|
# Build point inputs for tracker SAM decoder path
|
|
point_inputs = None
|
|
if has_points:
|
|
all_coords = [[p["x"] / W * 1008, p["y"] / H * 1008] for p in pos_pts] + \
|
|
[[p["x"] / W * 1008, p["y"] / H * 1008] for p in neg_pts]
|
|
all_labels = [1] * len(pos_pts) + [0] * len(neg_pts)
|
|
point_inputs = {
|
|
"point_coords": torch.tensor([all_coords], dtype=dtype, device=device),
|
|
"point_labels": torch.tensor([all_labels], dtype=torch.int32, device=device),
|
|
}
|
|
|
|
cond_list = _extract_text_prompts(conditioning, device, dtype) if conditioning is not None and len(conditioning) > 0 else []
|
|
has_text = len(cond_list) > 0
|
|
|
|
# Run per-image through detector (text/boxes) and/or tracker (points)
|
|
all_bbox_dicts = []
|
|
all_masks = []
|
|
pbar = comfy.utils.ProgressBar(B)
|
|
b_boxes_tensor = boxes_tensor.to(device=device, dtype=dtype) if boxes_tensor is not None else None
|
|
|
|
for b in range(B):
|
|
frame = image_in[b:b+1].to(device=device, dtype=dtype)
|
|
|
|
frame_bbox_dicts = []
|
|
frame_masks = []
|
|
|
|
# Point prompts: tracker SAM decoder path with iterative refinement
|
|
if point_inputs is not None:
|
|
mask_logit = sam3_model.forward_segment(frame, point_inputs=point_inputs)
|
|
for _ in range(max(0, refine_iterations - 1)):
|
|
mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit)
|
|
mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False)
|
|
frame_masks.append((mask[0] > 0).float())
|
|
|
|
# Box prompts: SAM decoder path (segment inside each box)
|
|
if b_boxes_tensor is not None and not has_text:
|
|
for box_cxcywh in b_boxes_tensor[0]:
|
|
cx, cy, bw, bh = box_cxcywh.tolist()
|
|
# Convert cxcywh normalized → xyxy in 1008 space → [1, 2, 2] corners
|
|
sam_box = torch.tensor([[[(cx - bw/2) * 1008, (cy - bh/2) * 1008],
|
|
[(cx + bw/2) * 1008, (cy + bh/2) * 1008]]],
|
|
device=device, dtype=dtype)
|
|
mask_logit = sam3_model.forward_segment(frame, box_inputs=sam_box)
|
|
for _ in range(max(0, refine_iterations - 1)):
|
|
mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit)
|
|
mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False)
|
|
frame_masks.append((mask[0] > 0).float())
|
|
|
|
# Text prompts: run detector per text prompt (each detects one category)
|
|
for text_embeddings, text_mask, max_det in cond_list:
|
|
results = sam3_model(
|
|
frame, text_embeddings=text_embeddings, text_mask=text_mask,
|
|
boxes=b_boxes_tensor, threshold=threshold, orig_size=(H, W))
|
|
|
|
pred_boxes = results["boxes"][0]
|
|
scores = results["scores"][0]
|
|
masks = results["masks"][0]
|
|
|
|
probs = scores.sigmoid()
|
|
keep = probs > threshold
|
|
kept_boxes = pred_boxes[keep].cpu()
|
|
kept_scores = probs[keep].cpu()
|
|
kept_masks = masks[keep]
|
|
|
|
order = kept_scores.argsort(descending=True)[:max_det]
|
|
kept_boxes = kept_boxes[order]
|
|
kept_scores = kept_scores[order]
|
|
kept_masks = kept_masks[order]
|
|
|
|
for box, score in zip(kept_boxes, kept_scores):
|
|
frame_bbox_dicts.append({
|
|
"x": float(box[0]), "y": float(box[1]),
|
|
"width": float(box[2] - box[0]), "height": float(box[3] - box[1]),
|
|
"score": float(score),
|
|
})
|
|
for m, box in zip(kept_masks, kept_boxes):
|
|
frame_masks.append(_refine_mask(
|
|
sam3_model, image[b], m, box, H, W, device, dtype, refine_iterations))
|
|
|
|
all_bbox_dicts.append(frame_bbox_dicts)
|
|
if len(frame_masks) > 0:
|
|
combined = torch.cat(frame_masks, dim=0) # [N_obj, H, W]
|
|
if individual_masks:
|
|
all_masks.append(combined)
|
|
else:
|
|
all_masks.append((combined > 0).any(dim=0).float())
|
|
else:
|
|
all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device()))
|
|
pbar.update(1)
|
|
|
|
mask_out = torch.cat(all_masks, dim=0) if individual_masks else torch.stack(all_masks)
|
|
return io.NodeOutput(mask_out, all_bbox_dicts)
|
|
|
|
|
|
SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
|
|
|
|
class SAM3_VideoTrack(io.ComfyNode):
|
|
"""Track objects across video frames using SAM3's memory-based tracker."""
|
|
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SAM3_VideoTrack",
|
|
display_name="SAM3 Video Track",
|
|
category="detection/",
|
|
search_aliases=["sam3", "video", "track", "propagate"],
|
|
inputs=[
|
|
io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"),
|
|
io.Model.Input("model", display_name="model"),
|
|
io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"),
|
|
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"),
|
|
io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection"),
|
|
io.Int.Input("max_objects", display_name="max_objects", default=0, min=0, tooltip="Max tracked objects (0=unlimited). Initial masks count toward this limit."),
|
|
io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."),
|
|
],
|
|
outputs=[
|
|
SAM3TrackData.Output("track_data", display_name="track_data"),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, images, model, initial_mask=None, conditioning=None, detection_threshold=0.5, max_objects=0, detect_interval=1) -> io.NodeOutput:
|
|
N, H, W, C = images.shape
|
|
|
|
comfy.model_management.load_model_gpu(model)
|
|
device = comfy.model_management.get_torch_device()
|
|
dtype = model.model.get_dtype()
|
|
sam3_model = model.model.diffusion_model
|
|
|
|
frames = images.movedim(-1, 1)
|
|
frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype)
|
|
|
|
init_masks = None
|
|
if initial_mask is not None:
|
|
init_masks = initial_mask.unsqueeze(1).to(device=device, dtype=dtype)
|
|
|
|
pbar = comfy.utils.ProgressBar(N)
|
|
|
|
text_prompts = None
|
|
if conditioning is not None:
|
|
text_prompts = [(emb, mask) for emb, mask, _ in _extract_text_prompts(conditioning, device, dtype)]
|
|
elif initial_mask is None:
|
|
raise ValueError("Either initial_mask or conditioning must be provided")
|
|
|
|
result = sam3_model.forward_video(
|
|
images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts,
|
|
new_det_thresh=detection_threshold, max_objects=max_objects,
|
|
detect_interval=detect_interval)
|
|
result["orig_size"] = (H, W)
|
|
return io.NodeOutput(result)
|
|
|
|
|
|
class SAM3_TrackPreview(io.ComfyNode):
|
|
"""Visualize tracked objects with distinct colors as a video preview. No tensor output — saves to temp video."""
|
|
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SAM3_TrackPreview",
|
|
display_name="SAM3 Track Preview",
|
|
category="detection/",
|
|
inputs=[
|
|
SAM3TrackData.Input("track_data", display_name="track_data"),
|
|
io.Image.Input("images", display_name="images", optional=True),
|
|
io.Float.Input("opacity", display_name="opacity", default=0.5, min=0.0, max=1.0, step=0.05),
|
|
io.Float.Input("fps", display_name="fps", default=24.0, min=1.0, max=120.0, step=1.0),
|
|
],
|
|
is_output_node=True,
|
|
)
|
|
|
|
COLORS = [
|
|
(0.12, 0.47, 0.71), (1.0, 0.5, 0.05), (0.17, 0.63, 0.17), (0.84, 0.15, 0.16),
|
|
(0.58, 0.4, 0.74), (0.55, 0.34, 0.29), (0.89, 0.47, 0.76), (0.5, 0.5, 0.5),
|
|
(0.74, 0.74, 0.13), (0.09, 0.75, 0.81), (0.94, 0.76, 0.06), (0.42, 0.68, 0.84),
|
|
]
|
|
|
|
# 5x3 bitmap font atlas for digits 0-9 [10, 5, 3]
|
|
_glyph_cache = {} # (device, scale) -> (glyphs, outlines, gh, gw, oh, ow)
|
|
|
|
@staticmethod
|
|
def _get_glyphs(device, scale=3):
|
|
key = (device, scale)
|
|
if key in SAM3_TrackPreview._glyph_cache:
|
|
return SAM3_TrackPreview._glyph_cache[key]
|
|
atlas = torch.tensor([
|
|
[[1,1,1],[1,0,1],[1,0,1],[1,0,1],[1,1,1]],
|
|
[[0,1,0],[1,1,0],[0,1,0],[0,1,0],[1,1,1]],
|
|
[[1,1,1],[0,0,1],[1,1,1],[1,0,0],[1,1,1]],
|
|
[[1,1,1],[0,0,1],[1,1,1],[0,0,1],[1,1,1]],
|
|
[[1,0,1],[1,0,1],[1,1,1],[0,0,1],[0,0,1]],
|
|
[[1,1,1],[1,0,0],[1,1,1],[0,0,1],[1,1,1]],
|
|
[[1,1,1],[1,0,0],[1,1,1],[1,0,1],[1,1,1]],
|
|
[[1,1,1],[0,0,1],[0,0,1],[0,0,1],[0,0,1]],
|
|
[[1,1,1],[1,0,1],[1,1,1],[1,0,1],[1,1,1]],
|
|
[[1,1,1],[1,0,1],[1,1,1],[0,0,1],[1,1,1]],
|
|
], dtype=torch.bool)
|
|
glyphs, outlines = [], []
|
|
for d in range(10):
|
|
g = atlas[d].repeat_interleave(scale, 0).repeat_interleave(scale, 1)
|
|
padded = F.pad(g.float().unsqueeze(0).unsqueeze(0), (1,1,1,1))
|
|
o = (F.max_pool2d(padded, 3, stride=1, padding=1)[0, 0] > 0)
|
|
glyphs.append(g.to(device))
|
|
outlines.append(o.to(device))
|
|
gh, gw = glyphs[0].shape
|
|
oh, ow = outlines[0].shape
|
|
SAM3_TrackPreview._glyph_cache[key] = (glyphs, outlines, gh, gw, oh, ow)
|
|
return SAM3_TrackPreview._glyph_cache[key]
|
|
|
|
@staticmethod
|
|
def _draw_number_gpu(frame, number, cx, cy, color, scale=3):
|
|
"""Draw a number on a GPU tensor [H, W, 3] float 0-1 at (cx, cy) with outline."""
|
|
H, W = frame.shape[:2]
|
|
device = frame.device
|
|
glyphs, outlines, gh, gw, oh, ow = SAM3_TrackPreview._get_glyphs(device, scale)
|
|
color_t = torch.tensor(color, device=device, dtype=frame.dtype)
|
|
digs = [int(d) for d in str(number)]
|
|
total_w = len(digs) * (gw + scale) - scale
|
|
x0 = cx - total_w // 2
|
|
y0 = cy - gh // 2
|
|
for i, d in enumerate(digs):
|
|
dx = x0 + i * (gw + scale)
|
|
# Black outline
|
|
oy0, ox0 = y0 - 1, dx - 1
|
|
osy1, osx1 = max(0, -oy0), max(0, -ox0)
|
|
osy2, osx2 = min(oh, H - oy0), min(ow, W - ox0)
|
|
if osy2 > osy1 and osx2 > osx1:
|
|
fy1, fx1 = oy0 + osy1, ox0 + osx1
|
|
frame[fy1:fy1+(osy2-osy1), fx1:fx1+(osx2-osx1)][outlines[d][osy1:osy2, osx1:osx2]] = 0
|
|
# Colored fill
|
|
sy1, sx1 = max(0, -y0), max(0, -dx)
|
|
sy2, sx2 = min(gh, H - y0), min(gw, W - dx)
|
|
if sy2 > sy1 and sx2 > sx1:
|
|
fy1, fx1 = y0 + sy1, dx + sx1
|
|
frame[fy1:fy1+(sy2-sy1), fx1:fx1+(sx2-sx1)][glyphs[d][sy1:sy2, sx1:sx2]] = color_t
|
|
|
|
@classmethod
|
|
def execute(cls, track_data, images=None, opacity=0.5, fps=24.0) -> io.NodeOutput:
|
|
|
|
from comfy.ldm.sam3.tracker import unpack_masks
|
|
packed = track_data["packed_masks"]
|
|
H, W = track_data["orig_size"]
|
|
if images is not None:
|
|
H, W = images.shape[1], images.shape[2]
|
|
if packed is None:
|
|
N, N_obj = track_data["n_frames"], 0
|
|
else:
|
|
N, N_obj = packed.shape[0], packed.shape[1]
|
|
|
|
gpu = comfy.model_management.get_torch_device()
|
|
temp_dir = folder_paths.get_temp_directory()
|
|
filename = "sam3_track_preview.mp4"
|
|
filepath = os.path.join(temp_dir, filename)
|
|
with av.open(filepath, mode='w') as output:
|
|
stream = output.add_stream('h264', rate=Fraction(round(fps * 1000), 1000))
|
|
stream.width = W
|
|
stream.height = H
|
|
stream.pix_fmt = 'yuv420p'
|
|
|
|
frame_cpu = torch.empty(H, W, 3, dtype=torch.uint8)
|
|
frame_np = frame_cpu.numpy()
|
|
if N_obj > 0:
|
|
colors_t = torch.tensor([cls.COLORS[i % len(cls.COLORS)] for i in range(N_obj)],
|
|
device=gpu, dtype=torch.float32)
|
|
grid_y = torch.arange(H, device=gpu).view(1, H, 1)
|
|
grid_x = torch.arange(W, device=gpu).view(1, 1, W)
|
|
for t in range(N):
|
|
if images is not None:
|
|
frame = images[t].clone()
|
|
else:
|
|
frame = torch.zeros(H, W, 3)
|
|
|
|
if N_obj > 0:
|
|
frame_binary = unpack_masks(packed[t:t+1].to(gpu)) # [1, N_obj, H, W] bool
|
|
frame_masks = F.interpolate(frame_binary.float(), size=(H, W), mode="nearest")[0]
|
|
frame_gpu = frame.to(gpu)
|
|
bool_masks = frame_masks > 0.5
|
|
any_mask = bool_masks.any(dim=0)
|
|
if any_mask.any():
|
|
obj_idx_map = bool_masks.to(torch.uint8).argmax(dim=0)
|
|
color_overlay = colors_t[obj_idx_map]
|
|
mask_3d = any_mask.unsqueeze(-1)
|
|
frame_gpu = torch.where(mask_3d, frame_gpu * (1 - opacity) + color_overlay * opacity, frame_gpu)
|
|
area = bool_masks.sum(dim=(-1, -2)).clamp_(min=1)
|
|
cy = (bool_masks * grid_y).sum(dim=(-1, -2)) // area
|
|
cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area
|
|
has = area > 1
|
|
scores = track_data.get("scores", [])
|
|
for obj_idx in range(N_obj):
|
|
if has[obj_idx]:
|
|
_cx, _cy = int(cx[obj_idx]), int(cy[obj_idx])
|
|
color = cls.COLORS[obj_idx % len(cls.COLORS)]
|
|
SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color)
|
|
if obj_idx < len(scores) and scores[obj_idx] < 1.0:
|
|
SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100),
|
|
_cx, _cy + 5 * 3 + 3, color, scale=2)
|
|
frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte())
|
|
else:
|
|
frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte())
|
|
|
|
vframe = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
|
|
output.mux(stream.encode(vframe.reformat(format='yuv420p')))
|
|
output.mux(stream.encode(None))
|
|
return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(filename, "", io.FolderType.temp)]))
|
|
|
|
|
|
class SAM3_TrackToMask(io.ComfyNode):
|
|
"""Select tracked objects by index and output as mask."""
|
|
|
|
@classmethod
|
|
def define_schema(cls):
|
|
return io.Schema(
|
|
node_id="SAM3_TrackToMask",
|
|
display_name="SAM3 Track to Mask",
|
|
category="detection/",
|
|
inputs=[
|
|
SAM3TrackData.Input("track_data", display_name="track_data"),
|
|
io.String.Input("object_indices", display_name="object_indices", default="",
|
|
tooltip="Comma-separated object indices to include (e.g. '0,2,3'). Empty = all objects."),
|
|
],
|
|
outputs=[
|
|
io.Mask.Output("masks", display_name="masks"),
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def execute(cls, track_data, object_indices="") -> io.NodeOutput:
|
|
from comfy.ldm.sam3.tracker import unpack_masks
|
|
packed = track_data["packed_masks"]
|
|
H, W = track_data["orig_size"]
|
|
|
|
if packed is None:
|
|
N = track_data["n_frames"]
|
|
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
|
|
|
|
N, N_obj = packed.shape[0], packed.shape[1]
|
|
|
|
if object_indices.strip():
|
|
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
|
|
indices = [i for i in indices if 0 <= i < N_obj]
|
|
else:
|
|
indices = list(range(N_obj))
|
|
|
|
if not indices:
|
|
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
|
|
|
|
selected = packed[:, indices]
|
|
binary = unpack_masks(selected) # [N, len(indices), Hm, Wm] bool
|
|
union = binary.any(dim=1, keepdim=True).float()
|
|
mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0]
|
|
return io.NodeOutput(mask_out)
|
|
|
|
|
|
class SAM3Extension(ComfyExtension):
|
|
@override
|
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
|
return [
|
|
SAM3_Detect,
|
|
SAM3_VideoTrack,
|
|
SAM3_TrackPreview,
|
|
SAM3_TrackToMask,
|
|
]
|
|
|
|
|
|
async def comfy_entrypoint() -> SAM3Extension:
|
|
return SAM3Extension()
|