Merge branch 'master' into ltxv_self_attn_mask

This commit is contained in:
Jukka Seppänen 2026-05-09 00:47:19 +03:00 committed by GitHub
commit 1e76c3b9c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
58 changed files with 6193 additions and 146 deletions

View File

@ -27,7 +27,7 @@ def frontend_install_warning_message():
return f"""
{get_missing_requirements_message()}
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
The ComfyUI frontend is shipped in a pip package so it needs to be updated separately from the ComfyUI code.
""".strip()
def parse_version(version: str) -> tuple[int, int, int]:

View File

@ -1,5 +1,7 @@
from __future__ import annotations
import logging
from aiohttp import web
from typing import TYPE_CHECKING, TypedDict
@ -31,8 +33,22 @@ class NodeReplaceManager:
self._replacements: dict[str, list[NodeReplace]] = {}
def register(self, node_replace: NodeReplace):
"""Register a node replacement mapping."""
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
"""Register a node replacement mapping.
Idempotent: if a replacement with the same (old_node_id, new_node_id)
is already registered, the duplicate is ignored. This prevents stale
entries from accumulating when custom nodes are reloaded in the same
process (e.g. via ComfyUI-Manager).
"""
existing = self._replacements.setdefault(node_replace.old_node_id, [])
for entry in existing:
if entry.new_node_id == node_replace.new_node_id:
logging.debug(
"Node replacement %s -> %s already registered, ignoring duplicate.",
node_replace.old_node_id, node_replace.new_node_id,
)
return
existing.append(node_replace)
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
"""Get replacements for an old node ID."""

View File

@ -0,0 +1,7 @@
{
"model_type": "birefnet",
"image_std": [1.0, 1.0, 1.0],
"image_mean": [0.0, 0.0, 0.0],
"image_size": 1024,
"resize_to_original": true
}

View File

@ -0,0 +1,689 @@
import torch
import comfy.ops
import numpy as np
import torch.nn as nn
from functools import partial
import torch.nn.functional as F
from torchvision.ops import deform_conv2d
from comfy.ldm.modules.attention import optimized_attention_for_device
CXT = [3072, 1536, 768, 384][1:][::-1][-3:]
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
self.kv = operations.Linear(dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
def forward(self, x):
B, N, C = x.shape
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
x = optimized_attention(
q, k, v, heads=self.num_heads, skip_output_reshape=True, skip_reshape=True
).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, device=None, dtype=None, operations=None):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = operations.Linear(in_features, hidden_features, device=device, dtype=dtype)
self.act = nn.GELU()
self.fc2 = operations.Linear(hidden_features, out_features, device=device, dtype=dtype)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads, device=device, dtype=dtype))
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
return x
class SwinTransformerBlock(nn.Module):
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None,
norm_layer=nn.LayerNorm, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.norm1 = norm_layer(dim, device=device, dtype=dtype)
self.attn = WindowAttention(
dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, device=device, dtype=dtype, operations=operations)
self.norm2 = norm_layer(dim, device=device, dtype=dtype)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, device=device, dtype=dtype, operations=operations)
self.H = None
self.W = None
def forward(self, x, mask_matrix):
B, L, C = x.shape
H, W = self.H, self.W
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
x_windows = window_partition(shifted_x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
attn_windows = self.attn(x_windows, mask=attn_mask)
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class PatchMerging(nn.Module):
def __init__(self, dim, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.reduction = operations.Linear(4 * dim, 2 * dim, bias=False, device=device, dtype=dtype)
self.norm = operations.LayerNorm(4 * dim, device=device, dtype=dtype)
def forward(self, x, H, W):
B, L, C = x.shape
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(nn.Module):
def __init__(self,
dim,
depth,
num_heads,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
norm_layer=nn.LayerNorm,
downsample=None,
device=None, dtype=None, operations=None):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
norm_layer=norm_layer,
device=device, dtype=dtype, operations=operations)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, device=device, dtype=dtype, operations=operations)
else:
self.downsample = None
def forward(self, x, H, W):
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
for blk in self.blocks:
blk.H, blk.W = H, W
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
class PatchEmbed(nn.Module):
def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None, device=None, dtype=None, operations=None):
super().__init__()
patch_size = (patch_size, patch_size)
self.patch_size = patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype)
if norm_layer is not None:
self.norm = norm_layer(embed_dim, device=device, dtype=dtype)
else:
self.norm = None
def forward(self, x):
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x
class SwinTransformer(nn.Module):
def __init__(self,
pretrain_img_size=224,
patch_size=4,
in_channels=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
patch_norm=True,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
device=None, dtype=None, operations=None):
super().__init__()
norm_layer = partial(operations.LayerNorm, device=device, dtype=dtype)
self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.patch_embed = PatchEmbed(
patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
device=device, dtype=dtype, operations=operations,
norm_layer=norm_layer if self.patch_norm else None)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
device=device, dtype=dtype, operations=operations)
self.layers.append(layer)
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
self.num_features = num_features
for i_layer in out_indices:
layer = norm_layer(num_features[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
def forward(self, x):
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
outs = []
x = x.flatten(2).transpose(1, 2)
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)
return tuple(outs)
class DeformableConv2d(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False, device=None, dtype=None, operations=None):
super(DeformableConv2d, self).__init__()
kernel_size = kernel_size if type(kernel_size) is tuple else (kernel_size, kernel_size)
self.stride = stride if type(stride) is tuple else (stride, stride)
self.padding = padding
self.offset_conv = operations.Conv2d(in_channels,
2 * kernel_size[0] * kernel_size[1],
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=True, device=device, dtype=dtype)
self.modulator_conv = operations.Conv2d(in_channels,
1 * kernel_size[0] * kernel_size[1],
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=True, device=device, dtype=dtype)
self.regular_conv = operations.Conv2d(in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=bias, device=device, dtype=dtype)
def forward(self, x):
offset = self.offset_conv(x)
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
weight, bias, offload_info = comfy.ops.cast_bias_weight(self.regular_conv, x, offloadable=True)
x = deform_conv2d(
input=x,
offset=offset,
weight=weight,
bias=None,
padding=self.padding,
mask=modulator,
stride=self.stride,
)
comfy.ops.uncast_bias_weight(self.regular_conv, weight, bias, offload_info)
return x
class BasicDecBlk(nn.Module):
def __init__(self, in_channels=64, out_channels=64, inter_channels=64, device=None, dtype=None, operations=None):
super(BasicDecBlk, self).__init__()
inter_channels = 64
self.conv_in = operations.Conv2d(in_channels, inter_channels, 3, 1, padding=1, device=device, dtype=dtype)
self.relu_in = nn.ReLU(inplace=True)
self.dec_att = ASPPDeformable(in_channels=inter_channels, device=device, dtype=dtype, operations=operations)
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, padding=1, device=device, dtype=dtype)
self.bn_in = operations.BatchNorm2d(inter_channels, device=device, dtype=dtype)
self.bn_out = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
def forward(self, x):
x = self.conv_in(x)
x = self.bn_in(x)
x = self.relu_in(x)
x = self.dec_att(x)
x = self.conv_out(x)
x = self.bn_out(x)
return x
class BasicLatBlk(nn.Module):
def __init__(self, in_channels=64, out_channels=64, device=None, dtype=None, operations=None):
super(BasicLatBlk, self).__init__()
self.conv = operations.Conv2d(in_channels, out_channels, 1, 1, 0, device=device, dtype=dtype)
def forward(self, x):
x = self.conv(x)
return x
class _ASPPModuleDeformable(nn.Module):
def __init__(self, in_channels, planes, kernel_size, padding, device, dtype, operations):
super(_ASPPModuleDeformable, self).__init__()
self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size,
stride=1, padding=padding, bias=False, device=device, dtype=dtype, operations=operations)
self.bn = operations.BatchNorm2d(planes, device=device, dtype=dtype)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.atrous_conv(x)
x = self.bn(x)
return self.relu(x)
class ASPPDeformable(nn.Module):
def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7], device=None, dtype=None, operations=None):
super(ASPPDeformable, self).__init__()
self.down_scale = 1
if out_channels is None:
out_channels = in_channels
self.in_channelster = 256 // self.down_scale
self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0, device=device, dtype=dtype, operations=operations)
self.aspp_deforms = nn.ModuleList([
_ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2), device=device, dtype=dtype, operations=operations)
for conv_size in parallel_block_sizes
])
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
operations.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False, device=device, dtype=dtype),
operations.BatchNorm2d(self.in_channelster, device=device, dtype=dtype),
nn.ReLU(inplace=True))
self.conv1 = operations.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False, device=device, dtype=dtype)
self.bn1 = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x1 = self.aspp1(x)
x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
x5 = self.global_avg_pool(x)
x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x1, *x_aspp_deforms, x5), dim=1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
return x
class BiRefNet(nn.Module):
def __init__(self, config=None, dtype=None, device=None, operations=None):
super(BiRefNet, self).__init__()
self.bb = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12, device=device, dtype=dtype, operations=operations)
channels = [1536, 768, 384, 192]
channels = [c * 2 for c in channels]
self.cxt = channels[1:][::-1][-3:]
self.squeeze_module = nn.Sequential(*[
BasicDecBlk(channels[0]+sum(self.cxt), channels[0], device=device, dtype=dtype, operations=operations)
for _ in range(1)
])
self.decoder = Decoder(channels, device=device, dtype=dtype, operations=operations)
def forward_enc(self, x):
x1, x2, x3, x4 = self.bb(x)
B, C, H, W = x.shape
x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x4 = torch.cat(
(
*[
F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
][-len(CXT):],
x4
),
dim=1
)
return (x1, x2, x3, x4)
def forward_ori(self, x):
(x1, x2, x3, x4) = self.forward_enc(x)
x4 = self.squeeze_module(x4)
features = [x, x1, x2, x3, x4]
scaled_preds = self.decoder(features)
return scaled_preds
def forward(self, pixel_values, intermediate_output=None):
scaled_preds = self.forward_ori(pixel_values)
return scaled_preds
class Decoder(nn.Module):
def __init__(self, channels, device, dtype, operations):
super(Decoder, self).__init__()
# factory kwargs
fk = {"device":device, "dtype":dtype, "operations":operations}
DecoderBlock = partial(BasicDecBlk, **fk)
LateralBlock = partial(BasicLatBlk, **fk)
DBlock = partial(SimpleConvs, **fk)
self.split = True
N_dec_ipt = 64
ic = 64
ipt_cha_opt = 1
self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic)
self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[1])
self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[2])
self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt]), channels[3])
self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt]), channels[3]//2)
fk = {"device":device, "dtype":dtype}
self.conv_out1 = nn.Sequential(operations.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt]), 1, 1, 1, 0, **fk))
self.lateral_block4 = LateralBlock(channels[1], channels[1])
self.lateral_block3 = LateralBlock(channels[2], channels[2])
self.lateral_block2 = LateralBlock(channels[3], channels[3])
self.conv_ms_spvn_4 = operations.Conv2d(channels[1], 1, 1, 1, 0, **fk)
self.conv_ms_spvn_3 = operations.Conv2d(channels[2], 1, 1, 1, 0, **fk)
self.conv_ms_spvn_2 = operations.Conv2d(channels[3], 1, 1, 1, 0, **fk)
_N = 16
self.gdt_convs_4 = nn.Sequential(operations.Conv2d(channels[0] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
self.gdt_convs_3 = nn.Sequential(operations.Conv2d(channels[1] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
self.gdt_convs_2 = nn.Sequential(operations.Conv2d(channels[2] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
[setattr(self, f"gdt_convs_pred_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
[setattr(self, f"gdt_convs_attn_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
def get_patches_batch(self, x, p):
_size_h, _size_w = p.shape[2:]
patches_batch = []
for idx in range(x.shape[0]):
columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
patches_x = []
for column_x in columns_x:
patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
patch_sample = torch.cat(patches_x, dim=1)
patches_batch.append(patch_sample)
return torch.cat(patches_batch, dim=0)
def forward(self, features):
x, x1, x2, x3, x4 = features
patches_batch = self.get_patches_batch(x, x4) if self.split else x
x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
p4 = self.decoder_block4(x4)
p4_gdt = self.gdt_convs_4(p4)
gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
p4 = p4 * gdt_attn_4
_p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
_p3 = _p4 + self.lateral_block4(x3)
patches_batch = self.get_patches_batch(x, _p3) if self.split else x
_p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
p3 = self.decoder_block3(_p3)
p3_gdt = self.gdt_convs_3(p3)
gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
p3 = p3 * gdt_attn_3
_p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
_p2 = _p3 + self.lateral_block3(x2)
patches_batch = self.get_patches_batch(x, _p2) if self.split else x
_p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
p2 = self.decoder_block2(_p2)
p2_gdt = self.gdt_convs_2(p2)
gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
p2 = p2 * gdt_attn_2
_p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
_p1 = _p2 + self.lateral_block2(x1)
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
_p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
_p1 = self.decoder_block1(_p1)
_p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
_p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
p1_out = self.conv_out1(_p1)
return p1_out
class SimpleConvs(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, inter_channels=64, device=None, dtype=None, operations=None
) -> None:
super().__init__()
self.conv1 = operations.Conv2d(in_channels, inter_channels, 3, 1, 1, device=device, dtype=dtype)
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, 1, device=device, dtype=dtype)
def forward(self, x):
return self.conv_out(self.conv1(x))

78
comfy/bg_removal_model.py Normal file
View File

@ -0,0 +1,78 @@
from .utils import load_torch_file
import os
import json
import torch
import logging
import comfy.ops
import comfy.model_patcher
import comfy.model_management
import comfy.clip_model
import comfy.background_removal.birefnet
BG_REMOVAL_MODELS = {
"birefnet": comfy.background_removal.birefnet.BiRefNet
}
class BackgroundRemovalModel():
def __init__(self, json_config):
with open(json_config) as f:
config = json.load(f)
self.image_size = config.get("image_size", 1024)
self.image_mean = config.get("image_mean", [0.0, 0.0, 0.0])
self.image_std = config.get("image_std", [1.0, 1.0, 1.0])
self.model_type = config.get("model_type", "birefnet")
self.config = config.copy()
model_class = BG_REMOVAL_MODELS.get(self.model_type)
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image):
comfy.model_management.load_model_gpu(self.patcher)
H, W = image.shape[1], image.shape[2]
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False)
out = self.model(pixel_values=pixel_values)
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
if mask.ndim == 3:
mask = mask.unsqueeze(0)
if mask.shape[1] != 1:
mask = mask.movedim(-1, 1)
return mask
def load_background_removal_model(sd):
if "bb.layers.1.blocks.0.attn.relative_position_index" in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "background_removal"), "birefnet.json")
else:
return None
bg_model = BackgroundRemovalModel(json_config)
m, u = bg_model.load_sd(sd)
if len(m) > 0:
logging.warning("missing background removal: {}".format(m))
u = set(u)
keys = list(sd.keys())
for k in keys:
if k not in u:
sd.pop(k)
return bg_model
def load(ckpt_path):
sd = load_torch_file(ckpt_path)
return load_background_removal_model(sd)

View File

@ -93,7 +93,7 @@ class Hook:
self.hook_scope = hook_scope
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
self.custom_should_register = default_should_register
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
'''Can be overridden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
@property
def strength(self):

View File

@ -1859,6 +1859,23 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
output = torch.zeros_like(x)
s_in = x.new_ones([x.shape[0]])
current_start_frame = 0
# I2V: seed KV cache with the initial image latent before the denoising loop
initial_latent = transformer_options.get("ar_config", {}).get("initial_latent", None)
if initial_latent is not None:
initial_latent = inner_model.process_latent_in(initial_latent).to(device=device, dtype=model_dtype)
n_init = initial_latent.shape[2]
output[:, :, :n_init] = initial_latent
ar_state = {"start_frame": 0, "kv_caches": kv_caches, "crossattn_caches": crossattn_caches}
transformer_options["ar_state"] = ar_state
zero_sigma = sigmas.new_zeros([1])
_ = model(initial_latent, zero_sigma * s_in, **extra_args)
current_start_frame = n_init
remaining = lat_t - n_init
num_blocks = -(-remaining // num_frame_per_block)
num_sigma_steps = len(sigmas) - 1
total_real_steps = num_blocks * num_sigma_steps
step_count = 0

View File

@ -140,7 +140,7 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502
# according to the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose:
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')

View File

@ -561,7 +561,8 @@ class SAM3Model(nn.Module):
return high_res_masks
def forward_video(self, images, initial_masks, pbar=None, text_prompts=None,
new_det_thresh=0.5, max_objects=0, detect_interval=1):
new_det_thresh=0.5, max_objects=0, detect_interval=1,
target_device=None, target_dtype=None):
"""Track video with optional per-frame text-prompted detection."""
bb = self.detector.backbone["vision_backbone"]
@ -589,8 +590,10 @@ class SAM3Model(nn.Module):
return self.tracker.track_video_with_detection(
backbone_fn, images, initial_masks, detect_fn,
new_det_thresh=new_det_thresh, max_objects=max_objects,
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar)
detect_interval=detect_interval, backbone_obj=bb, pbar=pbar,
target_device=target_device, target_dtype=target_dtype)
# SAM3 (non-multiplex) — no detection support, requires initial masks
if initial_masks is None:
raise ValueError("SAM3 (non-multiplex) requires initial_mask for video tracking")
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb)
return self.tracker.track_video(backbone_fn, images, initial_masks, pbar=pbar, backbone_obj=bb,
target_device=target_device, target_dtype=target_dtype)

View File

@ -200,8 +200,13 @@ def pack_masks(masks):
def unpack_masks(packed):
"""Unpack bit-packed [*, H, W//8] uint8 to bool [*, H, W*8]."""
shifts = torch.arange(8, device=packed.device)
return ((packed.unsqueeze(-1) >> shifts) & 1).view(*packed.shape[:-1], -1).bool()
bits = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], dtype=torch.uint8, device=packed.device)
return (packed.unsqueeze(-1) & bits).bool().view(*packed.shape[:-1], -1)
def _prep_frame(images, idx, device, dt, size):
"""Slice CPU full-res frames, transfer to GPU in target dtype, and resize to (size, size)."""
return comfy.utils.common_upscale(images[idx].to(device=device, dtype=dt), size, size, "bicubic", crop="disabled")
def _compute_backbone(backbone_fn, frame, frame_idx=None):
@ -1078,16 +1083,19 @@ class SAM3Tracker(nn.Module):
# SAM3: drop last FPN level
return vision_feats[:-1], vision_pos[:-1], feat_sizes[:-1]
def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None):
def _track_single_object(self, backbone_fn, images, initial_mask, pbar=None,
target_device=None, target_dtype=None):
"""Track one object, computing backbone per frame to save VRAM."""
N = images.shape[0]
device, dt = images.device, images.dtype
device = target_device if target_device is not None else images.device
dt = target_dtype if target_dtype is not None else images.dtype
size = self.image_size
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
all_masks = []
for frame_idx in tqdm(range(N), desc="tracking"):
vision_feats, vision_pos, feat_sizes = self._compute_backbone_frame(
backbone_fn, images[frame_idx:frame_idx + 1], frame_idx=frame_idx)
backbone_fn, _prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), frame_idx=frame_idx)
mask_input = None
if frame_idx == 0:
mask_input = F.interpolate(initial_mask.to(device=device, dtype=dt),
@ -1114,12 +1122,13 @@ class SAM3Tracker(nn.Module):
return torch.cat(all_masks, dim=0) # [N, 1, H, W]
def track_video(self, backbone_fn, images, initial_masks, pbar=None, **kwargs):
def track_video(self, backbone_fn, images, initial_masks, pbar=None,
target_device=None, target_dtype=None, **kwargs):
"""Track one or more objects across video frames.
Args:
backbone_fn: callable that returns (sam2_features, sam2_positions, trunk_out) for a frame
images: [N, 3, 1008, 1008] video frames
images: [N, 3, H, W] CPU full-res video frames (resized per-frame to self.image_size)
initial_masks: [N_obj, 1, H, W] binary masks for first frame (one per object)
pbar: optional progress bar
@ -1130,7 +1139,8 @@ class SAM3Tracker(nn.Module):
per_object = []
for obj_idx in range(N_obj):
obj_masks = self._track_single_object(
backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar)
backbone_fn, images, initial_masks[obj_idx:obj_idx + 1], pbar=pbar,
target_device=target_device, target_dtype=target_dtype)
per_object.append(obj_masks)
return torch.cat(per_object, dim=1) # [N, N_obj, H, W]
@ -1632,11 +1642,18 @@ class SAM31Tracker(nn.Module):
return det_scores[new_dets].tolist() if det_scores is not None else [0.0] * new_dets.sum().item()
return []
INTERNAL_MAX_OBJECTS = 64 # Hard ceiling on accumulated tracks; max_objects=0 or any value above this is clamped here.
def track_video_with_detection(self, backbone_fn, images, initial_masks, detect_fn=None,
new_det_thresh=0.5, max_objects=0, detect_interval=1,
backbone_obj=None, pbar=None):
backbone_obj=None, pbar=None, target_device=None, target_dtype=None):
"""Track with optional per-frame detection. Returns [N, max_N_obj, H, W] mask logits."""
N, device, dt = images.shape[0], images.device, images.dtype
if max_objects <= 0 or max_objects > self.INTERNAL_MAX_OBJECTS:
max_objects = self.INTERNAL_MAX_OBJECTS
N = images.shape[0]
device = target_device if target_device is not None else images.device
dt = target_dtype if target_dtype is not None else images.dtype
size = self.image_size
output_dict = {"cond_frame_outputs": {}, "non_cond_frame_outputs": {}}
all_masks = []
idev = comfy.model_management.intermediate_device()
@ -1656,7 +1673,7 @@ class SAM31Tracker(nn.Module):
prefetch = True
except RuntimeError:
pass
cur_bb = self._compute_backbone_frame(backbone_fn, images[0:1], frame_idx=0)
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(0, 1), device, dt, size), frame_idx=0)
for frame_idx in tqdm(range(N), desc="tracking"):
vision_feats, vision_pos, feat_sizes, high_res_prop, trunk_out = cur_bb
@ -1666,7 +1683,7 @@ class SAM31Tracker(nn.Module):
backbone_stream.wait_stream(torch.cuda.current_stream(device))
with torch.cuda.stream(backbone_stream):
next_bb = self._compute_backbone_frame(
backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
# Per-frame detection with NMS (skip if no detect_fn, or interval/max not met)
det_masks = torch.empty(0, device=device)
@ -1687,7 +1704,7 @@ class SAM31Tracker(nn.Module):
current_out = self._condition_with_masks(
initial_masks.to(device=device, dtype=dt), frame_idx, vision_feats, vision_pos,
feat_sizes, high_res_prop, output_dict, N, mux_state, backbone_obj,
images[frame_idx:frame_idx + 1], trunk_out)
_prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out)
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
obj_scores = [1.0] * mux_state.total_valid_entries
if keep_alive is not None:
@ -1702,7 +1719,7 @@ class SAM31Tracker(nn.Module):
current_out = self._condition_with_masks(
det_masks, frame_idx, vision_feats, vision_pos, feat_sizes, high_res_prop,
output_dict, N, mux_state, backbone_obj,
images[frame_idx:frame_idx + 1], trunk_out, threshold=0.0)
_prep_frame(images, slice(frame_idx, frame_idx + 1), device, dt, size), trunk_out, threshold=0.0)
last_occluded = torch.full((mux_state.total_valid_entries,), -1, device=device, dtype=torch.long)
obj_scores = det_scores[:mux_state.total_valid_entries].tolist()
if keep_alive is not None:
@ -1718,7 +1735,7 @@ class SAM31Tracker(nn.Module):
torch.cuda.current_stream(device).wait_stream(backbone_stream)
cur_bb = next_bb
else:
cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
continue
else:
N_obj = mux_state.total_valid_entries
@ -1768,7 +1785,7 @@ class SAM31Tracker(nn.Module):
torch.cuda.current_stream(device).wait_stream(backbone_stream)
cur_bb = next_bb
else:
cur_bb = self._compute_backbone_frame(backbone_fn, images[frame_idx + 1:frame_idx + 2], frame_idx=frame_idx + 1)
cur_bb = self._compute_backbone_frame(backbone_fn, _prep_frame(images, slice(frame_idx + 1, frame_idx + 2), device, dt, size), frame_idx=frame_idx + 1)
if not all_masks or all(m is None for m in all_masks):
return {"packed_masks": None, "n_frames": N, "scores": []}

View File

@ -562,6 +562,25 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
running_mean = self.running_mean.to(device=input.device, dtype=weight.dtype) if self.running_mean is not None else None
running_var = self.running_var.to(device=input.device, dtype=weight.dtype) if self.running_var is not None else None
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, self.training, self.momentum, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
def reset_parameters(self):
return None
@ -749,6 +768,9 @@ class manual_cast(disable_weight_init):
class Conv3d(disable_weight_init.Conv3d):
comfy_cast_weights = True
class BatchNorm2d(disable_weight_init.BatchNorm2d):
comfy_cast_weights = True
class GroupNorm(disable_weight_init.GroupNorm):
comfy_cast_weights = True

View File

@ -17,6 +17,7 @@ if TYPE_CHECKING:
from spandrel import ImageModelDescriptor
from comfy.clip_vision import ClipVisionModel
from comfy.clip_vision import Output as ClipVisionOutput_
from comfy.bg_removal_model import BackgroundRemovalModel
from comfy.controlnet import ControlNet
from comfy.hooks import HookGroup, HookKeyframeGroup
from comfy.model_patcher import ModelPatcher
@ -614,6 +615,11 @@ class Model(ComfyTypeIO):
if TYPE_CHECKING:
Type = ModelPatcher
@comfytype(io_type="BACKGROUND_REMOVAL")
class BackgroundRemoval(ComfyTypeIO):
if TYPE_CHECKING:
Type = BackgroundRemovalModel
@comfytype(io_type="CLIP_VISION")
class ClipVision(ComfyTypeIO):
if TYPE_CHECKING:
@ -2257,6 +2263,7 @@ __all__ = [
"ModelPatch",
"ClipVision",
"ClipVisionOutput",
"BackgroundRemoval",
"AudioEncoder",
"AudioEncoderOutput",
"StyleModel",

View File

@ -1271,7 +1271,7 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
)
def _seedance2_text_inputs(resolutions: list[str]):
def _seedance2_text_inputs(resolutions: list[str], default_ratio: str = "16:9"):
return [
IO.String.Input(
"prompt",
@ -1287,6 +1287,7 @@ def _seedance2_text_inputs(resolutions: list[str]):
IO.Combo.Input(
"ratio",
options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"],
default=default_ratio,
tooltip="Aspect ratio of the output video.",
),
IO.Int.Input(
@ -1420,8 +1421,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
IO.DynamicCombo.Option(
"Seedance 2.0",
_seedance2_text_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option(
"Seedance 2.0 Fast",
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
),
],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
),
@ -1588,9 +1595,9 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
def _seedance2_reference_inputs(resolutions: list[str]):
def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16:9"):
return [
*_seedance2_text_inputs(resolutions),
*_seedance2_text_inputs(resolutions, default_ratio=default_ratio),
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplateNames(
@ -1668,8 +1675,14 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs(["480p", "720p", "1080p"])),
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs(["480p", "720p"])),
IO.DynamicCombo.Option(
"Seedance 2.0",
_seedance2_reference_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option(
"Seedance 2.0 Fast",
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
),
],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
),

View File

@ -83,13 +83,16 @@ class GeminiImageModel(str, Enum):
async def create_image_parts(
cls: type[IO.ComfyNode],
images: Input.Image,
images: Input.Image | list[Input.Image],
image_limit: int = 0,
) -> list[GeminiPart]:
image_parts: list[GeminiPart] = []
if image_limit < 0:
raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.")
total_images = get_number_of_images(images)
# Accept either a single (possibly-batched) tensor or a list of them; share URL budget across all.
images_list: list[Input.Image] = images if isinstance(images, list) else [images]
total_images = sum(get_number_of_images(img) for img in images_list)
if total_images <= 0:
raise ValueError("No images provided to create_image_parts; at least one image is required.")
@ -98,10 +101,18 @@ async def create_image_parts(
# Number of images we'll send as URLs (fileData)
num_url_images = min(effective_max, 10) # Vertex API max number of image links
upload_kwargs: dict = {"wait_label": "Uploading reference images"}
if effective_max > num_url_images:
# Split path (e.g. 11+ images): suppress per-image counter to avoid a confusing dual-fraction label.
upload_kwargs = {
"wait_label": f"Uploading reference images ({num_url_images}+)",
"show_batch_index": False,
}
reference_images_urls = await upload_images_to_comfyapi(
cls,
images,
images_list,
max_images=num_url_images,
**upload_kwargs,
)
for reference_image_url in reference_images_urls:
image_parts.append(
@ -112,15 +123,22 @@ async def create_image_parts(
)
)
)
for idx in range(num_url_images, effective_max):
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png,
data=tensor_to_base64_string(images[idx]),
if effective_max > num_url_images:
flat: list[torch.Tensor] = []
for tensor in images_list:
if len(tensor.shape) == 4:
flat.extend(tensor[i] for i in range(tensor.shape[0]))
else:
flat.append(tensor)
for idx in range(num_url_images, effective_max):
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png,
data=tensor_to_base64_string(flat[idx]),
)
)
)
)
return image_parts
@ -891,10 +909,6 @@ class GeminiNanoBanana2(IO.ComfyNode):
"9:16",
"16:9",
"21:9",
# "1:4",
# "4:1",
# "8:1",
# "1:8",
],
default="auto",
tooltip="If set to 'auto', matches your input image's aspect ratio; "
@ -902,12 +916,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
),
IO.Combo.Input(
"resolution",
options=[
# "512px",
"1K",
"2K",
"4K",
],
options=["1K", "2K", "4K"],
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
),
IO.Combo.Input(
@ -956,6 +965,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
],
is_api_node=True,
price_badge=GEMINI_IMAGE_2_PRICE_BADGE,
is_deprecated=True,
)
@classmethod
@ -1016,6 +1026,197 @@ class GeminiNanoBanana2(IO.ComfyNode):
)
def _nano_banana_2_v2_model_inputs():
return [
IO.Combo.Input(
"aspect_ratio",
options=[
"auto",
"1:1",
"2:3",
"3:2",
"3:4",
"4:3",
"4:5",
"5:4",
"9:16",
"16:9",
"21:9",
"1:4",
"4:1",
"8:1",
"1:8",
],
default="auto",
tooltip="If set to 'auto', matches your input image's aspect ratio; "
"if no image is provided, a 16:9 square is usually generated.",
),
IO.Combo.Input(
"resolution",
options=["1K", "2K", "4K"],
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
),
IO.Combo.Input(
"thinking_level",
options=["MINIMAL", "HIGH"],
),
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, 15)],
min=0,
),
tooltip="Optional reference image(s). Up to 14 images total.",
),
IO.Custom("GEMINI_INPUT_FILES").Input(
"files",
optional=True,
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Generate Content Input Files node.",
),
]
class GeminiNanoBanana2V2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GeminiNanoBanana2V2",
display_name="Nano Banana 2",
category="api node/image/Gemini",
description="Generate or edit images synchronously via Google Vertex API.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Text prompt describing the image to generate or the edits to apply. "
"Include any constraints, styles, or details the model should follow.",
default="",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"Nano Banana 2 (Gemini 3.1 Flash Image)",
_nano_banana_2_v2_model_inputs(),
),
],
),
IO.Int.Input(
"seed",
default=42,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide "
"the same response for repeated requests. Deterministic output isn't guaranteed. "
"Also, changing the model or parameter settings, such as the temperature, "
"can cause variations in the response even when you use the same seed value. "
"By default, a random seed value is used.",
),
IO.Combo.Input(
"response_modalities",
options=["IMAGE", "IMAGE+TEXT"],
advanced=True,
),
IO.String.Input(
"system_prompt",
multiline=True,
default=GEMINI_IMAGE_SYS_PROMPT,
optional=True,
tooltip="Foundational instructions that dictate an AI's behavior.",
advanced=True,
),
],
outputs=[
IO.Image.Output(),
IO.String.Output(),
IO.Image.Output(
display_name="thought_image",
tooltip="First image from the model's thinking process. "
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.resolution"]),
expr="""
(
$r := $lookup(widgets, "model.resolution");
$prices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
response_modalities: str,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
model_choice = model["model"]
if model_choice == "Nano Banana 2 (Gemini 3.1 Flash Image)":
model_id = "gemini-3.1-flash-image-preview"
else:
model_id = model_choice
images = model.get("images") or {}
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
if images:
image_tensors: list[Input.Image] = [t for t in images.values() if t is not None]
if image_tensors:
if sum(get_number_of_images(t) for t in image_tensors) > 14:
raise ValueError("The current maximum number of supported images is 14.")
parts.extend(await create_image_parts(cls, image_tensors))
files = model.get("files")
if files is not None:
parts.extend(files)
image_config = GeminiImageConfig(imageSize=model["resolution"])
if model["aspect_ratio"] != "auto":
image_config.aspectRatio = model["aspect_ratio"]
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
ApiEndpoint(path=f"/proxy/vertexai/gemini/{model_id}", method="POST"),
data=GeminiImageGenerateContentRequest(
contents=[
GeminiContent(role=GeminiRole.user, parts=parts),
],
generationConfig=GeminiImageGenerationConfig(
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=image_config,
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
),
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
return IO.NodeOutput(
await get_image_from_response(response),
get_text_from_response(response),
await get_image_from_response(response, thought=True),
)
class GeminiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -1024,6 +1225,7 @@ class GeminiExtension(ComfyExtension):
GeminiImage,
GeminiImage2,
GeminiNanoBanana2,
GeminiNanoBanana2V2,
GeminiInputFiles,
]

View File

@ -2787,11 +2787,15 @@ class MotionControl(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["mode"]),
depends_on=IO.PriceBadgeDepends(widgets=["mode", "model"]),
expr="""
(
$prices := {"std": 0.07, "pro": 0.112};
{"type":"usd","usd": $lookup($prices, widgets.mode), "format":{"suffix":"/second"}}
$prices := {
"kling-v3": {"std": 0.126, "pro": 0.168},
"kling-v2-6": {"std": 0.07, "pro": 0.112}
};
$modelPrices := $lookup($prices, widgets.model);
{"type":"usd","usd": $lookup($modelPrices, widgets.mode), "format":{"suffix":"/second"}}
)
""",
),

View File

@ -488,10 +488,30 @@ async def _diagnose_connectivity() -> dict[str, bool]:
"api_accessible": False,
}
timeout = aiohttp.ClientTimeout(total=5.0)
# Probe Google and Baidu in parallel: Google is blocked by the GFW in mainland China, so a Baidu probe is required
# to correctly detect that Chinese users with working internet do have working internet.
internet_probe_urls = ("https://www.google.com", "https://www.baidu.com")
async with aiohttp.ClientSession(timeout=timeout) as session:
with contextlib.suppress(ClientError, OSError):
async with session.get("https://www.google.com") as resp:
results["internet_accessible"] = resp.status < 500
async def _probe(url: str) -> bool:
try:
async with session.get(url) as resp:
return resp.status < 500
except (ClientError, OSError, asyncio.TimeoutError):
return False
probe_tasks = [asyncio.create_task(_probe(u)) for u in internet_probe_urls]
try:
for fut in asyncio.as_completed(probe_tasks):
if await fut:
results["internet_accessible"] = True
break
finally:
for t in probe_tasks:
if not t.done():
t.cancel()
await asyncio.gather(*probe_tasks, return_exceptions=True)
if not results["internet_accessible"]:
return results

View File

@ -92,7 +92,7 @@ class SamplerEulerCFGpp(io.ComfyNode):
return io.Schema(
node_id="SamplerEulerCFGpp",
display_name="SamplerEulerCFG++",
category="_for_testing", # "sampling/custom_sampling/samplers"
category="experimental", # "sampling/custom_sampling/samplers"
inputs=[
io.Combo.Input("version", options=["regular", "alternative"], advanced=True),
],

View File

@ -2,6 +2,7 @@
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
- SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop
- ARVideoI2V: image-to-video conditioning for AR models (seeds KV cache with start image)
"""
import torch
@ -9,6 +10,7 @@ from typing_extensions import override
import comfy.model_management
import comfy.samplers
import comfy.utils
from comfy_api.latest import ComfyExtension, io
@ -71,12 +73,62 @@ class SamplerARVideo(io.ComfyNode):
return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options))
class ARVideoI2V(io.ComfyNode):
"""Image-to-video setup for AR video models (Causal Forcing, Self-Forcing).
VAE-encodes the start image and stores it in the model's transformer_options
so that sample_ar_video can seed the KV cache before denoising.
Uses the same T2V model checkpoint -- no separate I2V architecture needed.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ARVideoI2V",
category="conditioning/video_models",
inputs=[
io.Model.Input("model"),
io.Vae.Input("vae"),
io.Image.Input("start_image"),
io.Int.Input("width", default=832, min=16, max=8192, step=16),
io.Int.Input("height", default=480, min=16, max=8192, step=16),
io.Int.Input("length", default=81, min=1, max=1024, step=4),
io.Int.Input("batch_size", default=1, min=1, max=64),
],
outputs=[
io.Model.Output(display_name="MODEL"),
io.Latent.Output(display_name="LATENT"),
],
)
@classmethod
def execute(cls, model, vae, start_image, width, height, length, batch_size) -> io.NodeOutput:
start_image = comfy.utils.common_upscale(
start_image[:1].movedim(-1, 1), width, height, "bilinear", "center"
).movedim(1, -1)
initial_latent = vae.encode(start_image[:, :, :, :3])
m = model.clone()
to = m.model_options.setdefault("transformer_options", {})
ar_cfg = to.setdefault("ar_config", {})
ar_cfg["initial_latent"] = initial_latent
lat_t = ((length - 1) // 4) + 1
latent = torch.zeros(
[batch_size, 16, lat_t, height // 8, width // 8],
device=comfy.model_management.intermediate_device(),
)
return io.NodeOutput(m, {"samples": latent})
class ARVideoExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyARVideoLatent,
SamplerARVideo,
ARVideoI2V,
]

View File

@ -25,7 +25,7 @@ class UNetSelfAttentionMultiply(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="UNetSelfAttentionMultiply",
category="_for_testing/attention_experiments",
category="experimental/attention_experiments",
inputs=[
io.Model.Input("model"),
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
@ -48,7 +48,7 @@ class UNetCrossAttentionMultiply(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="UNetCrossAttentionMultiply",
category="_for_testing/attention_experiments",
category="experimental/attention_experiments",
inputs=[
io.Model.Input("model"),
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
@ -72,7 +72,7 @@ class CLIPAttentionMultiply(io.ComfyNode):
return io.Schema(
node_id="CLIPAttentionMultiply",
search_aliases=["clip attention scale", "text encoder attention"],
category="_for_testing/attention_experiments",
category="experimental/attention_experiments",
inputs=[
io.Clip.Input("clip"),
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
@ -106,7 +106,7 @@ class UNetTemporalAttentionMultiply(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="UNetTemporalAttentionMultiply",
category="_for_testing/attention_experiments",
category="experimental/attention_experiments",
inputs=[
io.Model.Input("model"),
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),

View File

@ -10,6 +10,7 @@ class AudioEncoderLoader(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="AudioEncoderLoader",
display_name="Load Audio Encoder",
category="loaders",
inputs=[
io.Combo.Input(

View File

@ -0,0 +1,60 @@
import folder_paths
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy.bg_removal_model import load
class LoadBackgroundRemovalModel(IO.ComfyNode):
@classmethod
def define_schema(cls):
files = folder_paths.get_filename_list("background_removal")
return IO.Schema(
node_id="LoadBackgroundRemovalModel",
display_name="Load Background Removal Model",
category="loaders",
inputs=[
IO.Combo.Input("bg_removal_name", options=sorted(files), tooltip="The model used to remove backgrounds from images"),
],
outputs=[
IO.BackgroundRemoval.Output("bg_model")
]
)
@classmethod
def execute(cls, bg_removal_name):
path = folder_paths.get_full_path_or_raise("background_removal", bg_removal_name)
bg = load(path)
if bg is None:
raise RuntimeError("ERROR: background model file is invalid and does not contain a valid background removal model.")
return IO.NodeOutput(bg)
class RemoveBackground(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RemoveBackground",
display_name="Remove Background",
category="image/background removal",
inputs=[
IO.Image.Input("image", tooltip="Input image to remove the background from"),
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask")
],
outputs=[
IO.Mask.Output("mask", tooltip="Generated foreground mask")
]
)
@classmethod
def execute(cls, image, bg_removal_model):
mask = bg_removal_model.encode_image(image)
return IO.NodeOutput(mask)
class BackgroundRemovalExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
LoadBackgroundRemovalModel,
RemoveBackground
]
async def comfy_entrypoint() -> BackgroundRemovalExtension:
return BackgroundRemovalExtension()

View File

@ -153,7 +153,7 @@ class WanCameraEmbedding(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="WanCameraEmbedding",
category="camera",
category="conditioning/video_models",
inputs=[
io.Combo.Input(
"camera_pose",

View File

@ -203,7 +203,7 @@ class JoinImageWithAlpha(io.ComfyNode):
@classmethod
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
batch_size = max(len(image), len(alpha))
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
alpha = 1.0 - resize_mask(alpha.to(image), image.shape[1:])
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
image = comfy.utils.repeat_to_batch_size(image, batch_size)
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))

View File

@ -8,7 +8,7 @@ class CLIPTextEncodeControlnet(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="CLIPTextEncodeControlnet",
category="_for_testing/conditioning",
category="experimental/conditioning",
inputs=[
io.Clip.Input("clip"),
io.Conditioning.Input("conditioning"),
@ -35,7 +35,7 @@ class T5TokenizerOptions(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="T5TokenizerOptions",
category="_for_testing/conditioning",
category="experimental/conditioning",
inputs=[
io.Clip.Input("clip"),
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True),

View File

@ -10,7 +10,7 @@ class ContextWindowsManualNode(io.ComfyNode):
return io.Schema(
node_id="ContextWindowsManual",
display_name="Context Windows (Manual)",
category="context",
category="model_patches",
description="Manually set context windows.",
inputs=[
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),

View File

@ -984,7 +984,7 @@ class AddNoise(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="AddNoise",
category="_for_testing/custom_sampling/noise",
category="experimental/custom_sampling/noise",
is_experimental=True,
inputs=[
io.Model.Input("model"),
@ -1034,7 +1034,7 @@ class ManualSigmas(io.ComfyNode):
return io.Schema(
node_id="ManualSigmas",
search_aliases=["custom noise schedule", "define sigmas"],
category="_for_testing/custom_sampling",
category="experimental/custom_sampling",
is_experimental=True,
inputs=[
io.String.Input("sigmas", default="1, 0.5", multiline=False)

View File

@ -13,7 +13,7 @@ class DifferentialDiffusion(io.ComfyNode):
node_id="DifferentialDiffusion",
search_aliases=["inpaint gradient", "variable denoise strength"],
display_name="Differential Diffusion",
category="_for_testing",
category="experimental",
inputs=[
io.Model.Input("model"),
io.Float.Input(

View File

@ -102,7 +102,7 @@ class FluxDisableGuidance(io.ComfyNode):
append = execute # TODO: remove
PREFERED_KONTEXT_RESOLUTIONS = [
PREFERRED_KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
@ -143,7 +143,7 @@ class FluxKontextImageScale(io.ComfyNode):
width = image.shape[2]
height = image.shape[1]
aspect_ratio = width / height
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS)
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
return io.NodeOutput(image)

View File

@ -60,7 +60,7 @@ class FreSca(io.ComfyNode):
node_id="FreSca",
search_aliases=["frequency guidance"],
display_name="FreSca",
category="_for_testing",
category="experimental",
description="Applies frequency-dependent scaling to the guidance",
inputs=[
io.Model.Input("model"),

View File

@ -131,6 +131,8 @@ class HunyuanVideo15SuperResolution(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="HunyuanVideo15SuperResolution",
display_name="Hunyuan Video 1.5 Super Resolution",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
@ -381,6 +383,8 @@ class HunyuanRefinerLatent(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="HunyuanRefinerLatent",
display_name="Hunyuan Latent Refiner",
category="conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),

View File

@ -40,7 +40,7 @@ class Hunyuan3Dv2Conditioning(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="Hunyuan3Dv2Conditioning",
category="conditioning/video_models",
category="conditioning/3d_models",
inputs=[
IO.ClipVisionOutput.Input("clip_vision_output"),
],
@ -65,7 +65,7 @@ class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="Hunyuan3Dv2ConditioningMultiView",
category="conditioning/video_models",
category="conditioning/3d_models",
inputs=[
IO.ClipVisionOutput.Input("front", optional=True),
IO.ClipVisionOutput.Input("left", optional=True),
@ -424,6 +424,7 @@ class VoxelToMeshBasic(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="VoxelToMeshBasic",
display_name="Voxel to Mesh (Basic)",
category="3d",
inputs=[
IO.Voxel.Input("voxel"),
@ -453,6 +454,7 @@ class VoxelToMesh(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="VoxelToMesh",
display_name="Voxel to Mesh",
category="3d",
inputs=[
IO.Voxel.Input("voxel"),

View File

@ -102,6 +102,7 @@ class HypernetworkLoader(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="HypernetworkLoader",
display_name="Load Hypernetwork",
category="loaders",
inputs=[
IO.Model.Input("model"),

View File

@ -91,7 +91,7 @@ class LoraSave(io.ComfyNode):
node_id="LoraSave",
search_aliases=["export lora"],
display_name="Extract and Save Lora",
category="_for_testing",
category="experimental",
inputs=[
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
io.Int.Input("rank", default=8, min=1, max=4096, step=1, advanced=True),

View File

@ -106,12 +106,12 @@ class LTXVImgToVideoInplace(io.ComfyNode):
if bypass:
return (latent,)
samples = latent["samples"]
samples = latent["samples"].clone()
_, height_scale_factor, width_scale_factor = (
vae.downscale_index_formula
)
batch, _, latent_frames, latent_height, latent_width = samples.shape
_, _, _, latent_height, latent_width = samples.shape
width = latent_width * width_scale_factor
height = latent_height * height_scale_factor
@ -124,11 +124,7 @@ class LTXVImgToVideoInplace(io.ComfyNode):
samples[:, :, :t.shape[2]] = t
conditioning_latent_frames_mask = torch.ones(
(batch, 1, latent_frames, 1, 1),
dtype=torch.float32,
device=samples.device,
)
conditioning_latent_frames_mask = get_noise_mask(latent)
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
@ -236,7 +232,7 @@ class LTXVAddGuide(io.ComfyNode):
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
return encode_pixels, t
@ -594,7 +590,8 @@ class LTXVPreprocess(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LTXVPreprocess",
category="image",
display_name="LTXV Preprocess",
category="video/preprocessors",
inputs=[
io.Image.Input("image"),
io.Int.Input(

View File

@ -11,7 +11,7 @@ class Mahiro(io.ComfyNode):
return io.Schema(
node_id="Mahiro",
display_name="Positive-Biased Guidance",
category="_for_testing",
category="experimental",
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
inputs=[
io.Model.Input("model"),

View File

@ -40,10 +40,21 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
inverse_mask = torch.ones_like(mask) - mask
source_portion = mask * source[..., :visible_height, :visible_width]
destination_portion = inverse_mask * destination[..., top:bottom, left:right]
source_rgb = source[:, :3, :visible_height, :visible_width]
dest_slice = destination[..., top:bottom, left:right]
if destination.shape[1] == 4:
if torch.max(dest_slice) == 0:
destination[:, :3, top:bottom, left:right] = source_rgb
destination[:, 3:4, top:bottom, left:right] = mask
else:
destination[:, :3, top:bottom, left:right] = (mask * source_rgb) + (inverse_mask * dest_slice[:, :3])
destination[:, 3:4, top:bottom, left:right] = torch.max(mask, dest_slice[:, 3:4])
else:
source_portion = mask * source_rgb
destination_portion = inverse_mask * dest_slice
destination[..., top:bottom, left:right] = source_portion + destination_portion
destination[..., top:bottom, left:right] = source_portion + destination_portion
return destination
class LatentCompositeMasked(IO.ComfyNode):
@ -84,18 +95,23 @@ class ImageCompositeMasked(IO.ComfyNode):
display_name="Image Composite Masked",
category="image",
inputs=[
IO.Image.Input("destination"),
IO.Image.Input("source"),
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Boolean.Input("resize_source", default=False),
IO.Image.Input("destination", optional=True),
IO.Mask.Input("mask", optional=True),
],
outputs=[IO.Image.Output()],
)
@classmethod
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
def execute(cls, source, x, y, resize_source, destination = None, mask = None) -> IO.NodeOutput:
if destination is None: # transparent rgba
B, H, W, C = source.shape
destination = torch.zeros((B, H, W, 4), dtype=source.dtype, device=source.device)
if C == 3:
source = torch.nn.functional.pad(source, (0, 1), value=1.0)
destination, source = node_helpers.image_alpha_fix(destination, source)
destination = destination.clone().movedim(-1, 1)
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
@ -381,7 +397,6 @@ class GrowMask(IO.ComfyNode):
expand_mask = execute # TODO: remove
class ThresholdMask(IO.ComfyNode):
@classmethod
def define_schema(cls):

View File

@ -70,7 +70,7 @@ class MathExpressionNode(io.ComfyNode):
return io.Schema(
node_id="ComfyMathExpression",
display_name="Math Expression",
category="math",
category="logic",
search_aliases=[
"expression", "formula", "calculate", "calculator",
"eval", "math",

View File

@ -21,7 +21,7 @@ class NumberConvertNode(io.ComfyNode):
return io.Schema(
node_id="ComfyNumberConvert",
display_name="Number Convert",
category="math",
category="utils",
search_aliases=[
"int to float", "float to int", "number convert",
"int2float", "float2int", "cast", "parse number",

View File

@ -24,8 +24,8 @@ class PerpNeg(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="PerpNeg",
display_name="Perp-Neg (DEPRECATED by PerpNegGuider)",
category="_for_testing",
display_name="Perp-Neg (DEPRECATED by Perp-Neg Guider)",
category="experimental",
inputs=[
io.Model.Input("model"),
io.Conditioning.Input("empty_conditioning"),
@ -127,7 +127,8 @@ class PerpNegGuider(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="PerpNegGuider",
category="_for_testing",
display_name="Perp-Neg Guider",
category="experimental",
inputs=[
io.Model.Input("model"),
io.Conditioning.Input("positive"),

View File

@ -123,7 +123,7 @@ class PhotoMakerLoader(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="PhotoMakerLoader",
category="_for_testing/photomaker",
category="experimental/photomaker",
inputs=[
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
],
@ -149,7 +149,7 @@ class PhotoMakerEncode(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="PhotoMakerEncode",
category="_for_testing/photomaker",
category="experimental/photomaker",
inputs=[
io.Photomaker.Input("photomaker"),
io.Image.Input("image"),

View File

@ -116,6 +116,7 @@ class Quantize(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ImageQuantize",
display_name="Quantize Image",
category="image/postprocessing",
inputs=[
io.Image.Input("image"),
@ -181,6 +182,7 @@ class Sharpen(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ImageSharpen",
display_name="Sharpen Image",
category="image/postprocessing",
inputs=[
io.Image.Input("image"),
@ -436,7 +438,7 @@ class ResizeImageMaskNode(io.ComfyNode):
node_id="ResizeImageMaskNode",
display_name="Resize Image/Mask",
description="Resize an image or mask using various scaling methods.",
category="transform",
category="image/transform",
search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"],
inputs=[
io.MatchType.Input("input", template=template),

View File

@ -15,7 +15,7 @@ class RTDETR_detect(io.ComfyNode):
return io.Schema(
node_id="RTDETR_detect",
display_name="RT-DETR Detect",
category="detection/",
category="detection",
search_aliases=["bbox", "bounding box", "object detection", "coco"],
inputs=[
io.Model.Input("model", display_name="model"),
@ -71,7 +71,7 @@ class DrawBBoxes(io.ComfyNode):
return io.Schema(
node_id="DrawBBoxes",
display_name="Draw BBoxes",
category="detection/",
category="detection",
search_aliases=["bbox", "bounding box", "object detection", "rt_detr", "visualize detections", "coco"],
inputs=[
io.Image.Input("image", optional=True),

View File

@ -113,7 +113,7 @@ class SelfAttentionGuidance(io.ComfyNode):
return io.Schema(
node_id="SelfAttentionGuidance",
display_name="Self-Attention Guidance",
category="_for_testing",
category="experimental",
inputs=[
io.Model.Input("model"),
io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01),

View File

@ -93,7 +93,7 @@ class SAM3_Detect(io.ComfyNode):
return io.Schema(
node_id="SAM3_Detect",
display_name="SAM3 Detect",
category="detection/",
category="detection",
search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"],
inputs=[
io.Model.Input("model", display_name="model"),
@ -265,15 +265,15 @@ class SAM3_VideoTrack(io.ComfyNode):
return io.Schema(
node_id="SAM3_VideoTrack",
display_name="SAM3 Video Track",
category="detection/",
category="detection",
search_aliases=["sam3", "video", "track", "propagate"],
inputs=[
io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"),
io.Model.Input("model", display_name="model"),
io.Mask.Input("initial_mask", display_name="initial_mask", optional=True, tooltip="Mask(s) for the first frame to track (one per object)"),
io.Conditioning.Input("conditioning", display_name="conditioning", optional=True, tooltip="Text conditioning for detecting new objects during tracking"),
io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection"),
io.Int.Input("max_objects", display_name="max_objects", default=0, min=0, tooltip="Max tracked objects (0=unlimited). Initial masks count toward this limit."),
io.Float.Input("detection_threshold", display_name="detection_threshold", default=0.5, min=0.0, max=1.0, step=0.01, tooltip="Score threshold for text-prompted detection."),
io.Int.Input("max_objects", display_name="max_objects", default=4, min=0, max=64, tooltip="Max tracked objects. Initial masks count toward this limit. 0 uses the internal cap of 64."),
io.Int.Input("detect_interval", display_name="detect_interval", default=1, min=1, tooltip="Run detection every N frames (1=every frame). Higher values save compute."),
],
outputs=[
@ -290,8 +290,7 @@ class SAM3_VideoTrack(io.ComfyNode):
dtype = model.model.get_dtype()
sam3_model = model.model.diffusion_model
frames = images[..., :3].movedim(-1, 1)
frames_in = comfy.utils.common_upscale(frames, 1008, 1008, "bilinear", crop="disabled").to(device=device, dtype=dtype)
frames_in = images[..., :3].movedim(-1, 1)
init_masks = None
if initial_mask is not None:
@ -308,7 +307,7 @@ class SAM3_VideoTrack(io.ComfyNode):
result = sam3_model.forward_video(
images=frames_in, initial_masks=init_masks, pbar=pbar, text_prompts=text_prompts,
new_det_thresh=detection_threshold, max_objects=max_objects,
detect_interval=detect_interval)
detect_interval=detect_interval, target_device=device, target_dtype=dtype)
result["orig_size"] = (H, W)
return io.NodeOutput(result)
@ -321,7 +320,7 @@ class SAM3_TrackPreview(io.ComfyNode):
return io.Schema(
node_id="SAM3_TrackPreview",
display_name="SAM3 Track Preview",
category="detection/",
category="detection",
inputs=[
SAM3TrackData.Input("track_data", display_name="track_data"),
io.Image.Input("images", display_name="images", optional=True),
@ -449,14 +448,18 @@ class SAM3_TrackPreview(io.ComfyNode):
cx = (bool_masks * grid_x).sum(dim=(-1, -2)) // area
has = area > 1
scores = track_data.get("scores", [])
label_scale = max(3, H // 240) # Scale font with resolutio
size_caps = (area.float().sqrt() / 15).clamp_(min=1).long().tolist() #cap per-object so the number doesn't dwarf small masks
for obj_idx in range(N_obj):
if has[obj_idx]:
_cx, _cy = int(cx[obj_idx]), int(cy[obj_idx])
color = cls.COLORS[obj_idx % len(cls.COLORS)]
SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color)
obj_scale = min(label_scale, size_caps[obj_idx])
score_scale = max(1, obj_scale * 2 // 3)
SAM3_TrackPreview._draw_number_gpu(frame_gpu, obj_idx, _cx, _cy, color, scale=obj_scale)
if obj_idx < len(scores) and scores[obj_idx] < 1.0:
SAM3_TrackPreview._draw_number_gpu(frame_gpu, int(scores[obj_idx] * 100),
_cx, _cy + 5 * 3 + 3, color, scale=2)
_cx, _cy + 5 * obj_scale + 3, color, scale=score_scale)
frame_cpu.copy_(frame_gpu.clamp_(0, 1).mul_(255).byte())
else:
frame_cpu.copy_(frame.clamp_(0, 1).mul_(255).byte())
@ -475,7 +478,7 @@ class SAM3_TrackToMask(io.ComfyNode):
return io.Schema(
node_id="SAM3_TrackToMask",
display_name="SAM3 Track to Mask",
category="detection/",
category="detection",
inputs=[
SAM3TrackData.Input("track_data", display_name="track_data"),
io.String.Input("object_indices", display_name="object_indices", default="",
@ -507,9 +510,10 @@ class SAM3_TrackToMask(io.ComfyNode):
if not indices:
return io.NodeOutput(torch.zeros(N, H, W, device=comfy.model_management.intermediate_device()))
selected = packed[:, indices]
binary = unpack_masks(selected) # [N, len(indices), Hm, Wm] bool
union = binary.any(dim=1, keepdim=True).float()
union_packed = packed[:, indices[0]].clone()
for i in indices[1:]:
union_packed |= packed[:, i]
union = unpack_masks(union_packed).unsqueeze(1).float() # [N, 1, Hm, Wm]
mask_out = F.interpolate(union, size=(H, W), mode="bilinear", align_corners=False)[:, 0]
return io.NodeOutput(mask_out)

View File

@ -119,7 +119,7 @@ class StableCascade_SuperResolutionControlnet(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="StableCascade_SuperResolutionControlnet",
category="_for_testing/stable_cascade",
category="experimental/stable_cascade",
is_experimental=True,
inputs=[
io.Image.Input("image"),

View File

@ -26,7 +26,8 @@ class TextGenerate(io.ComfyNode):
return io.Schema(
node_id="TextGenerate",
category="textgen",
display_name="Generate Text",
category="text",
search_aliases=["LLM", "gemma"],
inputs=[
io.Clip.Input("clip"),
@ -157,6 +158,7 @@ class TextGenerateLTX2Prompt(TextGenerate):
parent_schema = super().define_schema()
return io.Schema(
node_id="TextGenerateLTX2Prompt",
display_name="Generate LTX2 Prompt",
category=parent_schema.category,
inputs=parent_schema.inputs,
outputs=parent_schema.outputs,

View File

@ -10,7 +10,7 @@ class TorchCompileModel(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="TorchCompileModel",
category="_for_testing",
category="experimental",
inputs=[
io.Model.Input("model"),
io.Combo.Input(

View File

@ -1361,7 +1361,7 @@ class SaveLoRA(io.ComfyNode):
node_id="SaveLoRA",
search_aliases=["export lora"],
display_name="Save LoRA Weights",
category="loaders",
category="advanced/model_merging",
is_experimental=True,
is_output_node=True,
inputs=[

View File

@ -15,7 +15,7 @@ class ImageOnlyCheckpointLoader:
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
FUNCTION = "load_checkpoint"
CATEGORY = "loaders/video_models"
CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)

View File

@ -22,7 +22,7 @@ class SaveImageWebsocket:
OUTPUT_NODE = True
CATEGORY = "api/image"
CATEGORY = "image"
def save_images(self, images):
pbar = comfy.utils.ProgressBar(images.shape[0])
@ -42,3 +42,7 @@ class SaveImageWebsocket:
NODE_CLASS_MAPPINGS = {
"SaveImageWebsocket": SaveImageWebsocket,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"SaveImageWebsocket": "Save Image (Websocket)",
}

View File

@ -52,6 +52,8 @@ folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patc
folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions)
folder_names_and_paths["background_removal"] = ([os.path.join(models_dir, "background_removal")], supported_pt_extensions)
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)

View File

@ -330,7 +330,7 @@ class VAEDecodeTiled:
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
CATEGORY = "_for_testing"
CATEGORY = "experimental"
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
if tile_size < overlap * 4:
@ -377,7 +377,7 @@ class VAEEncodeTiled:
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
CATEGORY = "_for_testing"
CATEGORY = "experimental"
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
@ -493,7 +493,7 @@ class SaveLatent:
OUTPUT_NODE = True
CATEGORY = "_for_testing"
CATEGORY = "experimental"
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
@ -538,7 +538,7 @@ class LoadLatent:
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
return {"required": {"latent": [sorted(files), ]}, }
CATEGORY = "_for_testing"
CATEGORY = "experimental"
RETURN_TYPES = ("LATENT", )
FUNCTION = "load"
@ -1443,7 +1443,7 @@ class LatentBlend:
RETURN_TYPES = ("LATENT",)
FUNCTION = "blend"
CATEGORY = "_for_testing"
CATEGORY = "experimental"
def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
@ -2092,6 +2092,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"StyleModelLoader": "Load Style Model",
"CLIPVisionLoader": "Load CLIP Vision",
"UNETLoader": "Load Diffusion Model",
"unCLIPCheckpointLoader": "Load unCLIP Checkpoint",
"GLIGENLoader": "Load GLIGEN Model",
# Conditioning
"CLIPVisionEncode": "CLIP Vision Encode",
"StyleModelApply": "Apply Style Model",
@ -2140,7 +2142,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImageSharpen": "Sharpen Image",
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
"GetImageSize": "Get Image Size",
# _for_testing
# experimental
"VAEDecodeTiled": "VAE Decode (Tiled)",
"VAEEncodeTiled": "VAE Encode (Tiled)",
}
@ -2427,6 +2429,7 @@ async def init_builtin_extra_nodes():
"nodes_number_convert.py",
"nodes_painter.py",
"nodes_curve.py",
"nodes_bg_removal.py",
"nodes_rtdetr.py",
"nodes_frame_interpolation.py",
"nodes_sam3.py",

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,90 @@
"""Tests for NodeReplaceManager registration behavior."""
import importlib
import sys
import types
import pytest
@pytest.fixture
def NodeReplaceManager(monkeypatch):
"""Provide NodeReplaceManager with `nodes` stubbed.
`app.node_replace_manager` does `import nodes` at module level, which pulls in
torch + the full ComfyUI graph. register() doesn't actually need it, so we
stub `nodes` per-test (via monkeypatch so it's torn down) and reload the
module so it picks up the stub instead of any cached real import.
"""
fake_nodes = types.ModuleType("nodes")
fake_nodes.NODE_CLASS_MAPPINGS = {}
monkeypatch.setitem(sys.modules, "nodes", fake_nodes)
monkeypatch.delitem(sys.modules, "app.node_replace_manager", raising=False)
module = importlib.import_module("app.node_replace_manager")
yield module.NodeReplaceManager
# Drop the freshly-imported module so the next test (or a later real import
# of `nodes`) starts from a clean slate.
sys.modules.pop("app.node_replace_manager", None)
class FakeNodeReplace:
"""Lightweight stand-in for comfy_api.latest._io.NodeReplace."""
def __init__(self, new_node_id, old_node_id, old_widget_ids=None,
input_mapping=None, output_mapping=None):
self.new_node_id = new_node_id
self.old_node_id = old_node_id
self.old_widget_ids = old_widget_ids
self.input_mapping = input_mapping
self.output_mapping = output_mapping
def test_register_adds_replacement(NodeReplaceManager):
manager = NodeReplaceManager()
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
assert manager.has_replacement("OldNode")
assert len(manager.get_replacement("OldNode")) == 1
def test_register_allows_multiple_alternatives_for_same_old_node(NodeReplaceManager):
"""Different new_node_ids for the same old_node_id should all be kept."""
manager = NodeReplaceManager()
manager.register(FakeNodeReplace(new_node_id="AltA", old_node_id="OldNode"))
manager.register(FakeNodeReplace(new_node_id="AltB", old_node_id="OldNode"))
replacements = manager.get_replacement("OldNode")
assert len(replacements) == 2
assert {r.new_node_id for r in replacements} == {"AltA", "AltB"}
def test_register_is_idempotent_for_duplicate_pair(NodeReplaceManager):
"""Re-registering the same (old_node_id, new_node_id) should be a no-op."""
manager = NodeReplaceManager()
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
assert len(manager.get_replacement("OldNode")) == 1
def test_register_idempotent_preserves_first_registration(NodeReplaceManager):
"""First registration wins; later duplicates with different mappings are ignored."""
manager = NodeReplaceManager()
first = FakeNodeReplace(
new_node_id="NewNode", old_node_id="OldNode",
input_mapping=[{"new_id": "a", "old_id": "x"}],
)
second = FakeNodeReplace(
new_node_id="NewNode", old_node_id="OldNode",
input_mapping=[{"new_id": "b", "old_id": "y"}],
)
manager.register(first)
manager.register(second)
replacements = manager.get_replacement("OldNode")
assert len(replacements) == 1
assert replacements[0] is first
def test_register_dedupe_does_not_affect_other_old_nodes(NodeReplaceManager):
manager = NodeReplaceManager()
manager.register(FakeNodeReplace(new_node_id="NewA", old_node_id="OldA"))
manager.register(FakeNodeReplace(new_node_id="NewA", old_node_id="OldA"))
manager.register(FakeNodeReplace(new_node_id="NewB", old_node_id="OldB"))
assert len(manager.get_replacement("OldA")) == 1
assert len(manager.get_replacement("OldB")) == 1

View File

@ -21,7 +21,7 @@ class TestAsyncProgressUpdate(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,)
FUNCTION = "execute"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
async def execute(self, value, sleep_seconds):
start = time.time()
@ -51,7 +51,7 @@ class TestSyncProgressUpdate(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,)
FUNCTION = "execute"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
def execute(self, value, sleep_seconds):
start = time.time()

View File

@ -21,7 +21,7 @@ class TestAsyncValidation(ComfyNodeABC):
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
@classmethod
async def VALIDATE_INPUTS(cls, value, threshold):
@ -53,7 +53,7 @@ class TestAsyncError(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,)
FUNCTION = "error_execution"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
async def error_execution(self, value, error_after):
await asyncio.sleep(error_after)
@ -74,7 +74,7 @@ class TestAsyncValidationError(ComfyNodeABC):
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
@classmethod
async def VALIDATE_INPUTS(cls, value, max_value):
@ -105,7 +105,7 @@ class TestAsyncTimeout(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,)
FUNCTION = "timeout_execution"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
async def timeout_execution(self, value, timeout, operation_time):
try:
@ -129,7 +129,7 @@ class TestSyncError(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,)
FUNCTION = "sync_error"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
def sync_error(self, value):
raise RuntimeError("Intentional sync execution error for testing")
@ -150,7 +150,7 @@ class TestAsyncLazyCheck(ComfyNodeABC):
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
async def check_lazy_status(self, condition, input1, input2):
# Simulate async checking (e.g., querying remote service)
@ -184,7 +184,7 @@ class TestDynamicAsyncGeneration(ComfyNodeABC):
RETURN_TYPES = ("IMAGE",)
FUNCTION = "generate_async_workflow"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
g = GraphBuilder()
@ -229,7 +229,7 @@ class TestAsyncResourceUser(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,)
FUNCTION = "use_resource"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
async def use_resource(self, value, resource_id, duration):
# Check if resource is already in use
@ -265,7 +265,7 @@ class TestAsyncBatchProcessing(ComfyNodeABC):
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process_batch"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
async def process_batch(self, images, process_time_per_item, unique_id):
batch_size = images.shape[0]
@ -305,7 +305,7 @@ class TestAsyncConcurrentLimit(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,)
FUNCTION = "limited_execution"
CATEGORY = "_for_testing/async"
CATEGORY = "experimental/async"
async def limited_execution(self, value, duration, node_id):
async with self._semaphore:

View File

@ -409,7 +409,7 @@ class TestSleep(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,)
FUNCTION = "sleep"
CATEGORY = "_for_testing"
CATEGORY = "experimental"
async def sleep(self, value, seconds, unique_id):
pbar = ProgressBar(seconds, node_id=unique_id)
@ -440,7 +440,7 @@ class TestParallelSleep(ComfyNodeABC):
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "parallel_sleep"
CATEGORY = "_for_testing"
CATEGORY = "experimental"
OUTPUT_NODE = True
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
@ -474,7 +474,7 @@ class TestOutputNodeWithSocketOutput:
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "process"
CATEGORY = "_for_testing"
CATEGORY = "experimental"
OUTPUT_NODE = True
def process(self, image, value):