Compare commits

...

19 Commits

Author SHA1 Message Date
Matt Miller
86e30d0927
Merge 9557a41c83 into 65045730a6 2026-05-08 22:21:57 +00:00
Matt Miller
9557a41c83 Add idempotence tests for asset preview endpoints
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
PUT with the same preview_id twice should be a 200 no-op, and DELETE on
an asset with no preview set should still return 204. The existing test
file covered set/clear happy paths and the 404/400 error cases but did
not exercise idempotence, which is part of the documented contract for
both verbs.
2026-05-08 15:21:34 -07:00
Matt Miller
05939232b1 Fix Ruff lint failures
- Drop unused 'result' assignment in clear_asset_preview_route (DELETE
  returns 204, the result isn't needed).
- Remove unused 'pytest' import from test_preview.py.
2026-05-08 15:21:34 -07:00
Cursor Agent
9bd5710acb Add asset preview management endpoints (PUT/DELETE /api/assets/{id}/preview)
- PUT /api/assets/{id}/preview: sets the preview asset for an existing asset,
  returns the full updated Asset object. Returns 404 if asset ID doesn't exist.
- DELETE /api/assets/{id}/preview: clears the preview asset link, returns 204.
  Returns 404 if asset ID doesn't exist.

Added OpenAPI spec entries under the assets tag with no x-runtime restriction.
Implemented handlers using the existing set_asset_preview service function.
Added integration tests covering happy paths and 404 cases.

Co-authored-by: Matt Miller <MillerMedia@users.noreply.github.com>
2026-05-08 15:21:34 -07:00
Alexander Piskun
65045730a6
[Partner Nodes] additionally use Baidu server to detect the accessibility of internet (#13803)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-08 13:11:52 -07:00
Matt Miller
87878f354f
Add cloud-runtime FE-facing operations to spec (#13734)
* Add cloud-runtime FE-facing operations to openapi.yaml

Add ~67 cloud-runtime FE-facing path operations to the core OpenAPI spec,
each tagged with x-runtime: [cloud] at the operation level. These operations
are served by the cloud runtime; the local runtime returns 404 for all of
these paths.

Domain groups added:
- Jobs / prompts: /api/job/*, /api/jobs/*/cancel, /api/prompt/*, etc.
- History v2: /api/history_v2, /api/history_v2/{prompt_id}
- Cloud logs: /api/logs
- Asset extensions: /api/assets/download, export, import, etc.
- Custom nodes: /api/experiment/nodes (cloud install/uninstall)
- Hub: /api/hub/profiles, /api/hub/workflows, /api/hub/labels, etc.
- Workflows: /api/workflows CRUD, versioning, fork, publish
- Auth/session: /api/auth/session, /api/auth/token, /.well-known/jwks.json
- Billing: /api/billing/balance, plans, subscribe, topup, etc.
- Workspace: /api/workspace/*, /api/workspaces/*
- User/settings/misc: /api/user, /api/secrets, /api/feedback, etc.

Also adds corresponding cloud-only component schemas (CloudJob, CloudWorkflow,
BillingPlan, Workspace, HubProfile, AuthSession, etc.), all tagged with
x-runtime: [cloud].

Spectral lint passes under the existing ruleset with zero new warnings.

* Add job_id field to Asset schema and deprecate prompt_id (#13736)

- Add job_id as a nullable UUID field to the Asset schema
- Mark prompt_id as deprecated with note pointing to job_id
- No x-runtime tag needed as both runtimes populate the field

* Add hash field to Asset schemas and deprecate asset_hash (#13738)

- Add 'hash' as a nullable string field to Asset and AssetUpdated schemas
- Mark 'asset_hash' as deprecated with a note pointing to 'hash'
- AssetCreated inherits 'hash' via allOf from Asset
- Spectral lint clean (no new warnings)

* Fix method drift on cloud-runtime endpoints

Three PUT operations were added that should be PATCH (cloud serves
PATCH for partial updates):

- /api/workflows/{workflow_id}
- /api/workspaces/{id}
- /api/workspace/members/{userId}

Two POST operations were added that should be GET (cloud serves GET
with query params):

- /api/assets/remote-metadata (url moves to query param)
- /api/files/mask-layers (response shape replaced — operation queries
  related mask layer filenames, not file uploads)

* Add missing cloud-runtime operations and schemas

PR review surfaced operations the cloud runtime serves that weren't
covered by the initial spec push, plus one path family missed entirely.

New methods on existing paths:

- /api/auth/session: add POST (create session cookie) and DELETE (logout)
- /api/secrets/{id}: add GET (read metadata) and PATCH (update)
- /api/hub/profiles: add POST (create profile)
- /api/hub/workflows: add POST (publish to hub)
- /api/hub/workflows/{share_id}: add DELETE (unpublish)
- /api/workspaces/{id}: add DELETE (soft-delete workspace)
- /api/workspace/members/{user_id}/api-keys: add DELETE (bulk revoke)
- /api/workflows/{workflow_id}/versions: add POST (create new version)
- /api/userdata/{file}/publish: add GET (read publish info)

New path family:

- /api/tasks (GET list) and /api/tasks/{task_id} (GET detail) for the
  background task framework

New component schemas (all tagged x-runtime: [cloud]):

CreateSessionResponse, DeleteSessionResponse, UpdateSecretRequest,
BulkRevokeAPIKeysResponse, CreateHubProfileRequest, PublishHubWorkflowRequest,
HubWorkflowDetail, AssetInfo, CreateWorkflowVersionRequest,
WorkflowVersionResponse, WorkflowPublishInfo, TaskEntry, TaskResponse,
TasksListResponse. Existing SecretMeta extended with provider and
last_used_at fields the cloud runtime actually returns.

New tag: task. Spectral lint passes with zero errors.

* Add job_id and prompt_id to AssetUpdated schema

Mirrors the Asset schema's deprecation pattern: prompt_id is marked
deprecated with a description pointing to job_id; job_id is the new
preferred field. PUT /api/assets/{id} responses can now carry both fields
consistent with the other Asset-returning endpoints.

* feat: add width and height fields to Asset schema (#13745)

Add nullable integer fields 'width' and 'height' to the Asset schema
in openapi.yaml. These expose original image dimensions in pixels for
clients that need pre-thumbnail size info. Both fields are null for
non-image assets or assets ingested before dimension extraction.

Co-authored-by: Matt Miller <MillerMedia@users.noreply.github.com>

* Remove /api/job/{job_id} and /api/job/{job_id}/outputs

These two paths are not actually served by the cloud runtime — they
return 404 with a redirect message pointing callers to the canonical
`/api/jobs/{job_id}` (plural). Declaring them with `x-runtime: [cloud]`
and a 200 response schema is incorrect.

`/api/job/{job_id}/status` stays — it is a real cloud-served endpoint.

Also drops the now-orphaned `CloudJob` and `CloudJobOutputs` component
schemas. `CloudJobStatus` is retained.
2026-05-08 12:39:16 -07:00
Alexis Rolland
c5ecd231a2
fix: Fix bug when mask not on same device (CORE-181) (#13801)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-08 23:06:29 +08:00
drozbay
9864f5ac86
fix: Stop LTXVImgToVideoInplace from mutating input latents and dropping noise_mask (#13793) 2026-05-08 23:02:17 +08:00
drozbay
05cd076bc1
fix: Make LTXVAddGuide center-crop guide images to match other LTXV nodes (#13794) 2026-05-08 22:48:59 +08:00
Yousef R. Gamaleldin
d3c18c1636
Add support for BiRefNet background remove model (CORE-46) (#12747) 2026-05-08 17:59:24 +08:00
omahs
bac6fc35fb
Fix typos (#10986) 2026-05-08 17:14:45 +08:00
Alexander Piskun
56c74094c7
[Partner Nodes] use "adaptive" aspect ratio for SD2 nodes (#13800)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-07 23:39:13 -07:00
Alexis Rolland
594de378fe
Update nodes categories and display names (CORE-89) (#13786) 2026-05-08 01:02:55 -04:00
Jedrzej Kosinski
c8673542f7
fix: make NodeReplaceManager.register() idempotent (#13596) 2026-05-07 19:21:12 -07:00
comfyanonymous
df7bf1d3dc
Update warning message for ComfyUI frontend installation. (#13796) 2026-05-07 19:04:30 -07:00
Talmaj
ef8f25601a
Add I2V for causal forcing model. (#13719) 2026-05-07 18:38:36 -07:00
Jukka Seppänen
8dc3f3f209
Improve SAM3 large input handling (#13767) 2026-05-07 17:18:28 -07:00
Alexander Piskun
c011fb520c
[Partner Nodes] new NanoBanana2 node with DynamicCombo/Autogrow (#13753)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
* feat(api-nodes): new NanoBanana2 node with  DynamicCombo/Autogrow

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* feat: improved status text on uploading

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* feat: improved status text on uploading (2)

Signed-off-by: bigcat88 <bigcat88@icloud.com>

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-07 12:19:44 -07:00
Alexander Piskun
c945a433ae
fix(api-nodes): fixed price badge for Kling V3 model in the Motion Control node (#13790)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-07 11:55:09 -07:00
60 changed files with 6468 additions and 146 deletions

View File

@ -36,6 +36,7 @@ from app.assets.services import (
list_tags,
remove_tags,
resolve_asset_for_download,
set_asset_preview,
update_asset_metadata,
upload_from_temp_path,
)
@ -545,6 +546,70 @@ async def delete_asset_route(request: web.Request) -> web.Response:
return web.Response(status=204)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}/preview")
@_require_assets_feature_enabled
async def set_asset_preview_route(request: web.Request) -> web.Response:
reference_id = str(uuid.UUID(request.match_info["id"]))
try:
payload = await request.json()
except Exception:
return _build_error_response(
400, "INVALID_JSON", "Request body must be valid JSON."
)
preview_id = payload.get("preview_id")
if not preview_id:
return _build_error_response(
400, "INVALID_BODY", "preview_id is required."
)
try:
result = set_asset_preview(
reference_id=reference_id,
preview_reference_id=preview_id,
owner_id=USER_MANAGER.get_request_user_id(request),
)
response_payload = _build_asset_response(result)
except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
except ValueError as ve:
return _build_error_response(
404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id}
)
except Exception:
logging.exception(
"set_asset_preview failed for reference_id=%s",
reference_id,
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(response_payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/preview")
@_require_assets_feature_enabled
async def clear_asset_preview_route(request: web.Request) -> web.Response:
reference_id = str(uuid.UUID(request.match_info["id"]))
try:
set_asset_preview(
reference_id=reference_id,
preview_reference_id=None,
owner_id=USER_MANAGER.get_request_user_id(request),
)
except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
except ValueError as ve:
return _build_error_response(
404, "ASSET_NOT_FOUND", str(ve), {"id": reference_id}
)
except Exception:
logging.exception(
"clear_asset_preview failed for reference_id=%s",
reference_id,
)
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.Response(status=204)
@ROUTES.get("/api/tags")
@_require_assets_feature_enabled
async def get_tags(request: web.Request) -> web.Response:

View File

@ -27,7 +27,7 @@ def frontend_install_warning_message():
return f"""
{get_missing_requirements_message()}
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
The ComfyUI frontend is shipped in a pip package so it needs to be updated separately from the ComfyUI code.
""".strip()
def parse_version(version: str) -> tuple[int, int, int]:

View File

@ -1,5 +1,7 @@
from __future__ import annotations
import logging
from aiohttp import web
from typing import TYPE_CHECKING, TypedDict
@ -31,8 +33,22 @@ class NodeReplaceManager:
self._replacements: dict[str, list[NodeReplace]] = {}
def register(self, node_replace: NodeReplace):
"""Register a node replacement mapping."""
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
"""Register a node replacement mapping.
Idempotent: if a replacement with the same (old_node_id, new_node_id)
is already registered, the duplicate is ignored. This prevents stale
entries from accumulating when custom nodes are reloaded in the same
process (e.g. via ComfyUI-Manager).
"""
existing = self._replacements.setdefault(node_replace.old_node_id, [])
for entry in existing:
if entry.new_node_id == node_replace.new_node_id:
logging.debug(
"Node replacement %s -> %s already registered, ignoring duplicate.",
node_replace.old_node_id, node_replace.new_node_id,
)
return
existing.append(node_replace)
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
"""Get replacements for an old node ID."""

View File

@ -0,0 +1,7 @@
{
"model_type": "birefnet",
"image_std": [1.0, 1.0, 1.0],
"image_mean": [0.0, 0.0, 0.0],
"image_size": 1024,
"resize_to_original": true
}

View File

@ -0,0 +1,689 @@
import torch
import comfy.ops
import numpy as np
import torch.nn as nn
from functools import partial
import torch.nn.functional as F
from torchvision.ops import deform_conv2d
from comfy.ldm.modules.attention import optimized_attention_for_device
CXT = [3072, 1536, 768, 384][1:][::-1][-3:]
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
self.kv = operations.Linear(dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
def forward(self, x):
B, N, C = x.shape
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
x = optimized_attention(
q, k, v, heads=self.num_heads, skip_output_reshape=True, skip_reshape=True
).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, device=None, dtype=None, operations=None):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = operations.Linear(in_features, hidden_features, device=device, dtype=dtype)
self.act = nn.GELU()
self.fc2 = operations.Linear(hidden_features, out_features, device=device, dtype=dtype)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads, device=device, dtype=dtype))
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
return x
class SwinTransformerBlock(nn.Module):
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None,
norm_layer=nn.LayerNorm, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.norm1 = norm_layer(dim, device=device, dtype=dtype)
self.attn = WindowAttention(
dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, device=device, dtype=dtype, operations=operations)
self.norm2 = norm_layer(dim, device=device, dtype=dtype)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, device=device, dtype=dtype, operations=operations)
self.H = None
self.W = None
def forward(self, x, mask_matrix):
B, L, C = x.shape
H, W = self.H, self.W
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
x_windows = window_partition(shifted_x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
attn_windows = self.attn(x_windows, mask=attn_mask)
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class PatchMerging(nn.Module):
def __init__(self, dim, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.reduction = operations.Linear(4 * dim, 2 * dim, bias=False, device=device, dtype=dtype)
self.norm = operations.LayerNorm(4 * dim, device=device, dtype=dtype)
def forward(self, x, H, W):
B, L, C = x.shape
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(nn.Module):
def __init__(self,
dim,
depth,
num_heads,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
norm_layer=nn.LayerNorm,
downsample=None,
device=None, dtype=None, operations=None):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
norm_layer=norm_layer,
device=device, dtype=dtype, operations=operations)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, device=device, dtype=dtype, operations=operations)
else:
self.downsample = None
def forward(self, x, H, W):
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
for blk in self.blocks:
blk.H, blk.W = H, W
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
class PatchEmbed(nn.Module):
def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None, device=None, dtype=None, operations=None):
super().__init__()
patch_size = (patch_size, patch_size)
self.patch_size = patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype)
if norm_layer is not None:
self.norm = norm_layer(embed_dim, device=device, dtype=dtype)
else:
self.norm = None
def forward(self, x):
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x
class SwinTransformer(nn.Module):
def __init__(self,
pretrain_img_size=224,
patch_size=4,
in_channels=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
patch_norm=True,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
device=None, dtype=None, operations=None):
super().__init__()
norm_layer = partial(operations.LayerNorm, device=device, dtype=dtype)
self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.patch_embed = PatchEmbed(
patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
device=device, dtype=dtype, operations=operations,
norm_layer=norm_layer if self.patch_norm else None)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
device=device, dtype=dtype, operations=operations)
self.layers.append(layer)
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
self.num_features = num_features
for i_layer in out_indices:
layer = norm_layer(num_features[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
def forward(self, x):
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
outs = []
x = x.flatten(2).transpose(1, 2)
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)
return tuple(outs)
class DeformableConv2d(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False, device=None, dtype=None, operations=None):
super(DeformableConv2d, self).__init__()
kernel_size = kernel_size if type(kernel_size) is tuple else (kernel_size, kernel_size)
self.stride = stride if type(stride) is tuple else (stride, stride)
self.padding = padding
self.offset_conv = operations.Conv2d(in_channels,
2 * kernel_size[0] * kernel_size[1],
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=True, device=device, dtype=dtype)
self.modulator_conv = operations.Conv2d(in_channels,
1 * kernel_size[0] * kernel_size[1],
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=True, device=device, dtype=dtype)
self.regular_conv = operations.Conv2d(in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=bias, device=device, dtype=dtype)
def forward(self, x):
offset = self.offset_conv(x)
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
weight, bias, offload_info = comfy.ops.cast_bias_weight(self.regular_conv, x, offloadable=True)
x = deform_conv2d(
input=x,
offset=offset,
weight=weight,
bias=None,
padding=self.padding,
mask=modulator,
stride=self.stride,
)
comfy.ops.uncast_bias_weight(self.regular_conv, weight, bias, offload_info)
return x
class BasicDecBlk(nn.Module):
def __init__(self, in_channels=64, out_channels=64, inter_channels=64, device=None, dtype=None, operations=None):
super(BasicDecBlk, self).__init__()
inter_channels = 64
self.conv_in = operations.Conv2d(in_channels, inter_channels, 3, 1, padding=1, device=device, dtype=dtype)
self.relu_in = nn.ReLU(inplace=True)
self.dec_att = ASPPDeformable(in_channels=inter_channels, device=device, dtype=dtype, operations=operations)
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, padding=1, device=device, dtype=dtype)
self.bn_in = operations.BatchNorm2d(inter_channels, device=device, dtype=dtype)
self.bn_out = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
def forward(self, x):
x = self.conv_in(x)
x = self.bn_in(x)
x = self.relu_in(x)
x = self.dec_att(x)
x = self.conv_out(x)
x = self.bn_out(x)
return x
class BasicLatBlk(nn.Module):
def __init__(self, in_channels=64, out_channels=64, device=None, dtype=None, operations=None):
super(BasicLatBlk, self).__init__()
self.conv = operations.Conv2d(in_channels, out_channels, 1, 1, 0, device=device, dtype=dtype)
def forward(self, x):
x = self.conv(x)
return x
class _ASPPModuleDeformable(nn.Module):
def __init__(self, in_channels, planes, kernel_size, padding, device, dtype, operations):
super(_ASPPModuleDeformable, self).__init__()
self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size,
stride=1, padding=padding, bias=False, device=device, dtype=dtype, operations=operations)
self.bn = operations.BatchNorm2d(planes, device=device, dtype=dtype)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.atrous_conv(x)
x = self.bn(x)
return self.relu(x)
class ASPPDeformable(nn.Module):
def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7], device=None, dtype=None, operations=None):
super(ASPPDeformable, self).__init__()
self.down_scale = 1
if out_channels is None:
out_channels = in_channels
self.in_channelster = 256 // self.down_scale
self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0, device=device, dtype=dtype, operations=operations)
self.aspp_deforms = nn.ModuleList([
_ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2), device=device, dtype=dtype, operations=operations)
for conv_size in parallel_block_sizes
])
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
operations.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False, device=device, dtype=dtype),
operations.BatchNorm2d(self.in_channelster, device=device, dtype=dtype),
nn.ReLU(inplace=True))
self.conv1 = operations.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False, device=device, dtype=dtype)
self.bn1 = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x1 = self.aspp1(x)
x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
x5 = self.global_avg_pool(x)
x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x1, *x_aspp_deforms, x5), dim=1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
return x
class BiRefNet(nn.Module):
def __init__(self, config=None, dtype=None, device=None, operations=None):
super(BiRefNet, self).__init__()
self.bb = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12, device=device, dtype=dtype, operations=operations)
channels = [1536, 768, 384, 192]
channels = [c * 2 for c in channels]
self.cxt = channels[1:][::-1][-3:]
self.squeeze_module = nn.Sequential(*[
BasicDecBlk(channels[0]+sum(self.cxt), channels[0], device=device, dtype=dtype, operations=operations)
for _ in range(1)
])
self.decoder = Decoder(channels, device=device, dtype=dtype, operations=operations)
def forward_enc(self, x):
x1, x2, x3, x4 = self.bb(x)
B, C, H, W = x.shape
x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x4 = torch.cat(
(
*[
F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
][-len(CXT):],
x4
),
dim=1
)
return (x1, x2, x3, x4)
def forward_ori(self, x):
(x1, x2, x3, x4) = self.forward_enc(x)
x4 = self.squeeze_module(x4)
features = [x, x1, x2, x3, x4]
scaled_preds = self.decoder(features)
return scaled_preds
def forward(self, pixel_values, intermediate_output=None):
scaled_preds = self.forward_ori(pixel_values)
return scaled_preds
class Decoder(nn.Module):
def __init__(self, channels, device, dtype, operations):
super(Decoder, self).__init__()
# factory kwargs
fk = {"device":device, "dtype":dtype, "operations":operations}
DecoderBlock = partial(BasicDecBlk, **fk)
LateralBlock = partial(BasicLatBlk, **fk)
DBlock = partial(SimpleConvs, **fk)
self.split = True
N_dec_ipt = 64
ic = 64
ipt_cha_opt = 1
self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic)
self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[1])
self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[2])
self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt]), channels[3])
self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt]), channels[3]//2)
fk = {"device":device, "dtype":dtype}
self.conv_out1 = nn.Sequential(operations.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt]), 1, 1, 1, 0, **fk))
self.lateral_block4 = LateralBlock(channels[1], channels[1])
self.lateral_block3 = LateralBlock(channels[2], channels[2])
self.lateral_block2 = LateralBlock(channels[3], channels[3])
self.conv_ms_spvn_4 = operations.Conv2d(channels[1], 1, 1, 1, 0, **fk)
self.conv_ms_spvn_3 = operations.Conv2d(channels[2], 1, 1, 1, 0, **fk)
self.conv_ms_spvn_2 = operations.Conv2d(channels[3], 1, 1, 1, 0, **fk)
_N = 16
self.gdt_convs_4 = nn.Sequential(operations.Conv2d(channels[0] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
self.gdt_convs_3 = nn.Sequential(operations.Conv2d(channels[1] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
self.gdt_convs_2 = nn.Sequential(operations.Conv2d(channels[2] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
[setattr(self, f"gdt_convs_pred_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
[setattr(self, f"gdt_convs_attn_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
def get_patches_batch(self, x, p):
_size_h, _size_w = p.shape[2:]
patches_batch = []
for idx in range(x.shape[0]):
columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
patches_x = []
for column_x in columns_x:
patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
patch_sample = torch.cat(patches_x, dim=1)
patches_batch.append(patch_sample)
return torch.cat(patches_batch, dim=0)
def forward(self, features):
x, x1, x2, x3, x4 = features
patches_batch = self.get_patches_batch(x, x4) if self.split else x
x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
p4 = self.decoder_block4(x4)
p4_gdt = self.gdt_convs_4(p4)
gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
p4 = p4 * gdt_attn_4
_p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
_p3 = _p4 + self.lateral_block4(x3)
patches_batch = self.get_patches_batch(x, _p3) if self.split else x
_p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
p3 = self.decoder_block3(_p3)
p3_gdt = self.gdt_convs_3(p3)
gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
p3 = p3 * gdt_attn_3
_p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
_p2 = _p3 + self.lateral_block3(x2)
patches_batch = self.get_patches_batch(x, _p2) if self.split else x
_p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
p2 = self.decoder_block2(_p2)
p2_gdt = self.gdt_convs_2(p2)
gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
p2 = p2 * gdt_attn_2
_p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
_p1 = _p2 + self.lateral_block2(x1)
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
_p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
_p1 = self.decoder_block1(_p1)
_p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
_p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
p1_out = self.conv_out1(_p1)
return p1_out
class SimpleConvs(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, inter_channels=64, device=None, dtype=None, operations=None
) -> None:
super().__init__()
self.conv1 = operations.Conv2d(in_channels, inter_channels, 3, 1, 1, device=device, dtype=dtype)
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, 1, device=device, dtype=dtype)
def forward(self, x):
return self.conv_out(self.conv1(x))

78
comfy/bg_removal_model.py Normal file
View File

@ -0,0 +1,78 @@
from .utils import load_torch_file
import os
import json
import torch
import logging
import comfy.ops
import comfy.model_patcher
import comfy.model_management
import comfy.clip_model
import comfy.background_removal.birefnet
BG_REMOVAL_MODELS = {
"birefnet": comfy.background_removal.birefnet.BiRefNet
}
class BackgroundRemovalModel():
def __init__(self, json_config):
with open(json_config) as f:
config = json.load(f)
self.image_size = config.get("image_size", 1024)
self.image_mean = config.get("image_mean", [0.0, 0.0, 0.0])
self.image_std = config.get("image_std", [1.0, 1.0, 1.0])
self.model_type = config.get("model_type", "birefnet")
self.config = config.copy()
model_class = BG_REMOVAL_MODELS.get(self.model_type)
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image):
comfy.model_management.load_model_gpu(self.patcher)
H, W = image.shape[1], image.shape[2]
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False)
out = self.model(pixel_values=pixel_values)
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
if mask.ndim == 3:
mask = mask.unsqueeze(0)
if mask.shape[1] != 1:
mask = mask.movedim(-1, 1)
return mask
def load_background_removal_model(sd):
if "bb.layers.1.blocks.0.attn.relative_position_index" in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "background_removal"), "birefnet.json")
else:
return None
bg_model = BackgroundRemovalModel(json_config)
m, u = bg_model.load_sd(sd)
if len(m) > 0:
logging.warning("missing background removal: {}".format(m))
u = set(u)
keys = list(sd.keys())
for k in keys:
if k not in u:
sd.pop(k)
return bg_model
def load(ckpt_path):
sd = load_torch_file(ckpt_path)
return load_background_removal_model(sd)

View File

@ -93,7 +93,7 @@ class Hook:
self.hook_scope = hook_scope
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
self.custom_should_register = default_should_register
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
'''Can be overridden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
@property
def strength(self):

View File

@ -1859,6 +1859,23 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
output = torch.zeros_like(x)
s_in = x.new_ones([x.shape[0]])
current_start_frame = 0
# I2V: seed KV cache with the initial image latent before the denoising loop
initial_latent = transformer_options.get("ar_config", {}).get("initial_latent", None)
if initial_latent is not None:
initial_latent = inner_model.process_latent_in(initial_latent).to(device=device, dtype=model_dtype)
n_init = initial_latent.shape[2]
output[:, :, :n_init] = initial_latent
ar_state = {"start_frame": 0, "kv_caches": kv_caches, "crossattn_caches": crossattn_caches}
transformer_options["ar_state"] = ar_state
zero_sigma = sigmas.new_zeros([1])
_ = model(initial_latent, zero_sigma * s_in, **extra_args)
current_start_frame = n_init
remaining = lat_t - n_init
num_blocks = -(-remaining // num_frame_per_block)
num_sigma_steps = len(sigmas) - 1
total_real_steps = num_blocks * num_sigma_steps
step_count = 0

View File

@ -140,7 +140,7 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502
# according to the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose:
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')

View File

@ -561,7 +561,8 @@ class SAM3Model(nn.Module):
return high_res_masks
def forward_video(self, images, initial_masks, pbar=None, text_prompts=None,
new_det_thresh=0.5, max_objects=0, detect_interval=1):
new_det_thresh=0.5, max_objects=0, detect_interval=1,
target_device=None, target_dtype=None):
"""Track video with optional per-frame text-prompted detection."""
bb = self.detector.backbone["vision_backbone"]
@ -589,8 +590,10 @@ class SAM3Model(nn.Module):
return self.tracker.track_video_with_detection(
backbone_fn, images, initial_masks, detect_fn,
new_det_thresh=new_det_thresh, max_objects=max_objects,
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar)
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar,
target_device=target_device, target_dtype=target_dtype)
# SAM3 (non-multiplex) — no detection support, requires initial masks
if initial_masks is None:
raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking")
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb)
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb,
target_device=target_device, target_dtype=target_dtype)

View File

@ -200,8 +200,13 @@ def pack_masks(masks):
def unpack_masks(packed):
"""Unpack bit-packed [*, H, W//8] uint8 to bool [*, H, W*8]."""
shifts = torch.arange(8, device=packed.device)
return ((packed.unsqueeze(-1) >> shifts) & 1).view(*packed.shape[:-1], -1).bool()
bits = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], dtype=torch.uint8, device=packed.device)
return (packed.unsqueeze(-1) & bits).bool().view(*packed.shape[:-1], -1)
def _prep_frame(images, idx, device, dt, size):
"""Slice CPU full-res frames, transfer to GPU in target dtype, and resize to (size, size)."""
return comfy.utils.common_upscale(images[idx].to(device=device, dtype=dt), size, size, "bicubic", crop="disabled")
def _compute_backbone(backbone_fn, frame, frame_idx=None):
@ -1078,16 +1083,19 @@ class SAM3Tracker(nn.Module):
# SAM3: drop last FPN level
return vision_feats[:-1], vision_pos[:-1], feat_sizes[:-1]
def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None):
def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None,
target_device=None, target_dtype=None):
"""Track one object, computing backbone per frame to save VRAM."""
N = images.shape[0]
device, dt = images.device, images.dtype
device = target_device if target_device is not None else images.device
dt = target_dtype if target_dtype is not None else images.dtype
size = self.image_size
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
all_masks = []
for frame_idx in tqdm(range(N), desc="tracking"):
vision_feats, vision_pos, feat_sizes = self._compute_backbone_frame(
backbone_fn, images[frame_idx:frame_idx + 1], frame_idx=frame_idx)
backbone_fn, _prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), frame_idx=frame_idx)
mask_input = None
if frame_idx == 0:
mask_input = F.interpolate(initial_mask.to(device=device, dtype=dt),
@ -1114,12 +1122,13 @@ class SAM3Tracker(nn.Module):
return torch.cat(all_masks, dim=0) # [N, 1, H, W]
def track_video(self, backbone_fn, images, initial_masks, pbar=None, **kwargs):
def track_video(self, backbone_fn, images, initial_masks, pbar=None,
target_device=None, target_dtype=None, **kwargs):
"""Track one or more objects across video frames.
Args:
backbone_fn: callable that returns (sam2_features, sam2_positions, trunk_out) for a frame
images: [N, 3, 1008, 1008] video frames
images: [N, 3, H, W] CPU full-res video frames (resized per-frame to self.image_size)
initial_masks: [N_obj, 1, H, W] binary masks for first frame (one per object)
pbar: optional progress bar
@ -1130,7 +1139,8 @@ class SAM3Tracker(nn.Module):
per_object = []
for obj_idx in range(N_obj):
obj_masks = self._track_single_object(
backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar)
backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar,
target_device=target_device, target_dtype=target_dtype)
per_object.append(obj_masks)
return torch.cat(per_object, dim=1) # [N, N_obj, H, W]
@ -1632,11 +1642,18 @@ class SAM31Tracker(nn.Module):
return det_scores[new_dets].tolist() if det_scores is not None else [0.0] * new_dets.sum().item()
return []
INTERNAL_MAX_OBJECTS = 64 # Hard ceiling on accumulated tracks; max_objects=0 or any value above this is clamped here.
def track_video_with_detection(self, backbone_fn, images, initial_masks, detect_fn=None,
new_det_thresh=0.5, max_objects=0, detect_interval=1,
backbone_obj=None, pbar=None):
backbone_obj=None, pbar=None, target_device=None, target_dtype=None):
"""Track with optional per-frame detection. Returns [N, max_N_obj, H, W] mask logits."""
N, device, dt = images.shape[0], images.device, images.dtype
if max_objects <= 0 or max_objects > self.INTERNAL_MAX_OBJECTS:
max_objects = self.INTERNAL_MAX_OBJECTS
N = images.shape[0]
device = target_device if target_device is not None else images.device
dt = target_dtype if target_dtype is not None else images.dtype
size = self.image_size
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
all_masks = []
idev = comfy.model_management.intermediate_device()
@ -1656,7 +1673,7 @@ class SAM31Tracker(nn.Module):
prefetch = True
except RuntimeError:
pass
cur_bb = self._compute_backbone_frame(backbone_fn, images[0:1], frame_idx=0)
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(0, 1), device, dt, size), frame_idx=0)
for frame_idx in tqdm(range(N), desc="tracking"):
vision_feats, vision_pos, feat_sizes, high_res_prop, trunk_out = cur_bb
@ -1666,7 +1683,7 @@ class SAM31Tracker(nn.Module):
backbone_stream.wait_stream(torch.cuda.current_stream(device))
with torch.cuda.stream(backbone_stream):
next_bb = self._compute_backbone_frame(
backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
# Per-frame detection with NMS (skip if no detect_fn, or interval/max not met)
det_masks = torch.empty(0, device=device)
@ -1687,7 +1704,7 @@ class SAM31Tracker(nn.Module):
current_out = self._condition_with_masks(
initial_masks.to(device=device, dtype=dt), frame_idx, vision_feats, vision_pos,
feat_sizes, high_res_prop, output_dict, N, mux_state, backbone_obj,
images[frame_idx:frame_idx + 1], trunk_out)
_prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out)
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
obj_scores = [1.0] * mux_state.total_valid_entries
if keep_alive is not None:
@ -1702,7 +1719,7 @@ class SAM31Tracker(nn.Module):
current_out = self._condition_with_masks(
det_masks, frame_idx, vision_feats, vision_pos, feat_sizes, high_res_prop,
output_dict, N, mux_state, backbone_obj,
images[frame_idx:frame_idx + 1], trunk_out, threshold=0.0)
_prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out, threshold=0.0)
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
obj_scores = det_scores[:mux_state.total_valid_entries].tolist()
if keep_alive is not None:
@ -1718,7 +1735,7 @@ class SAM31Tracker(nn.Module):
torch.cuda.current_stream(device).wait_stream(backbone_stream)
cur_bb = next_bb
else:
cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
continue
else:
N_obj = mux_state.total_valid_entries
@ -1768,7 +1785,7 @@ class SAM31Tracker(nn.Module):
torch.cuda.current_stream(device).wait_stream(backbone_stream)
cur_bb = next_bb
else:
cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
if not all_masks or all(m is None for m in all_masks):
return {"packed_masks": None, "n_frames": N, "scores": []}

View File

@ -562,6 +562,25 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
running_mean = self.running_mean.to(device=input.device, dtype=weight.dtype) if self.running_mean is not None else None
running_var = self.running_var.to(device=input.device, dtype=weight.dtype) if self.running_var is not None else None
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, self.training, self.momentum, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
def reset_parameters(self):
return None
@ -749,6 +768,9 @@ class manual_cast(disable_weight_init):
class Conv3d(disable_weight_init.Conv3d):
comfy_cast_weights = True
class BatchNorm2d(disable_weight_init.BatchNorm2d):
comfy_cast_weights = True
class GroupNorm(disable_weight_init.GroupNorm):
comfy_cast_weights = True

View File

@ -17,6 +17,7 @@ if TYPE_CHECKING:
from spandrel import ImageModelDescriptor
from comfy.clip_vision import ClipVisionModel
from comfy.clip_vision import Output as ClipVisionOutput_
from comfy.bg_removal_model import BackgroundRemovalModel
from comfy.controlnet import ControlNet
from comfy.hooks import HookGroup, HookKeyframeGroup
from comfy.model_patcher import ModelPatcher
@ -614,6 +615,11 @@ class Model(ComfyTypeIO):
if TYPE_CHECKING:
Type = ModelPatcher
@comfytype(io_type="BACKGROUND_REMOVAL")
class BackgroundRemoval(ComfyTypeIO):
if TYPE_CHECKING:
Type = BackgroundRemovalModel
@comfytype(io_type="CLIP_VISION")
class ClipVision(ComfyTypeIO):
if TYPE_CHECKING:
@ -2257,6 +2263,7 @@ __all__ = [
"ModelPatch",
"ClipVision",
"ClipVisionOutput",
"BackgroundRemoval",
"AudioEncoder",
"AudioEncoderOutput",
"StyleModel",

View File

@ -1271,7 +1271,7 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
)
def _seedance2_text_inputs(resolutions: list[str]):
def _seedance2_text_inputs(resolutions: list[str], default_ratio: str = "16:9"):
return [
IO.String.Input(
"prompt",
@ -1287,6 +1287,7 @@ def _seedance2_text_inputs(resolutions: list[str]):
IO.Combo.Input(
"ratio",
options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"],
default=default_ratio,
tooltip="Aspect ratio of the output video.",
),
IO.Int.Input(
@ -1420,8 +1421,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
IO.DynamicCombo.Option(
"Seedance 2.0",
_seedance2_text_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option(
"Seedance 2.0 Fast",
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
),
],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
),
@ -1588,9 +1595,9 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
def _seedance2_reference_inputs(resolutions: list[str]):
def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16:9"):
return [
*_seedance2_text_inputs(resolutions),
*_seedance2_text_inputs(resolutions, default_ratio=default_ratio),
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplateNames(
@ -1668,8 +1675,14 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs(["480p", "720p", "1080p"])),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs(["480p", "720p"])),
IO.DynamicCombo.Option(
"Seedance 2.0",
_seedance2_reference_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option(
"Seedance 2.0 Fast",
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
),
],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
),

View File

@ -83,13 +83,16 @@ class GeminiImageModel(str, Enum):
async def create_image_parts(
cls: type[IO.ComfyNode],
images: Input.Image,
images: Input.Image | list[Input.Image],
image_limit: int = 0,
) -> list[GeminiPart]:
image_parts: list[GeminiPart] = []
if image_limit < 0:
raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.")
total_images = get_number_of_images(images)
# Accept either a single (possibly-batched) tensor or a list of them; share URL budget across all.
images_list: list[Input.Image] = images if isinstance(images, list) else [images]
total_images = sum(get_number_of_images(img) for img in images_list)
if total_images <= 0:
raise ValueError("No images provided to create_image_parts; at least one image is required.")
@ -98,10 +101,18 @@ async def create_image_parts(
# Number of images we'll send as URLs (fileData)
num_url_images = min(effective_max, 10) # Vertex API max number of image links
upload_kwargs: dict = {"wait_label": "Uploading reference images"}
if effective_max > num_url_images:
# Split path (e.g. 11+ images): suppress per-image counter to avoid a confusing dual-fraction label.
upload_kwargs = {
"wait_label": f"Uploading reference images ({num_url_images}+)",
"show_batch_index": False,
}
reference_images_urls = await upload_images_to_comfyapi(
cls,
images,
images_list,
max_images=num_url_images,
**upload_kwargs,
)
for reference_image_url in reference_images_urls:
image_parts.append(
@ -112,15 +123,22 @@ async def create_image_parts(
)
)
)
for idx in range(num_url_images, effective_max):
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png,
data=tensor_to_base64_string(images[idx]),
if effective_max > num_url_images:
flat: list[torch.Tensor] = []
for tensor in images_list:
if len(tensor.shape) == 4:
flat.extend(tensor[i] for i in range(tensor.shape[0]))
else:
flat.append(tensor)
for idx in range(num_url_images, effective_max):
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png,
data=tensor_to_base64_string(flat[idx]),
)
)
)
)
return image_parts
@ -891,10 +909,6 @@ class GeminiNanoBanana2(IO.ComfyNode):
"9:16",
"16:9",
"21:9",
# "1:4",
# "4:1",
# "8:1",
# "1:8",
],
default="auto",
tooltip="If set to 'auto', matches your input image's aspect ratio; "
@ -902,12 +916,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
),
IO.Combo.Input(
"resolution",
options=[
# "512px",
"1K",
"2K",
"4K",
],
options=["1K", "2K", "4K"],
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
),
IO.Combo.Input(
@ -956,6 +965,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
],
is_api_node=True,
price_badge=GEMINI_IMAGE_2_PRICE_BADGE,
is_deprecated=True,
)
@classmethod
@ -1016,6 +1026,197 @@ class GeminiNanoBanana2(IO.ComfyNode):
)
def _nano_banana_2_v2_model_inputs():
return [
IO.Combo.Input(
"aspect_ratio",
options=[
"auto",
"1:1",
"2:3",
"3:2",
"3:4",
"4:3",
"4:5",
"5:4",
"9:16",
"16:9",
"21:9",
"1:4",
"4:1",
"8:1",
"1:8",
],
default="auto",
tooltip="If set to 'auto', matches your input image's aspect ratio; "
"if no image is provided, a 16:9 square is usually generated.",
),
IO.Combo.Input(
"resolution",
options=["1K", "2K", "4K"],
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
),
IO.Combo.Input(
"thinking_level",
options=["MINIMAL", "HIGH"],
),
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, 15)],
min=0,
),
tooltip="Optional reference image(s). Up to 14 images total.",
),
IO.Custom("GEMINI_INPUT_FILES").Input(
"files",
optional=True,
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
]
class GeminiNanoBanana2V2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GeminiNanoBanana2V2",
display_name="Nano Banana 2",
category="api node/image/Gemini",
description="Generate or edit images synchronously via Google Vertex API.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Text prompt describing the image to generate or the edits to apply. "
"Include any constraints, styles, or details the model should follow.",
default="",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"Nano Banana 2 (Gemini 3.1 Flash Image)",
_nano_banana_2_v2_model_inputs(),
),
],
),
IO.Int.Input(
"seed",
default=42,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide "
"the same response for repeated requests. Deterministic output isn't guaranteed. "
"Also, changing the model or parameter settings, such as the temperature, "
"can cause variations in the response even when you use the same seed value. "
"By default, a random seed value is used.",
),
IO.Combo.Input(
"response_modalities",
options=["IMAGE", "IMAGE+TEXT"],
advanced=True,
),
IO.String.Input(
"system_prompt",
multiline=True,
default=GEMINI_IMAGE_SYS_PROMPT,
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
advanced=True,
),
],
outputs=[
IO.Image.Output(),
IO.String.Output(),
IO.Image.Output(
display_name="thought_image",
tooltip="First image from the model's thinking process. "
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
expr="""
(
$r := $lookup(widgets, "model.resolution");
$prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
response_modalities: str,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
model_choice = model["model"]
if model_choice == "Nano Banana 2 (Gemini 3.1 Flash Image)":
model_id = "gemini-3.1-flash-image-preview"
else:
model_id = model_choice
images = model.get("images") or {}
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
if images:
image_tensors: list[Input.Image] = [t for t in images.values() if t is not None]
if image_tensors:
if sum(get_number_of_images(t) for t in image_tensors) > 14:
raise ValueError("The current maximum number of supported images is 14.")
parts.extend(await create_image_parts(cls, image_tensors))
files = model.get("files")
if files is not None:
parts.extend(files)
image_config = GeminiImageConfig(imageSize=model["resolution"])
if model["aspect_ratio"] != "auto":
image_config.aspectRatio = model["aspect_ratio"]
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model_id}", method="POST"),
data=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(role=GeminiRole.user, parts=parts),
],
generationConfig=GeminiImageGenerationConfig(
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=image_config,
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
),
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(
await get_image_from_response(response),
get_text_from_response(response),
await get_image_from_response(response, thought=True),
)
class GeminiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -1024,6 +1225,7 @@ class GeminiExtension(ComfyExtension):
GeminiImage,
GeminiImage2,
GeminiNanoBanana2,
GeminiNanoBanana2V2,
GeminiInputFiles,
]

View File

@ -2787,11 +2787,15 @@ class MotionControl(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode"]),
depends_on=IO.PriceBadgeDepends(widgets=["mode", "model"]),
expr="""
(
$prices := {"std": 0.07, "pro": 0.112};
{"type":"usd","usd": $lookup($prices, widgets.mode), "format":{"suffix":"/second"}}
$prices := {
"kling-v3": {"std": 0.126, "pro": 0.168},
"kling-v2-6": {"std": 0.07, "pro": 0.112}
};
$modelPrices := $lookup($prices, widgets.model);
{"type":"usd","usd": $lookup($modelPrices, widgets.mode), "format":{"suffix":"/second"}}
)
""",
),

View File

@ -488,10 +488,30 @@ async def _diagnose_connectivity() -> dict[str, bool]:
"api_accessible": False,
}
timeout = aiohttp.ClientTimeout(total=5.0)
# Probe Google and Baidu in parallel: Google is blocked by the GFW in mainland China, so a Baidu probe is required
# to correctly detect that Chinese users with working internet do have working internet.
internet_probe_urls = ("https://www.google.com", "https://www.baidu.com")
async with aiohttp.ClientSession(timeout=timeout) as session:
with contextlib.suppress(ClientError, OSError):
async with session.get("https://www.google.com") as resp:
results["internet_accessible"] = resp.status < 500
async def _probe(url: str) -> bool:
try:
async with session.get(url) as resp:
return resp.status < 500
except (ClientError, OSError, asyncio.TimeoutError):
return False
probe_tasks = [asyncio.create_task(_probe(u)) for u in internet_probe_urls]
try:
for fut in asyncio.as_completed(probe_tasks):
if await fut:
results["internet_accessible"] = True
break
finally:
for t in probe_tasks:
if not t.done():
t.cancel()
await asyncio.gather(*probe_tasks, return_exceptions=True)
if not results["internet_accessible"]:
return results

View File

@ -92,7 +92,7 @@ class SamplerEulerCFGpp(io.ComfyNode):
return io.Schema(
node_id="SamplerEulerCFGpp",
display_name="SamplerEulerCFG++",
category="_for_testing", # "sampling/custom_sampling/samplers"
category="experimental", # "sampling/custom_sampling/samplers"
inputs=[
io.Combo.Input("version", options=["regular", "alternative"], advanced=True),
],

View File

@ -2,6 +2,7 @@
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
- SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop
- ARVideoI2V: image-to-video conditioning for AR models (seeds KV cache with start image)
"""
import torch
@ -9,6 +10,7 @@ from typing_extensions import override
import comfy.model_management
import comfy.samplers
import comfy.utils
from comfy_api.latest import ComfyExtension, io
@ -71,12 +73,62 @@ class SamplerARVideo(io.ComfyNode):
return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options))
class ARVideoI2V(io.ComfyNode):
"""Image-to-video setup for AR video models (Causal Forcing, Self-Forcing).
VAE-encodes the start image and stores it in the model's transformer_options
so that sample_ar_video can seed the KV cache before denoising.
Uses the same T2V model checkpoint -- no separate I2V architecture needed.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ARVideoI2V",
category="conditioning/video_models",
inputs=[
io.Model.Input("model"),
io.Vae.Input("vae"),
io.Image.Input("start_image"),
io.Int.Input("width", default=832, min=16, max=8192, step=16),
io.Int.Input("height", default=480, min=16, max=8192, step=16),
io.Int.Input("length", default=81, min=1, max=1024, step=4),
io.Int.Input("batch_size", default=1, min=1, max=64),
],
outputs=[
io.Model.Output(display_name="MODEL"),
io.Latent.Output(display_name="LATENT"),
],
)
@classmethod
def execute(cls, model, vae, start_image, width, height, length, batch_size) -> io.NodeOutput:
start_image = comfy.utils.common_upscale(
start_image[:1].movedim(-1, 1), width, height, "bilinear", "center"
).movedim(1, -1)
initial_latent = vae.encode(start_image[:, :, :, :3])
m = model.clone()
to = m.model_options.setdefault("transformer_options", {})
ar_cfg = to.setdefault("ar_config", {})
ar_cfg["initial_latent"] = initial_latent
lat_t = ((length - 1) // 4) + 1
latent = torch.zeros(
[batch_size, 16, lat_t, height // 8, width // 8],
device=comfy.model_management.intermediate_device(),
)
return io.NodeOutput(m, {"samples": latent})
class ARVideoExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyARVideoLatent,
SamplerARVideo,
ARVideoI2V,
]

View File

@ -25,7 +25,7 @@ class UNetSelfAttentionMultiply(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="UNetSelfAttentionMultiply",
category="_for_testing/attention_experiments",
category="experimental/attention_experiments",
inputs=[
io.Model.Input("model"),
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
@ -48,7 +48,7 @@ class UNetCrossAttentionMultiply(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="UNetCrossAttentionMultiply",
category="_for_testing/attention_experiments",
category="experimental/attention_experiments",
inputs=[
io.Model.Input("model"),
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
@ -72,7 +72,7 @@ class CLIPAttentionMultiply(io.ComfyNode):
return io.Schema(
node_id="CLIPAttentionMultiply",
search_aliases=["clip attention scale", "text encoder attention"],
category="_for_testing/attention_experiments",
category="experimental/attention_experiments",
inputs=[
io.Clip.Input("clip"),
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
@ -106,7 +106,7 @@ class UNetTemporalAttentionMultiply(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="UNetTemporalAttentionMultiply",
category="_for_testing/attention_experiments",
category="experimental/attention_experiments",
inputs=[
io.Model.Input("model"),
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),

View File

@ -10,6 +10,7 @@ class AudioEncoderLoader(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="AudioEncoderLoader",
display_name="Load Audio Encoder",
category="loaders",
inputs=[
io.Combo.Input(

View File

@ -0,0 +1,60 @@
import folder_paths
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy.bg_removal_model import load
class LoadBackgroundRemovalModel(IO.ComfyNode):
@classmethod
def define_schema(cls):
files = folder_paths.get_filename_list("background_removal")
return IO.Schema(
node_id="LoadBackgroundRemovalModel",
display_name="Load Background Removal Model",
category="loaders",
inputs=[
IO.Combo.Input("bg_removal_name", options=sorted(files), tooltip="The model used to remove backgrounds from images"),
],
outputs=[
IO.BackgroundRemoval.Output("bg_model")
]
)
@classmethod
def execute(cls, bg_removal_name):
path = folder_paths.get_full_path_or_raise("background_removal", bg_removal_name)
bg = load(path)
if bg is None:
raise RuntimeError("ERROR: background model file is invalid and does not contain a valid background removal model.")
return IO.NodeOutput(bg)
class RemoveBackground(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RemoveBackground",
display_name="Remove Background",
category="image/background removal",
inputs=[
IO.Image.Input("image", tooltip="Input image to remove the background from"),
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask")
],
outputs=[
IO.Mask.Output("mask", tooltip="Generated foreground mask")
]
)
@classmethod
def execute(cls, image, bg_removal_model):
mask = bg_removal_model.encode_image(image)
return IO.NodeOutput(mask)
class BackgroundRemovalExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
LoadBackgroundRemovalModel,
RemoveBackground
]
async def comfy_entrypoint() -> BackgroundRemovalExtension:
return BackgroundRemovalExtension()

View File

@ -153,7 +153,7 @@ class WanCameraEmbedding(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="WanCameraEmbedding",
category="camera",
category="conditioning/video_models",
inputs=[
io.Combo.Input(
"camera_pose",

View File

@ -203,7 +203,7 @@ class JoinImageWithAlpha(io.ComfyNode):
@classmethod
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
batch_size = max(len(image), len(alpha))
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
alpha = 1.0 - resize_mask(alpha.to(image), image.shape[1:])
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
image = comfy.utils.repeat_to_batch_size(image, batch_size)
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))

View File

@ -8,7 +8,7 @@ class CLIPTextEncodeControlnet(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="CLIPTextEncodeControlnet",
category="_for_testing/conditioning",
category="experimental/conditioning",
inputs=[
io.Clip.Input("clip"),
io.Conditioning.Input("conditioning"),
@ -35,7 +35,7 @@ class T5TokenizerOptions(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="T5TokenizerOptions",
category="_for_testing/conditioning",
category="experimental/conditioning",
inputs=[
io.Clip.Input("clip"),
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True),

View File

@ -10,7 +10,7 @@ class ContextWindowsManualNode(io.ComfyNode):
return io.Schema(
node_id="ContextWindowsManual",
display_name="Context Windows (Manual)",
category="context",
category="model_patches",
description="Manually set context windows.",
inputs=[
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),

View File

@ -984,7 +984,7 @@ class AddNoise(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="AddNoise",
category="_for_testing/custom_sampling/noise",
category="experimental/custom_sampling/noise",
is_experimental=True,
inputs=[
io.Model.Input("model"),
@ -1034,7 +1034,7 @@ class ManualSigmas(io.ComfyNode):
return io.Schema(
node_id="ManualSigmas",
search_aliases=["custom noise schedule", "define sigmas"],
category="_for_testing/custom_sampling",
category="experimental/custom_sampling",
is_experimental=True,
inputs=[
io.String.Input("sigmas", default="1, 0.5", multiline=False)

View File

@ -13,7 +13,7 @@ class DifferentialDiffusion(io.ComfyNode):
node_id="DifferentialDiffusion",
search_aliases=["inpaint gradient", "variable denoise strength"],
display_name="Differential Diffusion",
category="_for_testing",
category="experimental",
inputs=[
io.Model.Input("model"),
io.Float.Input(

View File

@ -102,7 +102,7 @@ class FluxDisableGuidance(io.ComfyNode):
append = execute # TODO: remove
PREFERED_KONTEXT_RESOLUTIONS = [
PREFERRED_KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
@ -143,7 +143,7 @@ class FluxKontextImageScale(io.ComfyNode):
width = image.shape[2]
height = image.shape[1]
aspect_ratio = width / height
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS)
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
return io.NodeOutput(image)

View File

@ -60,7 +60,7 @@ class FreSca(io.ComfyNode):
node_id="FreSca",
search_aliases=["frequency guidance"],
display_name="FreSca",
category="_for_testing",
category="experimental",
description="Applies frequency-dependent scaling to the guidance",
inputs=[
io.Model.Input("model"),

View File

@ -131,6 +131,8 @@ class HunyuanVideo15SuperResolution(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="HunyuanVideo15SuperResolution",
display_name="Hunyuan Video 1.5 Super Resolution",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
@ -381,6 +383,8 @@ class HunyuanRefinerLatent(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="HunyuanRefinerLatent",
display_name="Hunyuan Latent Refiner",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),

View File

@ -40,7 +40,7 @@ class Hunyuan3Dv2Conditioning(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="Hunyuan3Dv2Conditioning",
category="conditioning/video_models",
category="conditioning/3d_models",
inputs=[
IO.ClipVisionOutput.Input("clip_vision_output"),
],
@ -65,7 +65,7 @@ class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="Hunyuan3Dv2ConditioningMultiView",
category="conditioning/video_models",
category="conditioning/3d_models",
inputs=[
IO.ClipVisionOutput.Input("front", optional=True),
IO.ClipVisionOutput.Input("left", optional=True),
@ -424,6 +424,7 @@ class VoxelToMeshBasic(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="VoxelToMeshBasic",
display_name="Voxel to Mesh (Basic)",
category="3d",
inputs=[
IO.Voxel.Input("voxel"),
@ -453,6 +454,7 @@ class VoxelToMesh(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="VoxelToMesh",
display_name="Voxel to Mesh",
category="3d",
inputs=[
IO.Voxel.Input("voxel"),

View File

@ -102,6 +102,7 @@ class HypernetworkLoader(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="HypernetworkLoader",
display_name="Load Hypernetwork",
category="loaders",
inputs=[
IO.Model.Input("model"),

View File

@ -91,7 +91,7 @@ class LoraSave(io.ComfyNode):
node_id="LoraSave",
search_aliases=["export lora"],
display_name="Extract and Save Lora",
category="_for_testing",
category="experimental",
inputs=[
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
io.Int.Input("rank", default=8, min=1, max=4096, step=1, advanced=True),

View File

@ -106,12 +106,12 @@ class LTXVImgToVideoInplace(io.ComfyNode):
if bypass:
return (latent,)
samples = latent["samples"]
samples = latent["samples"].clone()
_, height_scale_factor, width_scale_factor = (
vae.downscale_index_formula
)
batch, _, latent_frames, latent_height, latent_width = samples.shape
_, _, _, latent_height, latent_width = samples.shape
width = latent_width * width_scale_factor
height = latent_height * height_scale_factor
@ -124,11 +124,7 @@ class LTXVImgToVideoInplace(io.ComfyNode):
samples[:, :, :t.shape[2]] = t
conditioning_latent_frames_mask = torch.ones(
(batch, 1, latent_frames, 1, 1),
dtype=torch.float32,
device=samples.device,
)
conditioning_latent_frames_mask = get_noise_mask(latent)
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
@ -236,7 +232,7 @@ class LTXVAddGuide(io.ComfyNode):
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
return encode_pixels, t
@ -594,7 +590,8 @@ class LTXVPreprocess(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LTXVPreprocess",
category="image",
display_name="LTXV Preprocess",
category="video/preprocessors",
inputs=[
io.Image.Input("image"),
io.Int.Input(

View File

@ -11,7 +11,7 @@ class Mahiro(io.ComfyNode):
return io.Schema(
node_id="Mahiro",
display_name="Positive-Biased Guidance",
category="_for_testing",
category="experimental",
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
inputs=[
io.Model.Input("model"),

View File

@ -40,10 +40,21 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
inverse_mask = torch.ones_like(mask) - mask
source_portion = mask * source[..., :visible_height, :visible_width]
destination_portion = inverse_mask * destination[..., top:bottom, left:right]
source_rgb = source[:, :3, :visible_height, :visible_width]
dest_slice = destination[..., top:bottom, left:right]
if destination.shape[1] == 4:
if torch.max(dest_slice) == 0:
destination[:, :3, top:bottom, left:right] = source_rgb
destination[:, 3:4, top:bottom, left:right] = mask
else:
destination[:, :3, top:bottom, left:right] = (mask * source_rgb) + (inverse_mask * dest_slice[:, :3])
destination[:, 3:4, top:bottom, left:right] = torch.max(mask, dest_slice[:, 3:4])
else:
source_portion = mask * source_rgb
destination_portion = inverse_mask * dest_slice
destination[..., top:bottom, left:right] = source_portion + destination_portion
destination[..., top:bottom, left:right] = source_portion + destination_portion
return destination
class LatentCompositeMasked(IO.ComfyNode):
@ -84,18 +95,23 @@ class ImageCompositeMasked(IO.ComfyNode):
display_name="Image Composite Masked",
category="image",
inputs=[
IO.Image.Input("destination"),
IO.Image.Input("source"),
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Boolean.Input("resize_source", default=False),
IO.Image.Input("destination", optional=True),
IO.Mask.Input("mask", optional=True),
],
outputs=[IO.Image.Output()],
)
@classmethod
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
def execute(cls, source, x, y, resize_source, destination = None, mask = None) -> IO.NodeOutput:
if destination is None: # transparent rgba
B, H, W, C = source.shape
destination = torch.zeros((B, H, W, 4), dtype=source.dtype, device=source.device)
if C == 3:
source = torch.nn.functional.pad(source, (0, 1), value=1.0)
destination, source = node_helpers.image_alpha_fix(destination, source)
destination = destination.clone().movedim(-1, 1)
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
@ -381,7 +397,6 @@ class GrowMask(IO.ComfyNode):
expand_mask = execute # TODO: remove
class ThresholdMask(IO.ComfyNode):
@classmethod
def define_schema(cls):

View File

@ -70,7 +70,7 @@ class MathExpressionNode(io.ComfyNode):
return io.Schema(
node_id="ComfyMathExpression",
display_name="Math Expression",
category="math",
category="logic",
search_aliases=[
"expression", "formula", "calculate", "calculator",
"eval", "math",

View File

@ -21,7 +21,7 @@ class NumberConvertNode(io.ComfyNode):
return io.Schema(
node_id="ComfyNumberConvert",
display_name="Number Convert",
category="math",
category="utils",
search_aliases=[
"int to float", "float to int", "number convert",
"int2float", "float2int", "cast", "parse number",

View File

@ -24,8 +24,8 @@ class PerpNeg(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="PerpNeg",
display_name="Perp-Neg (DEPRECATED by PerpNegGuider)",
category="_for_testing",
display_name="Perp-Neg (DEPRECATED by Perp-Neg Guider)",
category="experimental",
inputs=[
io.Model.Input("model"),
io.Conditioning.Input("empty_conditioning"),
@ -127,7 +127,8 @@ class PerpNegGuider(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="PerpNegGuider",
category="_for_testing",
display_name="Perp-Neg Guider",
category="experimental",
inputs=[
io.Model.Input("model"),
io.Conditioning.Input("positive"),

View File

@ -123,7 +123,7 @@ class PhotoMakerLoader(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="PhotoMakerLoader",
category="_for_testing/photomaker",
category="experimental/photomaker",
inputs=[
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
],
@ -149,7 +149,7 @@ class PhotoMakerEncode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="PhotoMakerEncode",
category="_for_testing/photomaker",
category="experimental/photomaker",
inputs=[
io.Photomaker.Input("photomaker"),
io.Image.Input("image"),

View File

@ -116,6 +116,7 @@ class Quantize(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ImageQuantize",
display_name="Quantize Image",
category="image/postprocessing",
inputs=[
io.Image.Input("image"),
@ -181,6 +182,7 @@ class Sharpen(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ImageSharpen",
display_name="Sharpen Image",
category="image/postprocessing",
inputs=[
io.Image.Input("image"),
@ -436,7 +438,7 @@ class ResizeImageMaskNode(io.ComfyNode):
node_id="ResizeImageMaskNode",
display_name="Resize Image/Mask",
description="Resize an image or mask using various scaling methods.",
category="transform",
category="image/transform",
search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"],
inputs=[
io.MatchType.Input("input", template=template),

View File

@ -15,7 +15,7 @@ class RTDETR_detect(io.ComfyNode):
return io.Schema(
node_id="RTDETR_detect",
display_name="RT-DETR Detect",
category="detection/",
category="detection",
search_aliases=["bbox", "bounding box", "object detection", "coco"],
inputs=[
io.Model.Input("model", display_name="model"),
@ -71,7 +71,7 @@ class DrawBBoxes(io.ComfyNode):
return io.Schema(
node_id="DrawBBoxes",
display_name="Draw BBoxes",
category="detection/",
category="detection",
search_aliases=["bbox", "bounding box", "object detection", "rt_detr", "visualize detections", "coco"],
inputs=[
io.Image.Input("image", optional=True),

View File

@ -113,7 +113,7 @@ class SelfAttentionGuidance(io.ComfyNode):
return io.Schema(
node_id="SelfAttentionGuidance",
display_name="Self-Attention Guidance",
category="_for_testing",
category="experimental",
inputs=[
io.Model.Input("model"),
io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01),

View File

@ -93,7 +93,7 @@ class SAM3_Detect(io.ComfyNode):
return io.Schema(
node_id="SAM3_Detect",
display_name="SAM3 Detect",
category="detection/",
category="detection",
search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"],
inputs=[
io.Model.Input("model", display_name="model"),
@ -265,15 +265,15 @@ class SAM3_VideoTrack(io.ComfyNode):
return io.Schema(
node_id="SAM3_VideoTrack",
display_name="SAM3 Video Track",
category="detection/",
category="detection",
search_aliases=["sam3", "video", "track", "propagate"],
inputs=[
io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"),
io.Model.Input("model", display_name="model"),
io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"),
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"),
io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection"),
io.Int.Input("max_objects", display_name="max_objects", default=0, min=0, tooltip="Max tracked objects (0=unlimited). Initial masks count toward this limit."),
io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection."),
io.Int.Input("max_objects", display_name="max_objects", default=4, min=0, max=64, tooltip="Max tracked objects. Initial masks count toward this limit. 0 uses the internal cap of 64."),
io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."),
],
outputs=[
@ -290,8 +290,7 @@ class SAM3_VideoTrack(io.ComfyNode):
dtype = model.model.get_dtype()
sam3_model = model.model.diffusion_model
frames = images[..., :3].movedim(-1, 1)
frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype)
frames_in = images[..., :3].movedim(-1, 1)
init_masks = None
if initial_mask is not None:
@ -308,7 +307,7 @@ class SAM3_VideoTrack(io.ComfyNode):
result = sam3_model.forward_video(
images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts,
new_det_thresh=detection_threshold, max_objects=max_objects,
detect_interval=detect_interval)
detect_interval=detect_interval, target_device=device, target_dtype=dtype)
result["orig_size"] = (H, W)
return io.NodeOutput(result)
@ -321,7 +320,7 @@ class SAM3_TrackPreview(io.ComfyNode):
return io.Schema(
node_id="SAM3_TrackPreview",
display_name="SAM3 Track Preview",
category="detection/",
category="detection",
inputs=[
SAM3TrackData.Input("track_data", display_name="track_data"),
io.Image.Input("images", display_name="images", optional=True),
@ -449,14 +448,18 @@ class SAM3_TrackPreview(io.ComfyNode):
cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area
has = area > 1
scores = track_data.get("scores", [])
label_scale = max(3, H // 240) # Scale font with resolutio
size_caps = (area.float().sqrt() / 15).clamp_(min=1).long().tolist() #cap per-object so the number doesn't dwarf small masks
for obj_idx in range(N_obj):
if has[obj_idx]:
_cx, _cy = int(cx[obj_idx]), int(cy[obj_idx])
color = cls.COLORS[obj_idx % len(cls.COLORS)]
SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color)
obj_scale = min(label_scale, size_caps[obj_idx])
score_scale = max(1, obj_scale * 2 // 3)
SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color, scale=obj_scale)
if obj_idx < len(scores) and scores[obj_idx] < 1.0:
SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100),
_cx, _cy + 5 * 3 + 3, color, scale=2)
_cx, _cy + 5 * obj_scale + 3, color, scale=score_scale)
frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte())
else:
frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte())
@ -475,7 +478,7 @@ class SAM3_TrackToMask(io.ComfyNode):
return io.Schema(
node_id="SAM3_TrackToMask",
display_name="SAM3 Track to Mask",
category="detection/",
category="detection",
inputs=[
SAM3TrackData.Input("track_data", display_name="track_data"),
io.String.Input("object_indices", display_name="object_indices", default="",
@ -507,9 +510,10 @@ class SAM3_TrackToMask(io.ComfyNode):
if not indices:
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
selected = packed[:, indices]
binary = unpack_masks(selected) # [N, len(indices), Hm, Wm] bool
union = binary.any(dim=1, keepdim=True).float()
union_packed = packed[:, indices[0]].clone()
for i in indices[1:]:
union_packed |= packed[:, i]
union = unpack_masks(union_packed).unsqueeze(1).float() # [N, 1, Hm, Wm]
mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0]
return io.NodeOutput(mask_out)

View File

@ -119,7 +119,7 @@ class StableCascade_SuperResolutionControlnet(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StableCascade_SuperResolutionControlnet",
category="_for_testing/stable_cascade",
category="experimental/stable_cascade",
is_experimental=True,
inputs=[
io.Image.Input("image"),

View File

@ -26,7 +26,8 @@ class TextGenerate(io.ComfyNode):
return io.Schema(
node_id="TextGenerate",
category="textgen",
display_name="Generate Text",
category="text",
search_aliases=["LLM", "gemma"],
inputs=[
io.Clip.Input("clip"),
@ -157,6 +158,7 @@ class TextGenerateLTX2Prompt(TextGenerate):
parent_schema = super().define_schema()
return io.Schema(
node_id="TextGenerateLTX2Prompt",
display_name="Generate LTX2 Prompt",
category=parent_schema.category,
inputs=parent_schema.inputs,
outputs=parent_schema.outputs,

View File

@ -10,7 +10,7 @@ class TorchCompileModel(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="TorchCompileModel",
category="_for_testing",
category="experimental",
inputs=[
io.Model.Input("model"),
io.Combo.Input(

View File

@ -1361,7 +1361,7 @@ class SaveLoRA(io.ComfyNode):
node_id="SaveLoRA",
search_aliases=["export lora"],
display_name="Save LoRA Weights",
category="loaders",
category="advanced/model_merging",
is_experimental=True,
is_output_node=True,
inputs=[

View File

@ -15,7 +15,7 @@ class ImageOnlyCheckpointLoader:
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "loaders/video_models"
CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)

View File

@ -22,7 +22,7 @@ class SaveImageWebsocket:
OUTPUT_NODE = True
CATEGORY = "api/image"
CATEGORY = "image"
def save_images(self, images):
pbar = comfy.utils.ProgressBar(images.shape[0])
@ -42,3 +42,7 @@ class SaveImageWebsocket:
NODE_CLASS_MAPPINGS = {
"SaveImageWebsocket": SaveImageWebsocket,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SaveImageWebsocket": "Save Image (Websocket)",
}

View File

@ -52,6 +52,8 @@ folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patc
folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions)
folder_names_and_paths["background_removal"] = ([os.path.join(models_dir, "background_removal")], supported_pt_extensions)
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)

View File

@ -330,7 +330,7 @@ class VAEDecodeTiled:
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
CATEGORY = "_for_testing"
CATEGORY = "experimental"
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
if tile_size < overlap * 4:
@ -377,7 +377,7 @@ class VAEEncodeTiled:
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "_for_testing"
CATEGORY = "experimental"
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
@ -493,7 +493,7 @@ class SaveLatent:
OUTPUT_NODE = True
CATEGORY = "_for_testing"
CATEGORY = "experimental"
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
@ -538,7 +538,7 @@ class LoadLatent:
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
return {"required": {"latent": [sorted(files), ]}, }
CATEGORY = "_for_testing"
CATEGORY = "experimental"
RETURN_TYPES = ("LATENT", )
FUNCTION = "load"
@ -1443,7 +1443,7 @@ class LatentBlend:
RETURN_TYPES = ("LATENT",)
FUNCTION = "blend"
CATEGORY = "_for_testing"
CATEGORY = "experimental"
def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
@ -2092,6 +2092,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"StyleModelLoader": "Load Style Model",
"CLIPVisionLoader": "Load CLIP Vision",
"UNETLoader": "Load Diffusion Model",
"unCLIPCheckpointLoader": "Load unCLIP Checkpoint",
"GLIGENLoader": "Load GLIGEN Model",
# Conditioning
"CLIPVisionEncode": "CLIP Vision Encode",
"StyleModelApply": "Apply Style Model",
@ -2140,7 +2142,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImageSharpen": "Sharpen Image",
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
"GetImageSize": "Get Image Size",
# _for_testing
# experimental
"VAEDecodeTiled": "VAE Decode (Tiled)",
"VAEEncodeTiled": "VAE Encode (Tiled)",
}
@ -2427,6 +2429,7 @@ async def init_builtin_extra_nodes():
"nodes_number_convert.py",
"nodes_painter.py",
"nodes_curve.py",
"nodes_bg_removal.py",
"nodes_rtdetr.py",
"nodes_frame_interpolation.py",
"nodes_sam3.py",

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,90 @@
"""Tests for NodeReplaceManager registration behavior."""
import importlib
import sys
import types
import pytest
@pytest.fixture
def NodeReplaceManager(monkeypatch):
"""Provide NodeReplaceManager with `nodes` stubbed.
`app.node_replace_manager` does `import nodes` at module level, which pulls in
torch + the full ComfyUI graph. register() doesn't actually need it, so we
stub `nodes` per-test (via monkeypatch so it's torn down) and reload the
module so it picks up the stub instead of any cached real import.
"""
fake_nodes = types.ModuleType("nodes")
fake_nodes.NODE_CLASS_MAPPINGS = {}
monkeypatch.setitem(sys.modules, "nodes", fake_nodes)
monkeypatch.delitem(sys.modules, "app.node_replace_manager", raising=False)
module = importlib.import_module("app.node_replace_manager")
yield module.NodeReplaceManager
# Drop the freshly-imported module so the next test (or a later real import
# of `nodes`) starts from a clean slate.
sys.modules.pop("app.node_replace_manager", None)
class FakeNodeReplace:
"""Lightweight stand-in for comfy_api.latest._io.NodeReplace."""
def __init__(self, new_node_id, old_node_id, old_widget_ids=None,
input_mapping=None, output_mapping=None):
self.new_node_id = new_node_id
self.old_node_id = old_node_id
self.old_widget_ids = old_widget_ids
self.input_mapping = input_mapping
self.output_mapping = output_mapping
def test_register_adds_replacement(NodeReplaceManager):
manager = NodeReplaceManager()
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
assert manager.has_replacement("OldNode")
assert len(manager.get_replacement("OldNode")) == 1
def test_register_allows_multiple_alternatives_for_same_old_node(NodeReplaceManager):
"""Different new_node_ids for the same old_node_id should all be kept."""
manager = NodeReplaceManager()
manager.register(FakeNodeReplace(new_node_id="AltA", old_node_id="OldNode"))
manager.register(FakeNodeReplace(new_node_id="AltB", old_node_id="OldNode"))
replacements = manager.get_replacement("OldNode")
assert len(replacements) == 2
assert {r.new_node_id for r in replacements} == {"AltA", "AltB"}
def test_register_is_idempotent_for_duplicate_pair(NodeReplaceManager):
"""Re-registering the same (old_node_id, new_node_id) should be a no-op."""
manager = NodeReplaceManager()
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
assert len(manager.get_replacement("OldNode")) == 1
def test_register_idempotent_preserves_first_registration(NodeReplaceManager):
"""First registration wins; later duplicates with different mappings are ignored."""
manager = NodeReplaceManager()
first = FakeNodeReplace(
new_node_id="NewNode", old_node_id="OldNode",
input_mapping=[{"new_id": "a", "old_id": "x"}],
)
second = FakeNodeReplace(
new_node_id="NewNode", old_node_id="OldNode",
input_mapping=[{"new_id": "b", "old_id": "y"}],
)
manager.register(first)
manager.register(second)
replacements = manager.get_replacement("OldNode")
assert len(replacements) == 1
assert replacements[0] is first
def test_register_dedupe_does_not_affect_other_old_nodes(NodeReplaceManager):
manager = NodeReplaceManager()
manager.register(FakeNodeReplace(new_node_id="NewA", old_node_id="OldA"))
manager.register(FakeNodeReplace(new_node_id="NewA", old_node_id="OldA"))
manager.register(FakeNodeReplace(new_node_id="NewB", old_node_id="OldB"))
assert len(manager.get_replacement("OldA")) == 1
assert len(manager.get_replacement("OldB")) == 1

View File

@ -0,0 +1,153 @@
import uuid
import requests
def test_set_preview_success(
http: requests.Session, api_base: str, asset_factory, make_asset_bytes
):
"""PUT /api/assets/{id}/preview sets the preview and returns the full Asset."""
main_data = make_asset_bytes("main_asset.png", 2048)
preview_data = make_asset_bytes("preview_asset.png", 1024)
main = asset_factory("main_asset.png", ["input", "unit-tests"], {}, main_data)
preview = asset_factory("preview_asset.png", ["input", "unit-tests"], {}, preview_data)
r = http.put(
f"{api_base}/api/assets/{main['id']}/preview",
json={"preview_id": preview["id"]},
timeout=120,
)
body = r.json()
assert r.status_code == 200, body
assert body["id"] == main["id"]
assert body["preview_id"] == preview["id"]
def test_clear_preview_success(
http: requests.Session, api_base: str, asset_factory, make_asset_bytes
):
"""DELETE /api/assets/{id}/preview clears the preview and returns 204."""
main_data = make_asset_bytes("main_clear.png", 2048)
preview_data = make_asset_bytes("prev_clear.png", 1024)
main = asset_factory("main_clear.png", ["input", "unit-tests"], {}, main_data)
preview = asset_factory("prev_clear.png", ["input", "unit-tests"], {}, preview_data)
# First set a preview
r_set = http.put(
f"{api_base}/api/assets/{main['id']}/preview",
json={"preview_id": preview["id"]},
timeout=120,
)
assert r_set.status_code == 200
# Now clear it
r_del = http.delete(f"{api_base}/api/assets/{main['id']}/preview", timeout=120)
assert r_del.status_code == 204
# Verify preview_id is cleared
r_get = http.get(f"{api_base}/api/assets/{main['id']}", timeout=120)
detail = r_get.json()
assert r_get.status_code == 200, detail
assert detail.get("preview_id") is None
def test_set_preview_asset_not_found(http: requests.Session, api_base: str):
"""PUT /api/assets/{id}/preview returns 404 for non-existent asset ID."""
fake_id = str(uuid.uuid4())
r = http.put(
f"{api_base}/api/assets/{fake_id}/preview",
json={"preview_id": str(uuid.uuid4())},
timeout=120,
)
body = r.json()
assert r.status_code == 404
assert body["error"]["code"] == "ASSET_NOT_FOUND"
def test_clear_preview_asset_not_found(http: requests.Session, api_base: str):
"""DELETE /api/assets/{id}/preview returns 404 for non-existent asset ID."""
fake_id = str(uuid.uuid4())
r = http.delete(f"{api_base}/api/assets/{fake_id}/preview", timeout=120)
body = r.json()
assert r.status_code == 404
assert body["error"]["code"] == "ASSET_NOT_FOUND"
def test_set_preview_missing_body(
http: requests.Session, api_base: str, asset_factory, make_asset_bytes
):
"""PUT /api/assets/{id}/preview returns 400 when preview_id is not provided."""
main_data = make_asset_bytes("main_nobody.png", 2048)
main = asset_factory("main_nobody.png", ["input", "unit-tests"], {}, main_data)
r = http.put(
f"{api_base}/api/assets/{main['id']}/preview",
json={},
timeout=120,
)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] == "INVALID_BODY"
def test_set_preview_invalid_preview_id(
http: requests.Session, api_base: str, asset_factory, make_asset_bytes
):
"""PUT /api/assets/{id}/preview returns 404 when preview_id references non-existent asset."""
main_data = make_asset_bytes("main_badprev.png", 2048)
main = asset_factory("main_badprev.png", ["input", "unit-tests"], {}, main_data)
fake_preview_id = str(uuid.uuid4())
r = http.put(
f"{api_base}/api/assets/{main['id']}/preview",
json={"preview_id": fake_preview_id},
timeout=120,
)
body = r.json()
assert r.status_code == 404
assert body["error"]["code"] == "ASSET_NOT_FOUND"
def test_clear_preview_idempotent_when_no_preview_set(
http: requests.Session, api_base: str, asset_factory, make_asset_bytes
):
"""DELETE /api/assets/{id}/preview returns 204 even if no preview was previously set."""
main_data = make_asset_bytes("main_idem_clear.png", 2048)
main = asset_factory("main_idem_clear.png", ["input", "unit-tests"], {}, main_data)
# Asset has no preview at this point — DELETE should still succeed.
r_first = http.delete(f"{api_base}/api/assets/{main['id']}/preview", timeout=120)
assert r_first.status_code == 204, r_first.text
# And a second DELETE on the same asset should also be a 204 no-op.
r_second = http.delete(f"{api_base}/api/assets/{main['id']}/preview", timeout=120)
assert r_second.status_code == 204, r_second.text
def test_set_preview_idempotent_with_same_id(
http: requests.Session, api_base: str, asset_factory, make_asset_bytes
):
"""PUT /api/assets/{id}/preview with the same preview_id twice is a 200 no-op."""
main_data = make_asset_bytes("main_idem_set.png", 2048)
preview_data = make_asset_bytes("prev_idem_set.png", 1024)
main = asset_factory("main_idem_set.png", ["input", "unit-tests"], {}, main_data)
preview = asset_factory("prev_idem_set.png", ["input", "unit-tests"], {}, preview_data)
r_first = http.put(
f"{api_base}/api/assets/{main['id']}/preview",
json={"preview_id": preview["id"]},
timeout=120,
)
assert r_first.status_code == 200, r_first.text
assert r_first.json()["preview_id"] == preview["id"]
r_second = http.put(
f"{api_base}/api/assets/{main['id']}/preview",
json={"preview_id": preview["id"]},
timeout=120,
)
assert r_second.status_code == 200, r_second.text
assert r_second.json()["preview_id"] == preview["id"]

View File

@ -21,7 +21,7 @@ class TestAsyncProgressUpdate(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,)
FUNCTION = "execute"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
async def execute(self, value, sleep_seconds):
start = time.time()
@ -51,7 +51,7 @@ class TestSyncProgressUpdate(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,)
FUNCTION = "execute"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
def execute(self, value, sleep_seconds):
start = time.time()

View File

@ -21,7 +21,7 @@ class TestAsyncValidation(ComfyNodeABC):
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
@classmethod
async def VALIDATE_INPUTS(cls, value, threshold):
@ -53,7 +53,7 @@ class TestAsyncError(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,)
FUNCTION = "error_execution"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
async def error_execution(self, value, error_after):
await asyncio.sleep(error_after)
@ -74,7 +74,7 @@ class TestAsyncValidationError(ComfyNodeABC):
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
@classmethod
async def VALIDATE_INPUTS(cls, value, max_value):
@ -105,7 +105,7 @@ class TestAsyncTimeout(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,)
FUNCTION = "timeout_execution"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
async def timeout_execution(self, value, timeout, operation_time):
try:
@ -129,7 +129,7 @@ class TestSyncError(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,)
FUNCTION = "sync_error"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
def sync_error(self, value):
raise RuntimeError("Intentional sync execution error for testing")
@ -150,7 +150,7 @@ class TestAsyncLazyCheck(ComfyNodeABC):
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
async def check_lazy_status(self, condition, input1, input2):
# Simulate async checking (e.g., querying remote service)
@ -184,7 +184,7 @@ class TestDynamicAsyncGeneration(ComfyNodeABC):
RETURN_TYPES = ("IMAGE",)
FUNCTION = "generate_async_workflow"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
g = GraphBuilder()
@ -229,7 +229,7 @@ class TestAsyncResourceUser(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,)
FUNCTION = "use_resource"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
async def use_resource(self, value, resource_id, duration):
# Check if resource is already in use
@ -265,7 +265,7 @@ class TestAsyncBatchProcessing(ComfyNodeABC):
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process_batch"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
async def process_batch(self, images, process_time_per_item, unique_id):
batch_size = images.shape[0]
@ -305,7 +305,7 @@ class TestAsyncConcurrentLimit(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,)
FUNCTION = "limited_execution"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
async def limited_execution(self, value, duration, node_id):
async with self._semaphore:

View File

@ -409,7 +409,7 @@ class TestSleep(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,)
FUNCTION = "sleep"
CATEGORY = "_for_testing"
CATEGORY = "experimental"
async def sleep(self, value, seconds, unique_id):
pbar = ProgressBar(seconds, node_id=unique_id)
@ -440,7 +440,7 @@ class TestParallelSleep(ComfyNodeABC):
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "parallel_sleep"
CATEGORY = "_for_testing"
CATEGORY = "experimental"
OUTPUT_NODE = True
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
@ -474,7 +474,7 @@ class TestOutputNodeWithSocketOutput:
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing"
CATEGORY = "experimental"
OUTPUT_NODE = True
def process(self, image, value):