from typing_extensions import override import torch from comfy.ldm.rt_detr.rtdetr_v4 import COCO_CLASSES import comfy.model_management import comfy.utils from comfy_api.latest import ComfyExtension, io from torchvision.transforms import ToPILImage, ToTensor from PIL import ImageDraw, ImageFont class RTDETR_detect(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="RTDETR_detect", display_name="RT-DETR Detect", category="detection/", search_aliases=["bbox", "bounding box", "object detection", "coco"], inputs=[ io.Model.Input("model", display_name="model"), io.Image.Input("image", display_name="image"), io.Float.Input("threshold", display_name="threshold", default=0.5), io.Combo.Input("class_name", options=["all"] + COCO_CLASSES, default="all", tooltip="Filter detections by class. Set to 'all' to disable filtering."), io.Int.Input("max_detections", display_name="max_detections", default=100, tooltip="Maximum number of detections to return per image. In order of descending confidence score."), ], outputs=[ io.BoundingBox.Output("bboxes")], ) @classmethod def execute(cls, model, image, threshold, class_name, max_detections) -> io.NodeOutput: B, H, W, C = image.shape image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled") comfy.model_management.load_model_gpu(model) results = model.model.diffusion_model(image_in, (W, H)) # list of B dicts all_bbox_dicts = [] for det in results: keep = det['scores'] > threshold boxes = det['boxes'][keep].cpu() labels = det['labels'][keep].cpu() scores = det['scores'][keep].cpu() bbox_dicts = [ { "x": float(box[0]), "y": float(box[1]), "width": float(box[2] - box[0]), "height": float(box[3] - box[1]), "label": COCO_CLASSES[int(label)], "score": float(score) } for box, label, score in zip(boxes, labels, scores) if class_name == "all" or COCO_CLASSES[int(label)] == class_name ] bbox_dicts.sort(key=lambda d: d["score"], reverse=True) all_bbox_dicts.append(bbox_dicts[:max_detections]) return io.NodeOutput(all_bbox_dicts) class DrawBBoxes(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="DrawBBoxes", display_name="Draw BBoxes", category="detection/", search_aliases=["bbox", "bounding box", "object detection", "rt_detr", "visualize detections", "coco"], inputs=[ io.Image.Input("image", optional=True), io.BoundingBox.Input("bboxes", force_input=True), ], outputs=[ io.Image.Output("out_image"), ], ) @classmethod def execute(cls, bboxes, image=None) -> io.NodeOutput: # Normalise to list[list[dict]], then fit to batch size B. B = image.shape[0] if image is not None else 1 if isinstance(bboxes, dict): bboxes = [[bboxes]] elif not isinstance(bboxes, list) or not bboxes: bboxes = [[]] elif isinstance(bboxes[0], dict): bboxes = [bboxes] # flat list → same detections for every image if len(bboxes) == 1: bboxes = bboxes * B bboxes = (bboxes + [[]] * B)[:B] if image is None: B = len(bboxes) max_w = max((int(d["x"] + d["width"]) for frame in bboxes for d in frame), default=640) max_h = max((int(d["y"] + d["height"]) for frame in bboxes for d in frame), default=640) image = torch.zeros((B, max_h, max_w, 3), dtype=torch.float32) all_out_images = [] for i in range(B): detections = bboxes[i] if detections: boxes = torch.tensor([[d["x"], d["y"], d["x"] + d["width"], d["y"] + d["height"]] for d in detections]) labels = [d.get("label") if d.get("label") in COCO_CLASSES else None for d in detections] scores = torch.tensor([d.get("score", 1.0) for d in detections]) else: boxes = torch.zeros((0, 4)) labels = [] scores = torch.zeros((0,)) pil_image = image[i].movedim(-1, 0) img = ToPILImage()(pil_image) if detections: img = cls.draw_detections(img, boxes, labels, scores) all_out_images.append(ToTensor()(img).unsqueeze(0).movedim(1, -1)) out_images = torch.cat(all_out_images, dim=0).to(comfy.model_management.intermediate_device()) return io.NodeOutput(out_images) @classmethod def draw_detections(cls, img, boxes, labels, scores): draw = ImageDraw.Draw(img) try: font = ImageFont.truetype('arial.ttf', 16) except Exception: font = ImageFont.load_default() colors = [(255,0,0),(0,200,0),(0,0,255),(255,165,0),(128,0,128), (0,255,255),(255,20,147),(100,149,237)] for box, label, score in sorted(zip(boxes, labels, scores), key=lambda x: x[2].item()): x1, y1, x2, y2 = box.tolist() color_idx = COCO_CLASSES.index(label) if label is not None else 0 c = colors[color_idx % len(colors)] draw.rectangle([x1, y1, x2, y2], outline=c, width=3) if label is not None: draw.text((x1 + 2, y1 + 2), f'{label} {score:.2f}', fill=c, font=font) return img class RTDETRExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ RTDETR_detect, DrawBBoxes, ] async def comfy_entrypoint() -> RTDETRExtension: return RTDETRExtension()