mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 09:12:31 +08:00
Merge branch 'breifnet' of https://github.com/yousef-rafat/ComfyUI into breifnet
This commit is contained in:
commit
7b89cd587c
@ -27,7 +27,7 @@ def frontend_install_warning_message():
|
|||||||
return f"""
|
return f"""
|
||||||
{get_missing_requirements_message()}
|
{get_missing_requirements_message()}
|
||||||
|
|
||||||
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
The ComfyUI frontend is shipped in a pip package so it needs to be updated separately from the ComfyUI code.
|
||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
def parse_version(version: str) -> tuple[int, int, int]:
|
def parse_version(version: str) -> tuple[int, int, int]:
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, TypedDict
|
from typing import TYPE_CHECKING, TypedDict
|
||||||
@ -31,8 +33,22 @@ class NodeReplaceManager:
|
|||||||
self._replacements: dict[str, list[NodeReplace]] = {}
|
self._replacements: dict[str, list[NodeReplace]] = {}
|
||||||
|
|
||||||
def register(self, node_replace: NodeReplace):
|
def register(self, node_replace: NodeReplace):
|
||||||
"""Register a node replacement mapping."""
|
"""Register a node replacement mapping.
|
||||||
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
|
|
||||||
|
Idempotent: if a replacement with the same (old_node_id, new_node_id)
|
||||||
|
is already registered, the duplicate is ignored. This prevents stale
|
||||||
|
entries from accumulating when custom nodes are reloaded in the same
|
||||||
|
process (e.g. via ComfyUI-Manager).
|
||||||
|
"""
|
||||||
|
existing = self._replacements.setdefault(node_replace.old_node_id, [])
|
||||||
|
for entry in existing:
|
||||||
|
if entry.new_node_id == node_replace.new_node_id:
|
||||||
|
logging.debug(
|
||||||
|
"Node replacement %s -> %s already registered, ignoring duplicate.",
|
||||||
|
node_replace.old_node_id, node_replace.new_node_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
existing.append(node_replace)
|
||||||
|
|
||||||
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
||||||
"""Get replacements for an old node ID."""
|
"""Get replacements for an old node ID."""
|
||||||
|
|||||||
@ -93,7 +93,7 @@ class Hook:
|
|||||||
self.hook_scope = hook_scope
|
self.hook_scope = hook_scope
|
||||||
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
|
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
|
||||||
self.custom_should_register = default_should_register
|
self.custom_should_register = default_should_register
|
||||||
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
|
'''Can be overridden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def strength(self):
|
def strength(self):
|
||||||
|
|||||||
@ -1859,6 +1859,23 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
output = torch.zeros_like(x)
|
output = torch.zeros_like(x)
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
current_start_frame = 0
|
current_start_frame = 0
|
||||||
|
|
||||||
|
# I2V: seed KV cache with the initial image latent before the denoising loop
|
||||||
|
initial_latent = transformer_options.get("ar_config", {}).get("initial_latent", None)
|
||||||
|
if initial_latent is not None:
|
||||||
|
initial_latent = inner_model.process_latent_in(initial_latent).to(device=device, dtype=model_dtype)
|
||||||
|
n_init = initial_latent.shape[2]
|
||||||
|
output[:, :, :n_init] = initial_latent
|
||||||
|
|
||||||
|
ar_state = {"start_frame": 0, "kv_caches": kv_caches, "crossattn_caches": crossattn_caches}
|
||||||
|
transformer_options["ar_state"] = ar_state
|
||||||
|
zero_sigma = sigmas.new_zeros([1])
|
||||||
|
_ = model(initial_latent, zero_sigma * s_in, **extra_args)
|
||||||
|
|
||||||
|
current_start_frame = n_init
|
||||||
|
remaining = lat_t - n_init
|
||||||
|
num_blocks = -(-remaining // num_frame_per_block)
|
||||||
|
|
||||||
num_sigma_steps = len(sigmas) - 1
|
num_sigma_steps = len(sigmas) - 1
|
||||||
total_real_steps = num_blocks * num_sigma_steps
|
total_real_steps = num_blocks * num_sigma_steps
|
||||||
step_count = 0
|
step_count = 0
|
||||||
|
|||||||
@ -140,7 +140,7 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
|||||||
alphas = alphacums[ddim_timesteps]
|
alphas = alphacums[ddim_timesteps]
|
||||||
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||||
|
|
||||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
# according to the formula provided in https://arxiv.org/abs/2010.02502
|
||||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||||
if verbose:
|
if verbose:
|
||||||
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||||
|
|||||||
@ -561,7 +561,8 @@ class SAM3Model(nn.Module):
|
|||||||
return high_res_masks
|
return high_res_masks
|
||||||
|
|
||||||
def forward_video(self, images, initial_masks, pbar=None, text_prompts=None,
|
def forward_video(self, images, initial_masks, pbar=None, text_prompts=None,
|
||||||
new_det_thresh=0.5, max_objects=0, detect_interval=1):
|
new_det_thresh=0.5, max_objects=0, detect_interval=1,
|
||||||
|
target_device=None, target_dtype=None):
|
||||||
"""Track video with optional per-frame text-prompted detection."""
|
"""Track video with optional per-frame text-prompted detection."""
|
||||||
bb = self.detector.backbone["vision_backbone"]
|
bb = self.detector.backbone["vision_backbone"]
|
||||||
|
|
||||||
@ -589,8 +590,10 @@ class SAM3Model(nn.Module):
|
|||||||
return self.tracker.track_video_with_detection(
|
return self.tracker.track_video_with_detection(
|
||||||
backbone_fn, images, initial_masks, detect_fn,
|
backbone_fn, images, initial_masks, detect_fn,
|
||||||
new_det_thresh=new_det_thresh, max_objects=max_objects,
|
new_det_thresh=new_det_thresh, max_objects=max_objects,
|
||||||
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar)
|
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar,
|
||||||
|
target_device=target_device, target_dtype=target_dtype)
|
||||||
# SAM3 (non-multiplex) — no detection support, requires initial masks
|
# SAM3 (non-multiplex) — no detection support, requires initial masks
|
||||||
if initial_masks is None:
|
if initial_masks is None:
|
||||||
raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking")
|
raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking")
|
||||||
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb)
|
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb,
|
||||||
|
target_device=target_device, target_dtype=target_dtype)
|
||||||
|
|||||||
@ -200,8 +200,13 @@ def pack_masks(masks):
|
|||||||
|
|
||||||
def unpack_masks(packed):
|
def unpack_masks(packed):
|
||||||
"""Unpack bit-packed [*, H, W//8] uint8 to bool [*, H, W*8]."""
|
"""Unpack bit-packed [*, H, W//8] uint8 to bool [*, H, W*8]."""
|
||||||
shifts = torch.arange(8, device=packed.device)
|
bits = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], dtype=torch.uint8, device=packed.device)
|
||||||
return ((packed.unsqueeze(-1) >> shifts) & 1).view(*packed.shape[:-1], -1).bool()
|
return (packed.unsqueeze(-1) & bits).bool().view(*packed.shape[:-1], -1)
|
||||||
|
|
||||||
|
|
||||||
|
def _prep_frame(images, idx, device, dt, size):
|
||||||
|
"""Slice CPU full-res frames, transfer to GPU in target dtype, and resize to (size, size)."""
|
||||||
|
return comfy.utils.common_upscale(images[idx].to(device=device, dtype=dt), size, size, "bicubic", crop="disabled")
|
||||||
|
|
||||||
|
|
||||||
def _compute_backbone(backbone_fn, frame, frame_idx=None):
|
def _compute_backbone(backbone_fn, frame, frame_idx=None):
|
||||||
@ -1078,16 +1083,19 @@ class SAM3Tracker(nn.Module):
|
|||||||
# SAM3: drop last FPN level
|
# SAM3: drop last FPN level
|
||||||
return vision_feats[:-1], vision_pos[:-1], feat_sizes[:-1]
|
return vision_feats[:-1], vision_pos[:-1], feat_sizes[:-1]
|
||||||
|
|
||||||
def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None):
|
def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None,
|
||||||
|
target_device=None, target_dtype=None):
|
||||||
"""Track one object, computing backbone per frame to save VRAM."""
|
"""Track one object, computing backbone per frame to save VRAM."""
|
||||||
N = images.shape[0]
|
N = images.shape[0]
|
||||||
device, dt = images.device, images.dtype
|
device = target_device if target_device is not None else images.device
|
||||||
|
dt = target_dtype if target_dtype is not None else images.dtype
|
||||||
|
size = self.image_size
|
||||||
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
|
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
|
||||||
all_masks = []
|
all_masks = []
|
||||||
|
|
||||||
for frame_idx in tqdm(range(N), desc="tracking"):
|
for frame_idx in tqdm(range(N), desc="tracking"):
|
||||||
vision_feats, vision_pos, feat_sizes = self._compute_backbone_frame(
|
vision_feats, vision_pos, feat_sizes = self._compute_backbone_frame(
|
||||||
backbone_fn, images[frame_idx:frame_idx + 1], frame_idx=frame_idx)
|
backbone_fn, _prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), frame_idx=frame_idx)
|
||||||
mask_input = None
|
mask_input = None
|
||||||
if frame_idx == 0:
|
if frame_idx == 0:
|
||||||
mask_input = F.interpolate(initial_mask.to(device=device, dtype=dt),
|
mask_input = F.interpolate(initial_mask.to(device=device, dtype=dt),
|
||||||
@ -1114,12 +1122,13 @@ class SAM3Tracker(nn.Module):
|
|||||||
|
|
||||||
return torch.cat(all_masks, dim=0) # [N, 1, H, W]
|
return torch.cat(all_masks, dim=0) # [N, 1, H, W]
|
||||||
|
|
||||||
def track_video(self, backbone_fn, images, initial_masks, pbar=None, **kwargs):
|
def track_video(self, backbone_fn, images, initial_masks, pbar=None,
|
||||||
|
target_device=None, target_dtype=None, **kwargs):
|
||||||
"""Track one or more objects across video frames.
|
"""Track one or more objects across video frames.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
backbone_fn: callable that returns (sam2_features, sam2_positions, trunk_out) for a frame
|
backbone_fn: callable that returns (sam2_features, sam2_positions, trunk_out) for a frame
|
||||||
images: [N, 3, 1008, 1008] video frames
|
images: [N, 3, H, W] CPU full-res video frames (resized per-frame to self.image_size)
|
||||||
initial_masks: [N_obj, 1, H, W] binary masks for first frame (one per object)
|
initial_masks: [N_obj, 1, H, W] binary masks for first frame (one per object)
|
||||||
pbar: optional progress bar
|
pbar: optional progress bar
|
||||||
|
|
||||||
@ -1130,7 +1139,8 @@ class SAM3Tracker(nn.Module):
|
|||||||
per_object = []
|
per_object = []
|
||||||
for obj_idx in range(N_obj):
|
for obj_idx in range(N_obj):
|
||||||
obj_masks = self._track_single_object(
|
obj_masks = self._track_single_object(
|
||||||
backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar)
|
backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar,
|
||||||
|
target_device=target_device, target_dtype=target_dtype)
|
||||||
per_object.append(obj_masks)
|
per_object.append(obj_masks)
|
||||||
|
|
||||||
return torch.cat(per_object, dim=1) # [N, N_obj, H, W]
|
return torch.cat(per_object, dim=1) # [N, N_obj, H, W]
|
||||||
@ -1632,11 +1642,18 @@ class SAM31Tracker(nn.Module):
|
|||||||
return det_scores[new_dets].tolist() if det_scores is not None else [0.0] * new_dets.sum().item()
|
return det_scores[new_dets].tolist() if det_scores is not None else [0.0] * new_dets.sum().item()
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
INTERNAL_MAX_OBJECTS = 64 # Hard ceiling on accumulated tracks; max_objects=0 or any value above this is clamped here.
|
||||||
|
|
||||||
def track_video_with_detection(self, backbone_fn, images, initial_masks, detect_fn=None,
|
def track_video_with_detection(self, backbone_fn, images, initial_masks, detect_fn=None,
|
||||||
new_det_thresh=0.5, max_objects=0, detect_interval=1,
|
new_det_thresh=0.5, max_objects=0, detect_interval=1,
|
||||||
backbone_obj=None, pbar=None):
|
backbone_obj=None, pbar=None, target_device=None, target_dtype=None):
|
||||||
"""Track with optional per-frame detection. Returns [N, max_N_obj, H, W] mask logits."""
|
"""Track with optional per-frame detection. Returns [N, max_N_obj, H, W] mask logits."""
|
||||||
N, device, dt = images.shape[0], images.device, images.dtype
|
if max_objects <= 0 or max_objects > self.INTERNAL_MAX_OBJECTS:
|
||||||
|
max_objects = self.INTERNAL_MAX_OBJECTS
|
||||||
|
N = images.shape[0]
|
||||||
|
device = target_device if target_device is not None else images.device
|
||||||
|
dt = target_dtype if target_dtype is not None else images.dtype
|
||||||
|
size = self.image_size
|
||||||
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
|
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
|
||||||
all_masks = []
|
all_masks = []
|
||||||
idev = comfy.model_management.intermediate_device()
|
idev = comfy.model_management.intermediate_device()
|
||||||
@ -1656,7 +1673,7 @@ class SAM31Tracker(nn.Module):
|
|||||||
prefetch = True
|
prefetch = True
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
pass
|
pass
|
||||||
cur_bb = self._compute_backbone_frame(backbone_fn, images[0:1], frame_idx=0)
|
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(0, 1), device, dt, size), frame_idx=0)
|
||||||
|
|
||||||
for frame_idx in tqdm(range(N), desc="tracking"):
|
for frame_idx in tqdm(range(N), desc="tracking"):
|
||||||
vision_feats, vision_pos, feat_sizes, high_res_prop, trunk_out = cur_bb
|
vision_feats, vision_pos, feat_sizes, high_res_prop, trunk_out = cur_bb
|
||||||
@ -1666,7 +1683,7 @@ class SAM31Tracker(nn.Module):
|
|||||||
backbone_stream.wait_stream(torch.cuda.current_stream(device))
|
backbone_stream.wait_stream(torch.cuda.current_stream(device))
|
||||||
with torch.cuda.stream(backbone_stream):
|
with torch.cuda.stream(backbone_stream):
|
||||||
next_bb = self._compute_backbone_frame(
|
next_bb = self._compute_backbone_frame(
|
||||||
backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
|
backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
|
||||||
|
|
||||||
# Per-frame detection with NMS (skip if no detect_fn, or interval/max not met)
|
# Per-frame detection with NMS (skip if no detect_fn, or interval/max not met)
|
||||||
det_masks = torch.empty(0, device=device)
|
det_masks = torch.empty(0, device=device)
|
||||||
@ -1687,7 +1704,7 @@ class SAM31Tracker(nn.Module):
|
|||||||
current_out = self._condition_with_masks(
|
current_out = self._condition_with_masks(
|
||||||
initial_masks.to(device=device, dtype=dt), frame_idx, vision_feats, vision_pos,
|
initial_masks.to(device=device, dtype=dt), frame_idx, vision_feats, vision_pos,
|
||||||
feat_sizes, high_res_prop, output_dict, N, mux_state, backbone_obj,
|
feat_sizes, high_res_prop, output_dict, N, mux_state, backbone_obj,
|
||||||
images[frame_idx:frame_idx + 1], trunk_out)
|
_prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out)
|
||||||
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
|
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
|
||||||
obj_scores = [1.0] * mux_state.total_valid_entries
|
obj_scores = [1.0] * mux_state.total_valid_entries
|
||||||
if keep_alive is not None:
|
if keep_alive is not None:
|
||||||
@ -1702,7 +1719,7 @@ class SAM31Tracker(nn.Module):
|
|||||||
current_out = self._condition_with_masks(
|
current_out = self._condition_with_masks(
|
||||||
det_masks, frame_idx, vision_feats, vision_pos, feat_sizes, high_res_prop,
|
det_masks, frame_idx, vision_feats, vision_pos, feat_sizes, high_res_prop,
|
||||||
output_dict, N, mux_state, backbone_obj,
|
output_dict, N, mux_state, backbone_obj,
|
||||||
images[frame_idx:frame_idx + 1], trunk_out, threshold=0.0)
|
_prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out, threshold=0.0)
|
||||||
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
|
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
|
||||||
obj_scores = det_scores[:mux_state.total_valid_entries].tolist()
|
obj_scores = det_scores[:mux_state.total_valid_entries].tolist()
|
||||||
if keep_alive is not None:
|
if keep_alive is not None:
|
||||||
@ -1718,7 +1735,7 @@ class SAM31Tracker(nn.Module):
|
|||||||
torch.cuda.current_stream(device).wait_stream(backbone_stream)
|
torch.cuda.current_stream(device).wait_stream(backbone_stream)
|
||||||
cur_bb = next_bb
|
cur_bb = next_bb
|
||||||
else:
|
else:
|
||||||
cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
|
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
N_obj = mux_state.total_valid_entries
|
N_obj = mux_state.total_valid_entries
|
||||||
@ -1768,7 +1785,7 @@ class SAM31Tracker(nn.Module):
|
|||||||
torch.cuda.current_stream(device).wait_stream(backbone_stream)
|
torch.cuda.current_stream(device).wait_stream(backbone_stream)
|
||||||
cur_bb = next_bb
|
cur_bb = next_bb
|
||||||
else:
|
else:
|
||||||
cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
|
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
|
||||||
|
|
||||||
if not all_masks or all(m is None for m in all_masks):
|
if not all_masks or all(m is None for m in all_masks):
|
||||||
return {"packed_masks": None, "n_frames": N, "scores": []}
|
return {"packed_masks": None, "n_frames": N, "scores": []}
|
||||||
|
|||||||
@ -1271,7 +1271,7 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _seedance2_text_inputs(resolutions: list[str]):
|
def _seedance2_text_inputs(resolutions: list[str], default_ratio: str = "16:9"):
|
||||||
return [
|
return [
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -1287,6 +1287,7 @@ def _seedance2_text_inputs(resolutions: list[str]):
|
|||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"ratio",
|
"ratio",
|
||||||
options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"],
|
options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"],
|
||||||
|
default=default_ratio,
|
||||||
tooltip="Aspect ratio of the output video.",
|
tooltip="Aspect ratio of the output video.",
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
@ -1420,8 +1421,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
|||||||
IO.DynamicCombo.Input(
|
IO.DynamicCombo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[
|
options=[
|
||||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])),
|
IO.DynamicCombo.Option(
|
||||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
|
"Seedance 2.0",
|
||||||
|
_seedance2_text_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"Seedance 2.0 Fast",
|
||||||
|
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||||
),
|
),
|
||||||
@ -1588,9 +1595,9 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||||
|
|
||||||
|
|
||||||
def _seedance2_reference_inputs(resolutions: list[str]):
|
def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16:9"):
|
||||||
return [
|
return [
|
||||||
*_seedance2_text_inputs(resolutions),
|
*_seedance2_text_inputs(resolutions, default_ratio=default_ratio),
|
||||||
IO.Autogrow.Input(
|
IO.Autogrow.Input(
|
||||||
"reference_images",
|
"reference_images",
|
||||||
template=IO.Autogrow.TemplateNames(
|
template=IO.Autogrow.TemplateNames(
|
||||||
@ -1668,8 +1675,14 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
|||||||
IO.DynamicCombo.Input(
|
IO.DynamicCombo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[
|
options=[
|
||||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs(["480p", "720p", "1080p"])),
|
IO.DynamicCombo.Option(
|
||||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs(["480p", "720p"])),
|
"Seedance 2.0",
|
||||||
|
_seedance2_reference_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"Seedance 2.0 Fast",
|
||||||
|
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||||
),
|
),
|
||||||
|
|||||||
@ -92,7 +92,7 @@ class SamplerEulerCFGpp(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="SamplerEulerCFGpp",
|
node_id="SamplerEulerCFGpp",
|
||||||
display_name="SamplerEulerCFG++",
|
display_name="SamplerEulerCFG++",
|
||||||
category="_for_testing", # "sampling/custom_sampling/samplers"
|
category="experimental", # "sampling/custom_sampling/samplers"
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Combo.Input("version", options=["regular", "alternative"], advanced=True),
|
io.Combo.Input("version", options=["regular", "alternative"], advanced=True),
|
||||||
],
|
],
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
|
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
|
||||||
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
|
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
|
||||||
- SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop
|
- SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop
|
||||||
|
- ARVideoI2V: image-to-video conditioning for AR models (seeds KV cache with start image)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -9,6 +10,7 @@ from typing_extensions import override
|
|||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
|
import comfy.utils
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
@ -71,12 +73,62 @@ class SamplerARVideo(io.ComfyNode):
|
|||||||
return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options))
|
return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options))
|
||||||
|
|
||||||
|
|
||||||
|
class ARVideoI2V(io.ComfyNode):
|
||||||
|
"""Image-to-video setup for AR video models (Causal Forcing, Self-Forcing).
|
||||||
|
|
||||||
|
VAE-encodes the start image and stores it in the model's transformer_options
|
||||||
|
so that sample_ar_video can seed the KV cache before denoising.
|
||||||
|
Uses the same T2V model checkpoint -- no separate I2V architecture needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ARVideoI2V",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Image.Input("start_image"),
|
||||||
|
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
||||||
|
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
||||||
|
io.Int.Input("length", default=81, min=1, max=1024, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=64),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(display_name="MODEL"),
|
||||||
|
io.Latent.Output(display_name="LATENT"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, vae, start_image, width, height, length, batch_size) -> io.NodeOutput:
|
||||||
|
start_image = comfy.utils.common_upscale(
|
||||||
|
start_image[:1].movedim(-1, 1), width, height, "bilinear", "center"
|
||||||
|
).movedim(1, -1)
|
||||||
|
|
||||||
|
initial_latent = vae.encode(start_image[:, :, :, :3])
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
to = m.model_options.setdefault("transformer_options", {})
|
||||||
|
ar_cfg = to.setdefault("ar_config", {})
|
||||||
|
ar_cfg["initial_latent"] = initial_latent
|
||||||
|
|
||||||
|
lat_t = ((length - 1) // 4) + 1
|
||||||
|
latent = torch.zeros(
|
||||||
|
[batch_size, 16, lat_t, height // 8, width // 8],
|
||||||
|
device=comfy.model_management.intermediate_device(),
|
||||||
|
)
|
||||||
|
return io.NodeOutput(m, {"samples": latent})
|
||||||
|
|
||||||
|
|
||||||
class ARVideoExtension(ComfyExtension):
|
class ARVideoExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
EmptyARVideoLatent,
|
EmptyARVideoLatent,
|
||||||
SamplerARVideo,
|
SamplerARVideo,
|
||||||
|
ARVideoI2V,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -25,7 +25,7 @@ class UNetSelfAttentionMultiply(io.ComfyNode):
|
|||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="UNetSelfAttentionMultiply",
|
node_id="UNetSelfAttentionMultiply",
|
||||||
category="_for_testing/attention_experiments",
|
category="experimental/attention_experiments",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||||
@ -48,7 +48,7 @@ class UNetCrossAttentionMultiply(io.ComfyNode):
|
|||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="UNetCrossAttentionMultiply",
|
node_id="UNetCrossAttentionMultiply",
|
||||||
category="_for_testing/attention_experiments",
|
category="experimental/attention_experiments",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||||
@ -72,7 +72,7 @@ class CLIPAttentionMultiply(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="CLIPAttentionMultiply",
|
node_id="CLIPAttentionMultiply",
|
||||||
search_aliases=["clip attention scale", "text encoder attention"],
|
search_aliases=["clip attention scale", "text encoder attention"],
|
||||||
category="_for_testing/attention_experiments",
|
category="experimental/attention_experiments",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Clip.Input("clip"),
|
io.Clip.Input("clip"),
|
||||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||||
@ -106,7 +106,7 @@ class UNetTemporalAttentionMultiply(io.ComfyNode):
|
|||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="UNetTemporalAttentionMultiply",
|
node_id="UNetTemporalAttentionMultiply",
|
||||||
category="_for_testing/attention_experiments",
|
category="experimental/attention_experiments",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||||
|
|||||||
@ -10,6 +10,7 @@ class AudioEncoderLoader(io.ComfyNode):
|
|||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="AudioEncoderLoader",
|
node_id="AudioEncoderLoader",
|
||||||
|
display_name="Load Audio Encoder",
|
||||||
category="loaders",
|
category="loaders",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
|
|||||||
@ -153,7 +153,7 @@ class WanCameraEmbedding(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="WanCameraEmbedding",
|
node_id="WanCameraEmbedding",
|
||||||
category="camera",
|
category="conditioning/video_models",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"camera_pose",
|
"camera_pose",
|
||||||
|
|||||||
@ -8,7 +8,7 @@ class CLIPTextEncodeControlnet(io.ComfyNode):
|
|||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="CLIPTextEncodeControlnet",
|
node_id="CLIPTextEncodeControlnet",
|
||||||
category="_for_testing/conditioning",
|
category="experimental/conditioning",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Clip.Input("clip"),
|
io.Clip.Input("clip"),
|
||||||
io.Conditioning.Input("conditioning"),
|
io.Conditioning.Input("conditioning"),
|
||||||
@ -35,7 +35,7 @@ class T5TokenizerOptions(io.ComfyNode):
|
|||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="T5TokenizerOptions",
|
node_id="T5TokenizerOptions",
|
||||||
category="_for_testing/conditioning",
|
category="experimental/conditioning",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Clip.Input("clip"),
|
io.Clip.Input("clip"),
|
||||||
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True),
|
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True),
|
||||||
|
|||||||
@ -10,7 +10,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="ContextWindowsManual",
|
node_id="ContextWindowsManual",
|
||||||
display_name="Context Windows (Manual)",
|
display_name="Context Windows (Manual)",
|
||||||
category="context",
|
category="model_patches",
|
||||||
description="Manually set context windows.",
|
description="Manually set context windows.",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||||
|
|||||||
@ -984,7 +984,7 @@ class AddNoise(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="AddNoise",
|
node_id="AddNoise",
|
||||||
category="_for_testing/custom_sampling/noise",
|
category="experimental/custom_sampling/noise",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
@ -1034,7 +1034,7 @@ class ManualSigmas(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="ManualSigmas",
|
node_id="ManualSigmas",
|
||||||
search_aliases=["custom noise schedule", "define sigmas"],
|
search_aliases=["custom noise schedule", "define sigmas"],
|
||||||
category="_for_testing/custom_sampling",
|
category="experimental/custom_sampling",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.String.Input("sigmas", default="1, 0.5", multiline=False)
|
io.String.Input("sigmas", default="1, 0.5", multiline=False)
|
||||||
|
|||||||
@ -13,7 +13,7 @@ class DifferentialDiffusion(io.ComfyNode):
|
|||||||
node_id="DifferentialDiffusion",
|
node_id="DifferentialDiffusion",
|
||||||
search_aliases=["inpaint gradient", "variable denoise strength"],
|
search_aliases=["inpaint gradient", "variable denoise strength"],
|
||||||
display_name="Differential Diffusion",
|
display_name="Differential Diffusion",
|
||||||
category="_for_testing",
|
category="experimental",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.Float.Input(
|
io.Float.Input(
|
||||||
|
|||||||
@ -102,7 +102,7 @@ class FluxDisableGuidance(io.ComfyNode):
|
|||||||
append = execute # TODO: remove
|
append = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
PREFERED_KONTEXT_RESOLUTIONS = [
|
PREFERRED_KONTEXT_RESOLUTIONS = [
|
||||||
(672, 1568),
|
(672, 1568),
|
||||||
(688, 1504),
|
(688, 1504),
|
||||||
(720, 1456),
|
(720, 1456),
|
||||||
@ -143,7 +143,7 @@ class FluxKontextImageScale(io.ComfyNode):
|
|||||||
width = image.shape[2]
|
width = image.shape[2]
|
||||||
height = image.shape[1]
|
height = image.shape[1]
|
||||||
aspect_ratio = width / height
|
aspect_ratio = width / height
|
||||||
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
|
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS)
|
||||||
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
||||||
return io.NodeOutput(image)
|
return io.NodeOutput(image)
|
||||||
|
|
||||||
|
|||||||
@ -60,7 +60,7 @@ class FreSca(io.ComfyNode):
|
|||||||
node_id="FreSca",
|
node_id="FreSca",
|
||||||
search_aliases=["frequency guidance"],
|
search_aliases=["frequency guidance"],
|
||||||
display_name="FreSca",
|
display_name="FreSca",
|
||||||
category="_for_testing",
|
category="experimental",
|
||||||
description="Applies frequency-dependent scaling to the guidance",
|
description="Applies frequency-dependent scaling to the guidance",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
|
|||||||
@ -131,6 +131,8 @@ class HunyuanVideo15SuperResolution(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="HunyuanVideo15SuperResolution",
|
node_id="HunyuanVideo15SuperResolution",
|
||||||
|
display_name="Hunyuan Video 1.5 Super Resolution",
|
||||||
|
category="conditioning/video_models",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Conditioning.Input("positive"),
|
io.Conditioning.Input("positive"),
|
||||||
io.Conditioning.Input("negative"),
|
io.Conditioning.Input("negative"),
|
||||||
@ -381,6 +383,8 @@ class HunyuanRefinerLatent(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="HunyuanRefinerLatent",
|
node_id="HunyuanRefinerLatent",
|
||||||
|
display_name="Hunyuan Latent Refiner",
|
||||||
|
category="conditioning/video_models",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Conditioning.Input("positive"),
|
io.Conditioning.Input("positive"),
|
||||||
io.Conditioning.Input("negative"),
|
io.Conditioning.Input("negative"),
|
||||||
|
|||||||
@ -40,7 +40,7 @@ class Hunyuan3Dv2Conditioning(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="Hunyuan3Dv2Conditioning",
|
node_id="Hunyuan3Dv2Conditioning",
|
||||||
category="conditioning/video_models",
|
category="conditioning/3d_models",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.ClipVisionOutput.Input("clip_vision_output"),
|
IO.ClipVisionOutput.Input("clip_vision_output"),
|
||||||
],
|
],
|
||||||
@ -65,7 +65,7 @@ class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="Hunyuan3Dv2ConditioningMultiView",
|
node_id="Hunyuan3Dv2ConditioningMultiView",
|
||||||
category="conditioning/video_models",
|
category="conditioning/3d_models",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.ClipVisionOutput.Input("front", optional=True),
|
IO.ClipVisionOutput.Input("front", optional=True),
|
||||||
IO.ClipVisionOutput.Input("left", optional=True),
|
IO.ClipVisionOutput.Input("left", optional=True),
|
||||||
@ -424,6 +424,7 @@ class VoxelToMeshBasic(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="VoxelToMeshBasic",
|
node_id="VoxelToMeshBasic",
|
||||||
|
display_name="Voxel to Mesh (Basic)",
|
||||||
category="3d",
|
category="3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Voxel.Input("voxel"),
|
IO.Voxel.Input("voxel"),
|
||||||
@ -453,6 +454,7 @@ class VoxelToMesh(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="VoxelToMesh",
|
node_id="VoxelToMesh",
|
||||||
|
display_name="Voxel to Mesh",
|
||||||
category="3d",
|
category="3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Voxel.Input("voxel"),
|
IO.Voxel.Input("voxel"),
|
||||||
|
|||||||
@ -102,6 +102,7 @@ class HypernetworkLoader(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="HypernetworkLoader",
|
node_id="HypernetworkLoader",
|
||||||
|
display_name="Load Hypernetwork",
|
||||||
category="loaders",
|
category="loaders",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Model.Input("model"),
|
IO.Model.Input("model"),
|
||||||
|
|||||||
@ -91,7 +91,7 @@ class LoraSave(io.ComfyNode):
|
|||||||
node_id="LoraSave",
|
node_id="LoraSave",
|
||||||
search_aliases=["export lora"],
|
search_aliases=["export lora"],
|
||||||
display_name="Extract and Save Lora",
|
display_name="Extract and Save Lora",
|
||||||
category="_for_testing",
|
category="experimental",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
|
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
|
||||||
io.Int.Input("rank", default=8, min=1, max=4096, step=1, advanced=True),
|
io.Int.Input("rank", default=8, min=1, max=4096, step=1, advanced=True),
|
||||||
|
|||||||
@ -594,7 +594,8 @@ class LTXVPreprocess(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="LTXVPreprocess",
|
node_id="LTXVPreprocess",
|
||||||
category="image",
|
display_name="LTXV Preprocess",
|
||||||
|
category="video/preprocessors",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
io.Int.Input(
|
io.Int.Input(
|
||||||
|
|||||||
@ -11,7 +11,7 @@ class Mahiro(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="Mahiro",
|
node_id="Mahiro",
|
||||||
display_name="Positive-Biased Guidance",
|
display_name="Positive-Biased Guidance",
|
||||||
category="_for_testing",
|
category="experimental",
|
||||||
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
|
|||||||
@ -70,7 +70,7 @@ class MathExpressionNode(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="ComfyMathExpression",
|
node_id="ComfyMathExpression",
|
||||||
display_name="Math Expression",
|
display_name="Math Expression",
|
||||||
category="math",
|
category="logic",
|
||||||
search_aliases=[
|
search_aliases=[
|
||||||
"expression", "formula", "calculate", "calculator",
|
"expression", "formula", "calculate", "calculator",
|
||||||
"eval", "math",
|
"eval", "math",
|
||||||
|
|||||||
@ -21,7 +21,7 @@ class NumberConvertNode(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="ComfyNumberConvert",
|
node_id="ComfyNumberConvert",
|
||||||
display_name="Number Convert",
|
display_name="Number Convert",
|
||||||
category="math",
|
category="utils",
|
||||||
search_aliases=[
|
search_aliases=[
|
||||||
"int to float", "float to int", "number convert",
|
"int to float", "float to int", "number convert",
|
||||||
"int2float", "float2int", "cast", "parse number",
|
"int2float", "float2int", "cast", "parse number",
|
||||||
|
|||||||
@ -24,8 +24,8 @@ class PerpNeg(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="PerpNeg",
|
node_id="PerpNeg",
|
||||||
display_name="Perp-Neg (DEPRECATED by PerpNegGuider)",
|
display_name="Perp-Neg (DEPRECATED by Perp-Neg Guider)",
|
||||||
category="_for_testing",
|
category="experimental",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.Conditioning.Input("empty_conditioning"),
|
io.Conditioning.Input("empty_conditioning"),
|
||||||
@ -127,7 +127,8 @@ class PerpNegGuider(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="PerpNegGuider",
|
node_id="PerpNegGuider",
|
||||||
category="_for_testing",
|
display_name="Perp-Neg Guider",
|
||||||
|
category="experimental",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.Conditioning.Input("positive"),
|
io.Conditioning.Input("positive"),
|
||||||
|
|||||||
@ -123,7 +123,7 @@ class PhotoMakerLoader(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="PhotoMakerLoader",
|
node_id="PhotoMakerLoader",
|
||||||
category="_for_testing/photomaker",
|
category="experimental/photomaker",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
|
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
|
||||||
],
|
],
|
||||||
@ -149,7 +149,7 @@ class PhotoMakerEncode(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="PhotoMakerEncode",
|
node_id="PhotoMakerEncode",
|
||||||
category="_for_testing/photomaker",
|
category="experimental/photomaker",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Photomaker.Input("photomaker"),
|
io.Photomaker.Input("photomaker"),
|
||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
|
|||||||
@ -116,6 +116,7 @@ class Quantize(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="ImageQuantize",
|
node_id="ImageQuantize",
|
||||||
|
display_name="Quantize Image",
|
||||||
category="image/postprocessing",
|
category="image/postprocessing",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
@ -181,6 +182,7 @@ class Sharpen(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="ImageSharpen",
|
node_id="ImageSharpen",
|
||||||
|
display_name="Sharpen Image",
|
||||||
category="image/postprocessing",
|
category="image/postprocessing",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
@ -436,7 +438,7 @@ class ResizeImageMaskNode(io.ComfyNode):
|
|||||||
node_id="ResizeImageMaskNode",
|
node_id="ResizeImageMaskNode",
|
||||||
display_name="Resize Image/Mask",
|
display_name="Resize Image/Mask",
|
||||||
description="Resize an image or mask using various scaling methods.",
|
description="Resize an image or mask using various scaling methods.",
|
||||||
category="transform",
|
category="image/transform",
|
||||||
search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"],
|
search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"],
|
||||||
inputs=[
|
inputs=[
|
||||||
io.MatchType.Input("input", template=template),
|
io.MatchType.Input("input", template=template),
|
||||||
|
|||||||
@ -15,7 +15,7 @@ class RTDETR_detect(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="RTDETR_detect",
|
node_id="RTDETR_detect",
|
||||||
display_name="RT-DETR Detect",
|
display_name="RT-DETR Detect",
|
||||||
category="detection/",
|
category="detection",
|
||||||
search_aliases=["bbox", "bounding box", "object detection", "coco"],
|
search_aliases=["bbox", "bounding box", "object detection", "coco"],
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model", display_name="model"),
|
io.Model.Input("model", display_name="model"),
|
||||||
@ -71,7 +71,7 @@ class DrawBBoxes(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="DrawBBoxes",
|
node_id="DrawBBoxes",
|
||||||
display_name="Draw BBoxes",
|
display_name="Draw BBoxes",
|
||||||
category="detection/",
|
category="detection",
|
||||||
search_aliases=["bbox", "bounding box", "object detection", "rt_detr", "visualize detections", "coco"],
|
search_aliases=["bbox", "bounding box", "object detection", "rt_detr", "visualize detections", "coco"],
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("image", optional=True),
|
io.Image.Input("image", optional=True),
|
||||||
|
|||||||
@ -113,7 +113,7 @@ class SelfAttentionGuidance(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="SelfAttentionGuidance",
|
node_id="SelfAttentionGuidance",
|
||||||
display_name="Self-Attention Guidance",
|
display_name="Self-Attention Guidance",
|
||||||
category="_for_testing",
|
category="experimental",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01),
|
io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01),
|
||||||
|
|||||||
@ -93,7 +93,7 @@ class SAM3_Detect(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="SAM3_Detect",
|
node_id="SAM3_Detect",
|
||||||
display_name="SAM3 Detect",
|
display_name="SAM3 Detect",
|
||||||
category="detection/",
|
category="detection",
|
||||||
search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"],
|
search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"],
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model", display_name="model"),
|
io.Model.Input("model", display_name="model"),
|
||||||
@ -265,15 +265,15 @@ class SAM3_VideoTrack(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="SAM3_VideoTrack",
|
node_id="SAM3_VideoTrack",
|
||||||
display_name="SAM3 Video Track",
|
display_name="SAM3 Video Track",
|
||||||
category="detection/",
|
category="detection",
|
||||||
search_aliases=["sam3", "video", "track", "propagate"],
|
search_aliases=["sam3", "video", "track", "propagate"],
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"),
|
io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"),
|
||||||
io.Model.Input("model", display_name="model"),
|
io.Model.Input("model", display_name="model"),
|
||||||
io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"),
|
io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"),
|
||||||
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"),
|
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"),
|
||||||
io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection"),
|
io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection."),
|
||||||
io.Int.Input("max_objects", display_name="max_objects", default=0, min=0, tooltip="Max tracked objects (0=unlimited). Initial masks count toward this limit."),
|
io.Int.Input("max_objects", display_name="max_objects", default=4, min=0, max=64, tooltip="Max tracked objects. Initial masks count toward this limit. 0 uses the internal cap of 64."),
|
||||||
io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."),
|
io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
@ -290,8 +290,7 @@ class SAM3_VideoTrack(io.ComfyNode):
|
|||||||
dtype = model.model.get_dtype()
|
dtype = model.model.get_dtype()
|
||||||
sam3_model = model.model.diffusion_model
|
sam3_model = model.model.diffusion_model
|
||||||
|
|
||||||
frames = images[..., :3].movedim(-1, 1)
|
frames_in = images[..., :3].movedim(-1, 1)
|
||||||
frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
init_masks = None
|
init_masks = None
|
||||||
if initial_mask is not None:
|
if initial_mask is not None:
|
||||||
@ -308,7 +307,7 @@ class SAM3_VideoTrack(io.ComfyNode):
|
|||||||
result = sam3_model.forward_video(
|
result = sam3_model.forward_video(
|
||||||
images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts,
|
images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts,
|
||||||
new_det_thresh=detection_threshold, max_objects=max_objects,
|
new_det_thresh=detection_threshold, max_objects=max_objects,
|
||||||
detect_interval=detect_interval)
|
detect_interval=detect_interval, target_device=device, target_dtype=dtype)
|
||||||
result["orig_size"] = (H, W)
|
result["orig_size"] = (H, W)
|
||||||
return io.NodeOutput(result)
|
return io.NodeOutput(result)
|
||||||
|
|
||||||
@ -321,7 +320,7 @@ class SAM3_TrackPreview(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="SAM3_TrackPreview",
|
node_id="SAM3_TrackPreview",
|
||||||
display_name="SAM3 Track Preview",
|
display_name="SAM3 Track Preview",
|
||||||
category="detection/",
|
category="detection",
|
||||||
inputs=[
|
inputs=[
|
||||||
SAM3TrackData.Input("track_data", display_name="track_data"),
|
SAM3TrackData.Input("track_data", display_name="track_data"),
|
||||||
io.Image.Input("images", display_name="images", optional=True),
|
io.Image.Input("images", display_name="images", optional=True),
|
||||||
@ -449,14 +448,18 @@ class SAM3_TrackPreview(io.ComfyNode):
|
|||||||
cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area
|
cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area
|
||||||
has = area > 1
|
has = area > 1
|
||||||
scores = track_data.get("scores", [])
|
scores = track_data.get("scores", [])
|
||||||
|
label_scale = max(3, H // 240) # Scale font with resolutio
|
||||||
|
size_caps = (area.float().sqrt() / 15).clamp_(min=1).long().tolist() #cap per-object so the number doesn't dwarf small masks
|
||||||
for obj_idx in range(N_obj):
|
for obj_idx in range(N_obj):
|
||||||
if has[obj_idx]:
|
if has[obj_idx]:
|
||||||
_cx, _cy = int(cx[obj_idx]), int(cy[obj_idx])
|
_cx, _cy = int(cx[obj_idx]), int(cy[obj_idx])
|
||||||
color = cls.COLORS[obj_idx % len(cls.COLORS)]
|
color = cls.COLORS[obj_idx % len(cls.COLORS)]
|
||||||
SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color)
|
obj_scale = min(label_scale, size_caps[obj_idx])
|
||||||
|
score_scale = max(1, obj_scale * 2 // 3)
|
||||||
|
SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color, scale=obj_scale)
|
||||||
if obj_idx < len(scores) and scores[obj_idx] < 1.0:
|
if obj_idx < len(scores) and scores[obj_idx] < 1.0:
|
||||||
SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100),
|
SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100),
|
||||||
_cx, _cy + 5 * 3 + 3, color, scale=2)
|
_cx, _cy + 5 * obj_scale + 3, color, scale=score_scale)
|
||||||
frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte())
|
frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte())
|
||||||
else:
|
else:
|
||||||
frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte())
|
frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte())
|
||||||
@ -475,7 +478,7 @@ class SAM3_TrackToMask(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="SAM3_TrackToMask",
|
node_id="SAM3_TrackToMask",
|
||||||
display_name="SAM3 Track to Mask",
|
display_name="SAM3 Track to Mask",
|
||||||
category="detection/",
|
category="detection",
|
||||||
inputs=[
|
inputs=[
|
||||||
SAM3TrackData.Input("track_data", display_name="track_data"),
|
SAM3TrackData.Input("track_data", display_name="track_data"),
|
||||||
io.String.Input("object_indices", display_name="object_indices", default="",
|
io.String.Input("object_indices", display_name="object_indices", default="",
|
||||||
@ -507,9 +510,10 @@ class SAM3_TrackToMask(io.ComfyNode):
|
|||||||
if not indices:
|
if not indices:
|
||||||
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
|
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
|
||||||
|
|
||||||
selected = packed[:, indices]
|
union_packed = packed[:, indices[0]].clone()
|
||||||
binary = unpack_masks(selected) # [N, len(indices), Hm, Wm] bool
|
for i in indices[1:]:
|
||||||
union = binary.any(dim=1, keepdim=True).float()
|
union_packed |= packed[:, i]
|
||||||
|
union = unpack_masks(union_packed).unsqueeze(1).float() # [N, 1, Hm, Wm]
|
||||||
mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0]
|
mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0]
|
||||||
return io.NodeOutput(mask_out)
|
return io.NodeOutput(mask_out)
|
||||||
|
|
||||||
|
|||||||
@ -119,7 +119,7 @@ class StableCascade_SuperResolutionControlnet(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="StableCascade_SuperResolutionControlnet",
|
node_id="StableCascade_SuperResolutionControlnet",
|
||||||
category="_for_testing/stable_cascade",
|
category="experimental/stable_cascade",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
|
|||||||
@ -26,7 +26,8 @@ class TextGenerate(io.ComfyNode):
|
|||||||
|
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="TextGenerate",
|
node_id="TextGenerate",
|
||||||
category="textgen",
|
display_name="Generate Text",
|
||||||
|
category="text",
|
||||||
search_aliases=["LLM", "gemma"],
|
search_aliases=["LLM", "gemma"],
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Clip.Input("clip"),
|
io.Clip.Input("clip"),
|
||||||
@ -157,6 +158,7 @@ class TextGenerateLTX2Prompt(TextGenerate):
|
|||||||
parent_schema = super().define_schema()
|
parent_schema = super().define_schema()
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="TextGenerateLTX2Prompt",
|
node_id="TextGenerateLTX2Prompt",
|
||||||
|
display_name="Generate LTX2 Prompt",
|
||||||
category=parent_schema.category,
|
category=parent_schema.category,
|
||||||
inputs=parent_schema.inputs,
|
inputs=parent_schema.inputs,
|
||||||
outputs=parent_schema.outputs,
|
outputs=parent_schema.outputs,
|
||||||
|
|||||||
@ -10,7 +10,7 @@ class TorchCompileModel(io.ComfyNode):
|
|||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="TorchCompileModel",
|
node_id="TorchCompileModel",
|
||||||
category="_for_testing",
|
category="experimental",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
|
|||||||
@ -1361,7 +1361,7 @@ class SaveLoRA(io.ComfyNode):
|
|||||||
node_id="SaveLoRA",
|
node_id="SaveLoRA",
|
||||||
search_aliases=["export lora"],
|
search_aliases=["export lora"],
|
||||||
display_name="Save LoRA Weights",
|
display_name="Save LoRA Weights",
|
||||||
category="loaders",
|
category="advanced/model_merging",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
is_output_node=True,
|
is_output_node=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
|
|||||||
@ -15,7 +15,7 @@ class ImageOnlyCheckpointLoader:
|
|||||||
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
|
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
|
||||||
FUNCTION = "load_checkpoint"
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
CATEGORY = "loaders/video_models"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
|
|||||||
@ -22,7 +22,7 @@ class SaveImageWebsocket:
|
|||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
CATEGORY = "api/image"
|
CATEGORY = "image"
|
||||||
|
|
||||||
def save_images(self, images):
|
def save_images(self, images):
|
||||||
pbar = comfy.utils.ProgressBar(images.shape[0])
|
pbar = comfy.utils.ProgressBar(images.shape[0])
|
||||||
@ -42,3 +42,7 @@ class SaveImageWebsocket:
|
|||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"SaveImageWebsocket": SaveImageWebsocket,
|
"SaveImageWebsocket": SaveImageWebsocket,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"SaveImageWebsocket": "Save Image (Websocket)",
|
||||||
|
}
|
||||||
14
nodes.py
14
nodes.py
@ -330,7 +330,7 @@ class VAEDecodeTiled:
|
|||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "decode"
|
FUNCTION = "decode"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "experimental"
|
||||||
|
|
||||||
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
|
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
|
||||||
if tile_size < overlap * 4:
|
if tile_size < overlap * 4:
|
||||||
@ -377,7 +377,7 @@ class VAEEncodeTiled:
|
|||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "experimental"
|
||||||
|
|
||||||
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
|
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
|
||||||
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
|
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||||
@ -493,7 +493,7 @@ class SaveLatent:
|
|||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "experimental"
|
||||||
|
|
||||||
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||||
@ -538,7 +538,7 @@ class LoadLatent:
|
|||||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
|
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
|
||||||
return {"required": {"latent": [sorted(files), ]}, }
|
return {"required": {"latent": [sorted(files), ]}, }
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "experimental"
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT", )
|
RETURN_TYPES = ("LATENT", )
|
||||||
FUNCTION = "load"
|
FUNCTION = "load"
|
||||||
@ -1443,7 +1443,7 @@ class LatentBlend:
|
|||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "blend"
|
FUNCTION = "blend"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "experimental"
|
||||||
|
|
||||||
def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
|
def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
|
||||||
|
|
||||||
@ -2092,6 +2092,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"StyleModelLoader": "Load Style Model",
|
"StyleModelLoader": "Load Style Model",
|
||||||
"CLIPVisionLoader": "Load CLIP Vision",
|
"CLIPVisionLoader": "Load CLIP Vision",
|
||||||
"UNETLoader": "Load Diffusion Model",
|
"UNETLoader": "Load Diffusion Model",
|
||||||
|
"unCLIPCheckpointLoader": "Load unCLIP Checkpoint",
|
||||||
|
"GLIGENLoader": "Load GLIGEN Model",
|
||||||
# Conditioning
|
# Conditioning
|
||||||
"CLIPVisionEncode": "CLIP Vision Encode",
|
"CLIPVisionEncode": "CLIP Vision Encode",
|
||||||
"StyleModelApply": "Apply Style Model",
|
"StyleModelApply": "Apply Style Model",
|
||||||
@ -2140,7 +2142,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ImageSharpen": "Sharpen Image",
|
"ImageSharpen": "Sharpen Image",
|
||||||
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
||||||
"GetImageSize": "Get Image Size",
|
"GetImageSize": "Get Image Size",
|
||||||
# _for_testing
|
# experimental
|
||||||
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
||||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||||
}
|
}
|
||||||
|
|||||||
90
tests-unit/app_test/node_replace_manager_test.py
Normal file
90
tests-unit/app_test/node_replace_manager_test.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
"""Tests for NodeReplaceManager registration behavior."""
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def NodeReplaceManager(monkeypatch):
|
||||||
|
"""Provide NodeReplaceManager with `nodes` stubbed.
|
||||||
|
|
||||||
|
`app.node_replace_manager` does `import nodes` at module level, which pulls in
|
||||||
|
torch + the full ComfyUI graph. register() doesn't actually need it, so we
|
||||||
|
stub `nodes` per-test (via monkeypatch so it's torn down) and reload the
|
||||||
|
module so it picks up the stub instead of any cached real import.
|
||||||
|
"""
|
||||||
|
fake_nodes = types.ModuleType("nodes")
|
||||||
|
fake_nodes.NODE_CLASS_MAPPINGS = {}
|
||||||
|
monkeypatch.setitem(sys.modules, "nodes", fake_nodes)
|
||||||
|
monkeypatch.delitem(sys.modules, "app.node_replace_manager", raising=False)
|
||||||
|
module = importlib.import_module("app.node_replace_manager")
|
||||||
|
yield module.NodeReplaceManager
|
||||||
|
# Drop the freshly-imported module so the next test (or a later real import
|
||||||
|
# of `nodes`) starts from a clean slate.
|
||||||
|
sys.modules.pop("app.node_replace_manager", None)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeNodeReplace:
|
||||||
|
"""Lightweight stand-in for comfy_api.latest._io.NodeReplace."""
|
||||||
|
def __init__(self, new_node_id, old_node_id, old_widget_ids=None,
|
||||||
|
input_mapping=None, output_mapping=None):
|
||||||
|
self.new_node_id = new_node_id
|
||||||
|
self.old_node_id = old_node_id
|
||||||
|
self.old_widget_ids = old_widget_ids
|
||||||
|
self.input_mapping = input_mapping
|
||||||
|
self.output_mapping = output_mapping
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_adds_replacement(NodeReplaceManager):
|
||||||
|
manager = NodeReplaceManager()
|
||||||
|
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
|
||||||
|
assert manager.has_replacement("OldNode")
|
||||||
|
assert len(manager.get_replacement("OldNode")) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_allows_multiple_alternatives_for_same_old_node(NodeReplaceManager):
|
||||||
|
"""Different new_node_ids for the same old_node_id should all be kept."""
|
||||||
|
manager = NodeReplaceManager()
|
||||||
|
manager.register(FakeNodeReplace(new_node_id="AltA", old_node_id="OldNode"))
|
||||||
|
manager.register(FakeNodeReplace(new_node_id="AltB", old_node_id="OldNode"))
|
||||||
|
replacements = manager.get_replacement("OldNode")
|
||||||
|
assert len(replacements) == 2
|
||||||
|
assert {r.new_node_id for r in replacements} == {"AltA", "AltB"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_is_idempotent_for_duplicate_pair(NodeReplaceManager):
|
||||||
|
"""Re-registering the same (old_node_id, new_node_id) should be a no-op."""
|
||||||
|
manager = NodeReplaceManager()
|
||||||
|
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
|
||||||
|
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
|
||||||
|
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
|
||||||
|
assert len(manager.get_replacement("OldNode")) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_idempotent_preserves_first_registration(NodeReplaceManager):
|
||||||
|
"""First registration wins; later duplicates with different mappings are ignored."""
|
||||||
|
manager = NodeReplaceManager()
|
||||||
|
first = FakeNodeReplace(
|
||||||
|
new_node_id="NewNode", old_node_id="OldNode",
|
||||||
|
input_mapping=[{"new_id": "a", "old_id": "x"}],
|
||||||
|
)
|
||||||
|
second = FakeNodeReplace(
|
||||||
|
new_node_id="NewNode", old_node_id="OldNode",
|
||||||
|
input_mapping=[{"new_id": "b", "old_id": "y"}],
|
||||||
|
)
|
||||||
|
manager.register(first)
|
||||||
|
manager.register(second)
|
||||||
|
replacements = manager.get_replacement("OldNode")
|
||||||
|
assert len(replacements) == 1
|
||||||
|
assert replacements[0] is first
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_dedupe_does_not_affect_other_old_nodes(NodeReplaceManager):
|
||||||
|
manager = NodeReplaceManager()
|
||||||
|
manager.register(FakeNodeReplace(new_node_id="NewA", old_node_id="OldA"))
|
||||||
|
manager.register(FakeNodeReplace(new_node_id="NewA", old_node_id="OldA"))
|
||||||
|
manager.register(FakeNodeReplace(new_node_id="NewB", old_node_id="OldB"))
|
||||||
|
assert len(manager.get_replacement("OldA")) == 1
|
||||||
|
assert len(manager.get_replacement("OldB")) == 1
|
||||||
@ -21,7 +21,7 @@ class TestAsyncProgressUpdate(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = (IO.ANY,)
|
RETURN_TYPES = (IO.ANY,)
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
async def execute(self, value, sleep_seconds):
|
async def execute(self, value, sleep_seconds):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@ -51,7 +51,7 @@ class TestSyncProgressUpdate(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = (IO.ANY,)
|
RETURN_TYPES = (IO.ANY,)
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
def execute(self, value, sleep_seconds):
|
def execute(self, value, sleep_seconds):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|||||||
@ -21,7 +21,7 @@ class TestAsyncValidation(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def VALIDATE_INPUTS(cls, value, threshold):
|
async def VALIDATE_INPUTS(cls, value, threshold):
|
||||||
@ -53,7 +53,7 @@ class TestAsyncError(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = (IO.ANY,)
|
RETURN_TYPES = (IO.ANY,)
|
||||||
FUNCTION = "error_execution"
|
FUNCTION = "error_execution"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
async def error_execution(self, value, error_after):
|
async def error_execution(self, value, error_after):
|
||||||
await asyncio.sleep(error_after)
|
await asyncio.sleep(error_after)
|
||||||
@ -74,7 +74,7 @@ class TestAsyncValidationError(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def VALIDATE_INPUTS(cls, value, max_value):
|
async def VALIDATE_INPUTS(cls, value, max_value):
|
||||||
@ -105,7 +105,7 @@ class TestAsyncTimeout(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = (IO.ANY,)
|
RETURN_TYPES = (IO.ANY,)
|
||||||
FUNCTION = "timeout_execution"
|
FUNCTION = "timeout_execution"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
async def timeout_execution(self, value, timeout, operation_time):
|
async def timeout_execution(self, value, timeout, operation_time):
|
||||||
try:
|
try:
|
||||||
@ -129,7 +129,7 @@ class TestSyncError(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = (IO.ANY,)
|
RETURN_TYPES = (IO.ANY,)
|
||||||
FUNCTION = "sync_error"
|
FUNCTION = "sync_error"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
def sync_error(self, value):
|
def sync_error(self, value):
|
||||||
raise RuntimeError("Intentional sync execution error for testing")
|
raise RuntimeError("Intentional sync execution error for testing")
|
||||||
@ -150,7 +150,7 @@ class TestAsyncLazyCheck(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
async def check_lazy_status(self, condition, input1, input2):
|
async def check_lazy_status(self, condition, input1, input2):
|
||||||
# Simulate async checking (e.g., querying remote service)
|
# Simulate async checking (e.g., querying remote service)
|
||||||
@ -184,7 +184,7 @@ class TestDynamicAsyncGeneration(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "generate_async_workflow"
|
FUNCTION = "generate_async_workflow"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
|
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
|
||||||
g = GraphBuilder()
|
g = GraphBuilder()
|
||||||
@ -229,7 +229,7 @@ class TestAsyncResourceUser(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = (IO.ANY,)
|
RETURN_TYPES = (IO.ANY,)
|
||||||
FUNCTION = "use_resource"
|
FUNCTION = "use_resource"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
async def use_resource(self, value, resource_id, duration):
|
async def use_resource(self, value, resource_id, duration):
|
||||||
# Check if resource is already in use
|
# Check if resource is already in use
|
||||||
@ -265,7 +265,7 @@ class TestAsyncBatchProcessing(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "process_batch"
|
FUNCTION = "process_batch"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
async def process_batch(self, images, process_time_per_item, unique_id):
|
async def process_batch(self, images, process_time_per_item, unique_id):
|
||||||
batch_size = images.shape[0]
|
batch_size = images.shape[0]
|
||||||
@ -305,7 +305,7 @@ class TestAsyncConcurrentLimit(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = (IO.ANY,)
|
RETURN_TYPES = (IO.ANY,)
|
||||||
FUNCTION = "limited_execution"
|
FUNCTION = "limited_execution"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
async def limited_execution(self, value, duration, node_id):
|
async def limited_execution(self, value, duration, node_id):
|
||||||
async with self._semaphore:
|
async with self._semaphore:
|
||||||
|
|||||||
@ -409,7 +409,7 @@ class TestSleep(ComfyNodeABC):
|
|||||||
RETURN_TYPES = (IO.ANY,)
|
RETURN_TYPES = (IO.ANY,)
|
||||||
FUNCTION = "sleep"
|
FUNCTION = "sleep"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "experimental"
|
||||||
|
|
||||||
async def sleep(self, value, seconds, unique_id):
|
async def sleep(self, value, seconds, unique_id):
|
||||||
pbar = ProgressBar(seconds, node_id=unique_id)
|
pbar = ProgressBar(seconds, node_id=unique_id)
|
||||||
@ -440,7 +440,7 @@ class TestParallelSleep(ComfyNodeABC):
|
|||||||
}
|
}
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "parallel_sleep"
|
FUNCTION = "parallel_sleep"
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "experimental"
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
|
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
|
||||||
@ -474,7 +474,7 @@ class TestOutputNodeWithSocketOutput:
|
|||||||
}
|
}
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "experimental"
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
def process(self, image, value):
|
def process(self, image, value):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user