mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-13 18:47:29 +08:00
Compare commits
4 Commits
59cba27043
...
58ccd5be3c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
58ccd5be3c | ||
|
|
8dc3f3f209 | ||
|
|
c011fb520c | ||
|
|
c945a433ae |
@ -561,7 +561,8 @@ class SAM3Model(nn.Module):
|
||||
return high_res_masks
|
||||
|
||||
def forward_video(self, images, initial_masks, pbar=None, text_prompts=None,
|
||||
new_det_thresh=0.5, max_objects=0, detect_interval=1):
|
||||
new_det_thresh=0.5, max_objects=0, detect_interval=1,
|
||||
target_device=None, target_dtype=None):
|
||||
"""Track video with optional per-frame text-prompted detection."""
|
||||
bb = self.detector.backbone["vision_backbone"]
|
||||
|
||||
@ -589,8 +590,10 @@ class SAM3Model(nn.Module):
|
||||
return self.tracker.track_video_with_detection(
|
||||
backbone_fn, images, initial_masks, detect_fn,
|
||||
new_det_thresh=new_det_thresh, max_objects=max_objects,
|
||||
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar)
|
||||
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar,
|
||||
target_device=target_device, target_dtype=target_dtype)
|
||||
# SAM3 (non-multiplex) — no detection support, requires initial masks
|
||||
if initial_masks is None:
|
||||
raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking")
|
||||
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb)
|
||||
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb,
|
||||
target_device=target_device, target_dtype=target_dtype)
|
||||
|
||||
@ -200,8 +200,13 @@ def pack_masks(masks):
|
||||
|
||||
def unpack_masks(packed):
|
||||
"""Unpack bit-packed [*, H, W//8] uint8 to bool [*, H, W*8]."""
|
||||
shifts = torch.arange(8, device=packed.device)
|
||||
return ((packed.unsqueeze(-1) >> shifts) & 1).view(*packed.shape[:-1], -1).bool()
|
||||
bits = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], dtype=torch.uint8, device=packed.device)
|
||||
return (packed.unsqueeze(-1) & bits).bool().view(*packed.shape[:-1], -1)
|
||||
|
||||
|
||||
def _prep_frame(images, idx, device, dt, size):
|
||||
"""Slice CPU full-res frames, transfer to GPU in target dtype, and resize to (size, size)."""
|
||||
return comfy.utils.common_upscale(images[idx].to(device=device, dtype=dt), size, size, "bicubic", crop="disabled")
|
||||
|
||||
|
||||
def _compute_backbone(backbone_fn, frame, frame_idx=None):
|
||||
@ -1078,16 +1083,19 @@ class SAM3Tracker(nn.Module):
|
||||
# SAM3: drop last FPN level
|
||||
return vision_feats[:-1], vision_pos[:-1], feat_sizes[:-1]
|
||||
|
||||
def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None):
|
||||
def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None,
|
||||
target_device=None, target_dtype=None):
|
||||
"""Track one object, computing backbone per frame to save VRAM."""
|
||||
N = images.shape[0]
|
||||
device, dt = images.device, images.dtype
|
||||
device = target_device if target_device is not None else images.device
|
||||
dt = target_dtype if target_dtype is not None else images.dtype
|
||||
size = self.image_size
|
||||
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
|
||||
all_masks = []
|
||||
|
||||
for frame_idx in tqdm(range(N), desc="tracking"):
|
||||
vision_feats, vision_pos, feat_sizes = self._compute_backbone_frame(
|
||||
backbone_fn, images[frame_idx:frame_idx + 1], frame_idx=frame_idx)
|
||||
backbone_fn, _prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), frame_idx=frame_idx)
|
||||
mask_input = None
|
||||
if frame_idx == 0:
|
||||
mask_input = F.interpolate(initial_mask.to(device=device, dtype=dt),
|
||||
@ -1114,12 +1122,13 @@ class SAM3Tracker(nn.Module):
|
||||
|
||||
return torch.cat(all_masks, dim=0) # [N, 1, H, W]
|
||||
|
||||
def track_video(self, backbone_fn, images, initial_masks, pbar=None, **kwargs):
|
||||
def track_video(self, backbone_fn, images, initial_masks, pbar=None,
|
||||
target_device=None, target_dtype=None, **kwargs):
|
||||
"""Track one or more objects across video frames.
|
||||
|
||||
Args:
|
||||
backbone_fn: callable that returns (sam2_features, sam2_positions, trunk_out) for a frame
|
||||
images: [N, 3, 1008, 1008] video frames
|
||||
images: [N, 3, H, W] CPU full-res video frames (resized per-frame to self.image_size)
|
||||
initial_masks: [N_obj, 1, H, W] binary masks for first frame (one per object)
|
||||
pbar: optional progress bar
|
||||
|
||||
@ -1130,7 +1139,8 @@ class SAM3Tracker(nn.Module):
|
||||
per_object = []
|
||||
for obj_idx in range(N_obj):
|
||||
obj_masks = self._track_single_object(
|
||||
backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar)
|
||||
backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar,
|
||||
target_device=target_device, target_dtype=target_dtype)
|
||||
per_object.append(obj_masks)
|
||||
|
||||
return torch.cat(per_object, dim=1) # [N, N_obj, H, W]
|
||||
@ -1632,11 +1642,18 @@ class SAM31Tracker(nn.Module):
|
||||
return det_scores[new_dets].tolist() if det_scores is not None else [0.0] * new_dets.sum().item()
|
||||
return []
|
||||
|
||||
INTERNAL_MAX_OBJECTS = 64 # Hard ceiling on accumulated tracks; max_objects=0 or any value above this is clamped here.
|
||||
|
||||
def track_video_with_detection(self, backbone_fn, images, initial_masks, detect_fn=None,
|
||||
new_det_thresh=0.5, max_objects=0, detect_interval=1,
|
||||
backbone_obj=None, pbar=None):
|
||||
backbone_obj=None, pbar=None, target_device=None, target_dtype=None):
|
||||
"""Track with optional per-frame detection. Returns [N, max_N_obj, H, W] mask logits."""
|
||||
N, device, dt = images.shape[0], images.device, images.dtype
|
||||
if max_objects <= 0 or max_objects > self.INTERNAL_MAX_OBJECTS:
|
||||
max_objects = self.INTERNAL_MAX_OBJECTS
|
||||
N = images.shape[0]
|
||||
device = target_device if target_device is not None else images.device
|
||||
dt = target_dtype if target_dtype is not None else images.dtype
|
||||
size = self.image_size
|
||||
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
|
||||
all_masks = []
|
||||
idev = comfy.model_management.intermediate_device()
|
||||
@ -1656,7 +1673,7 @@ class SAM31Tracker(nn.Module):
|
||||
prefetch = True
|
||||
except RuntimeError:
|
||||
pass
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, images[0:1], frame_idx=0)
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(0, 1), device, dt, size), frame_idx=0)
|
||||
|
||||
for frame_idx in tqdm(range(N), desc="tracking"):
|
||||
vision_feats, vision_pos, feat_sizes, high_res_prop, trunk_out = cur_bb
|
||||
@ -1666,7 +1683,7 @@ class SAM31Tracker(nn.Module):
|
||||
backbone_stream.wait_stream(torch.cuda.current_stream(device))
|
||||
with torch.cuda.stream(backbone_stream):
|
||||
next_bb = self._compute_backbone_frame(
|
||||
backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
|
||||
backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
|
||||
|
||||
# Per-frame detection with NMS (skip if no detect_fn, or interval/max not met)
|
||||
det_masks = torch.empty(0, device=device)
|
||||
@ -1687,7 +1704,7 @@ class SAM31Tracker(nn.Module):
|
||||
current_out = self._condition_with_masks(
|
||||
initial_masks.to(device=device, dtype=dt), frame_idx, vision_feats, vision_pos,
|
||||
feat_sizes, high_res_prop, output_dict, N, mux_state, backbone_obj,
|
||||
images[frame_idx:frame_idx + 1], trunk_out)
|
||||
_prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out)
|
||||
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
|
||||
obj_scores = [1.0] * mux_state.total_valid_entries
|
||||
if keep_alive is not None:
|
||||
@ -1702,7 +1719,7 @@ class SAM31Tracker(nn.Module):
|
||||
current_out = self._condition_with_masks(
|
||||
det_masks, frame_idx, vision_feats, vision_pos, feat_sizes, high_res_prop,
|
||||
output_dict, N, mux_state, backbone_obj,
|
||||
images[frame_idx:frame_idx + 1], trunk_out, threshold=0.0)
|
||||
_prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out, threshold=0.0)
|
||||
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
|
||||
obj_scores = det_scores[:mux_state.total_valid_entries].tolist()
|
||||
if keep_alive is not None:
|
||||
@ -1718,7 +1735,7 @@ class SAM31Tracker(nn.Module):
|
||||
torch.cuda.current_stream(device).wait_stream(backbone_stream)
|
||||
cur_bb = next_bb
|
||||
else:
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
|
||||
continue
|
||||
else:
|
||||
N_obj = mux_state.total_valid_entries
|
||||
@ -1768,7 +1785,7 @@ class SAM31Tracker(nn.Module):
|
||||
torch.cuda.current_stream(device).wait_stream(backbone_stream)
|
||||
cur_bb = next_bb
|
||||
else:
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
|
||||
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
|
||||
|
||||
if not all_masks or all(m is None for m in all_masks):
|
||||
return {"packed_masks": None, "n_frames": N, "scores": []}
|
||||
|
||||
@ -83,13 +83,16 @@ class GeminiImageModel(str, Enum):
|
||||
|
||||
async def create_image_parts(
|
||||
cls: type[IO.ComfyNode],
|
||||
images: Input.Image,
|
||||
images: Input.Image | list[Input.Image],
|
||||
image_limit: int = 0,
|
||||
) -> list[GeminiPart]:
|
||||
image_parts: list[GeminiPart] = []
|
||||
if image_limit < 0:
|
||||
raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.")
|
||||
total_images = get_number_of_images(images)
|
||||
|
||||
# Accept either a single (possibly-batched) tensor or a list of them; share URL budget across all.
|
||||
images_list: list[Input.Image] = images if isinstance(images, list) else [images]
|
||||
total_images = sum(get_number_of_images(img) for img in images_list)
|
||||
if total_images <= 0:
|
||||
raise ValueError("No images provided to create_image_parts; at least one image is required.")
|
||||
|
||||
@ -98,10 +101,18 @@ async def create_image_parts(
|
||||
|
||||
# Number of images we'll send as URLs (fileData)
|
||||
num_url_images = min(effective_max, 10) # Vertex API max number of image links
|
||||
upload_kwargs: dict = {"wait_label": "Uploading reference images"}
|
||||
if effective_max > num_url_images:
|
||||
# Split path (e.g. 11+ images): suppress per-image counter to avoid a confusing dual-fraction label.
|
||||
upload_kwargs = {
|
||||
"wait_label": f"Uploading reference images ({num_url_images}+)",
|
||||
"show_batch_index": False,
|
||||
}
|
||||
reference_images_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
images,
|
||||
images_list,
|
||||
max_images=num_url_images,
|
||||
**upload_kwargs,
|
||||
)
|
||||
for reference_image_url in reference_images_urls:
|
||||
image_parts.append(
|
||||
@ -112,15 +123,22 @@ async def create_image_parts(
|
||||
)
|
||||
)
|
||||
)
|
||||
for idx in range(num_url_images, effective_max):
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.image_png,
|
||||
data=tensor_to_base64_string(images[idx]),
|
||||
if effective_max > num_url_images:
|
||||
flat: list[torch.Tensor] = []
|
||||
for tensor in images_list:
|
||||
if len(tensor.shape) == 4:
|
||||
flat.extend(tensor[i] for i in range(tensor.shape[0]))
|
||||
else:
|
||||
flat.append(tensor)
|
||||
for idx in range(num_url_images, effective_max):
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.image_png,
|
||||
data=tensor_to_base64_string(flat[idx]),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
return image_parts
|
||||
|
||||
|
||||
@ -891,10 +909,6 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
"9:16",
|
||||
"16:9",
|
||||
"21:9",
|
||||
# "1:4",
|
||||
# "4:1",
|
||||
# "8:1",
|
||||
# "1:8",
|
||||
],
|
||||
default="auto",
|
||||
tooltip="If set to 'auto', matches your input image's aspect ratio; "
|
||||
@ -902,12 +916,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
# "512px",
|
||||
"1K",
|
||||
"2K",
|
||||
"4K",
|
||||
],
|
||||
options=["1K", "2K", "4K"],
|
||||
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
@ -956,6 +965,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=GEMINI_IMAGE_2_PRICE_BADGE,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -1016,6 +1026,197 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
def _nano_banana_2_v2_model_inputs():
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[
|
||||
"auto",
|
||||
"1:1",
|
||||
"2:3",
|
||||
"3:2",
|
||||
"3:4",
|
||||
"4:3",
|
||||
"4:5",
|
||||
"5:4",
|
||||
"9:16",
|
||||
"16:9",
|
||||
"21:9",
|
||||
"1:4",
|
||||
"4:1",
|
||||
"8:1",
|
||||
"1:8",
|
||||
],
|
||||
default="auto",
|
||||
tooltip="If set to 'auto', matches your input image's aspect ratio; "
|
||||
"if no image is provided, a 16:9 square is usually generated.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["1K", "2K", "4K"],
|
||||
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"thinking_level",
|
||||
options=["MINIMAL", "HIGH"],
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, 15)],
|
||||
min=0,
|
||||
),
|
||||
tooltip="Optional reference image(s). Up to 14 images total.",
|
||||
),
|
||||
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||
"files",
|
||||
optional=True,
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GeminiNanoBanana2V2",
|
||||
display_name="Nano Banana 2",
|
||||
category="api node/image/Gemini",
|
||||
description="Generate or edit images synchronously via Google Vertex API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text prompt describing the image to generate or the edits to apply. "
|
||||
"Include any constraints, styles, or details the model should follow.",
|
||||
default="",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
"Nano Banana 2 (Gemini 3.1 Flash Image)",
|
||||
_nano_banana_2_v2_model_inputs(),
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=42,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide "
|
||||
"the same response for repeated requests. Deterministic output isn't guaranteed. "
|
||||
"Also, changing the model or parameter settings, such as the temperature, "
|
||||
"can cause variations in the response even when you use the same seed value. "
|
||||
"By default, a random seed value is used.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"response_modalities",
|
||||
options=["IMAGE", "IMAGE+TEXT"],
|
||||
advanced=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default=GEMINI_IMAGE_SYS_PROMPT,
|
||||
optional=True,
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
IO.String.Output(),
|
||||
IO.Image.Output(
|
||||
display_name="thought_image",
|
||||
tooltip="First image from the model's thinking process. "
|
||||
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
|
||||
),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$r := $lookup(widgets, "model.resolution");
|
||||
$prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
|
||||
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
response_modalities: str,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_choice = model["model"]
|
||||
if model_choice == "Nano Banana 2 (Gemini 3.1 Flash Image)":
|
||||
model_id = "gemini-3.1-flash-image-preview"
|
||||
else:
|
||||
model_id = model_choice
|
||||
|
||||
images = model.get("images") or {}
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
if images:
|
||||
image_tensors: list[Input.Image] = [t for t in images.values() if t is not None]
|
||||
if image_tensors:
|
||||
if sum(get_number_of_images(t) for t in image_tensors) > 14:
|
||||
raise ValueError("The current maximum number of supported images is 14.")
|
||||
parts.extend(await create_image_parts(cls, image_tensors))
|
||||
files = model.get("files")
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
image_config = GeminiImageConfig(imageSize=model["resolution"])
|
||||
if model["aspect_ratio"] != "auto":
|
||||
image_config.aspectRatio = model["aspect_ratio"]
|
||||
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model_id}", method="POST"),
|
||||
data=GeminiImageGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||
],
|
||||
generationConfig=GeminiImageGenerationConfig(
|
||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=image_config,
|
||||
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await get_image_from_response(response),
|
||||
get_text_from_response(response),
|
||||
await get_image_from_response(response, thought=True),
|
||||
)
|
||||
|
||||
|
||||
class GeminiExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -1024,6 +1225,7 @@ class GeminiExtension(ComfyExtension):
|
||||
GeminiImage,
|
||||
GeminiImage2,
|
||||
GeminiNanoBanana2,
|
||||
GeminiNanoBanana2V2,
|
||||
GeminiInputFiles,
|
||||
]
|
||||
|
||||
|
||||
@ -2787,11 +2787,15 @@ class MotionControl(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["mode", "model"]),
|
||||
expr="""
|
||||
(
|
||||
$prices := {"std": 0.07, "pro": 0.112};
|
||||
{"type":"usd","usd": $lookup($prices, widgets.mode), "format":{"suffix":"/second"}}
|
||||
$prices := {
|
||||
"kling-v3": {"std": 0.126, "pro": 0.168},
|
||||
"kling-v2-6": {"std": 0.07, "pro": 0.112}
|
||||
};
|
||||
$modelPrices := $lookup($prices, widgets.model);
|
||||
{"type":"usd","usd": $lookup($modelPrices, widgets.mode), "format":{"suffix":"/second"}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
|
||||
@ -272,8 +272,8 @@ class SAM3_VideoTrack(io.ComfyNode):
|
||||
io.Model.Input("model", display_name="model"),
|
||||
io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"),
|
||||
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"),
|
||||
io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection"),
|
||||
io.Int.Input("max_objects", display_name="max_objects", default=0, min=0, tooltip="Max tracked objects (0=unlimited). Initial masks count toward this limit."),
|
||||
io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection."),
|
||||
io.Int.Input("max_objects", display_name="max_objects", default=4, min=0, max=64, tooltip="Max tracked objects. Initial masks count toward this limit. 0 uses the internal cap of 64."),
|
||||
io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."),
|
||||
],
|
||||
outputs=[
|
||||
@ -290,8 +290,7 @@ class SAM3_VideoTrack(io.ComfyNode):
|
||||
dtype = model.model.get_dtype()
|
||||
sam3_model = model.model.diffusion_model
|
||||
|
||||
frames = images[..., :3].movedim(-1, 1)
|
||||
frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype)
|
||||
frames_in = images[..., :3].movedim(-1, 1)
|
||||
|
||||
init_masks = None
|
||||
if initial_mask is not None:
|
||||
@ -308,7 +307,7 @@ class SAM3_VideoTrack(io.ComfyNode):
|
||||
result = sam3_model.forward_video(
|
||||
images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts,
|
||||
new_det_thresh=detection_threshold, max_objects=max_objects,
|
||||
detect_interval=detect_interval)
|
||||
detect_interval=detect_interval, target_device=device, target_dtype=dtype)
|
||||
result["orig_size"] = (H, W)
|
||||
return io.NodeOutput(result)
|
||||
|
||||
@ -449,14 +448,18 @@ class SAM3_TrackPreview(io.ComfyNode):
|
||||
cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area
|
||||
has = area > 1
|
||||
scores = track_data.get("scores", [])
|
||||
label_scale = max(3, H // 240) # Scale font with resolutio
|
||||
size_caps = (area.float().sqrt() / 15).clamp_(min=1).long().tolist() #cap per-object so the number doesn't dwarf small masks
|
||||
for obj_idx in range(N_obj):
|
||||
if has[obj_idx]:
|
||||
_cx, _cy = int(cx[obj_idx]), int(cy[obj_idx])
|
||||
color = cls.COLORS[obj_idx % len(cls.COLORS)]
|
||||
SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color)
|
||||
obj_scale = min(label_scale, size_caps[obj_idx])
|
||||
score_scale = max(1, obj_scale * 2 // 3)
|
||||
SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color, scale=obj_scale)
|
||||
if obj_idx < len(scores) and scores[obj_idx] < 1.0:
|
||||
SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100),
|
||||
_cx, _cy + 5 * 3 + 3, color, scale=2)
|
||||
_cx, _cy + 5 * obj_scale + 3, color, scale=score_scale)
|
||||
frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte())
|
||||
else:
|
||||
frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte())
|
||||
@ -507,9 +510,10 @@ class SAM3_TrackToMask(io.ComfyNode):
|
||||
if not indices:
|
||||
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
|
||||
|
||||
selected = packed[:, indices]
|
||||
binary = unpack_masks(selected) # [N, len(indices), Hm, Wm] bool
|
||||
union = binary.any(dim=1, keepdim=True).float()
|
||||
union_packed = packed[:, indices[0]].clone()
|
||||
for i in indices[1:]:
|
||||
union_packed |= packed[:, i]
|
||||
union = unpack_masks(union_packed).unsqueeze(1).float() # [N, 1, Hm, Wm]
|
||||
mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0]
|
||||
return io.NodeOutput(mask_out)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user