""" SAM3 (Segment Anything 3) nodes for detection, segmentation, and video tracking. """ from typing_extensions import override import json import os import torch import torch.nn.functional as F import comfy.model_management import comfy.utils import folder_paths from comfy_api.latest import ComfyExtension, io, ui import av from fractions import Fraction def _extract_text_prompts(conditioning, device, dtype): """Extract list of (text_embeddings, text_mask) from conditioning.""" cond_meta = conditioning[0][1] multi = cond_meta.get("sam3_multi_cond") prompts = [] if multi is not None: for entry in multi: emb = entry["cond"].to(device=device, dtype=dtype) mask = entry["attention_mask"].to(device) if entry["attention_mask"] is not None else None if mask is None: mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device) prompts.append((emb, mask, entry.get("max_detections", 1))) else: emb = conditioning[0][0].to(device=device, dtype=dtype) mask = cond_meta.get("attention_mask") if mask is not None: mask = mask.to(device) else: mask = torch.ones(emb.shape[0], emb.shape[1], dtype=torch.int64, device=device) prompts.append((emb, mask, 1)) return prompts def _refine_mask(sam3_model, orig_image_hwc, coarse_mask, box_xyxy, H, W, device, dtype, iterations): """Refine a coarse detector mask via SAM decoder, cropping to the detection box. Returns: [1, H, W] binary mask """ def _coarse_fallback(): return (F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W), mode="bilinear", align_corners=False)[0] > 0).float() if iterations <= 0: return _coarse_fallback() pad_frac = 0.1 x1, y1, x2, y2 = box_xyxy.tolist() bw, bh = x2 - x1, y2 - y1 cx1 = max(0, int(x1 - bw * pad_frac)) cy1 = max(0, int(y1 - bh * pad_frac)) cx2 = min(W, int(x2 + bw * pad_frac)) cy2 = min(H, int(y2 + bh * pad_frac)) if cx2 <= cx1 or cy2 <= cy1: return _coarse_fallback() crop = orig_image_hwc[cy1:cy2, cx1:cx2] crop_1008 = comfy.utils.common_upscale(crop.unsqueeze(0).movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled") crop_frame = crop_1008.to(device=device, dtype=dtype) crop_h, crop_w = cy2 - cy1, cx2 - cx1 # Crop coarse mask and refine via SAM on the cropped image mask_h, mask_w = coarse_mask.shape[-2:] mx1, my1 = int(cx1 / W * mask_w), int(cy1 / H * mask_h) mx2, my2 = int(cx2 / W * mask_w), int(cy2 / H * mask_h) if mx2 <= mx1 or my2 <= my1: return _coarse_fallback() mask_logit = coarse_mask[..., my1:my2, mx1:mx2].unsqueeze(0).unsqueeze(0) for _ in range(iterations): coarse_input = F.interpolate(mask_logit, size=(1008, 1008), mode="bilinear", align_corners=False) mask_logit = sam3_model.forward_segment(crop_frame, mask_inputs=coarse_input) refined_crop = F.interpolate(mask_logit, size=(crop_h, crop_w), mode="bilinear", align_corners=False) full_mask = torch.zeros(1, 1, H, W, device=device, dtype=dtype) full_mask[:, :, cy1:cy2, cx1:cx2] = refined_crop coarse_full = F.interpolate(coarse_mask.unsqueeze(0).unsqueeze(0), size=(H, W), mode="bilinear", align_corners=False) return ((full_mask[0] > 0) | (coarse_full[0] > 0)).float() class SAM3_Detect(io.ComfyNode): """Open-vocabulary detection and segmentation using text, box, or point prompts.""" @classmethod def define_schema(cls): return io.Schema( node_id="SAM3_Detect", display_name="SAM3 Detect", category="detection/", search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"], inputs=[ io.Model.Input("model", display_name="model"), io.Image.Input("image", display_name="image"), io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning from CLIPTextEncode"), io.BoundingBox.Input("bboxes", display_name="bboxes", force_input=True, optional=True, tooltip="Bounding boxes to segment within"), io.String.Input("positive_coords", display_name="positive_coords", force_input=True, optional=True, tooltip="Positive point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"), io.String.Input("negative_coords", display_name="negative_coords", force_input=True, optional=True, tooltip="Negative point prompts as JSON [{\"x\": int, \"y\": int}, ...] (pixel coords)"), io.Float.Input("threshold", display_name="threshold", default=0.5, min=0.0, max=1.0, step=0.01), io.Int.Input("refine_iterations", display_name="refine_iterations", default=2, min=0, max=5, tooltip="SAM decoder refinement passes (0=use raw detector masks)"), io.Boolean.Input("individual_masks", display_name="individual_masks", default=False, tooltip="Output per-object masks instead of union"), ], outputs=[ io.Mask.Output("masks"), io.BoundingBox.Output("bboxes"), ], ) @classmethod def execute(cls, model, image, conditioning=None, bboxes=None, positive_coords=None, negative_coords=None, threshold=0.5, refine_iterations=2, individual_masks=False) -> io.NodeOutput: B, H, W, C = image.shape image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 1008, 1008, "bilinear", crop="disabled") # Convert bboxes to normalized cxcywh format, per-frame list of [1, N, 4] tensors. # Supports: single dict (all frames), list[dict] (all frames), list[list[dict]] (per-frame). def _boxes_to_tensor(box_list): coords = [] for d in box_list: cx = (d["x"] + d["width"] / 2) / W cy = (d["y"] + d["height"] / 2) / H coords.append([cx, cy, d["width"] / W, d["height"] / H]) return torch.tensor([coords], dtype=torch.float32) # [1, N, 4] per_frame_boxes = None if bboxes is not None: if isinstance(bboxes, dict): # Single box → same for all frames shared = _boxes_to_tensor([bboxes]) per_frame_boxes = [shared] * B elif isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list): # list[list[dict]] → per-frame boxes per_frame_boxes = [_boxes_to_tensor(frame_boxes) if frame_boxes else None for frame_boxes in bboxes] # Pad to B if fewer frames provided while len(per_frame_boxes) < B: per_frame_boxes.append(per_frame_boxes[-1] if per_frame_boxes else None) elif isinstance(bboxes, list) and len(bboxes) > 0: # list[dict] → same boxes for all frames shared = _boxes_to_tensor(bboxes) per_frame_boxes = [shared] * B # Parse point prompts from JSON (KJNodes PointsEditor format: [{"x": int, "y": int}, ...]) pos_pts = json.loads(positive_coords) if positive_coords else [] neg_pts = json.loads(negative_coords) if negative_coords else [] has_points = len(pos_pts) > 0 or len(neg_pts) > 0 comfy.model_management.load_model_gpu(model) device = comfy.model_management.get_torch_device() dtype = model.model.get_dtype() sam3_model = model.model.diffusion_model # Build point inputs for tracker SAM decoder path point_inputs = None if has_points: all_coords = [[p["x"] / W * 1008, p["y"] / H * 1008] for p in pos_pts] + \ [[p["x"] / W * 1008, p["y"] / H * 1008] for p in neg_pts] all_labels = [1] * len(pos_pts) + [0] * len(neg_pts) point_inputs = { "point_coords": torch.tensor([all_coords], dtype=dtype, device=device), "point_labels": torch.tensor([all_labels], dtype=torch.int32, device=device), } cond_list = _extract_text_prompts(conditioning, device, dtype) if conditioning is not None and len(conditioning) > 0 else [] has_text = len(cond_list) > 0 # Run per-image through detector (text/boxes) and/or tracker (points) all_bbox_dicts = [] all_masks = [] pbar = comfy.utils.ProgressBar(B) for b in range(B): frame = image_in[b:b+1].to(device=device, dtype=dtype) b_boxes = None if per_frame_boxes is not None and per_frame_boxes[b] is not None: b_boxes = per_frame_boxes[b].to(device=device, dtype=dtype) frame_bbox_dicts = [] frame_masks = [] # Point prompts: tracker SAM decoder path with iterative refinement if point_inputs is not None: mask_logit = sam3_model.forward_segment(frame, point_inputs=point_inputs) for _ in range(max(0, refine_iterations - 1)): mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit) mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False) frame_masks.append((mask[0] > 0).float()) # Box prompts: SAM decoder path (segment inside each box) if b_boxes is not None and not has_text: for box_cxcywh in b_boxes[0]: cx, cy, bw, bh = box_cxcywh.tolist() # Convert cxcywh normalized → xyxy in 1008 space → [1, 2, 2] corners sam_box = torch.tensor([[[(cx - bw/2) * 1008, (cy - bh/2) * 1008], [(cx + bw/2) * 1008, (cy + bh/2) * 1008]]], device=device, dtype=dtype) mask_logit = sam3_model.forward_segment(frame, box_inputs=sam_box) for _ in range(max(0, refine_iterations - 1)): mask_logit = sam3_model.forward_segment(frame, mask_inputs=mask_logit) mask = F.interpolate(mask_logit, size=(H, W), mode="bilinear", align_corners=False) frame_masks.append((mask[0] > 0).float()) # Text prompts: run detector per text prompt (each detects one category) for text_embeddings, text_mask, max_det in cond_list: results = sam3_model( frame, text_embeddings=text_embeddings, text_mask=text_mask, boxes=b_boxes, threshold=threshold, orig_size=(H, W)) pred_boxes = results["boxes"][0] scores = results["scores"][0] masks = results["masks"][0] probs = scores.sigmoid() keep = probs > threshold kept_boxes = pred_boxes[keep].cpu() kept_scores = probs[keep].cpu() kept_masks = masks[keep] order = kept_scores.argsort(descending=True)[:max_det] kept_boxes = kept_boxes[order] kept_scores = kept_scores[order] kept_masks = kept_masks[order] for box, score in zip(kept_boxes, kept_scores): frame_bbox_dicts.append({ "x": float(box[0]), "y": float(box[1]), "width": float(box[2] - box[0]), "height": float(box[3] - box[1]), "score": float(score), }) for m, box in zip(kept_masks, kept_boxes): frame_masks.append(_refine_mask( sam3_model, image[b], m, box, H, W, device, dtype, refine_iterations)) all_bbox_dicts.append(frame_bbox_dicts) if len(frame_masks) > 0: combined = torch.cat(frame_masks, dim=0) # [N_obj, H, W] if individual_masks: all_masks.append(combined) else: all_masks.append((combined > 0).any(dim=0).float()) else: if individual_masks: all_masks.append(torch.zeros(0, H, W, device=comfy.model_management.intermediate_device())) else: all_masks.append(torch.zeros(H, W, device=comfy.model_management.intermediate_device())) pbar.update(1) idev = comfy.model_management.intermediate_device() all_masks = [m.to(idev) for m in all_masks] mask_out = torch.cat(all_masks, dim=0) if individual_masks else torch.stack(all_masks) return io.NodeOutput(mask_out, all_bbox_dicts) SAM3TrackData = io.Custom("SAM3_TRACK_DATA") class SAM3_VideoTrack(io.ComfyNode): """Track objects across video frames using SAM3's memory-based tracker.""" @classmethod def define_schema(cls): return io.Schema( node_id="SAM3_VideoTrack", display_name="SAM3 Video Track", category="detection/", search_aliases=["sam3", "video", "track", "propagate"], inputs=[ io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"), io.Model.Input("model", display_name="model"), io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"), io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"), io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection"), io.Int.Input("max_objects", display_name="max_objects", default=0, min=0, tooltip="Max tracked objects (0=unlimited). Initial masks count toward this limit."), io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."), ], outputs=[ SAM3TrackData.Output("track_data", display_name="track_data"), ], ) @classmethod def execute(cls, images, model, initial_mask=None, conditioning=None, detection_threshold=0.5, max_objects=0, detect_interval=1) -> io.NodeOutput: N, H, W, C = images.shape comfy.model_management.load_model_gpu(model) device = comfy.model_management.get_torch_device() dtype = model.model.get_dtype() sam3_model = model.model.diffusion_model frames = images.movedim(-1, 1) frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype) init_masks = None if initial_mask is not None: init_masks = initial_mask.unsqueeze(1).to(device=device, dtype=dtype) pbar = comfy.utils.ProgressBar(N) text_prompts = None if conditioning is not None and len(conditioning) > 0: text_prompts = [(emb, mask) for emb, mask, _ in _extract_text_prompts(conditioning, device, dtype)] elif initial_mask is None: raise ValueError("Either initial_mask or conditioning must be provided") result = sam3_model.forward_video( images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts, new_det_thresh=detection_threshold, max_objects=max_objects, detect_interval=detect_interval) result["orig_size"] = (H, W) return io.NodeOutput(result) class SAM3_TrackPreview(io.ComfyNode): """Visualize tracked objects with distinct colors as a video preview. No tensor output — saves to temp video.""" @classmethod def define_schema(cls): return io.Schema( node_id="SAM3_TrackPreview", display_name="SAM3 Track Preview", category="detection/", inputs=[ SAM3TrackData.Input("track_data", display_name="track_data"), io.Image.Input("images", display_name="images", optional=True), io.Float.Input("opacity", display_name="opacity", default=0.5, min=0.0, max=1.0, step=0.05), io.Float.Input("fps", display_name="fps", default=24.0, min=1.0, max=120.0, step=1.0), ], is_output_node=True, ) COLORS = [ (0.12, 0.47, 0.71), (1.0, 0.5, 0.05), (0.17, 0.63, 0.17), (0.84, 0.15, 0.16), (0.58, 0.4, 0.74), (0.55, 0.34, 0.29), (0.89, 0.47, 0.76), (0.5, 0.5, 0.5), (0.74, 0.74, 0.13), (0.09, 0.75, 0.81), (0.94, 0.76, 0.06), (0.42, 0.68, 0.84), ] # 5x3 bitmap font atlas for digits 0-9 [10, 5, 3] _glyph_cache = {} # (device, scale) -> (glyphs, outlines, gh, gw, oh, ow) @staticmethod def _get_glyphs(device, scale=3): key = (device, scale) if key in SAM3_TrackPreview._glyph_cache: return SAM3_TrackPreview._glyph_cache[key] atlas = torch.tensor([ [[1,1,1],[1,0,1],[1,0,1],[1,0,1],[1,1,1]], [[0,1,0],[1,1,0],[0,1,0],[0,1,0],[1,1,1]], [[1,1,1],[0,0,1],[1,1,1],[1,0,0],[1,1,1]], [[1,1,1],[0,0,1],[1,1,1],[0,0,1],[1,1,1]], [[1,0,1],[1,0,1],[1,1,1],[0,0,1],[0,0,1]], [[1,1,1],[1,0,0],[1,1,1],[0,0,1],[1,1,1]], [[1,1,1],[1,0,0],[1,1,1],[1,0,1],[1,1,1]], [[1,1,1],[0,0,1],[0,0,1],[0,0,1],[0,0,1]], [[1,1,1],[1,0,1],[1,1,1],[1,0,1],[1,1,1]], [[1,1,1],[1,0,1],[1,1,1],[0,0,1],[1,1,1]], ], dtype=torch.bool) glyphs, outlines = [], [] for d in range(10): g = atlas[d].repeat_interleave(scale, 0).repeat_interleave(scale, 1) padded = F.pad(g.float().unsqueeze(0).unsqueeze(0), (1,1,1,1)) o = (F.max_pool2d(padded, 3, stride=1, padding=1)[0, 0] > 0) glyphs.append(g.to(device)) outlines.append(o.to(device)) gh, gw = glyphs[0].shape oh, ow = outlines[0].shape SAM3_TrackPreview._glyph_cache[key] = (glyphs, outlines, gh, gw, oh, ow) return SAM3_TrackPreview._glyph_cache[key] @staticmethod def _draw_number_gpu(frame, number, cx, cy, color, scale=3): """Draw a number on a GPU tensor [H, W, 3] float 0-1 at (cx, cy) with outline.""" H, W = frame.shape[:2] device = frame.device glyphs, outlines, gh, gw, oh, ow = SAM3_TrackPreview._get_glyphs(device, scale) color_t = torch.tensor(color, device=device, dtype=frame.dtype) digs = [int(d) for d in str(number)] total_w = len(digs) * (gw + scale) - scale x0 = cx - total_w // 2 y0 = cy - gh // 2 for i, d in enumerate(digs): dx = x0 + i * (gw + scale) # Black outline oy0, ox0 = y0 - 1, dx - 1 osy1, osx1 = max(0, -oy0), max(0, -ox0) osy2, osx2 = min(oh, H - oy0), min(ow, W - ox0) if osy2 > osy1 and osx2 > osx1: fy1, fx1 = oy0 + osy1, ox0 + osx1 frame[fy1:fy1+(osy2-osy1), fx1:fx1+(osx2-osx1)][outlines[d][osy1:osy2, osx1:osx2]] = 0 # Colored fill sy1, sx1 = max(0, -y0), max(0, -dx) sy2, sx2 = min(gh, H - y0), min(gw, W - dx) if sy2 > sy1 and sx2 > sx1: fy1, fx1 = y0 + sy1, dx + sx1 frame[fy1:fy1+(sy2-sy1), fx1:fx1+(sx2-sx1)][glyphs[d][sy1:sy2, sx1:sx2]] = color_t @classmethod def execute(cls, track_data, images=None, opacity=0.5, fps=24.0) -> io.NodeOutput: from comfy.ldm.sam3.tracker import unpack_masks packed = track_data["packed_masks"] H, W = track_data["orig_size"] if images is not None: H, W = images.shape[1], images.shape[2] if packed is None: N, N_obj = track_data["n_frames"], 0 else: N, N_obj = packed.shape[0], packed.shape[1] import uuid gpu = comfy.model_management.get_torch_device() temp_dir = folder_paths.get_temp_directory() filename = f"sam3_track_preview_{uuid.uuid4().hex[:8]}.mp4" filepath = os.path.join(temp_dir, filename) with av.open(filepath, mode='w') as output: stream = output.add_stream('h264', rate=Fraction(round(fps * 1000), 1000)) stream.width = W stream.height = H stream.pix_fmt = 'yuv420p' frame_cpu = torch.empty(H, W, 3, dtype=torch.uint8) frame_np = frame_cpu.numpy() if N_obj > 0: colors_t = torch.tensor([cls.COLORS[i % len(cls.COLORS)] for i in range(N_obj)], device=gpu, dtype=torch.float32) grid_y = torch.arange(H, device=gpu).view(1, H, 1) grid_x = torch.arange(W, device=gpu).view(1, 1, W) for t in range(N): if images is not None and t < images.shape[0]: frame = images[t].clone() else: frame = torch.zeros(H, W, 3) if N_obj > 0: frame_binary = unpack_masks(packed[t:t+1].to(gpu)) # [1, N_obj, H, W] bool frame_masks = F.interpolate(frame_binary.float(), size=(H, W), mode="nearest")[0] frame_gpu = frame.to(gpu) bool_masks = frame_masks > 0.5 any_mask = bool_masks.any(dim=0) if any_mask.any(): obj_idx_map = bool_masks.to(torch.uint8).argmax(dim=0) color_overlay = colors_t[obj_idx_map] mask_3d = any_mask.unsqueeze(-1) frame_gpu = torch.where(mask_3d, frame_gpu * (1 - opacity) + color_overlay * opacity, frame_gpu) area = bool_masks.sum(dim=(-1, -2)).clamp_(min=1) cy = (bool_masks * grid_y).sum(dim=(-1, -2)) // area cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area has = area > 1 scores = track_data.get("scores", []) for obj_idx in range(N_obj): if has[obj_idx]: _cx, _cy = int(cx[obj_idx]), int(cy[obj_idx]) color = cls.COLORS[obj_idx % len(cls.COLORS)] SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color) if obj_idx < len(scores) and scores[obj_idx] < 1.0: SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100), _cx, _cy + 5 * 3 + 3, color, scale=2) frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte()) else: frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte()) vframe = av.VideoFrame.from_ndarray(frame_np, format='rgb24') output.mux(stream.encode(vframe.reformat(format='yuv420p'))) output.mux(stream.encode(None)) return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(filename, "", io.FolderType.temp)])) class SAM3_TrackToMask(io.ComfyNode): """Select tracked objects by index and output as mask.""" @classmethod def define_schema(cls): return io.Schema( node_id="SAM3_TrackToMask", display_name="SAM3 Track to Mask", category="detection/", inputs=[ SAM3TrackData.Input("track_data", display_name="track_data"), io.String.Input("object_indices", display_name="object_indices", default="", tooltip="Comma-separated object indices to include (e.g. '0,2,3'). Empty = all objects."), ], outputs=[ io.Mask.Output("masks", display_name="masks"), ], ) @classmethod def execute(cls, track_data, object_indices="") -> io.NodeOutput: from comfy.ldm.sam3.tracker import unpack_masks packed = track_data["packed_masks"] H, W = track_data["orig_size"] if packed is None: N = track_data["n_frames"] return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device())) N, N_obj = packed.shape[0], packed.shape[1] if object_indices.strip(): indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()] indices = [i for i in indices if 0 <= i < N_obj] else: indices = list(range(N_obj)) if not indices: return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device())) selected = packed[:, indices] binary = unpack_masks(selected) # [N, len(indices), Hm, Wm] bool union = binary.any(dim=1, keepdim=True).float() mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0] return io.NodeOutput(mask_out) class SAM3Extension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ SAM3_Detect, SAM3_VideoTrack, SAM3_TrackPreview, SAM3_TrackToMask, ] async def comfy_entrypoint() -> SAM3Extension: return SAM3Extension()