mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-18 07:35:05 +08:00
Node adjustments
This commit is contained in:
parent
57ba8555fe
commit
0c66a69c91
@ -6,6 +6,7 @@ import comfy.model_management
|
||||
import comfy.utils
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from torchvision.transforms import ToPILImage, ToTensor
|
||||
from PIL import ImageDraw, ImageFont
|
||||
|
||||
|
||||
class RFDETR_detect(io.ComfyNode):
|
||||
@ -15,18 +16,20 @@ class RFDETR_detect(io.ComfyNode):
|
||||
node_id="RFDETR_detect",
|
||||
display_name="RF-DETR Detect",
|
||||
category="detection/",
|
||||
search_aliases=["bbox", "bounding box", "object detection", "coco"],
|
||||
inputs=[
|
||||
io.Model.Input("model", display_name="model"),
|
||||
io.Image.Input("image", display_name="image"),
|
||||
io.Float.Input("threshold", display_name="threshold", default=0.5),
|
||||
io.Combo.Input("class_name", options=["all"] + COCO_CLASSES, default="all")
|
||||
io.Combo.Input("class_name", options=["all"] + COCO_CLASSES, default="all", tooltip="Filter detections by class. Set to 'all' to disable filtering."),
|
||||
io.Int.Input("max_detections", display_name="max_detections", default=100, tooltip="Maximum number of detections to return per image. In order of descending confidence score."),
|
||||
],
|
||||
outputs=[
|
||||
io.BoundingBox.Output("bbox")],
|
||||
io.BoundingBox.Output("bboxes")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, image, threshold, class_name) -> io.NodeOutput:
|
||||
def execute(cls, model, image, threshold, class_name, max_detections) -> io.NodeOutput:
|
||||
B, H, W, C = image.shape
|
||||
|
||||
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
|
||||
@ -48,12 +51,7 @@ class RFDETR_detect(io.ComfyNode):
|
||||
for i in range(B):
|
||||
boxes, labels, scores = _postprocess(results[i:i+1], threshold=threshold)
|
||||
|
||||
print(f'\nImage {i + 1}/{B}: Detected {len(boxes)} objects (threshold={threshold}):')
|
||||
for box, label, score in sorted(zip(boxes, labels, scores), key=lambda x: -x[2].item()):
|
||||
print(f' {COCO_CLASSES[label.item()]:20s} {score:.3f} '
|
||||
f'[{box[0]:.0f},{box[1]:.0f},{box[2]:.0f},{box[3]:.0f}]')
|
||||
|
||||
bbox_dicts = [
|
||||
bbox_dicts = sorted([
|
||||
{
|
||||
"x": float(box[0]),
|
||||
"y": float(box[1]),
|
||||
@ -64,67 +62,71 @@ class RFDETR_detect(io.ComfyNode):
|
||||
}
|
||||
for box, label, score in zip(boxes, labels, scores)
|
||||
if class_name == "all" or COCO_CLASSES[int(label)] == class_name
|
||||
]
|
||||
], key=lambda d: d["score"], reverse=True)[:max_detections]
|
||||
all_bbox_dicts.append(bbox_dicts)
|
||||
|
||||
return io.NodeOutput(all_bbox_dicts)
|
||||
|
||||
|
||||
class RFDETR_draw(io.ComfyNode):
|
||||
class DrawBBoxes(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RFDETR_draw",
|
||||
display_name="RF-DETR Draw Detections",
|
||||
node_id="DrawBBoxes",
|
||||
display_name="Draw BBoxes",
|
||||
category="detection/",
|
||||
search_aliases=["bbox", "bounding box", "object detection", "rf_detr", "visualize detections", "coco"],
|
||||
inputs=[
|
||||
io.Image.Input("image", display_name="image", optional=True),
|
||||
io.BoundingBox.Input("bbox", display_name="bbox", force_input=True),
|
||||
io.Image.Input("image", optional=True),
|
||||
io.BoundingBox.Input("bboxes", force_input=True),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output("out_image", display_name="out_image"),
|
||||
io.Image.Output("out_image"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, bbox, image=None) -> io.NodeOutput:
|
||||
# Normalise bbox to a list-of-lists (one list of detections per image).
|
||||
def execute(cls, bboxes, image=None) -> io.NodeOutput:
|
||||
# Normalise bboxes to a list-of-lists (one list of detections per image).
|
||||
# It may arrive as: a bare dict, a flat list of dicts, or a list of lists.
|
||||
B = image.shape[0] if image is not None else 1
|
||||
if isinstance(bbox, dict):
|
||||
bbox = [[bbox]] * B
|
||||
elif not isinstance(bbox, list) or len(bbox) == 0:
|
||||
bbox = [[]] * B
|
||||
elif not isinstance(bbox[0], list):
|
||||
# flat list of dicts → same detections for every image
|
||||
bbox = [bbox] * B
|
||||
if isinstance(bboxes, dict):
|
||||
bboxes = [[bboxes]] * B
|
||||
elif not isinstance(bboxes, list) or len(bboxes) == 0:
|
||||
bboxes = [[]] * B
|
||||
elif not isinstance(bboxes[0], list):
|
||||
# flat list of dicts: same detections for every image
|
||||
bboxes = [bboxes] * B
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros((B, 3, 640, 640), dtype=torch.uint8)
|
||||
B = len(bboxes)
|
||||
max_w = max((int(d["x"] + d["width"]) for frame in bboxes for d in frame), default=640)
|
||||
max_h = max((int(d["y"] + d["height"]) for frame in bboxes for d in frame), default=640)
|
||||
image = torch.zeros((B, max_h, max_w, 3), dtype=torch.float32)
|
||||
|
||||
all_out_images = []
|
||||
for i in range(B):
|
||||
detections = bbox[i]
|
||||
detections = bboxes[i]
|
||||
if detections:
|
||||
boxes = torch.tensor([[d["x"], d["y"], d["x"] + d["width"], d["y"] + d["height"]] for d in detections])
|
||||
labels = torch.tensor([COCO_CLASSES.index(lbl) if (lbl := d.get("label")) in COCO_CLASSES else 0 for d in detections])
|
||||
labels = [d.get("label") if d.get("label") in COCO_CLASSES else None for d in detections]
|
||||
scores = torch.tensor([d.get("score", 1.0) for d in detections])
|
||||
else:
|
||||
boxes = torch.zeros((0, 4))
|
||||
labels = torch.zeros((0,), dtype=torch.long)
|
||||
labels = []
|
||||
scores = torch.zeros((0,))
|
||||
|
||||
pil_image = image[i].movedim(-1, 0)
|
||||
img = ToPILImage()(pil_image)
|
||||
out_image_pil = cls.draw_detections(img, boxes, labels, scores)
|
||||
all_out_images.append(ToTensor()(out_image_pil).unsqueeze(0).movedim(1, -1))
|
||||
if detections:
|
||||
img = cls.draw_detections(img, boxes, labels, scores)
|
||||
all_out_images.append(ToTensor()(img).unsqueeze(0).movedim(1, -1))
|
||||
|
||||
out_images = torch.cat(all_out_images, dim=0)
|
||||
out_images = torch.cat(all_out_images, dim=0).to(comfy.model_management.intermediate_device())
|
||||
return io.NodeOutput(out_images)
|
||||
|
||||
@classmethod
|
||||
def draw_detections(cls, img, boxes, labels, scores):
|
||||
from PIL import ImageDraw, ImageFont
|
||||
draw = ImageDraw.Draw(img)
|
||||
try:
|
||||
font = ImageFont.truetype('arial.ttf', 16)
|
||||
@ -134,10 +136,11 @@ class RFDETR_draw(io.ComfyNode):
|
||||
(0,255,255),(255,20,147),(100,149,237)]
|
||||
for box, label, score in sorted(zip(boxes, labels, scores), key=lambda x: x[2].item()):
|
||||
x1, y1, x2, y2 = box.tolist()
|
||||
c = colors[label.item() % len(colors)]
|
||||
color_idx = COCO_CLASSES.index(label) if label is not None else 0
|
||||
c = colors[color_idx % len(colors)]
|
||||
draw.rectangle([x1, y1, x2, y2], outline=c, width=3)
|
||||
draw.text((x1 + 2, y1 + 2),
|
||||
f'{COCO_CLASSES[label.item()]} {score:.2f}', fill=c, font=font)
|
||||
if label is not None:
|
||||
draw.text((x1 + 2, y1 + 2), f'{label} {score:.2f}', fill=c, font=font)
|
||||
return img
|
||||
|
||||
|
||||
@ -146,7 +149,7 @@ class RFDETRExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
RFDETR_detect,
|
||||
RFDETR_draw,
|
||||
DrawBBoxes,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -661,6 +661,7 @@ class CropByBBoxes(io.ComfyNode):
|
||||
io.Int.Input("output_width", default=512, min=64, max=4096, step=8, tooltip="Width each crop is resized to."),
|
||||
io.Int.Input("output_height", default=512, min=64, max=4096, step=8, tooltip="Height each crop is resized to."),
|
||||
io.Int.Input("padding", default=0, min=0, max=1024, step=1, tooltip="Extra padding in pixels added on each side of the bbox before cropping."),
|
||||
io.Combo.Input("keep_aspect", options=["stretch", "pad"], default="stretch", tooltip="Whether to stretch the crop to fit the output size, or pad with black pixels to preserve aspect ratio."),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output(tooltip="All crops stacked into a single image batch."),
|
||||
@ -668,7 +669,7 @@ class CropByBBoxes(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, image, bboxes, output_width, output_height, padding) -> io.NodeOutput:
|
||||
def execute(cls, image, bboxes, output_width, output_height, padding, keep_aspect="stretch") -> io.NodeOutput:
|
||||
total_frames = image.shape[0]
|
||||
img_h = image.shape[1]
|
||||
img_w = image.shape[2]
|
||||
@ -716,7 +717,19 @@ class CropByBBoxes(io.ComfyNode):
|
||||
x1, y1, x2, y2 = fb_x1, fb_y1, fb_x2, fb_y2
|
||||
|
||||
crop_chw = frame_chw[:, :, y1:y2, x1:x2] # (1, C, crop_h, crop_w)
|
||||
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
|
||||
|
||||
if keep_aspect == "pad":
|
||||
crop_h, crop_w = y2 - y1, x2 - x1
|
||||
scale = min(output_width / crop_w, output_height / crop_h)
|
||||
scaled_w = int(round(crop_w * scale))
|
||||
scaled_h = int(round(crop_h * scale))
|
||||
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
|
||||
pad_left = (output_width - scaled_w) // 2
|
||||
pad_top = (output_height - scaled_h) // 2
|
||||
resized = torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device)
|
||||
resized[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
|
||||
else: # "stretch"
|
||||
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
|
||||
crops.append(resized)
|
||||
|
||||
if not crops:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user