mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-10 01:02:56 +08:00
Merge branch 'master' into blueprints/subgraph-description
This commit is contained in:
commit
3e9654a0e6
@ -1,5 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, TypedDict
|
from typing import TYPE_CHECKING, TypedDict
|
||||||
@ -31,8 +33,22 @@ class NodeReplaceManager:
|
|||||||
self._replacements: dict[str, list[NodeReplace]] = {}
|
self._replacements: dict[str, list[NodeReplace]] = {}
|
||||||
|
|
||||||
def register(self, node_replace: NodeReplace):
|
def register(self, node_replace: NodeReplace):
|
||||||
"""Register a node replacement mapping."""
|
"""Register a node replacement mapping.
|
||||||
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
|
|
||||||
|
Idempotent: if a replacement with the same (old_node_id, new_node_id)
|
||||||
|
is already registered, the duplicate is ignored. This prevents stale
|
||||||
|
entries from accumulating when custom nodes are reloaded in the same
|
||||||
|
process (e.g. via ComfyUI-Manager).
|
||||||
|
"""
|
||||||
|
existing = self._replacements.setdefault(node_replace.old_node_id, [])
|
||||||
|
for entry in existing:
|
||||||
|
if entry.new_node_id == node_replace.new_node_id:
|
||||||
|
logging.debug(
|
||||||
|
"Node replacement %s -> %s already registered, ignoring duplicate.",
|
||||||
|
node_replace.old_node_id, node_replace.new_node_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
existing.append(node_replace)
|
||||||
|
|
||||||
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
|
||||||
"""Get replacements for an old node ID."""
|
"""Get replacements for an old node ID."""
|
||||||
|
|||||||
7
comfy/background_removal/birefnet.json
Normal file
7
comfy/background_removal/birefnet.json
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"model_type": "birefnet",
|
||||||
|
"image_std": [1.0, 1.0, 1.0],
|
||||||
|
"image_mean": [0.0, 0.0, 0.0],
|
||||||
|
"image_size": 1024,
|
||||||
|
"resize_to_original": true
|
||||||
|
}
|
||||||
689
comfy/background_removal/birefnet.py
Normal file
689
comfy/background_removal/birefnet.py
Normal file
@ -0,0 +1,689 @@
|
|||||||
|
import torch
|
||||||
|
import comfy.ops
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
from functools import partial
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torchvision.ops import deform_conv2d
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
|
|
||||||
|
CXT = [3072, 1536, 768, 384][1:][::-1][-3:]
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = qk_scale or head_dim ** -0.5
|
||||||
|
|
||||||
|
self.q = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
|
||||||
|
self.kv = operations.Linear(dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||||
|
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, N, C = x.shape
|
||||||
|
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
|
||||||
|
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
||||||
|
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
|
k, v = kv[0], kv[1]
|
||||||
|
|
||||||
|
x = optimized_attention(
|
||||||
|
q, k, v, heads=self.num_heads, skip_output_reshape=True, skip_reshape=True
|
||||||
|
).transpose(1, 2).reshape(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class Mlp(nn.Module):
|
||||||
|
def __init__(self, in_features, hidden_features=None, out_features=None, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
out_features = out_features or in_features
|
||||||
|
hidden_features = hidden_features or in_features
|
||||||
|
self.fc1 = operations.Linear(in_features, hidden_features, device=device, dtype=dtype)
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.fc2 = operations.Linear(hidden_features, out_features, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.fc2(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def window_partition(x, window_size):
|
||||||
|
B, H, W, C = x.shape
|
||||||
|
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
||||||
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
||||||
|
return windows
|
||||||
|
|
||||||
|
|
||||||
|
def window_reverse(windows, window_size, H, W):
|
||||||
|
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||||
|
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
||||||
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WindowAttention(nn.Module):
|
||||||
|
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, device=None, dtype=None, operations=None):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.window_size = window_size # Wh, Ww
|
||||||
|
self.num_heads = num_heads
|
||||||
|
head_dim = dim // num_heads
|
||||||
|
self.scale = qk_scale or head_dim ** -0.5
|
||||||
|
|
||||||
|
self.relative_position_bias_table = nn.Parameter(
|
||||||
|
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads, device=device, dtype=dtype))
|
||||||
|
|
||||||
|
coords_h = torch.arange(self.window_size[0])
|
||||||
|
coords_w = torch.arange(self.window_size[1])
|
||||||
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
|
||||||
|
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||||
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||||
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||||
|
relative_coords[:, :, 0] += self.window_size[0] - 1
|
||||||
|
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||||
|
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||||
|
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||||
|
self.register_buffer("relative_position_index", relative_position_index)
|
||||||
|
|
||||||
|
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
|
||||||
|
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||||
|
self.softmax = nn.Softmax(dim=-1)
|
||||||
|
|
||||||
|
def forward(self, x, mask=None):
|
||||||
|
B_, N, C = x.shape
|
||||||
|
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv[0], qkv[1], qkv[2]
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
attn = (q @ k.transpose(-2, -1))
|
||||||
|
|
||||||
|
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
|
||||||
|
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
||||||
|
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||||
|
attn = attn + relative_position_bias.unsqueeze(0)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
nW = mask.shape[0]
|
||||||
|
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
|
||||||
|
attn = attn.view(-1, self.num_heads, N, N)
|
||||||
|
attn = self.softmax(attn)
|
||||||
|
else:
|
||||||
|
attn = self.softmax(attn)
|
||||||
|
|
||||||
|
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SwinTransformerBlock(nn.Module):
|
||||||
|
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
|
||||||
|
mlp_ratio=4., qkv_bias=True, qk_scale=None,
|
||||||
|
norm_layer=nn.LayerNorm, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.window_size = window_size
|
||||||
|
self.shift_size = shift_size
|
||||||
|
self.mlp_ratio = mlp_ratio
|
||||||
|
|
||||||
|
self.norm1 = norm_layer(dim, device=device, dtype=dtype)
|
||||||
|
self.attn = WindowAttention(
|
||||||
|
dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
|
||||||
|
qkv_bias=qkv_bias, qk_scale=qk_scale, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
self.norm2 = norm_layer(dim, device=device, dtype=dtype)
|
||||||
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||||
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
self.H = None
|
||||||
|
self.W = None
|
||||||
|
|
||||||
|
def forward(self, x, mask_matrix):
|
||||||
|
B, L, C = x.shape
|
||||||
|
H, W = self.H, self.W
|
||||||
|
|
||||||
|
shortcut = x
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = x.view(B, H, W, C)
|
||||||
|
|
||||||
|
pad_l = pad_t = 0
|
||||||
|
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
||||||
|
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
||||||
|
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
||||||
|
_, Hp, Wp, _ = x.shape
|
||||||
|
|
||||||
|
if self.shift_size > 0:
|
||||||
|
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||||
|
attn_mask = mask_matrix
|
||||||
|
else:
|
||||||
|
shifted_x = x
|
||||||
|
attn_mask = None
|
||||||
|
|
||||||
|
x_windows = window_partition(shifted_x, self.window_size)
|
||||||
|
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
|
||||||
|
|
||||||
|
attn_windows = self.attn(x_windows, mask=attn_mask)
|
||||||
|
|
||||||
|
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
||||||
|
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
|
||||||
|
|
||||||
|
if self.shift_size > 0:
|
||||||
|
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
||||||
|
else:
|
||||||
|
x = shifted_x
|
||||||
|
|
||||||
|
if pad_r > 0 or pad_b > 0:
|
||||||
|
x = x[:, :H, :W, :].contiguous()
|
||||||
|
|
||||||
|
x = x.view(B, H * W, C)
|
||||||
|
|
||||||
|
x = shortcut + x
|
||||||
|
x = x + self.mlp(self.norm2(x))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PatchMerging(nn.Module):
|
||||||
|
def __init__(self, dim, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.reduction = operations.Linear(4 * dim, 2 * dim, bias=False, device=device, dtype=dtype)
|
||||||
|
self.norm = operations.LayerNorm(4 * dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x, H, W):
|
||||||
|
B, L, C = x.shape
|
||||||
|
x = x.view(B, H, W, C)
|
||||||
|
|
||||||
|
# padding
|
||||||
|
pad_input = (H % 2 == 1) or (W % 2 == 1)
|
||||||
|
if pad_input:
|
||||||
|
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
|
||||||
|
|
||||||
|
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||||
|
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||||
|
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||||
|
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||||
|
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||||
|
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
x = self.reduction(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BasicLayer(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
dim,
|
||||||
|
depth,
|
||||||
|
num_heads,
|
||||||
|
window_size=7,
|
||||||
|
mlp_ratio=4.,
|
||||||
|
qkv_bias=True,
|
||||||
|
qk_scale=None,
|
||||||
|
norm_layer=nn.LayerNorm,
|
||||||
|
downsample=None,
|
||||||
|
device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.window_size = window_size
|
||||||
|
self.shift_size = window_size // 2
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
|
# build blocks
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
SwinTransformerBlock(
|
||||||
|
dim=dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
window_size=window_size,
|
||||||
|
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_scale=qk_scale,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
device=device, dtype=dtype, operations=operations)
|
||||||
|
for i in range(depth)])
|
||||||
|
|
||||||
|
# patch merging layer
|
||||||
|
if downsample is not None:
|
||||||
|
self.downsample = downsample(dim=dim, device=device, dtype=dtype, operations=operations)
|
||||||
|
else:
|
||||||
|
self.downsample = None
|
||||||
|
|
||||||
|
def forward(self, x, H, W):
|
||||||
|
Hp = int(np.ceil(H / self.window_size)) * self.window_size
|
||||||
|
Wp = int(np.ceil(W / self.window_size)) * self.window_size
|
||||||
|
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
|
||||||
|
h_slices = (slice(0, -self.window_size),
|
||||||
|
slice(-self.window_size, -self.shift_size),
|
||||||
|
slice(-self.shift_size, None))
|
||||||
|
w_slices = (slice(0, -self.window_size),
|
||||||
|
slice(-self.window_size, -self.shift_size),
|
||||||
|
slice(-self.shift_size, None))
|
||||||
|
cnt = 0
|
||||||
|
for h in h_slices:
|
||||||
|
for w in w_slices:
|
||||||
|
img_mask[:, h, w, :] = cnt
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
mask_windows = window_partition(img_mask, self.window_size)
|
||||||
|
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
||||||
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||||
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
||||||
|
|
||||||
|
for blk in self.blocks:
|
||||||
|
blk.H, blk.W = H, W
|
||||||
|
x = blk(x, attn_mask)
|
||||||
|
if self.downsample is not None:
|
||||||
|
x_down = self.downsample(x, H, W)
|
||||||
|
Wh, Ww = (H + 1) // 2, (W + 1) // 2
|
||||||
|
return x, H, W, x_down, Wh, Ww
|
||||||
|
else:
|
||||||
|
return x, H, W, x, H, W
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbed(nn.Module):
|
||||||
|
def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None, device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
patch_size = (patch_size, patch_size)
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype)
|
||||||
|
if norm_layer is not None:
|
||||||
|
self.norm = norm_layer(embed_dim, device=device, dtype=dtype)
|
||||||
|
else:
|
||||||
|
self.norm = None
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
_, _, H, W = x.size()
|
||||||
|
if W % self.patch_size[1] != 0:
|
||||||
|
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
||||||
|
if H % self.patch_size[0] != 0:
|
||||||
|
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
||||||
|
|
||||||
|
x = self.proj(x) # B C Wh Ww
|
||||||
|
if self.norm is not None:
|
||||||
|
Wh, Ww = x.size(2), x.size(3)
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
x = self.norm(x)
|
||||||
|
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SwinTransformer(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
pretrain_img_size=224,
|
||||||
|
patch_size=4,
|
||||||
|
in_channels=3,
|
||||||
|
embed_dim=96,
|
||||||
|
depths=[2, 2, 6, 2],
|
||||||
|
num_heads=[3, 6, 12, 24],
|
||||||
|
window_size=7,
|
||||||
|
mlp_ratio=4.,
|
||||||
|
qkv_bias=True,
|
||||||
|
qk_scale=None,
|
||||||
|
patch_norm=True,
|
||||||
|
out_indices=(0, 1, 2, 3),
|
||||||
|
frozen_stages=-1,
|
||||||
|
device=None, dtype=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
norm_layer = partial(operations.LayerNorm, device=device, dtype=dtype)
|
||||||
|
self.pretrain_img_size = pretrain_img_size
|
||||||
|
self.num_layers = len(depths)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.patch_norm = patch_norm
|
||||||
|
self.out_indices = out_indices
|
||||||
|
self.frozen_stages = frozen_stages
|
||||||
|
|
||||||
|
self.patch_embed = PatchEmbed(
|
||||||
|
patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
|
||||||
|
device=device, dtype=dtype, operations=operations,
|
||||||
|
norm_layer=norm_layer if self.patch_norm else None)
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList()
|
||||||
|
for i_layer in range(self.num_layers):
|
||||||
|
layer = BasicLayer(
|
||||||
|
dim=int(embed_dim * 2 ** i_layer),
|
||||||
|
depth=depths[i_layer],
|
||||||
|
num_heads=num_heads[i_layer],
|
||||||
|
window_size=window_size,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_scale=qk_scale,
|
||||||
|
norm_layer=norm_layer,
|
||||||
|
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
|
||||||
|
device=device, dtype=dtype, operations=operations)
|
||||||
|
self.layers.append(layer)
|
||||||
|
|
||||||
|
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
|
||||||
|
self.num_features = num_features
|
||||||
|
|
||||||
|
for i_layer in out_indices:
|
||||||
|
layer = norm_layer(num_features[i_layer])
|
||||||
|
layer_name = f'norm{i_layer}'
|
||||||
|
self.add_module(layer_name, layer)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.patch_embed(x)
|
||||||
|
|
||||||
|
Wh, Ww = x.size(2), x.size(3)
|
||||||
|
|
||||||
|
outs = []
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
layer = self.layers[i]
|
||||||
|
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
|
||||||
|
|
||||||
|
if i in self.out_indices:
|
||||||
|
norm_layer = getattr(self, f'norm{i}')
|
||||||
|
x_out = norm_layer(x_out)
|
||||||
|
|
||||||
|
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
|
||||||
|
outs.append(out)
|
||||||
|
|
||||||
|
return tuple(outs)
|
||||||
|
|
||||||
|
class DeformableConv2d(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
padding=1,
|
||||||
|
bias=False, device=None, dtype=None, operations=None):
|
||||||
|
|
||||||
|
super(DeformableConv2d, self).__init__()
|
||||||
|
|
||||||
|
kernel_size = kernel_size if type(kernel_size) is tuple else (kernel_size, kernel_size)
|
||||||
|
self.stride = stride if type(stride) is tuple else (stride, stride)
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
self.offset_conv = operations.Conv2d(in_channels,
|
||||||
|
2 * kernel_size[0] * kernel_size[1],
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=self.padding,
|
||||||
|
bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.modulator_conv = operations.Conv2d(in_channels,
|
||||||
|
1 * kernel_size[0] * kernel_size[1],
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=self.padding,
|
||||||
|
bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.regular_conv = operations.Conv2d(in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=self.padding,
|
||||||
|
bias=bias, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
offset = self.offset_conv(x)
|
||||||
|
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
|
||||||
|
weight, bias, offload_info = comfy.ops.cast_bias_weight(self.regular_conv, x, offloadable=True)
|
||||||
|
|
||||||
|
x = deform_conv2d(
|
||||||
|
input=x,
|
||||||
|
offset=offset,
|
||||||
|
weight=weight,
|
||||||
|
bias=None,
|
||||||
|
padding=self.padding,
|
||||||
|
mask=modulator,
|
||||||
|
stride=self.stride,
|
||||||
|
)
|
||||||
|
comfy.ops.uncast_bias_weight(self.regular_conv, weight, bias, offload_info)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class BasicDecBlk(nn.Module):
|
||||||
|
def __init__(self, in_channels=64, out_channels=64, inter_channels=64, device=None, dtype=None, operations=None):
|
||||||
|
super(BasicDecBlk, self).__init__()
|
||||||
|
inter_channels = 64
|
||||||
|
self.conv_in = operations.Conv2d(in_channels, inter_channels, 3, 1, padding=1, device=device, dtype=dtype)
|
||||||
|
self.relu_in = nn.ReLU(inplace=True)
|
||||||
|
self.dec_att = ASPPDeformable(in_channels=inter_channels, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, padding=1, device=device, dtype=dtype)
|
||||||
|
self.bn_in = operations.BatchNorm2d(inter_channels, device=device, dtype=dtype)
|
||||||
|
self.bn_out = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv_in(x)
|
||||||
|
x = self.bn_in(x)
|
||||||
|
x = self.relu_in(x)
|
||||||
|
x = self.dec_att(x)
|
||||||
|
x = self.conv_out(x)
|
||||||
|
x = self.bn_out(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BasicLatBlk(nn.Module):
|
||||||
|
def __init__(self, in_channels=64, out_channels=64, device=None, dtype=None, operations=None):
|
||||||
|
super(BasicLatBlk, self).__init__()
|
||||||
|
self.conv = operations.Conv2d(in_channels, out_channels, 1, 1, 0, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class _ASPPModuleDeformable(nn.Module):
|
||||||
|
def __init__(self, in_channels, planes, kernel_size, padding, device, dtype, operations):
|
||||||
|
super(_ASPPModuleDeformable, self).__init__()
|
||||||
|
self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size,
|
||||||
|
stride=1, padding=padding, bias=False, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.bn = operations.BatchNorm2d(planes, device=device, dtype=dtype)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.atrous_conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
|
||||||
|
return self.relu(x)
|
||||||
|
|
||||||
|
|
||||||
|
class ASPPDeformable(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7], device=None, dtype=None, operations=None):
|
||||||
|
super(ASPPDeformable, self).__init__()
|
||||||
|
self.down_scale = 1
|
||||||
|
if out_channels is None:
|
||||||
|
out_channels = in_channels
|
||||||
|
self.in_channelster = 256 // self.down_scale
|
||||||
|
|
||||||
|
self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.aspp_deforms = nn.ModuleList([
|
||||||
|
_ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2), device=device, dtype=dtype, operations=operations)
|
||||||
|
for conv_size in parallel_block_sizes
|
||||||
|
])
|
||||||
|
|
||||||
|
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
|
||||||
|
operations.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False, device=device, dtype=dtype),
|
||||||
|
operations.BatchNorm2d(self.in_channelster, device=device, dtype=dtype),
|
||||||
|
nn.ReLU(inplace=True))
|
||||||
|
self.conv1 = operations.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False, device=device, dtype=dtype)
|
||||||
|
self.bn1 = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x1 = self.aspp1(x)
|
||||||
|
x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
|
||||||
|
x5 = self.global_avg_pool(x)
|
||||||
|
x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
|
||||||
|
x = torch.cat((x1, *x_aspp_deforms, x5), dim=1)
|
||||||
|
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
class BiRefNet(nn.Module):
|
||||||
|
def __init__(self, config=None, dtype=None, device=None, operations=None):
|
||||||
|
super(BiRefNet, self).__init__()
|
||||||
|
self.bb = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
channels = [1536, 768, 384, 192]
|
||||||
|
channels = [c * 2 for c in channels]
|
||||||
|
self.cxt = channels[1:][::-1][-3:]
|
||||||
|
self.squeeze_module = nn.Sequential(*[
|
||||||
|
BasicDecBlk(channels[0]+sum(self.cxt), channels[0], device=device, dtype=dtype, operations=operations)
|
||||||
|
for _ in range(1)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.decoder = Decoder(channels, device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
def forward_enc(self, x):
|
||||||
|
x1, x2, x3, x4 = self.bb(x)
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
|
||||||
|
x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
||||||
|
x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
||||||
|
x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
||||||
|
x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
|
||||||
|
x4 = torch.cat(
|
||||||
|
(
|
||||||
|
*[
|
||||||
|
F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
|
||||||
|
F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
|
||||||
|
F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
|
||||||
|
][-len(CXT):],
|
||||||
|
x4
|
||||||
|
),
|
||||||
|
dim=1
|
||||||
|
)
|
||||||
|
return (x1, x2, x3, x4)
|
||||||
|
|
||||||
|
def forward_ori(self, x):
|
||||||
|
(x1, x2, x3, x4) = self.forward_enc(x)
|
||||||
|
x4 = self.squeeze_module(x4)
|
||||||
|
features = [x, x1, x2, x3, x4]
|
||||||
|
scaled_preds = self.decoder(features)
|
||||||
|
return scaled_preds
|
||||||
|
|
||||||
|
def forward(self, pixel_values, intermediate_output=None):
|
||||||
|
scaled_preds = self.forward_ori(pixel_values)
|
||||||
|
return scaled_preds
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
def __init__(self, channels, device, dtype, operations):
|
||||||
|
super(Decoder, self).__init__()
|
||||||
|
# factory kwargs
|
||||||
|
fk = {"device":device, "dtype":dtype, "operations":operations}
|
||||||
|
DecoderBlock = partial(BasicDecBlk, **fk)
|
||||||
|
LateralBlock = partial(BasicLatBlk, **fk)
|
||||||
|
DBlock = partial(SimpleConvs, **fk)
|
||||||
|
|
||||||
|
self.split = True
|
||||||
|
N_dec_ipt = 64
|
||||||
|
ic = 64
|
||||||
|
ipt_cha_opt = 1
|
||||||
|
self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
|
||||||
|
self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
|
||||||
|
self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic)
|
||||||
|
self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic)
|
||||||
|
self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic)
|
||||||
|
|
||||||
|
self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[1])
|
||||||
|
self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[2])
|
||||||
|
self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt]), channels[3])
|
||||||
|
self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt]), channels[3]//2)
|
||||||
|
|
||||||
|
fk = {"device":device, "dtype":dtype}
|
||||||
|
|
||||||
|
self.conv_out1 = nn.Sequential(operations.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt]), 1, 1, 1, 0, **fk))
|
||||||
|
|
||||||
|
self.lateral_block4 = LateralBlock(channels[1], channels[1])
|
||||||
|
self.lateral_block3 = LateralBlock(channels[2], channels[2])
|
||||||
|
self.lateral_block2 = LateralBlock(channels[3], channels[3])
|
||||||
|
|
||||||
|
self.conv_ms_spvn_4 = operations.Conv2d(channels[1], 1, 1, 1, 0, **fk)
|
||||||
|
self.conv_ms_spvn_3 = operations.Conv2d(channels[2], 1, 1, 1, 0, **fk)
|
||||||
|
self.conv_ms_spvn_2 = operations.Conv2d(channels[3], 1, 1, 1, 0, **fk)
|
||||||
|
|
||||||
|
_N = 16
|
||||||
|
|
||||||
|
self.gdt_convs_4 = nn.Sequential(operations.Conv2d(channels[0] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
|
||||||
|
self.gdt_convs_3 = nn.Sequential(operations.Conv2d(channels[1] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
|
||||||
|
self.gdt_convs_2 = nn.Sequential(operations.Conv2d(channels[2] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
|
||||||
|
|
||||||
|
[setattr(self, f"gdt_convs_pred_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
|
||||||
|
[setattr(self, f"gdt_convs_attn_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
|
||||||
|
|
||||||
|
def get_patches_batch(self, x, p):
|
||||||
|
_size_h, _size_w = p.shape[2:]
|
||||||
|
patches_batch = []
|
||||||
|
for idx in range(x.shape[0]):
|
||||||
|
columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
|
||||||
|
patches_x = []
|
||||||
|
for column_x in columns_x:
|
||||||
|
patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
|
||||||
|
patch_sample = torch.cat(patches_x, dim=1)
|
||||||
|
patches_batch.append(patch_sample)
|
||||||
|
return torch.cat(patches_batch, dim=0)
|
||||||
|
|
||||||
|
def forward(self, features):
|
||||||
|
x, x1, x2, x3, x4 = features
|
||||||
|
|
||||||
|
patches_batch = self.get_patches_batch(x, x4) if self.split else x
|
||||||
|
x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||||
|
p4 = self.decoder_block4(x4)
|
||||||
|
p4_gdt = self.gdt_convs_4(p4)
|
||||||
|
gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
|
||||||
|
p4 = p4 * gdt_attn_4
|
||||||
|
_p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
|
||||||
|
_p3 = _p4 + self.lateral_block4(x3)
|
||||||
|
|
||||||
|
patches_batch = self.get_patches_batch(x, _p3) if self.split else x
|
||||||
|
_p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||||
|
p3 = self.decoder_block3(_p3)
|
||||||
|
|
||||||
|
p3_gdt = self.gdt_convs_3(p3)
|
||||||
|
gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
|
||||||
|
p3 = p3 * gdt_attn_3
|
||||||
|
_p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
|
||||||
|
_p2 = _p3 + self.lateral_block3(x2)
|
||||||
|
|
||||||
|
patches_batch = self.get_patches_batch(x, _p2) if self.split else x
|
||||||
|
_p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||||
|
p2 = self.decoder_block2(_p2)
|
||||||
|
|
||||||
|
p2_gdt = self.gdt_convs_2(p2)
|
||||||
|
gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
|
||||||
|
p2 = p2 * gdt_attn_2
|
||||||
|
|
||||||
|
_p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
|
||||||
|
_p1 = _p2 + self.lateral_block2(x1)
|
||||||
|
|
||||||
|
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
|
||||||
|
_p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||||
|
_p1 = self.decoder_block1(_p1)
|
||||||
|
_p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
|
||||||
|
|
||||||
|
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
|
||||||
|
_p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
|
||||||
|
p1_out = self.conv_out1(_p1)
|
||||||
|
return p1_out
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleConvs(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, in_channels: int, out_channels: int, inter_channels=64, device=None, dtype=None, operations=None
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.conv1 = operations.Conv2d(in_channels, inter_channels, 3, 1, 1, device=device, dtype=dtype)
|
||||||
|
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, 1, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv_out(self.conv1(x))
|
||||||
78
comfy/bg_removal_model.py
Normal file
78
comfy/bg_removal_model.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
from .utils import load_torch_file
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import comfy.ops
|
||||||
|
import comfy.model_patcher
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.clip_model
|
||||||
|
import comfy.background_removal.birefnet
|
||||||
|
|
||||||
|
BG_REMOVAL_MODELS = {
|
||||||
|
"birefnet": comfy.background_removal.birefnet.BiRefNet
|
||||||
|
}
|
||||||
|
|
||||||
|
class BackgroundRemovalModel():
|
||||||
|
def __init__(self, json_config):
|
||||||
|
with open(json_config) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
self.image_size = config.get("image_size", 1024)
|
||||||
|
self.image_mean = config.get("image_mean", [0.0, 0.0, 0.0])
|
||||||
|
self.image_std = config.get("image_std", [1.0, 1.0, 1.0])
|
||||||
|
self.model_type = config.get("model_type", "birefnet")
|
||||||
|
self.config = config.copy()
|
||||||
|
model_class = BG_REMOVAL_MODELS.get(self.model_type)
|
||||||
|
|
||||||
|
self.load_device = comfy.model_management.text_encoder_device()
|
||||||
|
offload_device = comfy.model_management.text_encoder_offload_device()
|
||||||
|
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
|
||||||
|
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||||
|
|
||||||
|
def load_sd(self, sd):
|
||||||
|
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||||
|
|
||||||
|
def get_sd(self):
|
||||||
|
return self.model.state_dict()
|
||||||
|
|
||||||
|
def encode_image(self, image):
|
||||||
|
comfy.model_management.load_model_gpu(self.patcher)
|
||||||
|
H, W = image.shape[1], image.shape[2]
|
||||||
|
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False)
|
||||||
|
out = self.model(pixel_values=pixel_values)
|
||||||
|
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
|
||||||
|
|
||||||
|
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
if mask.shape[1] != 1:
|
||||||
|
mask = mask.movedim(-1, 1)
|
||||||
|
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def load_background_removal_model(sd):
|
||||||
|
if "bb.layers.1.blocks.0.attn.relative_position_index" in sd:
|
||||||
|
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "background_removal"), "birefnet.json")
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
bg_model = BackgroundRemovalModel(json_config)
|
||||||
|
m, u = bg_model.load_sd(sd)
|
||||||
|
if len(m) > 0:
|
||||||
|
logging.warning("missing background removal: {}".format(m))
|
||||||
|
u = set(u)
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
if k not in u:
|
||||||
|
sd.pop(k)
|
||||||
|
return bg_model
|
||||||
|
|
||||||
|
def load(ckpt_path):
|
||||||
|
sd = load_torch_file(ckpt_path)
|
||||||
|
return load_background_removal_model(sd)
|
||||||
@ -93,7 +93,7 @@ class Hook:
|
|||||||
self.hook_scope = hook_scope
|
self.hook_scope = hook_scope
|
||||||
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
|
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
|
||||||
self.custom_should_register = default_should_register
|
self.custom_should_register = default_should_register
|
||||||
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
|
'''Can be overridden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def strength(self):
|
def strength(self):
|
||||||
|
|||||||
@ -140,7 +140,7 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
|||||||
alphas = alphacums[ddim_timesteps]
|
alphas = alphacums[ddim_timesteps]
|
||||||
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||||
|
|
||||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
# according to the formula provided in https://arxiv.org/abs/2010.02502
|
||||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||||
if verbose:
|
if verbose:
|
||||||
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||||
|
|||||||
22
comfy/ops.py
22
comfy/ops.py
@ -562,6 +562,25 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input):
|
||||||
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||||
|
running_mean = self.running_mean.to(device=input.device, dtype=weight.dtype) if self.running_mean is not None else None
|
||||||
|
running_var = self.running_var.to(device=input.device, dtype=weight.dtype) if self.running_var is not None else None
|
||||||
|
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, self.training, self.momentum, self.eps)
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
run_every_op()
|
||||||
|
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||||
|
return self.forward_comfy_cast_weights(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
@ -749,6 +768,9 @@ class manual_cast(disable_weight_init):
|
|||||||
class Conv3d(disable_weight_init.Conv3d):
|
class Conv3d(disable_weight_init.Conv3d):
|
||||||
comfy_cast_weights = True
|
comfy_cast_weights = True
|
||||||
|
|
||||||
|
class BatchNorm2d(disable_weight_init.BatchNorm2d):
|
||||||
|
comfy_cast_weights = True
|
||||||
|
|
||||||
class GroupNorm(disable_weight_init.GroupNorm):
|
class GroupNorm(disable_weight_init.GroupNorm):
|
||||||
comfy_cast_weights = True
|
comfy_cast_weights = True
|
||||||
|
|
||||||
|
|||||||
@ -1390,7 +1390,7 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
|||||||
k_out = "{}.weight_scale".format(layer)
|
k_out = "{}.weight_scale".format(layer)
|
||||||
|
|
||||||
if layer is not None:
|
if layer is not None:
|
||||||
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
|
layer_conf = {"format": "float8_e4m3fn"}
|
||||||
if full_precision_matrix_mult:
|
if full_precision_matrix_mult:
|
||||||
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
|
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
|
||||||
layers[layer] = layer_conf
|
layers[layer] = layer_conf
|
||||||
|
|||||||
@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
|||||||
from spandrel import ImageModelDescriptor
|
from spandrel import ImageModelDescriptor
|
||||||
from comfy.clip_vision import ClipVisionModel
|
from comfy.clip_vision import ClipVisionModel
|
||||||
from comfy.clip_vision import Output as ClipVisionOutput_
|
from comfy.clip_vision import Output as ClipVisionOutput_
|
||||||
|
from comfy.bg_removal_model import BackgroundRemovalModel
|
||||||
from comfy.controlnet import ControlNet
|
from comfy.controlnet import ControlNet
|
||||||
from comfy.hooks import HookGroup, HookKeyframeGroup
|
from comfy.hooks import HookGroup, HookKeyframeGroup
|
||||||
from comfy.model_patcher import ModelPatcher
|
from comfy.model_patcher import ModelPatcher
|
||||||
@ -614,6 +615,11 @@ class Model(ComfyTypeIO):
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
Type = ModelPatcher
|
Type = ModelPatcher
|
||||||
|
|
||||||
|
@comfytype(io_type="BACKGROUND_REMOVAL")
|
||||||
|
class BackgroundRemoval(ComfyTypeIO):
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
Type = BackgroundRemovalModel
|
||||||
|
|
||||||
@comfytype(io_type="CLIP_VISION")
|
@comfytype(io_type="CLIP_VISION")
|
||||||
class ClipVision(ComfyTypeIO):
|
class ClipVision(ComfyTypeIO):
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -2257,6 +2263,7 @@ __all__ = [
|
|||||||
"ModelPatch",
|
"ModelPatch",
|
||||||
"ClipVision",
|
"ClipVision",
|
||||||
"ClipVisionOutput",
|
"ClipVisionOutput",
|
||||||
|
"BackgroundRemoval",
|
||||||
"AudioEncoder",
|
"AudioEncoder",
|
||||||
"AudioEncoderOutput",
|
"AudioEncoderOutput",
|
||||||
"StyleModel",
|
"StyleModel",
|
||||||
|
|||||||
@ -1271,7 +1271,7 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _seedance2_text_inputs(resolutions: list[str]):
|
def _seedance2_text_inputs(resolutions: list[str], default_ratio: str = "16:9"):
|
||||||
return [
|
return [
|
||||||
IO.String.Input(
|
IO.String.Input(
|
||||||
"prompt",
|
"prompt",
|
||||||
@ -1287,6 +1287,7 @@ def _seedance2_text_inputs(resolutions: list[str]):
|
|||||||
IO.Combo.Input(
|
IO.Combo.Input(
|
||||||
"ratio",
|
"ratio",
|
||||||
options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"],
|
options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"],
|
||||||
|
default=default_ratio,
|
||||||
tooltip="Aspect ratio of the output video.",
|
tooltip="Aspect ratio of the output video.",
|
||||||
),
|
),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
@ -1420,8 +1421,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
|||||||
IO.DynamicCombo.Input(
|
IO.DynamicCombo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[
|
options=[
|
||||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])),
|
IO.DynamicCombo.Option(
|
||||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])),
|
"Seedance 2.0",
|
||||||
|
_seedance2_text_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"Seedance 2.0 Fast",
|
||||||
|
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||||
),
|
),
|
||||||
@ -1588,9 +1595,9 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||||
|
|
||||||
|
|
||||||
def _seedance2_reference_inputs(resolutions: list[str]):
|
def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16:9"):
|
||||||
return [
|
return [
|
||||||
*_seedance2_text_inputs(resolutions),
|
*_seedance2_text_inputs(resolutions, default_ratio=default_ratio),
|
||||||
IO.Autogrow.Input(
|
IO.Autogrow.Input(
|
||||||
"reference_images",
|
"reference_images",
|
||||||
template=IO.Autogrow.TemplateNames(
|
template=IO.Autogrow.TemplateNames(
|
||||||
@ -1668,8 +1675,14 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
|||||||
IO.DynamicCombo.Input(
|
IO.DynamicCombo.Input(
|
||||||
"model",
|
"model",
|
||||||
options=[
|
options=[
|
||||||
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs(["480p", "720p", "1080p"])),
|
IO.DynamicCombo.Option(
|
||||||
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs(["480p", "720p"])),
|
"Seedance 2.0",
|
||||||
|
_seedance2_reference_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"Seedance 2.0 Fast",
|
||||||
|
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
|
||||||
),
|
),
|
||||||
|
|||||||
@ -488,10 +488,30 @@ async def _diagnose_connectivity() -> dict[str, bool]:
|
|||||||
"api_accessible": False,
|
"api_accessible": False,
|
||||||
}
|
}
|
||||||
timeout = aiohttp.ClientTimeout(total=5.0)
|
timeout = aiohttp.ClientTimeout(total=5.0)
|
||||||
|
|
||||||
|
# Probe Google and Baidu in parallel: Google is blocked by the GFW in mainland China, so a Baidu probe is required
|
||||||
|
# to correctly detect that Chinese users with working internet do have working internet.
|
||||||
|
internet_probe_urls = ("https://www.google.com", "https://www.baidu.com")
|
||||||
|
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
with contextlib.suppress(ClientError, OSError):
|
async def _probe(url: str) -> bool:
|
||||||
async with session.get("https://www.google.com") as resp:
|
try:
|
||||||
results["internet_accessible"] = resp.status < 500
|
async with session.get(url) as resp:
|
||||||
|
return resp.status < 500
|
||||||
|
except (ClientError, OSError, asyncio.TimeoutError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
probe_tasks = [asyncio.create_task(_probe(u)) for u in internet_probe_urls]
|
||||||
|
try:
|
||||||
|
for fut in asyncio.as_completed(probe_tasks):
|
||||||
|
if await fut:
|
||||||
|
results["internet_accessible"] = True
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
for t in probe_tasks:
|
||||||
|
if not t.done():
|
||||||
|
t.cancel()
|
||||||
|
await asyncio.gather(*probe_tasks, return_exceptions=True)
|
||||||
if not results["internet_accessible"]:
|
if not results["internet_accessible"]:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|||||||
@ -92,7 +92,7 @@ class SamplerEulerCFGpp(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="SamplerEulerCFGpp",
|
node_id="SamplerEulerCFGpp",
|
||||||
display_name="SamplerEulerCFG++",
|
display_name="SamplerEulerCFG++",
|
||||||
category="_for_testing", # "sampling/custom_sampling/samplers"
|
category="experimental", # "sampling/custom_sampling/samplers"
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Combo.Input("version", options=["regular", "alternative"], advanced=True),
|
io.Combo.Input("version", options=["regular", "alternative"], advanced=True),
|
||||||
],
|
],
|
||||||
|
|||||||
@ -25,7 +25,7 @@ class UNetSelfAttentionMultiply(io.ComfyNode):
|
|||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="UNetSelfAttentionMultiply",
|
node_id="UNetSelfAttentionMultiply",
|
||||||
category="_for_testing/attention_experiments",
|
category="experimental/attention_experiments",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||||
@ -48,7 +48,7 @@ class UNetCrossAttentionMultiply(io.ComfyNode):
|
|||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="UNetCrossAttentionMultiply",
|
node_id="UNetCrossAttentionMultiply",
|
||||||
category="_for_testing/attention_experiments",
|
category="experimental/attention_experiments",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||||
@ -72,7 +72,7 @@ class CLIPAttentionMultiply(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="CLIPAttentionMultiply",
|
node_id="CLIPAttentionMultiply",
|
||||||
search_aliases=["clip attention scale", "text encoder attention"],
|
search_aliases=["clip attention scale", "text encoder attention"],
|
||||||
category="_for_testing/attention_experiments",
|
category="experimental/attention_experiments",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Clip.Input("clip"),
|
io.Clip.Input("clip"),
|
||||||
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||||
@ -106,7 +106,7 @@ class UNetTemporalAttentionMultiply(io.ComfyNode):
|
|||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="UNetTemporalAttentionMultiply",
|
node_id="UNetTemporalAttentionMultiply",
|
||||||
category="_for_testing/attention_experiments",
|
category="experimental/attention_experiments",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||||
|
|||||||
@ -10,6 +10,7 @@ class AudioEncoderLoader(io.ComfyNode):
|
|||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="AudioEncoderLoader",
|
node_id="AudioEncoderLoader",
|
||||||
|
display_name="Load Audio Encoder",
|
||||||
category="loaders",
|
category="loaders",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
|
|||||||
60
comfy_extras/nodes_bg_removal.py
Normal file
60
comfy_extras/nodes_bg_removal.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import folder_paths
|
||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, IO
|
||||||
|
from comfy.bg_removal_model import load
|
||||||
|
|
||||||
|
|
||||||
|
class LoadBackgroundRemovalModel(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
files = folder_paths.get_filename_list("background_removal")
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="LoadBackgroundRemovalModel",
|
||||||
|
display_name="Load Background Removal Model",
|
||||||
|
category="loaders",
|
||||||
|
inputs=[
|
||||||
|
IO.Combo.Input("bg_removal_name", options=sorted(files), tooltip="The model used to remove backgrounds from images"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.BackgroundRemoval.Output("bg_model")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, bg_removal_name):
|
||||||
|
path = folder_paths.get_full_path_or_raise("background_removal", bg_removal_name)
|
||||||
|
bg = load(path)
|
||||||
|
if bg is None:
|
||||||
|
raise RuntimeError("ERROR: background model file is invalid and does not contain a valid background removal model.")
|
||||||
|
return IO.NodeOutput(bg)
|
||||||
|
|
||||||
|
class RemoveBackground(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="RemoveBackground",
|
||||||
|
display_name="Remove Background",
|
||||||
|
category="image/background removal",
|
||||||
|
inputs=[
|
||||||
|
IO.Image.Input("image", tooltip="Input image to remove the background from"),
|
||||||
|
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask")
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Mask.Output("mask", tooltip="Generated foreground mask")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, image, bg_removal_model):
|
||||||
|
mask = bg_removal_model.encode_image(image)
|
||||||
|
return IO.NodeOutput(mask)
|
||||||
|
|
||||||
|
class BackgroundRemovalExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
LoadBackgroundRemovalModel,
|
||||||
|
RemoveBackground
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> BackgroundRemovalExtension:
|
||||||
|
return BackgroundRemovalExtension()
|
||||||
@ -153,7 +153,7 @@ class WanCameraEmbedding(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="WanCameraEmbedding",
|
node_id="WanCameraEmbedding",
|
||||||
category="camera",
|
category="conditioning/video_models",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"camera_pose",
|
"camera_pose",
|
||||||
|
|||||||
@ -203,7 +203,7 @@ class JoinImageWithAlpha(io.ComfyNode):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
|
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
|
||||||
batch_size = max(len(image), len(alpha))
|
batch_size = max(len(image), len(alpha))
|
||||||
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
|
alpha = 1.0 - resize_mask(alpha.to(image), image.shape[1:])
|
||||||
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
|
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
|
||||||
image = comfy.utils.repeat_to_batch_size(image, batch_size)
|
image = comfy.utils.repeat_to_batch_size(image, batch_size)
|
||||||
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))
|
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))
|
||||||
|
|||||||
@ -8,7 +8,7 @@ class CLIPTextEncodeControlnet(io.ComfyNode):
|
|||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="CLIPTextEncodeControlnet",
|
node_id="CLIPTextEncodeControlnet",
|
||||||
category="_for_testing/conditioning",
|
category="experimental/conditioning",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Clip.Input("clip"),
|
io.Clip.Input("clip"),
|
||||||
io.Conditioning.Input("conditioning"),
|
io.Conditioning.Input("conditioning"),
|
||||||
@ -35,7 +35,7 @@ class T5TokenizerOptions(io.ComfyNode):
|
|||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="T5TokenizerOptions",
|
node_id="T5TokenizerOptions",
|
||||||
category="_for_testing/conditioning",
|
category="experimental/conditioning",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Clip.Input("clip"),
|
io.Clip.Input("clip"),
|
||||||
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True),
|
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True),
|
||||||
|
|||||||
@ -10,7 +10,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="ContextWindowsManual",
|
node_id="ContextWindowsManual",
|
||||||
display_name="Context Windows (Manual)",
|
display_name="Context Windows (Manual)",
|
||||||
category="context",
|
category="model_patches",
|
||||||
description="Manually set context windows.",
|
description="Manually set context windows.",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||||
|
|||||||
@ -984,7 +984,7 @@ class AddNoise(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="AddNoise",
|
node_id="AddNoise",
|
||||||
category="_for_testing/custom_sampling/noise",
|
category="experimental/custom_sampling/noise",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
@ -1034,7 +1034,7 @@ class ManualSigmas(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="ManualSigmas",
|
node_id="ManualSigmas",
|
||||||
search_aliases=["custom noise schedule", "define sigmas"],
|
search_aliases=["custom noise schedule", "define sigmas"],
|
||||||
category="_for_testing/custom_sampling",
|
category="experimental/custom_sampling",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.String.Input("sigmas", default="1, 0.5", multiline=False)
|
io.String.Input("sigmas", default="1, 0.5", multiline=False)
|
||||||
|
|||||||
@ -13,7 +13,7 @@ class DifferentialDiffusion(io.ComfyNode):
|
|||||||
node_id="DifferentialDiffusion",
|
node_id="DifferentialDiffusion",
|
||||||
search_aliases=["inpaint gradient", "variable denoise strength"],
|
search_aliases=["inpaint gradient", "variable denoise strength"],
|
||||||
display_name="Differential Diffusion",
|
display_name="Differential Diffusion",
|
||||||
category="_for_testing",
|
category="experimental",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.Float.Input(
|
io.Float.Input(
|
||||||
|
|||||||
@ -102,7 +102,7 @@ class FluxDisableGuidance(io.ComfyNode):
|
|||||||
append = execute # TODO: remove
|
append = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
PREFERED_KONTEXT_RESOLUTIONS = [
|
PREFERRED_KONTEXT_RESOLUTIONS = [
|
||||||
(672, 1568),
|
(672, 1568),
|
||||||
(688, 1504),
|
(688, 1504),
|
||||||
(720, 1456),
|
(720, 1456),
|
||||||
@ -143,7 +143,7 @@ class FluxKontextImageScale(io.ComfyNode):
|
|||||||
width = image.shape[2]
|
width = image.shape[2]
|
||||||
height = image.shape[1]
|
height = image.shape[1]
|
||||||
aspect_ratio = width / height
|
aspect_ratio = width / height
|
||||||
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
|
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS)
|
||||||
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
|
||||||
return io.NodeOutput(image)
|
return io.NodeOutput(image)
|
||||||
|
|
||||||
|
|||||||
@ -60,7 +60,7 @@ class FreSca(io.ComfyNode):
|
|||||||
node_id="FreSca",
|
node_id="FreSca",
|
||||||
search_aliases=["frequency guidance"],
|
search_aliases=["frequency guidance"],
|
||||||
display_name="FreSca",
|
display_name="FreSca",
|
||||||
category="_for_testing",
|
category="experimental",
|
||||||
description="Applies frequency-dependent scaling to the guidance",
|
description="Applies frequency-dependent scaling to the guidance",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
|
|||||||
@ -131,6 +131,8 @@ class HunyuanVideo15SuperResolution(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="HunyuanVideo15SuperResolution",
|
node_id="HunyuanVideo15SuperResolution",
|
||||||
|
display_name="Hunyuan Video 1.5 Super Resolution",
|
||||||
|
category="conditioning/video_models",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Conditioning.Input("positive"),
|
io.Conditioning.Input("positive"),
|
||||||
io.Conditioning.Input("negative"),
|
io.Conditioning.Input("negative"),
|
||||||
@ -381,6 +383,8 @@ class HunyuanRefinerLatent(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="HunyuanRefinerLatent",
|
node_id="HunyuanRefinerLatent",
|
||||||
|
display_name="Hunyuan Latent Refiner",
|
||||||
|
category="conditioning/video_models",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Conditioning.Input("positive"),
|
io.Conditioning.Input("positive"),
|
||||||
io.Conditioning.Input("negative"),
|
io.Conditioning.Input("negative"),
|
||||||
|
|||||||
@ -40,7 +40,7 @@ class Hunyuan3Dv2Conditioning(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="Hunyuan3Dv2Conditioning",
|
node_id="Hunyuan3Dv2Conditioning",
|
||||||
category="conditioning/video_models",
|
category="conditioning/3d_models",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.ClipVisionOutput.Input("clip_vision_output"),
|
IO.ClipVisionOutput.Input("clip_vision_output"),
|
||||||
],
|
],
|
||||||
@ -65,7 +65,7 @@ class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="Hunyuan3Dv2ConditioningMultiView",
|
node_id="Hunyuan3Dv2ConditioningMultiView",
|
||||||
category="conditioning/video_models",
|
category="conditioning/3d_models",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.ClipVisionOutput.Input("front", optional=True),
|
IO.ClipVisionOutput.Input("front", optional=True),
|
||||||
IO.ClipVisionOutput.Input("left", optional=True),
|
IO.ClipVisionOutput.Input("left", optional=True),
|
||||||
@ -424,6 +424,7 @@ class VoxelToMeshBasic(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="VoxelToMeshBasic",
|
node_id="VoxelToMeshBasic",
|
||||||
|
display_name="Voxel to Mesh (Basic)",
|
||||||
category="3d",
|
category="3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Voxel.Input("voxel"),
|
IO.Voxel.Input("voxel"),
|
||||||
@ -453,6 +454,7 @@ class VoxelToMesh(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="VoxelToMesh",
|
node_id="VoxelToMesh",
|
||||||
|
display_name="Voxel to Mesh",
|
||||||
category="3d",
|
category="3d",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Voxel.Input("voxel"),
|
IO.Voxel.Input("voxel"),
|
||||||
|
|||||||
@ -102,6 +102,7 @@ class HypernetworkLoader(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="HypernetworkLoader",
|
node_id="HypernetworkLoader",
|
||||||
|
display_name="Load Hypernetwork",
|
||||||
category="loaders",
|
category="loaders",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Model.Input("model"),
|
IO.Model.Input("model"),
|
||||||
|
|||||||
@ -91,7 +91,7 @@ class LoraSave(io.ComfyNode):
|
|||||||
node_id="LoraSave",
|
node_id="LoraSave",
|
||||||
search_aliases=["export lora"],
|
search_aliases=["export lora"],
|
||||||
display_name="Extract and Save Lora",
|
display_name="Extract and Save Lora",
|
||||||
category="_for_testing",
|
category="experimental",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
|
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
|
||||||
io.Int.Input("rank", default=8, min=1, max=4096, step=1, advanced=True),
|
io.Int.Input("rank", default=8, min=1, max=4096, step=1, advanced=True),
|
||||||
|
|||||||
@ -106,12 +106,12 @@ class LTXVImgToVideoInplace(io.ComfyNode):
|
|||||||
if bypass:
|
if bypass:
|
||||||
return (latent,)
|
return (latent,)
|
||||||
|
|
||||||
samples = latent["samples"]
|
samples = latent["samples"].clone()
|
||||||
_, height_scale_factor, width_scale_factor = (
|
_, height_scale_factor, width_scale_factor = (
|
||||||
vae.downscale_index_formula
|
vae.downscale_index_formula
|
||||||
)
|
)
|
||||||
|
|
||||||
batch, _, latent_frames, latent_height, latent_width = samples.shape
|
_, _, _, latent_height, latent_width = samples.shape
|
||||||
width = latent_width * width_scale_factor
|
width = latent_width * width_scale_factor
|
||||||
height = latent_height * height_scale_factor
|
height = latent_height * height_scale_factor
|
||||||
|
|
||||||
@ -124,11 +124,7 @@ class LTXVImgToVideoInplace(io.ComfyNode):
|
|||||||
|
|
||||||
samples[:, :, :t.shape[2]] = t
|
samples[:, :, :t.shape[2]] = t
|
||||||
|
|
||||||
conditioning_latent_frames_mask = torch.ones(
|
conditioning_latent_frames_mask = get_noise_mask(latent)
|
||||||
(batch, 1, latent_frames, 1, 1),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=samples.device,
|
|
||||||
)
|
|
||||||
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
|
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
|
||||||
|
|
||||||
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
|
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
|
||||||
@ -236,7 +232,7 @@ class LTXVAddGuide(io.ComfyNode):
|
|||||||
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
|
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
|
||||||
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
|
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
|
||||||
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
|
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
|
||||||
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
|
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1)
|
||||||
encode_pixels = pixels[:, :, :, :3]
|
encode_pixels = pixels[:, :, :, :3]
|
||||||
t = vae.encode(encode_pixels)
|
t = vae.encode(encode_pixels)
|
||||||
return encode_pixels, t
|
return encode_pixels, t
|
||||||
@ -594,7 +590,8 @@ class LTXVPreprocess(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="LTXVPreprocess",
|
node_id="LTXVPreprocess",
|
||||||
category="image",
|
display_name="LTXV Preprocess",
|
||||||
|
category="video/preprocessors",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
io.Int.Input(
|
io.Int.Input(
|
||||||
|
|||||||
@ -11,7 +11,7 @@ class Mahiro(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="Mahiro",
|
node_id="Mahiro",
|
||||||
display_name="Positive-Biased Guidance",
|
display_name="Positive-Biased Guidance",
|
||||||
category="_for_testing",
|
category="experimental",
|
||||||
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
|
|||||||
@ -40,10 +40,21 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
|
|||||||
|
|
||||||
inverse_mask = torch.ones_like(mask) - mask
|
inverse_mask = torch.ones_like(mask) - mask
|
||||||
|
|
||||||
source_portion = mask * source[..., :visible_height, :visible_width]
|
source_rgb = source[:, :3, :visible_height, :visible_width]
|
||||||
destination_portion = inverse_mask * destination[..., top:bottom, left:right]
|
dest_slice = destination[..., top:bottom, left:right]
|
||||||
|
|
||||||
|
if destination.shape[1] == 4:
|
||||||
|
if torch.max(dest_slice) == 0:
|
||||||
|
destination[:, :3, top:bottom, left:right] = source_rgb
|
||||||
|
destination[:, 3:4, top:bottom, left:right] = mask
|
||||||
|
else:
|
||||||
|
destination[:, :3, top:bottom, left:right] = (mask * source_rgb) + (inverse_mask * dest_slice[:, :3])
|
||||||
|
destination[:, 3:4, top:bottom, left:right] = torch.max(mask, dest_slice[:, 3:4])
|
||||||
|
else:
|
||||||
|
source_portion = mask * source_rgb
|
||||||
|
destination_portion = inverse_mask * dest_slice
|
||||||
|
destination[..., top:bottom, left:right] = source_portion + destination_portion
|
||||||
|
|
||||||
destination[..., top:bottom, left:right] = source_portion + destination_portion
|
|
||||||
return destination
|
return destination
|
||||||
|
|
||||||
class LatentCompositeMasked(IO.ComfyNode):
|
class LatentCompositeMasked(IO.ComfyNode):
|
||||||
@ -84,18 +95,23 @@ class ImageCompositeMasked(IO.ComfyNode):
|
|||||||
display_name="Image Composite Masked",
|
display_name="Image Composite Masked",
|
||||||
category="image",
|
category="image",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("destination"),
|
|
||||||
IO.Image.Input("source"),
|
IO.Image.Input("source"),
|
||||||
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
|
||||||
IO.Boolean.Input("resize_source", default=False),
|
IO.Boolean.Input("resize_source", default=False),
|
||||||
|
IO.Image.Input("destination", optional=True),
|
||||||
IO.Mask.Input("mask", optional=True),
|
IO.Mask.Input("mask", optional=True),
|
||||||
],
|
],
|
||||||
outputs=[IO.Image.Output()],
|
outputs=[IO.Image.Output()],
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
|
def execute(cls, source, x, y, resize_source, destination = None, mask = None) -> IO.NodeOutput:
|
||||||
|
if destination is None: # transparent rgba
|
||||||
|
B, H, W, C = source.shape
|
||||||
|
destination = torch.zeros((B, H, W, 4), dtype=source.dtype, device=source.device)
|
||||||
|
if C == 3:
|
||||||
|
source = torch.nn.functional.pad(source, (0, 1), value=1.0)
|
||||||
destination, source = node_helpers.image_alpha_fix(destination, source)
|
destination, source = node_helpers.image_alpha_fix(destination, source)
|
||||||
destination = destination.clone().movedim(-1, 1)
|
destination = destination.clone().movedim(-1, 1)
|
||||||
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
|
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
|
||||||
@ -381,7 +397,6 @@ class GrowMask(IO.ComfyNode):
|
|||||||
|
|
||||||
expand_mask = execute # TODO: remove
|
expand_mask = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
class ThresholdMask(IO.ComfyNode):
|
class ThresholdMask(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
|
|||||||
@ -70,7 +70,7 @@ class MathExpressionNode(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="ComfyMathExpression",
|
node_id="ComfyMathExpression",
|
||||||
display_name="Math Expression",
|
display_name="Math Expression",
|
||||||
category="math",
|
category="logic",
|
||||||
search_aliases=[
|
search_aliases=[
|
||||||
"expression", "formula", "calculate", "calculator",
|
"expression", "formula", "calculate", "calculator",
|
||||||
"eval", "math",
|
"eval", "math",
|
||||||
|
|||||||
@ -21,7 +21,7 @@ class NumberConvertNode(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="ComfyNumberConvert",
|
node_id="ComfyNumberConvert",
|
||||||
display_name="Number Convert",
|
display_name="Number Convert",
|
||||||
category="math",
|
category="utils",
|
||||||
search_aliases=[
|
search_aliases=[
|
||||||
"int to float", "float to int", "number convert",
|
"int to float", "float to int", "number convert",
|
||||||
"int2float", "float2int", "cast", "parse number",
|
"int2float", "float2int", "cast", "parse number",
|
||||||
|
|||||||
@ -24,8 +24,8 @@ class PerpNeg(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="PerpNeg",
|
node_id="PerpNeg",
|
||||||
display_name="Perp-Neg (DEPRECATED by PerpNegGuider)",
|
display_name="Perp-Neg (DEPRECATED by Perp-Neg Guider)",
|
||||||
category="_for_testing",
|
category="experimental",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.Conditioning.Input("empty_conditioning"),
|
io.Conditioning.Input("empty_conditioning"),
|
||||||
@ -127,7 +127,8 @@ class PerpNegGuider(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="PerpNegGuider",
|
node_id="PerpNegGuider",
|
||||||
category="_for_testing",
|
display_name="Perp-Neg Guider",
|
||||||
|
category="experimental",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.Conditioning.Input("positive"),
|
io.Conditioning.Input("positive"),
|
||||||
|
|||||||
@ -123,7 +123,7 @@ class PhotoMakerLoader(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="PhotoMakerLoader",
|
node_id="PhotoMakerLoader",
|
||||||
category="_for_testing/photomaker",
|
category="experimental/photomaker",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
|
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
|
||||||
],
|
],
|
||||||
@ -149,7 +149,7 @@ class PhotoMakerEncode(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="PhotoMakerEncode",
|
node_id="PhotoMakerEncode",
|
||||||
category="_for_testing/photomaker",
|
category="experimental/photomaker",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Photomaker.Input("photomaker"),
|
io.Photomaker.Input("photomaker"),
|
||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
|
|||||||
@ -116,6 +116,7 @@ class Quantize(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="ImageQuantize",
|
node_id="ImageQuantize",
|
||||||
|
display_name="Quantize Image",
|
||||||
category="image/postprocessing",
|
category="image/postprocessing",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
@ -181,6 +182,7 @@ class Sharpen(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="ImageSharpen",
|
node_id="ImageSharpen",
|
||||||
|
display_name="Sharpen Image",
|
||||||
category="image/postprocessing",
|
category="image/postprocessing",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
@ -436,7 +438,7 @@ class ResizeImageMaskNode(io.ComfyNode):
|
|||||||
node_id="ResizeImageMaskNode",
|
node_id="ResizeImageMaskNode",
|
||||||
display_name="Resize Image/Mask",
|
display_name="Resize Image/Mask",
|
||||||
description="Resize an image or mask using various scaling methods.",
|
description="Resize an image or mask using various scaling methods.",
|
||||||
category="transform",
|
category="image/transform",
|
||||||
search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"],
|
search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"],
|
||||||
inputs=[
|
inputs=[
|
||||||
io.MatchType.Input("input", template=template),
|
io.MatchType.Input("input", template=template),
|
||||||
|
|||||||
@ -15,7 +15,7 @@ class RTDETR_detect(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="RTDETR_detect",
|
node_id="RTDETR_detect",
|
||||||
display_name="RT-DETR Detect",
|
display_name="RT-DETR Detect",
|
||||||
category="detection/",
|
category="detection",
|
||||||
search_aliases=["bbox", "bounding box", "object detection", "coco"],
|
search_aliases=["bbox", "bounding box", "object detection", "coco"],
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model", display_name="model"),
|
io.Model.Input("model", display_name="model"),
|
||||||
@ -71,7 +71,7 @@ class DrawBBoxes(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="DrawBBoxes",
|
node_id="DrawBBoxes",
|
||||||
display_name="Draw BBoxes",
|
display_name="Draw BBoxes",
|
||||||
category="detection/",
|
category="detection",
|
||||||
search_aliases=["bbox", "bounding box", "object detection", "rt_detr", "visualize detections", "coco"],
|
search_aliases=["bbox", "bounding box", "object detection", "rt_detr", "visualize detections", "coco"],
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("image", optional=True),
|
io.Image.Input("image", optional=True),
|
||||||
|
|||||||
@ -113,7 +113,7 @@ class SelfAttentionGuidance(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="SelfAttentionGuidance",
|
node_id="SelfAttentionGuidance",
|
||||||
display_name="Self-Attention Guidance",
|
display_name="Self-Attention Guidance",
|
||||||
category="_for_testing",
|
category="experimental",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01),
|
io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01),
|
||||||
|
|||||||
@ -93,7 +93,7 @@ class SAM3_Detect(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="SAM3_Detect",
|
node_id="SAM3_Detect",
|
||||||
display_name="SAM3 Detect",
|
display_name="SAM3 Detect",
|
||||||
category="detection/",
|
category="detection",
|
||||||
search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"],
|
search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"],
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model", display_name="model"),
|
io.Model.Input("model", display_name="model"),
|
||||||
@ -265,7 +265,7 @@ class SAM3_VideoTrack(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="SAM3_VideoTrack",
|
node_id="SAM3_VideoTrack",
|
||||||
display_name="SAM3 Video Track",
|
display_name="SAM3 Video Track",
|
||||||
category="detection/",
|
category="detection",
|
||||||
search_aliases=["sam3", "video", "track", "propagate"],
|
search_aliases=["sam3", "video", "track", "propagate"],
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"),
|
io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"),
|
||||||
@ -320,7 +320,7 @@ class SAM3_TrackPreview(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="SAM3_TrackPreview",
|
node_id="SAM3_TrackPreview",
|
||||||
display_name="SAM3 Track Preview",
|
display_name="SAM3 Track Preview",
|
||||||
category="detection/",
|
category="detection",
|
||||||
inputs=[
|
inputs=[
|
||||||
SAM3TrackData.Input("track_data", display_name="track_data"),
|
SAM3TrackData.Input("track_data", display_name="track_data"),
|
||||||
io.Image.Input("images", display_name="images", optional=True),
|
io.Image.Input("images", display_name="images", optional=True),
|
||||||
@ -478,7 +478,7 @@ class SAM3_TrackToMask(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="SAM3_TrackToMask",
|
node_id="SAM3_TrackToMask",
|
||||||
display_name="SAM3 Track to Mask",
|
display_name="SAM3 Track to Mask",
|
||||||
category="detection/",
|
category="detection",
|
||||||
inputs=[
|
inputs=[
|
||||||
SAM3TrackData.Input("track_data", display_name="track_data"),
|
SAM3TrackData.Input("track_data", display_name="track_data"),
|
||||||
io.String.Input("object_indices", display_name="object_indices", default="",
|
io.String.Input("object_indices", display_name="object_indices", default="",
|
||||||
|
|||||||
@ -119,7 +119,7 @@ class StableCascade_SuperResolutionControlnet(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="StableCascade_SuperResolutionControlnet",
|
node_id="StableCascade_SuperResolutionControlnet",
|
||||||
category="_for_testing/stable_cascade",
|
category="experimental/stable_cascade",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
|
|||||||
@ -26,7 +26,8 @@ class TextGenerate(io.ComfyNode):
|
|||||||
|
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="TextGenerate",
|
node_id="TextGenerate",
|
||||||
category="textgen",
|
display_name="Generate Text",
|
||||||
|
category="text",
|
||||||
search_aliases=["LLM", "gemma"],
|
search_aliases=["LLM", "gemma"],
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Clip.Input("clip"),
|
io.Clip.Input("clip"),
|
||||||
@ -157,6 +158,7 @@ class TextGenerateLTX2Prompt(TextGenerate):
|
|||||||
parent_schema = super().define_schema()
|
parent_schema = super().define_schema()
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="TextGenerateLTX2Prompt",
|
node_id="TextGenerateLTX2Prompt",
|
||||||
|
display_name="Generate LTX2 Prompt",
|
||||||
category=parent_schema.category,
|
category=parent_schema.category,
|
||||||
inputs=parent_schema.inputs,
|
inputs=parent_schema.inputs,
|
||||||
outputs=parent_schema.outputs,
|
outputs=parent_schema.outputs,
|
||||||
|
|||||||
@ -10,7 +10,7 @@ class TorchCompileModel(io.ComfyNode):
|
|||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="TorchCompileModel",
|
node_id="TorchCompileModel",
|
||||||
category="_for_testing",
|
category="experimental",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
io.Model.Input("model"),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
|
|||||||
@ -1361,7 +1361,7 @@ class SaveLoRA(io.ComfyNode):
|
|||||||
node_id="SaveLoRA",
|
node_id="SaveLoRA",
|
||||||
search_aliases=["export lora"],
|
search_aliases=["export lora"],
|
||||||
display_name="Save LoRA Weights",
|
display_name="Save LoRA Weights",
|
||||||
category="loaders",
|
category="advanced/model_merging",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
is_output_node=True,
|
is_output_node=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
|
|||||||
@ -15,7 +15,7 @@ class ImageOnlyCheckpointLoader:
|
|||||||
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
|
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
|
||||||
FUNCTION = "load_checkpoint"
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
CATEGORY = "loaders/video_models"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)
|
||||||
|
|||||||
@ -22,7 +22,7 @@ class SaveImageWebsocket:
|
|||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
CATEGORY = "api/image"
|
CATEGORY = "image"
|
||||||
|
|
||||||
def save_images(self, images):
|
def save_images(self, images):
|
||||||
pbar = comfy.utils.ProgressBar(images.shape[0])
|
pbar = comfy.utils.ProgressBar(images.shape[0])
|
||||||
@ -42,3 +42,7 @@ class SaveImageWebsocket:
|
|||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"SaveImageWebsocket": SaveImageWebsocket,
|
"SaveImageWebsocket": SaveImageWebsocket,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
"SaveImageWebsocket": "Save Image (Websocket)",
|
||||||
|
}
|
||||||
@ -52,6 +52,8 @@ folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patc
|
|||||||
|
|
||||||
folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions)
|
folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions)
|
||||||
|
|
||||||
|
folder_names_and_paths["background_removal"] = ([os.path.join(models_dir, "background_removal")], supported_pt_extensions)
|
||||||
|
|
||||||
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
|
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
|
||||||
|
|
||||||
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
|
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)
|
||||||
|
|||||||
15
nodes.py
15
nodes.py
@ -330,7 +330,7 @@ class VAEDecodeTiled:
|
|||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "decode"
|
FUNCTION = "decode"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "experimental"
|
||||||
|
|
||||||
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
|
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
|
||||||
if tile_size < overlap * 4:
|
if tile_size < overlap * 4:
|
||||||
@ -377,7 +377,7 @@ class VAEEncodeTiled:
|
|||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "experimental"
|
||||||
|
|
||||||
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
|
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
|
||||||
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
|
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||||
@ -493,7 +493,7 @@ class SaveLatent:
|
|||||||
|
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "experimental"
|
||||||
|
|
||||||
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
|
||||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||||
@ -538,7 +538,7 @@ class LoadLatent:
|
|||||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
|
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
|
||||||
return {"required": {"latent": [sorted(files), ]}, }
|
return {"required": {"latent": [sorted(files), ]}, }
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "experimental"
|
||||||
|
|
||||||
RETURN_TYPES = ("LATENT", )
|
RETURN_TYPES = ("LATENT", )
|
||||||
FUNCTION = "load"
|
FUNCTION = "load"
|
||||||
@ -1443,7 +1443,7 @@ class LatentBlend:
|
|||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "blend"
|
FUNCTION = "blend"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "experimental"
|
||||||
|
|
||||||
def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
|
def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
|
||||||
|
|
||||||
@ -2092,6 +2092,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"StyleModelLoader": "Load Style Model",
|
"StyleModelLoader": "Load Style Model",
|
||||||
"CLIPVisionLoader": "Load CLIP Vision",
|
"CLIPVisionLoader": "Load CLIP Vision",
|
||||||
"UNETLoader": "Load Diffusion Model",
|
"UNETLoader": "Load Diffusion Model",
|
||||||
|
"unCLIPCheckpointLoader": "Load unCLIP Checkpoint",
|
||||||
|
"GLIGENLoader": "Load GLIGEN Model",
|
||||||
# Conditioning
|
# Conditioning
|
||||||
"CLIPVisionEncode": "CLIP Vision Encode",
|
"CLIPVisionEncode": "CLIP Vision Encode",
|
||||||
"StyleModelApply": "Apply Style Model",
|
"StyleModelApply": "Apply Style Model",
|
||||||
@ -2140,7 +2142,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ImageSharpen": "Sharpen Image",
|
"ImageSharpen": "Sharpen Image",
|
||||||
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
||||||
"GetImageSize": "Get Image Size",
|
"GetImageSize": "Get Image Size",
|
||||||
# _for_testing
|
# experimental
|
||||||
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
"VAEDecodeTiled": "VAE Decode (Tiled)",
|
||||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||||
}
|
}
|
||||||
@ -2427,6 +2429,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_number_convert.py",
|
"nodes_number_convert.py",
|
||||||
"nodes_painter.py",
|
"nodes_painter.py",
|
||||||
"nodes_curve.py",
|
"nodes_curve.py",
|
||||||
|
"nodes_bg_removal.py",
|
||||||
"nodes_rtdetr.py",
|
"nodes_rtdetr.py",
|
||||||
"nodes_frame_interpolation.py",
|
"nodes_frame_interpolation.py",
|
||||||
"nodes_sam3.py",
|
"nodes_sam3.py",
|
||||||
|
|||||||
4712
openapi.yaml
4712
openapi.yaml
File diff suppressed because it is too large
Load Diff
90
tests-unit/app_test/node_replace_manager_test.py
Normal file
90
tests-unit/app_test/node_replace_manager_test.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
"""Tests for NodeReplaceManager registration behavior."""
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def NodeReplaceManager(monkeypatch):
|
||||||
|
"""Provide NodeReplaceManager with `nodes` stubbed.
|
||||||
|
|
||||||
|
`app.node_replace_manager` does `import nodes` at module level, which pulls in
|
||||||
|
torch + the full ComfyUI graph. register() doesn't actually need it, so we
|
||||||
|
stub `nodes` per-test (via monkeypatch so it's torn down) and reload the
|
||||||
|
module so it picks up the stub instead of any cached real import.
|
||||||
|
"""
|
||||||
|
fake_nodes = types.ModuleType("nodes")
|
||||||
|
fake_nodes.NODE_CLASS_MAPPINGS = {}
|
||||||
|
monkeypatch.setitem(sys.modules, "nodes", fake_nodes)
|
||||||
|
monkeypatch.delitem(sys.modules, "app.node_replace_manager", raising=False)
|
||||||
|
module = importlib.import_module("app.node_replace_manager")
|
||||||
|
yield module.NodeReplaceManager
|
||||||
|
# Drop the freshly-imported module so the next test (or a later real import
|
||||||
|
# of `nodes`) starts from a clean slate.
|
||||||
|
sys.modules.pop("app.node_replace_manager", None)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeNodeReplace:
|
||||||
|
"""Lightweight stand-in for comfy_api.latest._io.NodeReplace."""
|
||||||
|
def __init__(self, new_node_id, old_node_id, old_widget_ids=None,
|
||||||
|
input_mapping=None, output_mapping=None):
|
||||||
|
self.new_node_id = new_node_id
|
||||||
|
self.old_node_id = old_node_id
|
||||||
|
self.old_widget_ids = old_widget_ids
|
||||||
|
self.input_mapping = input_mapping
|
||||||
|
self.output_mapping = output_mapping
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_adds_replacement(NodeReplaceManager):
|
||||||
|
manager = NodeReplaceManager()
|
||||||
|
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
|
||||||
|
assert manager.has_replacement("OldNode")
|
||||||
|
assert len(manager.get_replacement("OldNode")) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_allows_multiple_alternatives_for_same_old_node(NodeReplaceManager):
|
||||||
|
"""Different new_node_ids for the same old_node_id should all be kept."""
|
||||||
|
manager = NodeReplaceManager()
|
||||||
|
manager.register(FakeNodeReplace(new_node_id="AltA", old_node_id="OldNode"))
|
||||||
|
manager.register(FakeNodeReplace(new_node_id="AltB", old_node_id="OldNode"))
|
||||||
|
replacements = manager.get_replacement("OldNode")
|
||||||
|
assert len(replacements) == 2
|
||||||
|
assert {r.new_node_id for r in replacements} == {"AltA", "AltB"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_is_idempotent_for_duplicate_pair(NodeReplaceManager):
|
||||||
|
"""Re-registering the same (old_node_id, new_node_id) should be a no-op."""
|
||||||
|
manager = NodeReplaceManager()
|
||||||
|
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
|
||||||
|
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
|
||||||
|
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
|
||||||
|
assert len(manager.get_replacement("OldNode")) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_idempotent_preserves_first_registration(NodeReplaceManager):
|
||||||
|
"""First registration wins; later duplicates with different mappings are ignored."""
|
||||||
|
manager = NodeReplaceManager()
|
||||||
|
first = FakeNodeReplace(
|
||||||
|
new_node_id="NewNode", old_node_id="OldNode",
|
||||||
|
input_mapping=[{"new_id": "a", "old_id": "x"}],
|
||||||
|
)
|
||||||
|
second = FakeNodeReplace(
|
||||||
|
new_node_id="NewNode", old_node_id="OldNode",
|
||||||
|
input_mapping=[{"new_id": "b", "old_id": "y"}],
|
||||||
|
)
|
||||||
|
manager.register(first)
|
||||||
|
manager.register(second)
|
||||||
|
replacements = manager.get_replacement("OldNode")
|
||||||
|
assert len(replacements) == 1
|
||||||
|
assert replacements[0] is first
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_dedupe_does_not_affect_other_old_nodes(NodeReplaceManager):
|
||||||
|
manager = NodeReplaceManager()
|
||||||
|
manager.register(FakeNodeReplace(new_node_id="NewA", old_node_id="OldA"))
|
||||||
|
manager.register(FakeNodeReplace(new_node_id="NewA", old_node_id="OldA"))
|
||||||
|
manager.register(FakeNodeReplace(new_node_id="NewB", old_node_id="OldB"))
|
||||||
|
assert len(manager.get_replacement("OldA")) == 1
|
||||||
|
assert len(manager.get_replacement("OldB")) == 1
|
||||||
@ -21,7 +21,7 @@ class TestAsyncProgressUpdate(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = (IO.ANY,)
|
RETURN_TYPES = (IO.ANY,)
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
async def execute(self, value, sleep_seconds):
|
async def execute(self, value, sleep_seconds):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@ -51,7 +51,7 @@ class TestSyncProgressUpdate(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = (IO.ANY,)
|
RETURN_TYPES = (IO.ANY,)
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
def execute(self, value, sleep_seconds):
|
def execute(self, value, sleep_seconds):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|||||||
@ -21,7 +21,7 @@ class TestAsyncValidation(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def VALIDATE_INPUTS(cls, value, threshold):
|
async def VALIDATE_INPUTS(cls, value, threshold):
|
||||||
@ -53,7 +53,7 @@ class TestAsyncError(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = (IO.ANY,)
|
RETURN_TYPES = (IO.ANY,)
|
||||||
FUNCTION = "error_execution"
|
FUNCTION = "error_execution"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
async def error_execution(self, value, error_after):
|
async def error_execution(self, value, error_after):
|
||||||
await asyncio.sleep(error_after)
|
await asyncio.sleep(error_after)
|
||||||
@ -74,7 +74,7 @@ class TestAsyncValidationError(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def VALIDATE_INPUTS(cls, value, max_value):
|
async def VALIDATE_INPUTS(cls, value, max_value):
|
||||||
@ -105,7 +105,7 @@ class TestAsyncTimeout(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = (IO.ANY,)
|
RETURN_TYPES = (IO.ANY,)
|
||||||
FUNCTION = "timeout_execution"
|
FUNCTION = "timeout_execution"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
async def timeout_execution(self, value, timeout, operation_time):
|
async def timeout_execution(self, value, timeout, operation_time):
|
||||||
try:
|
try:
|
||||||
@ -129,7 +129,7 @@ class TestSyncError(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = (IO.ANY,)
|
RETURN_TYPES = (IO.ANY,)
|
||||||
FUNCTION = "sync_error"
|
FUNCTION = "sync_error"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
def sync_error(self, value):
|
def sync_error(self, value):
|
||||||
raise RuntimeError("Intentional sync execution error for testing")
|
raise RuntimeError("Intentional sync execution error for testing")
|
||||||
@ -150,7 +150,7 @@ class TestAsyncLazyCheck(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
async def check_lazy_status(self, condition, input1, input2):
|
async def check_lazy_status(self, condition, input1, input2):
|
||||||
# Simulate async checking (e.g., querying remote service)
|
# Simulate async checking (e.g., querying remote service)
|
||||||
@ -184,7 +184,7 @@ class TestDynamicAsyncGeneration(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "generate_async_workflow"
|
FUNCTION = "generate_async_workflow"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
|
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
|
||||||
g = GraphBuilder()
|
g = GraphBuilder()
|
||||||
@ -229,7 +229,7 @@ class TestAsyncResourceUser(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = (IO.ANY,)
|
RETURN_TYPES = (IO.ANY,)
|
||||||
FUNCTION = "use_resource"
|
FUNCTION = "use_resource"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
async def use_resource(self, value, resource_id, duration):
|
async def use_resource(self, value, resource_id, duration):
|
||||||
# Check if resource is already in use
|
# Check if resource is already in use
|
||||||
@ -265,7 +265,7 @@ class TestAsyncBatchProcessing(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "process_batch"
|
FUNCTION = "process_batch"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
async def process_batch(self, images, process_time_per_item, unique_id):
|
async def process_batch(self, images, process_time_per_item, unique_id):
|
||||||
batch_size = images.shape[0]
|
batch_size = images.shape[0]
|
||||||
@ -305,7 +305,7 @@ class TestAsyncConcurrentLimit(ComfyNodeABC):
|
|||||||
|
|
||||||
RETURN_TYPES = (IO.ANY,)
|
RETURN_TYPES = (IO.ANY,)
|
||||||
FUNCTION = "limited_execution"
|
FUNCTION = "limited_execution"
|
||||||
CATEGORY = "_for_testing/async"
|
CATEGORY = "experimental/async"
|
||||||
|
|
||||||
async def limited_execution(self, value, duration, node_id):
|
async def limited_execution(self, value, duration, node_id):
|
||||||
async with self._semaphore:
|
async with self._semaphore:
|
||||||
|
|||||||
@ -409,7 +409,7 @@ class TestSleep(ComfyNodeABC):
|
|||||||
RETURN_TYPES = (IO.ANY,)
|
RETURN_TYPES = (IO.ANY,)
|
||||||
FUNCTION = "sleep"
|
FUNCTION = "sleep"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "experimental"
|
||||||
|
|
||||||
async def sleep(self, value, seconds, unique_id):
|
async def sleep(self, value, seconds, unique_id):
|
||||||
pbar = ProgressBar(seconds, node_id=unique_id)
|
pbar = ProgressBar(seconds, node_id=unique_id)
|
||||||
@ -440,7 +440,7 @@ class TestParallelSleep(ComfyNodeABC):
|
|||||||
}
|
}
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "parallel_sleep"
|
FUNCTION = "parallel_sleep"
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "experimental"
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
|
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
|
||||||
@ -474,7 +474,7 @@ class TestOutputNodeWithSocketOutput:
|
|||||||
}
|
}
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "experimental"
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
def process(self, image, value):
|
def process(self, image, value):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user