Compare commits

...

15 Commits

Author SHA1 Message Date
Jukka Seppänen
bdae19b47f
Merge f559a749e9 into 05cd076bc1 2026-05-08 16:54:28 +02:00
drozbay
05cd076bc1
fix: Make LTXVAddGuide center-crop guide images to match other LTXV nodes (#13794) 2026-05-08 22:48:59 +08:00
Yousef R. Gamaleldin
d3c18c1636
Add support for BiRefNet background remove model (CORE-46) (#12747) 2026-05-08 17:59:24 +08:00
omahs
bac6fc35fb
Fix typos (#10986) 2026-05-08 17:14:45 +08:00
Alexander Piskun
56c74094c7
[Partner Nodes] use "adaptive" aspect ratio for SD2 nodes (#13800)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Signed-off-by: bigcat88 <bigcat88@icloud.com>
2026-05-07 23:39:13 -07:00
Alexis Rolland
594de378fe
Update nodes categories and display names (CORE-89) (#13786) 2026-05-08 01:02:55 -04:00
Jedrzej Kosinski
c8673542f7
fix: make NodeReplaceManager.register() idempotent (#13596) 2026-05-07 19:21:12 -07:00
comfyanonymous
df7bf1d3dc
Update warning message for ComfyUI frontend installation. (#13796) 2026-05-07 19:04:30 -07:00
Talmaj
ef8f25601a
Add I2V for causal forcing model. (#13719) 2026-05-07 18:38:36 -07:00
kijai
f559a749e9 Merge remote-tracking branch 'upstream/master' into ltxv_self_attn_mask 2026-05-07 15:02:56 +03:00
kijai
989dea8c40 Allow strength above 1.0 2026-05-06 23:56:21 +03:00
kijai
848880c3d3 Merge remote-tracking branch 'upstream/master' into ltxv_self_attn_mask 2026-05-06 21:45:41 +03:00
kijai
6b97e3f4cb Only fall to pytorch attention from sage for guide mask 2026-05-06 21:31:49 +03:00
kijai
f2beaa5802 Reduce peak VRAM by handling self_attn_mask more efficiently 2026-05-06 21:08:15 +03:00
kijai
e6e3e6f628 Alternative self_attn_mask
Drastically lower memory use, different effect, for testing
2026-05-06 16:34:01 +03:00
53 changed files with 1279 additions and 171 deletions

View File

@ -27,7 +27,7 @@ def frontend_install_warning_message():
return f""" return f"""
{get_missing_requirements_message()} {get_missing_requirements_message()}
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead. The ComfyUI frontend is shipped in a pip package so it needs to be updated separately from the ComfyUI code.
""".strip() """.strip()
def parse_version(version: str) -> tuple[int, int, int]: def parse_version(version: str) -> tuple[int, int, int]:

View File

@ -1,5 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging
from aiohttp import web from aiohttp import web
from typing import TYPE_CHECKING, TypedDict from typing import TYPE_CHECKING, TypedDict
@ -31,8 +33,22 @@ class NodeReplaceManager:
self._replacements: dict[str, list[NodeReplace]] = {} self._replacements: dict[str, list[NodeReplace]] = {}
def register(self, node_replace: NodeReplace): def register(self, node_replace: NodeReplace):
"""Register a node replacement mapping.""" """Register a node replacement mapping.
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
Idempotent: if a replacement with the same (old_node_id, new_node_id)
is already registered, the duplicate is ignored. This prevents stale
entries from accumulating when custom nodes are reloaded in the same
process (e.g. via ComfyUI-Manager).
"""
existing = self._replacements.setdefault(node_replace.old_node_id, [])
for entry in existing:
if entry.new_node_id == node_replace.new_node_id:
logging.debug(
"Node replacement %s -> %s already registered, ignoring duplicate.",
node_replace.old_node_id, node_replace.new_node_id,
)
return
existing.append(node_replace)
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None: def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
"""Get replacements for an old node ID.""" """Get replacements for an old node ID."""

View File

@ -0,0 +1,7 @@
{
"model_type": "birefnet",
"image_std": [1.0, 1.0, 1.0],
"image_mean": [0.0, 0.0, 0.0],
"image_size": 1024,
"resize_to_original": true
}

View File

@ -0,0 +1,689 @@
import torch
import comfy.ops
import numpy as np
import torch.nn as nn
from functools import partial
import torch.nn.functional as F
from torchvision.ops import deform_conv2d
from comfy.ldm.modules.attention import optimized_attention_for_device
CXT = [3072, 1536, 768, 384][1:][::-1][-3:]
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
self.kv = operations.Linear(dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
def forward(self, x):
B, N, C = x.shape
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
x = optimized_attention(
q, k, v, heads=self.num_heads, skip_output_reshape=True, skip_reshape=True
).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, device=None, dtype=None, operations=None):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = operations.Linear(in_features, hidden_features, device=device, dtype=dtype)
self.act = nn.GELU()
self.fc2 = operations.Linear(hidden_features, out_features, device=device, dtype=dtype)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads, device=device, dtype=dtype))
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
return x
class SwinTransformerBlock(nn.Module):
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None,
norm_layer=nn.LayerNorm, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.norm1 = norm_layer(dim, device=device, dtype=dtype)
self.attn = WindowAttention(
dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, device=device, dtype=dtype, operations=operations)
self.norm2 = norm_layer(dim, device=device, dtype=dtype)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, device=device, dtype=dtype, operations=operations)
self.H = None
self.W = None
def forward(self, x, mask_matrix):
B, L, C = x.shape
H, W = self.H, self.W
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
x_windows = window_partition(shifted_x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
attn_windows = self.attn(x_windows, mask=attn_mask)
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class PatchMerging(nn.Module):
def __init__(self, dim, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.reduction = operations.Linear(4 * dim, 2 * dim, bias=False, device=device, dtype=dtype)
self.norm = operations.LayerNorm(4 * dim, device=device, dtype=dtype)
def forward(self, x, H, W):
B, L, C = x.shape
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(nn.Module):
def __init__(self,
dim,
depth,
num_heads,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
norm_layer=nn.LayerNorm,
downsample=None,
device=None, dtype=None, operations=None):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
norm_layer=norm_layer,
device=device, dtype=dtype, operations=operations)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, device=device, dtype=dtype, operations=operations)
else:
self.downsample = None
def forward(self, x, H, W):
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
for blk in self.blocks:
blk.H, blk.W = H, W
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
class PatchEmbed(nn.Module):
def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None, device=None, dtype=None, operations=None):
super().__init__()
patch_size = (patch_size, patch_size)
self.patch_size = patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype)
if norm_layer is not None:
self.norm = norm_layer(embed_dim, device=device, dtype=dtype)
else:
self.norm = None
def forward(self, x):
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x
class SwinTransformer(nn.Module):
def __init__(self,
pretrain_img_size=224,
patch_size=4,
in_channels=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
patch_norm=True,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
device=None, dtype=None, operations=None):
super().__init__()
norm_layer = partial(operations.LayerNorm, device=device, dtype=dtype)
self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.patch_embed = PatchEmbed(
patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
device=device, dtype=dtype, operations=operations,
norm_layer=norm_layer if self.patch_norm else None)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
device=device, dtype=dtype, operations=operations)
self.layers.append(layer)
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
self.num_features = num_features
for i_layer in out_indices:
layer = norm_layer(num_features[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
def forward(self, x):
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
outs = []
x = x.flatten(2).transpose(1, 2)
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)
return tuple(outs)
class DeformableConv2d(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False, device=None, dtype=None, operations=None):
super(DeformableConv2d, self).__init__()
kernel_size = kernel_size if type(kernel_size) is tuple else (kernel_size, kernel_size)
self.stride = stride if type(stride) is tuple else (stride, stride)
self.padding = padding
self.offset_conv = operations.Conv2d(in_channels,
2 * kernel_size[0] * kernel_size[1],
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=True, device=device, dtype=dtype)
self.modulator_conv = operations.Conv2d(in_channels,
1 * kernel_size[0] * kernel_size[1],
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=True, device=device, dtype=dtype)
self.regular_conv = operations.Conv2d(in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=bias, device=device, dtype=dtype)
def forward(self, x):
offset = self.offset_conv(x)
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
weight, bias, offload_info = comfy.ops.cast_bias_weight(self.regular_conv, x, offloadable=True)
x = deform_conv2d(
input=x,
offset=offset,
weight=weight,
bias=None,
padding=self.padding,
mask=modulator,
stride=self.stride,
)
comfy.ops.uncast_bias_weight(self.regular_conv, weight, bias, offload_info)
return x
class BasicDecBlk(nn.Module):
def __init__(self, in_channels=64, out_channels=64, inter_channels=64, device=None, dtype=None, operations=None):
super(BasicDecBlk, self).__init__()
inter_channels = 64
self.conv_in = operations.Conv2d(in_channels, inter_channels, 3, 1, padding=1, device=device, dtype=dtype)
self.relu_in = nn.ReLU(inplace=True)
self.dec_att = ASPPDeformable(in_channels=inter_channels, device=device, dtype=dtype, operations=operations)
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, padding=1, device=device, dtype=dtype)
self.bn_in = operations.BatchNorm2d(inter_channels, device=device, dtype=dtype)
self.bn_out = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
def forward(self, x):
x = self.conv_in(x)
x = self.bn_in(x)
x = self.relu_in(x)
x = self.dec_att(x)
x = self.conv_out(x)
x = self.bn_out(x)
return x
class BasicLatBlk(nn.Module):
def __init__(self, in_channels=64, out_channels=64, device=None, dtype=None, operations=None):
super(BasicLatBlk, self).__init__()
self.conv = operations.Conv2d(in_channels, out_channels, 1, 1, 0, device=device, dtype=dtype)
def forward(self, x):
x = self.conv(x)
return x
class _ASPPModuleDeformable(nn.Module):
def __init__(self, in_channels, planes, kernel_size, padding, device, dtype, operations):
super(_ASPPModuleDeformable, self).__init__()
self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size,
stride=1, padding=padding, bias=False, device=device, dtype=dtype, operations=operations)
self.bn = operations.BatchNorm2d(planes, device=device, dtype=dtype)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.atrous_conv(x)
x = self.bn(x)
return self.relu(x)
class ASPPDeformable(nn.Module):
def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7], device=None, dtype=None, operations=None):
super(ASPPDeformable, self).__init__()
self.down_scale = 1
if out_channels is None:
out_channels = in_channels
self.in_channelster = 256 // self.down_scale
self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0, device=device, dtype=dtype, operations=operations)
self.aspp_deforms = nn.ModuleList([
_ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2), device=device, dtype=dtype, operations=operations)
for conv_size in parallel_block_sizes
])
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
operations.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False, device=device, dtype=dtype),
operations.BatchNorm2d(self.in_channelster, device=device, dtype=dtype),
nn.ReLU(inplace=True))
self.conv1 = operations.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False, device=device, dtype=dtype)
self.bn1 = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x1 = self.aspp1(x)
x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
x5 = self.global_avg_pool(x)
x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x1, *x_aspp_deforms, x5), dim=1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
return x
class BiRefNet(nn.Module):
def __init__(self, config=None, dtype=None, device=None, operations=None):
super(BiRefNet, self).__init__()
self.bb = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12, device=device, dtype=dtype, operations=operations)
channels = [1536, 768, 384, 192]
channels = [c * 2 for c in channels]
self.cxt = channels[1:][::-1][-3:]
self.squeeze_module = nn.Sequential(*[
BasicDecBlk(channels[0]+sum(self.cxt), channels[0], device=device, dtype=dtype, operations=operations)
for _ in range(1)
])
self.decoder = Decoder(channels, device=device, dtype=dtype, operations=operations)
def forward_enc(self, x):
x1, x2, x3, x4 = self.bb(x)
B, C, H, W = x.shape
x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x4 = torch.cat(
(
*[
F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
][-len(CXT):],
x4
),
dim=1
)
return (x1, x2, x3, x4)
def forward_ori(self, x):
(x1, x2, x3, x4) = self.forward_enc(x)
x4 = self.squeeze_module(x4)
features = [x, x1, x2, x3, x4]
scaled_preds = self.decoder(features)
return scaled_preds
def forward(self, pixel_values, intermediate_output=None):
scaled_preds = self.forward_ori(pixel_values)
return scaled_preds
class Decoder(nn.Module):
def __init__(self, channels, device, dtype, operations):
super(Decoder, self).__init__()
# factory kwargs
fk = {"device":device, "dtype":dtype, "operations":operations}
DecoderBlock = partial(BasicDecBlk, **fk)
LateralBlock = partial(BasicLatBlk, **fk)
DBlock = partial(SimpleConvs, **fk)
self.split = True
N_dec_ipt = 64
ic = 64
ipt_cha_opt = 1
self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic)
self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[1])
self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[2])
self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt]), channels[3])
self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt]), channels[3]//2)
fk = {"device":device, "dtype":dtype}
self.conv_out1 = nn.Sequential(operations.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt]), 1, 1, 1, 0, **fk))
self.lateral_block4 = LateralBlock(channels[1], channels[1])
self.lateral_block3 = LateralBlock(channels[2], channels[2])
self.lateral_block2 = LateralBlock(channels[3], channels[3])
self.conv_ms_spvn_4 = operations.Conv2d(channels[1], 1, 1, 1, 0, **fk)
self.conv_ms_spvn_3 = operations.Conv2d(channels[2], 1, 1, 1, 0, **fk)
self.conv_ms_spvn_2 = operations.Conv2d(channels[3], 1, 1, 1, 0, **fk)
_N = 16
self.gdt_convs_4 = nn.Sequential(operations.Conv2d(channels[0] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
self.gdt_convs_3 = nn.Sequential(operations.Conv2d(channels[1] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
self.gdt_convs_2 = nn.Sequential(operations.Conv2d(channels[2] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
[setattr(self, f"gdt_convs_pred_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
[setattr(self, f"gdt_convs_attn_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
def get_patches_batch(self, x, p):
_size_h, _size_w = p.shape[2:]
patches_batch = []
for idx in range(x.shape[0]):
columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
patches_x = []
for column_x in columns_x:
patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
patch_sample = torch.cat(patches_x, dim=1)
patches_batch.append(patch_sample)
return torch.cat(patches_batch, dim=0)
def forward(self, features):
x, x1, x2, x3, x4 = features
patches_batch = self.get_patches_batch(x, x4) if self.split else x
x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
p4 = self.decoder_block4(x4)
p4_gdt = self.gdt_convs_4(p4)
gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
p4 = p4 * gdt_attn_4
_p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
_p3 = _p4 + self.lateral_block4(x3)
patches_batch = self.get_patches_batch(x, _p3) if self.split else x
_p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
p3 = self.decoder_block3(_p3)
p3_gdt = self.gdt_convs_3(p3)
gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
p3 = p3 * gdt_attn_3
_p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
_p2 = _p3 + self.lateral_block3(x2)
patches_batch = self.get_patches_batch(x, _p2) if self.split else x
_p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
p2 = self.decoder_block2(_p2)
p2_gdt = self.gdt_convs_2(p2)
gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
p2 = p2 * gdt_attn_2
_p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
_p1 = _p2 + self.lateral_block2(x1)
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
_p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
_p1 = self.decoder_block1(_p1)
_p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
_p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
p1_out = self.conv_out1(_p1)
return p1_out
class SimpleConvs(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, inter_channels=64, device=None, dtype=None, operations=None
) -> None:
super().__init__()
self.conv1 = operations.Conv2d(in_channels, inter_channels, 3, 1, 1, device=device, dtype=dtype)
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, 1, device=device, dtype=dtype)
def forward(self, x):
return self.conv_out(self.conv1(x))

78
comfy/bg_removal_model.py Normal file
View File

@ -0,0 +1,78 @@
from .utils import load_torch_file
import os
import json
import torch
import logging
import comfy.ops
import comfy.model_patcher
import comfy.model_management
import comfy.clip_model
import comfy.background_removal.birefnet
BG_REMOVAL_MODELS = {
"birefnet": comfy.background_removal.birefnet.BiRefNet
}
class BackgroundRemovalModel():
def __init__(self, json_config):
with open(json_config) as f:
config = json.load(f)
self.image_size = config.get("image_size", 1024)
self.image_mean = config.get("image_mean", [0.0, 0.0, 0.0])
self.image_std = config.get("image_std", [1.0, 1.0, 1.0])
self.model_type = config.get("model_type", "birefnet")
self.config = config.copy()
model_class = BG_REMOVAL_MODELS.get(self.model_type)
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image):
comfy.model_management.load_model_gpu(self.patcher)
H, W = image.shape[1], image.shape[2]
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False)
out = self.model(pixel_values=pixel_values)
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
mask = out.sigmoid()
if mask.ndim == 3:
mask = mask.unsqueeze(0)
if mask.shape[1] != 1:
mask = mask.movedim(-1, 1)
return mask
def load_background_removal_model(sd):
if "bb.layers.1.blocks.0.attn.relative_position_index" in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "background_removal"), "birefnet.json")
else:
return None
bg_model = BackgroundRemovalModel(json_config)
m, u = bg_model.load_sd(sd)
if len(m) > 0:
logging.warning("missing background removal: {}".format(m))
u = set(u)
keys = list(sd.keys())
for k in keys:
if k not in u:
sd.pop(k)
return bg_model
def load(ckpt_path):
sd = load_torch_file(ckpt_path)
return load_background_removal_model(sd)

View File

@ -93,7 +93,7 @@ class Hook:
self.hook_scope = hook_scope self.hook_scope = hook_scope
'''Scope of where this hook should apply in terms of the conds used in sampling run.''' '''Scope of where this hook should apply in terms of the conds used in sampling run.'''
self.custom_should_register = default_should_register self.custom_should_register = default_should_register
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register''' '''Can be overridden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
@property @property
def strength(self): def strength(self):

View File

@ -1859,6 +1859,23 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
output = torch.zeros_like(x) output = torch.zeros_like(x)
s_in = x.new_ones([x.shape[0]]) s_in = x.new_ones([x.shape[0]])
current_start_frame = 0 current_start_frame = 0
# I2V: seed KV cache with the initial image latent before the denoising loop
initial_latent = transformer_options.get("ar_config", {}).get("initial_latent", None)
if initial_latent is not None:
initial_latent = inner_model.process_latent_in(initial_latent).to(device=device, dtype=model_dtype)
n_init = initial_latent.shape[2]
output[:, :, :n_init] = initial_latent
ar_state = {"start_frame": 0, "kv_caches": kv_caches, "crossattn_caches": crossattn_caches}
transformer_options["ar_state"] = ar_state
zero_sigma = sigmas.new_zeros([1])
_ = model(initial_latent, zero_sigma * s_in, **extra_args)
current_start_frame = n_init
remaining = lat_t - n_init
num_blocks = -(-remaining // num_frame_per_block)
num_sigma_steps = len(sigmas) - 1 num_sigma_steps = len(sigmas) - 1
total_real_steps = num_blocks * num_sigma_steps total_real_steps = num_blocks * num_sigma_steps
step_count = 0 step_count = 0

View File

@ -22,26 +22,25 @@ class CompressedTimestep:
"""Store video timestep embeddings in compressed form using per-frame indexing.""" """Store video timestep embeddings in compressed form using per-frame indexing."""
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim') __slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
def __init__(self, tensor: torch.Tensor, patches_per_frame: int): def __init__(self, tensor: torch.Tensor, patches_per_frame: int, per_frame: bool = False):
""" """
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame tensor: [batch, num_tokens, feature_dim] (per-token, default) or
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression [batch, num_frames, feature_dim] (per_frame=True, already compressed).
patches_per_frame: spatial patches per frame; pass None to disable compression.
""" """
self.batch_size, num_tokens, self.feature_dim = tensor.shape self.batch_size, n, self.feature_dim = tensor.shape
if per_frame:
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
self.patches_per_frame = patches_per_frame self.patches_per_frame = patches_per_frame
self.num_frames = num_tokens // patches_per_frame self.num_frames = n
self.data = tensor
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame elif patches_per_frame is not None and n >= patches_per_frame and n % patches_per_frame == 0:
# All patches in a frame are identical, so we only keep the first one self.patches_per_frame = patches_per_frame
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim) self.num_frames = n // patches_per_frame
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim] # All patches in a frame are identical — keep only the first.
self.data = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)[:, :, 0, :].contiguous()
else: else:
# Not divisible or too small - store directly without compression
self.patches_per_frame = 1 self.patches_per_frame = 1
self.num_frames = num_tokens self.num_frames = n
self.data = tensor self.data = tensor
def expand(self): def expand(self):
@ -716,32 +715,35 @@ class LTXAVModel(LTXVModel):
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs): def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
"""Prepare timestep embeddings.""" """Prepare timestep embeddings."""
# TODO: some code reuse is needed here.
grid_mask = kwargs.get("grid_mask", None) grid_mask = kwargs.get("grid_mask", None)
if grid_mask is not None:
timestep = timestep[:, grid_mask]
timestep_scaled = timestep * self.timestep_scale_multiplier
v_timestep, v_embedded_timestep = self.adaln_single(
timestep_scaled.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
orig_shape = kwargs.get("orig_shape") orig_shape = kwargs.get("orig_shape")
has_spatial_mask = kwargs.get("has_spatial_mask", None) has_spatial_mask = kwargs.get("has_spatial_mask", None)
v_patches_per_frame = None v_patches_per_frame = None
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5: if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
v_patches_per_frame = orig_shape[3] * orig_shape[4] v_patches_per_frame = orig_shape[3] * orig_shape[4]
# Reshape to [batch_size, num_tokens, dim] and compress for storage # Used by compute_prompt_timestep and the audio cross-attention paths.
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame) timestep_scaled = (timestep[:, grid_mask] if grid_mask is not None else timestep) * self.timestep_scale_multiplier
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
# When patches in a frame share a timestep (no spatial mask), project one row per frame instead of one per token
per_frame_path = v_patches_per_frame is not None and (timestep.numel() // batch_size) % v_patches_per_frame == 0
if per_frame_path:
per_frame = timestep.reshape(batch_size, -1, v_patches_per_frame)[:, :, 0]
if grid_mask is not None:
# All-or-nothing per frame when has_spatial_mask=False.
per_frame = per_frame[:, grid_mask[::v_patches_per_frame]]
ts_input = per_frame * self.timestep_scale_multiplier
else:
ts_input = timestep_scaled
v_timestep, v_embedded_timestep = self.adaln_single(
ts_input.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
v_prompt_timestep = compute_prompt_timestep( v_prompt_timestep = compute_prompt_timestep(
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype

View File

@ -358,6 +358,63 @@ def apply_split_rotary_emb(input_tensor, cos, sin):
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
class GuideAttentionMask:
"""Holds the two per-group masks for LTXV guide self-attention.
_attention_with_guide_mask splits queries into noisy and tracked-guide
groups, so the largest mask is (1, 1, tracked_count, T).
"""
__slots__ = ("guide_start", "tracked_count", "noisy_mask", "tracked_mask")
def __init__(self, total_tokens, guide_start, tracked_count, tracked_weights):
device = tracked_weights.device
dtype = tracked_weights.dtype
finfo = torch.finfo(dtype)
pos = tracked_weights > 0
log_w = torch.full_like(tracked_weights, finfo.min)
log_w[pos] = torch.log(tracked_weights[pos].clamp(min=finfo.tiny))
self.guide_start = guide_start
self.tracked_count = tracked_count
self.noisy_mask = torch.zeros((1, 1, 1, total_tokens), device=device, dtype=dtype)
self.noisy_mask[:, :, :, guide_start:guide_start + tracked_count] = log_w.view(1, 1, 1, -1)
self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype)
self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1)
def to(self, *args, **kwargs):
new = GuideAttentionMask.__new__(GuideAttentionMask)
new.guide_start = self.guide_start
new.tracked_count = self.tracked_count
new.noisy_mask = self.noisy_mask.to(*args, **kwargs)
new.tracked_mask = self.tracked_mask.to(*args, **kwargs)
return new
def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options):
"""Apply the guide mask by partitioning Q into noisy and tracked-guide
groups, so each group needs only its own sub-mask. Avoids materializing
the (1,1,T,T) dense mask.
"""
guide_start = guide_mask.guide_start
tracked_end = guide_start + guide_mask.tracked_count
out = torch.empty_like(q)
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
attn_precision=attn_precision, transformer_options=transformer_options,
low_precision_attention=False, # sageattn mask support is unreliable
)
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention(
q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask,
attn_precision=attn_precision, transformer_options=transformer_options,
low_precision_attention=False,
)
return out
class CrossAttention(nn.Module): class CrossAttention(nn.Module):
def __init__( def __init__(
self, self,
@ -412,8 +469,10 @@ class CrossAttention(nn.Module):
if mask is None: if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
elif isinstance(mask, GuideAttentionMask):
out = _attention_with_guide_mask(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
else: else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options) out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, mask=mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
# Apply per-head gating if enabled # Apply per-head gating if enabled
if self.to_gate_logits is not None: if self.to_gate_logits is not None:
@ -1063,7 +1122,9 @@ class LTXVModel(LTXBaseModel):
additional_args["resolved_guide_entries"] = resolved_entries additional_args["resolved_guide_entries"] = resolved_entries
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :] keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
if keyframe_idxs.shape[2] > 0: # Guard for the case of no keyframes surviving
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
# Total surviving guide tokens (all guides) # Total surviving guide tokens (all guides)
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2] additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
@ -1099,12 +1160,12 @@ class LTXVModel(LTXBaseModel):
if not resolved_entries: if not resolved_entries:
return None return None
# Check if any attenuation is actually needed # strength != 1.0 means we want to either attenuate (< 1) or amplify (> 1) guide attention.
needs_attenuation = any( needs_mask = any(
e["strength"] < 1.0 or e.get("pixel_mask") is not None e["strength"] != 1.0 or e.get("pixel_mask") is not None
for e in resolved_entries for e in resolved_entries
) )
if not needs_attenuation: if not needs_mask:
return None return None
# Build per-guide-token weights for all tracked guide tokens. # Build per-guide-token weights for all tracked guide tokens.
@ -1159,16 +1220,11 @@ class LTXVModel(LTXBaseModel):
# Concatenate per-token weights for all tracked guides # Concatenate per-token weights for all tracked guides
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked) tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
# Check if any weight is actually < 1.0 (otherwise no attenuation needed) # Skip when every weight is exactly 1.0 (additive bias would be 0).
if (tracked_weights >= 1.0).all(): if (tracked_weights == 1.0).all():
return None return None
# Build the mask: guide tokens are at the end of the sequence. return GuideAttentionMask(total_tokens, guide_start, total_tracked, tracked_weights)
# Tracked guides come first (in order), untracked follow.
return self._build_self_attention_mask(
total_tokens, num_guide_tokens, total_tracked,
tracked_weights, guide_start, device, dtype,
)
@staticmethod @staticmethod
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat): def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
@ -1234,45 +1290,6 @@ class LTXVModel(LTXBaseModel):
return rearrange(latent_mask, "b 1 f h w -> b (f h w)") return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
@staticmethod
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
tracked_weights, guide_start, device, dtype):
"""Build a log-space additive self-attention bias mask.
Attenuates attention between noisy tokens and tracked guide tokens.
Untracked guide tokens (at the end of the guide portion) keep full attention.
Args:
total_tokens: Total sequence length.
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
tracked_count: Number of tracked guide tokens (first in the guide portion).
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
guide_start: Index where guide tokens begin in the sequence.
device: Target device.
dtype: Target dtype.
Returns:
(1, 1, total_tokens, total_tokens) additive bias mask.
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
"""
finfo = torch.finfo(dtype)
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
tracked_end = guide_start + tracked_count
# Convert weights to log-space bias
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
log_w = torch.full_like(w, finfo.min)
positive_mask = w > 0
if positive_mask.any():
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
# noisy → tracked guides: each noisy row gets the same per-guide weight
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
return mask
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs): def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
"""Process transformer blocks for LTXV.""" """Process transformer blocks for LTXV."""
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})

View File

@ -140,7 +140,7 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
alphas = alphacums[ddim_timesteps] alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502 # according to the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose: if verbose:
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')

View File

@ -562,6 +562,25 @@ class disable_weight_init:
else: else:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
running_mean = self.running_mean.to(device=input.device, dtype=weight.dtype) if self.running_mean is not None else None
running_var = self.running_var.to(device=input.device, dtype=weight.dtype) if self.running_var is not None else None
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, self.training, self.momentum, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp): class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
def reset_parameters(self): def reset_parameters(self):
return None return None
@ -749,6 +768,9 @@ class manual_cast(disable_weight_init):
class Conv3d(disable_weight_init.Conv3d): class Conv3d(disable_weight_init.Conv3d):
comfy_cast_weights = True comfy_cast_weights = True
class BatchNorm2d(disable_weight_init.BatchNorm2d):
comfy_cast_weights = True
class GroupNorm(disable_weight_init.GroupNorm): class GroupNorm(disable_weight_init.GroupNorm):
comfy_cast_weights = True comfy_cast_weights = True

View File

@ -17,6 +17,7 @@ if TYPE_CHECKING:
from spandrel import ImageModelDescriptor from spandrel import ImageModelDescriptor
from comfy.clip_vision import ClipVisionModel from comfy.clip_vision import ClipVisionModel
from comfy.clip_vision import Output as ClipVisionOutput_ from comfy.clip_vision import Output as ClipVisionOutput_
from comfy.bg_removal_model import BackgroundRemovalModel
from comfy.controlnet import ControlNet from comfy.controlnet import ControlNet
from comfy.hooks import HookGroup, HookKeyframeGroup from comfy.hooks import HookGroup, HookKeyframeGroup
from comfy.model_patcher import ModelPatcher from comfy.model_patcher import ModelPatcher
@ -614,6 +615,11 @@ class Model(ComfyTypeIO):
if TYPE_CHECKING: if TYPE_CHECKING:
Type = ModelPatcher Type = ModelPatcher
@comfytype(io_type="BACKGROUND_REMOVAL")
class BackgroundRemoval(ComfyTypeIO):
if TYPE_CHECKING:
Type = BackgroundRemovalModel
@comfytype(io_type="CLIP_VISION") @comfytype(io_type="CLIP_VISION")
class ClipVision(ComfyTypeIO): class ClipVision(ComfyTypeIO):
if TYPE_CHECKING: if TYPE_CHECKING:
@ -2257,6 +2263,7 @@ __all__ = [
"ModelPatch", "ModelPatch",
"ClipVision", "ClipVision",
"ClipVisionOutput", "ClipVisionOutput",
"BackgroundRemoval",
"AudioEncoder", "AudioEncoder",
"AudioEncoderOutput", "AudioEncoderOutput",
"StyleModel", "StyleModel",

View File

@ -1271,7 +1271,7 @@ PRICE_BADGE_VIDEO = IO.PriceBadge(
) )
def _seedance2_text_inputs(resolutions: list[str]): def _seedance2_text_inputs(resolutions: list[str], default_ratio: str = "16:9"):
return [ return [
IO.String.Input( IO.String.Input(
"prompt", "prompt",
@ -1287,6 +1287,7 @@ def _seedance2_text_inputs(resolutions: list[str]):
IO.Combo.Input( IO.Combo.Input(
"ratio", "ratio",
options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"], options=["16:9", "4:3", "1:1", "3:4", "9:16", "21:9", "adaptive"],
default=default_ratio,
tooltip="Aspect ratio of the output video.", tooltip="Aspect ratio of the output video.",
), ),
IO.Int.Input( IO.Int.Input(
@ -1420,8 +1421,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
"model", "model",
options=[ options=[
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_text_inputs(["480p", "720p", "1080p"])), IO.DynamicCombo.Option(
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_text_inputs(["480p", "720p"])), "Seedance 2.0",
_seedance2_text_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option(
"Seedance 2.0 Fast",
_seedance2_text_inputs(["480p", "720p"], default_ratio="adaptive"),
),
], ],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.", tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
), ),
@ -1588,9 +1595,9 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
def _seedance2_reference_inputs(resolutions: list[str]): def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16:9"):
return [ return [
*_seedance2_text_inputs(resolutions), *_seedance2_text_inputs(resolutions, default_ratio=default_ratio),
IO.Autogrow.Input( IO.Autogrow.Input(
"reference_images", "reference_images",
template=IO.Autogrow.TemplateNames( template=IO.Autogrow.TemplateNames(
@ -1668,8 +1675,14 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
"model", "model",
options=[ options=[
IO.DynamicCombo.Option("Seedance 2.0", _seedance2_reference_inputs(["480p", "720p", "1080p"])), IO.DynamicCombo.Option(
IO.DynamicCombo.Option("Seedance 2.0 Fast", _seedance2_reference_inputs(["480p", "720p"])), "Seedance 2.0",
_seedance2_reference_inputs(["480p", "720p", "1080p"], default_ratio="adaptive"),
),
IO.DynamicCombo.Option(
"Seedance 2.0 Fast",
_seedance2_reference_inputs(["480p", "720p"], default_ratio="adaptive"),
),
], ],
tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.", tooltip="Seedance 2.0 for maximum quality; Seedance 2.0 Fast for speed optimization.",
), ),

View File

@ -92,7 +92,7 @@ class SamplerEulerCFGpp(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="SamplerEulerCFGpp", node_id="SamplerEulerCFGpp",
display_name="SamplerEulerCFG++", display_name="SamplerEulerCFG++",
category="_for_testing", # "sampling/custom_sampling/samplers" category="experimental", # "sampling/custom_sampling/samplers"
inputs=[ inputs=[
io.Combo.Input("version", options=["regular", "alternative"], advanced=True), io.Combo.Input("version", options=["regular", "alternative"], advanced=True),
], ],

View File

@ -2,6 +2,7 @@
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.). ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors - EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
- SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop - SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop
- ARVideoI2V: image-to-video conditioning for AR models (seeds KV cache with start image)
""" """
import torch import torch
@ -9,6 +10,7 @@ from typing_extensions import override
import comfy.model_management import comfy.model_management
import comfy.samplers import comfy.samplers
import comfy.utils
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
@ -71,12 +73,62 @@ class SamplerARVideo(io.ComfyNode):
return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options)) return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options))
class ARVideoI2V(io.ComfyNode):
"""Image-to-video setup for AR video models (Causal Forcing, Self-Forcing).
VAE-encodes the start image and stores it in the model's transformer_options
so that sample_ar_video can seed the KV cache before denoising.
Uses the same T2V model checkpoint -- no separate I2V architecture needed.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="ARVideoI2V",
category="conditioning/video_models",
inputs=[
io.Model.Input("model"),
io.Vae.Input("vae"),
io.Image.Input("start_image"),
io.Int.Input("width", default=832, min=16, max=8192, step=16),
io.Int.Input("height", default=480, min=16, max=8192, step=16),
io.Int.Input("length", default=81, min=1, max=1024, step=4),
io.Int.Input("batch_size", default=1, min=1, max=64),
],
outputs=[
io.Model.Output(display_name="MODEL"),
io.Latent.Output(display_name="LATENT"),
],
)
@classmethod
def execute(cls, model, vae, start_image, width, height, length, batch_size) -> io.NodeOutput:
start_image = comfy.utils.common_upscale(
start_image[:1].movedim(-1, 1), width, height, "bilinear", "center"
).movedim(1, -1)
initial_latent = vae.encode(start_image[:, :, :, :3])
m = model.clone()
to = m.model_options.setdefault("transformer_options", {})
ar_cfg = to.setdefault("ar_config", {})
ar_cfg["initial_latent"] = initial_latent
lat_t = ((length - 1) // 4) + 1
latent = torch.zeros(
[batch_size, 16, lat_t, height // 8, width // 8],
device=comfy.model_management.intermediate_device(),
)
return io.NodeOutput(m, {"samples": latent})
class ARVideoExtension(ComfyExtension): class ARVideoExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [ return [
EmptyARVideoLatent, EmptyARVideoLatent,
SamplerARVideo, SamplerARVideo,
ARVideoI2V,
] ]

View File

@ -25,7 +25,7 @@ class UNetSelfAttentionMultiply(io.ComfyNode):
def define_schema(cls) -> io.Schema: def define_schema(cls) -> io.Schema:
return io.Schema( return io.Schema(
node_id="UNetSelfAttentionMultiply", node_id="UNetSelfAttentionMultiply",
category="_for_testing/attention_experiments", category="experimental/attention_experiments",
inputs=[ inputs=[
io.Model.Input("model"), io.Model.Input("model"),
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
@ -48,7 +48,7 @@ class UNetCrossAttentionMultiply(io.ComfyNode):
def define_schema(cls) -> io.Schema: def define_schema(cls) -> io.Schema:
return io.Schema( return io.Schema(
node_id="UNetCrossAttentionMultiply", node_id="UNetCrossAttentionMultiply",
category="_for_testing/attention_experiments", category="experimental/attention_experiments",
inputs=[ inputs=[
io.Model.Input("model"), io.Model.Input("model"),
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
@ -72,7 +72,7 @@ class CLIPAttentionMultiply(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="CLIPAttentionMultiply", node_id="CLIPAttentionMultiply",
search_aliases=["clip attention scale", "text encoder attention"], search_aliases=["clip attention scale", "text encoder attention"],
category="_for_testing/attention_experiments", category="experimental/attention_experiments",
inputs=[ inputs=[
io.Clip.Input("clip"), io.Clip.Input("clip"),
io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), io.Float.Input("q", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),
@ -106,7 +106,7 @@ class UNetTemporalAttentionMultiply(io.ComfyNode):
def define_schema(cls) -> io.Schema: def define_schema(cls) -> io.Schema:
return io.Schema( return io.Schema(
node_id="UNetTemporalAttentionMultiply", node_id="UNetTemporalAttentionMultiply",
category="_for_testing/attention_experiments", category="experimental/attention_experiments",
inputs=[ inputs=[
io.Model.Input("model"), io.Model.Input("model"),
io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True), io.Float.Input("self_structural", default=1.0, min=0.0, max=10.0, step=0.01, advanced=True),

View File

@ -10,6 +10,7 @@ class AudioEncoderLoader(io.ComfyNode):
def define_schema(cls) -> io.Schema: def define_schema(cls) -> io.Schema:
return io.Schema( return io.Schema(
node_id="AudioEncoderLoader", node_id="AudioEncoderLoader",
display_name="Load Audio Encoder",
category="loaders", category="loaders",
inputs=[ inputs=[
io.Combo.Input( io.Combo.Input(

View File

@ -0,0 +1,60 @@
import folder_paths
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy.bg_removal_model import load
class LoadBackgroundRemovalModel(IO.ComfyNode):
@classmethod
def define_schema(cls):
files = folder_paths.get_filename_list("background_removal")
return IO.Schema(
node_id="LoadBackgroundRemovalModel",
display_name="Load Background Removal Model",
category="loaders",
inputs=[
IO.Combo.Input("bg_removal_name", options=sorted(files), tooltip="The model used to remove backgrounds from images"),
],
outputs=[
IO.BackgroundRemoval.Output("bg_model")
]
)
@classmethod
def execute(cls, bg_removal_name):
path = folder_paths.get_full_path_or_raise("background_removal", bg_removal_name)
bg = load(path)
if bg is None:
raise RuntimeError("ERROR: background model file is invalid and does not contain a valid background removal model.")
return IO.NodeOutput(bg)
class RemoveBackground(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RemoveBackground",
display_name="Remove Background",
category="image/background removal",
inputs=[
IO.Image.Input("image", tooltip="Input image to remove the background from"),
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask")
],
outputs=[
IO.Mask.Output("mask", tooltip="Generated foreground mask")
]
)
@classmethod
def execute(cls, image, bg_removal_model):
mask = bg_removal_model.encode_image(image)
return IO.NodeOutput(mask)
class BackgroundRemovalExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
LoadBackgroundRemovalModel,
RemoveBackground
]
async def comfy_entrypoint() -> BackgroundRemovalExtension:
return BackgroundRemovalExtension()

View File

@ -153,7 +153,7 @@ class WanCameraEmbedding(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="WanCameraEmbedding", node_id="WanCameraEmbedding",
category="camera", category="conditioning/video_models",
inputs=[ inputs=[
io.Combo.Input( io.Combo.Input(
"camera_pose", "camera_pose",

View File

@ -8,7 +8,7 @@ class CLIPTextEncodeControlnet(io.ComfyNode):
def define_schema(cls) -> io.Schema: def define_schema(cls) -> io.Schema:
return io.Schema( return io.Schema(
node_id="CLIPTextEncodeControlnet", node_id="CLIPTextEncodeControlnet",
category="_for_testing/conditioning", category="experimental/conditioning",
inputs=[ inputs=[
io.Clip.Input("clip"), io.Clip.Input("clip"),
io.Conditioning.Input("conditioning"), io.Conditioning.Input("conditioning"),
@ -35,7 +35,7 @@ class T5TokenizerOptions(io.ComfyNode):
def define_schema(cls) -> io.Schema: def define_schema(cls) -> io.Schema:
return io.Schema( return io.Schema(
node_id="T5TokenizerOptions", node_id="T5TokenizerOptions",
category="_for_testing/conditioning", category="experimental/conditioning",
inputs=[ inputs=[
io.Clip.Input("clip"), io.Clip.Input("clip"),
io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True), io.Int.Input("min_padding", default=0, min=0, max=10000, step=1, advanced=True),

View File

@ -10,7 +10,7 @@ class ContextWindowsManualNode(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="ContextWindowsManual", node_id="ContextWindowsManual",
display_name="Context Windows (Manual)", display_name="Context Windows (Manual)",
category="context", category="model_patches",
description="Manually set context windows.", description="Manually set context windows.",
inputs=[ inputs=[
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),

View File

@ -984,7 +984,7 @@ class AddNoise(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="AddNoise", node_id="AddNoise",
category="_for_testing/custom_sampling/noise", category="experimental/custom_sampling/noise",
is_experimental=True, is_experimental=True,
inputs=[ inputs=[
io.Model.Input("model"), io.Model.Input("model"),
@ -1034,7 +1034,7 @@ class ManualSigmas(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="ManualSigmas", node_id="ManualSigmas",
search_aliases=["custom noise schedule", "define sigmas"], search_aliases=["custom noise schedule", "define sigmas"],
category="_for_testing/custom_sampling", category="experimental/custom_sampling",
is_experimental=True, is_experimental=True,
inputs=[ inputs=[
io.String.Input("sigmas", default="1, 0.5", multiline=False) io.String.Input("sigmas", default="1, 0.5", multiline=False)

View File

@ -13,7 +13,7 @@ class DifferentialDiffusion(io.ComfyNode):
node_id="DifferentialDiffusion", node_id="DifferentialDiffusion",
search_aliases=["inpaint gradient", "variable denoise strength"], search_aliases=["inpaint gradient", "variable denoise strength"],
display_name="Differential Diffusion", display_name="Differential Diffusion",
category="_for_testing", category="experimental",
inputs=[ inputs=[
io.Model.Input("model"), io.Model.Input("model"),
io.Float.Input( io.Float.Input(

View File

@ -102,7 +102,7 @@ class FluxDisableGuidance(io.ComfyNode):
append = execute # TODO: remove append = execute # TODO: remove
PREFERED_KONTEXT_RESOLUTIONS = [ PREFERRED_KONTEXT_RESOLUTIONS = [
(672, 1568), (672, 1568),
(688, 1504), (688, 1504),
(720, 1456), (720, 1456),
@ -143,7 +143,7 @@ class FluxKontextImageScale(io.ComfyNode):
width = image.shape[2] width = image.shape[2]
height = image.shape[1] height = image.shape[1]
aspect_ratio = width / height aspect_ratio = width / height
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS) _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS)
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1) image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
return io.NodeOutput(image) return io.NodeOutput(image)

View File

@ -60,7 +60,7 @@ class FreSca(io.ComfyNode):
node_id="FreSca", node_id="FreSca",
search_aliases=["frequency guidance"], search_aliases=["frequency guidance"],
display_name="FreSca", display_name="FreSca",
category="_for_testing", category="experimental",
description="Applies frequency-dependent scaling to the guidance", description="Applies frequency-dependent scaling to the guidance",
inputs=[ inputs=[
io.Model.Input("model"), io.Model.Input("model"),

View File

@ -131,6 +131,8 @@ class HunyuanVideo15SuperResolution(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="HunyuanVideo15SuperResolution", node_id="HunyuanVideo15SuperResolution",
display_name="Hunyuan Video 1.5 Super Resolution",
category="conditioning/video_models",
inputs=[ inputs=[
io.Conditioning.Input("positive"), io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"), io.Conditioning.Input("negative"),
@ -381,6 +383,8 @@ class HunyuanRefinerLatent(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="HunyuanRefinerLatent", node_id="HunyuanRefinerLatent",
display_name="Hunyuan Latent Refiner",
category="conditioning/video_models",
inputs=[ inputs=[
io.Conditioning.Input("positive"), io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"), io.Conditioning.Input("negative"),

View File

@ -40,7 +40,7 @@ class Hunyuan3Dv2Conditioning(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="Hunyuan3Dv2Conditioning", node_id="Hunyuan3Dv2Conditioning",
category="conditioning/video_models", category="conditioning/3d_models",
inputs=[ inputs=[
IO.ClipVisionOutput.Input("clip_vision_output"), IO.ClipVisionOutput.Input("clip_vision_output"),
], ],
@ -65,7 +65,7 @@ class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="Hunyuan3Dv2ConditioningMultiView", node_id="Hunyuan3Dv2ConditioningMultiView",
category="conditioning/video_models", category="conditioning/3d_models",
inputs=[ inputs=[
IO.ClipVisionOutput.Input("front", optional=True), IO.ClipVisionOutput.Input("front", optional=True),
IO.ClipVisionOutput.Input("left", optional=True), IO.ClipVisionOutput.Input("left", optional=True),
@ -424,6 +424,7 @@ class VoxelToMeshBasic(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="VoxelToMeshBasic", node_id="VoxelToMeshBasic",
display_name="Voxel to Mesh (Basic)",
category="3d", category="3d",
inputs=[ inputs=[
IO.Voxel.Input("voxel"), IO.Voxel.Input("voxel"),
@ -453,6 +454,7 @@ class VoxelToMesh(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="VoxelToMesh", node_id="VoxelToMesh",
display_name="Voxel to Mesh",
category="3d", category="3d",
inputs=[ inputs=[
IO.Voxel.Input("voxel"), IO.Voxel.Input("voxel"),

View File

@ -102,6 +102,7 @@ class HypernetworkLoader(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="HypernetworkLoader", node_id="HypernetworkLoader",
display_name="Load Hypernetwork",
category="loaders", category="loaders",
inputs=[ inputs=[
IO.Model.Input("model"), IO.Model.Input("model"),

View File

@ -91,7 +91,7 @@ class LoraSave(io.ComfyNode):
node_id="LoraSave", node_id="LoraSave",
search_aliases=["export lora"], search_aliases=["export lora"],
display_name="Extract and Save Lora", display_name="Extract and Save Lora",
category="_for_testing", category="experimental",
inputs=[ inputs=[
io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"), io.String.Input("filename_prefix", default="loras/ComfyUI_extracted_lora"),
io.Int.Input("rank", default=8, min=1, max=4096, step=1, advanced=True), io.Int.Input("rank", default=8, min=1, max=4096, step=1, advanced=True),

View File

@ -223,7 +223,7 @@ class LTXVAddGuide(io.ComfyNode):
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded " "For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
"down to the nearest multiple of 8. Negative values are counted from the end of the video.", "down to the nearest multiple of 8. Negative values are counted from the end of the video.",
), ),
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01), io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
], ],
outputs=[ outputs=[
io.Conditioning.Output(display_name="positive"), io.Conditioning.Output(display_name="positive"),
@ -236,7 +236,7 @@ class LTXVAddGuide(io.ComfyNode):
def encode(cls, vae, latent_width, latent_height, images, scale_factors): def encode(cls, vae, latent_width, latent_height, images, scale_factors):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1] images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1) pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3] encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels) t = vae.encode(encode_pixels)
return encode_pixels, t return encode_pixels, t
@ -302,7 +302,7 @@ class LTXVAddGuide(io.ComfyNode):
else: else:
mask = torch.full( mask = torch.full(
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]), (noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
1.0 - strength, max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
dtype=noise_mask.dtype, dtype=noise_mask.dtype,
device=noise_mask.device, device=noise_mask.device,
) )
@ -322,7 +322,7 @@ class LTXVAddGuide(io.ComfyNode):
mask = torch.full( mask = torch.full(
(noise_mask.shape[0], 1, cond_length, 1, 1), (noise_mask.shape[0], 1, cond_length, 1, 1),
1.0 - strength, max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
dtype=noise_mask.dtype, dtype=noise_mask.dtype,
device=noise_mask.device, device=noise_mask.device,
) )
@ -594,7 +594,8 @@ class LTXVPreprocess(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="LTXVPreprocess", node_id="LTXVPreprocess",
category="image", display_name="LTXV Preprocess",
category="video/preprocessors",
inputs=[ inputs=[
io.Image.Input("image"), io.Image.Input("image"),
io.Int.Input( io.Int.Input(

View File

@ -11,7 +11,7 @@ class Mahiro(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="Mahiro", node_id="Mahiro",
display_name="Positive-Biased Guidance", display_name="Positive-Biased Guidance",
category="_for_testing", category="experimental",
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.", description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
inputs=[ inputs=[
io.Model.Input("model"), io.Model.Input("model"),

View File

@ -40,10 +40,21 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
inverse_mask = torch.ones_like(mask) - mask inverse_mask = torch.ones_like(mask) - mask
source_portion = mask * source[..., :visible_height, :visible_width] source_rgb = source[:, :3, :visible_height, :visible_width]
destination_portion = inverse_mask * destination[..., top:bottom, left:right] dest_slice = destination[..., top:bottom, left:right]
if destination.shape[1] == 4:
if torch.max(dest_slice) == 0:
destination[:, :3, top:bottom, left:right] = source_rgb
destination[:, 3:4, top:bottom, left:right] = mask
else:
destination[:, :3, top:bottom, left:right] = (mask * source_rgb) + (inverse_mask * dest_slice[:, :3])
destination[:, 3:4, top:bottom, left:right] = torch.max(mask, dest_slice[:, 3:4])
else:
source_portion = mask * source_rgb
destination_portion = inverse_mask * dest_slice
destination[..., top:bottom, left:right] = source_portion + destination_portion
destination[..., top:bottom, left:right] = source_portion + destination_portion
return destination return destination
class LatentCompositeMasked(IO.ComfyNode): class LatentCompositeMasked(IO.ComfyNode):
@ -84,18 +95,23 @@ class ImageCompositeMasked(IO.ComfyNode):
display_name="Image Composite Masked", display_name="Image Composite Masked",
category="image", category="image",
inputs=[ inputs=[
IO.Image.Input("destination"),
IO.Image.Input("source"), IO.Image.Input("source"),
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Boolean.Input("resize_source", default=False), IO.Boolean.Input("resize_source", default=False),
IO.Image.Input("destination", optional=True),
IO.Mask.Input("mask", optional=True), IO.Mask.Input("mask", optional=True),
], ],
outputs=[IO.Image.Output()], outputs=[IO.Image.Output()],
) )
@classmethod @classmethod
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput: def execute(cls, source, x, y, resize_source, destination = None, mask = None) -> IO.NodeOutput:
if destination is None: # transparent rgba
B, H, W, C = source.shape
destination = torch.zeros((B, H, W, 4), dtype=source.dtype, device=source.device)
if C == 3:
source = torch.nn.functional.pad(source, (0, 1), value=1.0)
destination, source = node_helpers.image_alpha_fix(destination, source) destination, source = node_helpers.image_alpha_fix(destination, source)
destination = destination.clone().movedim(-1, 1) destination = destination.clone().movedim(-1, 1)
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
@ -381,7 +397,6 @@ class GrowMask(IO.ComfyNode):
expand_mask = execute # TODO: remove expand_mask = execute # TODO: remove
class ThresholdMask(IO.ComfyNode): class ThresholdMask(IO.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):

View File

@ -70,7 +70,7 @@ class MathExpressionNode(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="ComfyMathExpression", node_id="ComfyMathExpression",
display_name="Math Expression", display_name="Math Expression",
category="math", category="logic",
search_aliases=[ search_aliases=[
"expression", "formula", "calculate", "calculator", "expression", "formula", "calculate", "calculator",
"eval", "math", "eval", "math",

View File

@ -21,7 +21,7 @@ class NumberConvertNode(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="ComfyNumberConvert", node_id="ComfyNumberConvert",
display_name="Number Convert", display_name="Number Convert",
category="math", category="utils",
search_aliases=[ search_aliases=[
"int to float", "float to int", "number convert", "int to float", "float to int", "number convert",
"int2float", "float2int", "cast", "parse number", "int2float", "float2int", "cast", "parse number",

View File

@ -24,8 +24,8 @@ class PerpNeg(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="PerpNeg", node_id="PerpNeg",
display_name="Perp-Neg (DEPRECATED by PerpNegGuider)", display_name="Perp-Neg (DEPRECATED by Perp-Neg Guider)",
category="_for_testing", category="experimental",
inputs=[ inputs=[
io.Model.Input("model"), io.Model.Input("model"),
io.Conditioning.Input("empty_conditioning"), io.Conditioning.Input("empty_conditioning"),
@ -127,7 +127,8 @@ class PerpNegGuider(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="PerpNegGuider", node_id="PerpNegGuider",
category="_for_testing", display_name="Perp-Neg Guider",
category="experimental",
inputs=[ inputs=[
io.Model.Input("model"), io.Model.Input("model"),
io.Conditioning.Input("positive"), io.Conditioning.Input("positive"),

View File

@ -123,7 +123,7 @@ class PhotoMakerLoader(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="PhotoMakerLoader", node_id="PhotoMakerLoader",
category="_for_testing/photomaker", category="experimental/photomaker",
inputs=[ inputs=[
io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")), io.Combo.Input("photomaker_model_name", options=folder_paths.get_filename_list("photomaker")),
], ],
@ -149,7 +149,7 @@ class PhotoMakerEncode(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="PhotoMakerEncode", node_id="PhotoMakerEncode",
category="_for_testing/photomaker", category="experimental/photomaker",
inputs=[ inputs=[
io.Photomaker.Input("photomaker"), io.Photomaker.Input("photomaker"),
io.Image.Input("image"), io.Image.Input("image"),

View File

@ -116,6 +116,7 @@ class Quantize(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="ImageQuantize", node_id="ImageQuantize",
display_name="Quantize Image",
category="image/postprocessing", category="image/postprocessing",
inputs=[ inputs=[
io.Image.Input("image"), io.Image.Input("image"),
@ -181,6 +182,7 @@ class Sharpen(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="ImageSharpen", node_id="ImageSharpen",
display_name="Sharpen Image",
category="image/postprocessing", category="image/postprocessing",
inputs=[ inputs=[
io.Image.Input("image"), io.Image.Input("image"),
@ -436,7 +438,7 @@ class ResizeImageMaskNode(io.ComfyNode):
node_id="ResizeImageMaskNode", node_id="ResizeImageMaskNode",
display_name="Resize Image/Mask", display_name="Resize Image/Mask",
description="Resize an image or mask using various scaling methods.", description="Resize an image or mask using various scaling methods.",
category="transform", category="image/transform",
search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"], search_aliases=["resize", "resize image", "resize mask", "scale", "scale image", "scale mask", "image resize", "change size", "dimensions", "shrink", "enlarge"],
inputs=[ inputs=[
io.MatchType.Input("input", template=template), io.MatchType.Input("input", template=template),

View File

@ -15,7 +15,7 @@ class RTDETR_detect(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="RTDETR_detect", node_id="RTDETR_detect",
display_name="RT-DETR Detect", display_name="RT-DETR Detect",
category="detection/", category="detection",
search_aliases=["bbox", "bounding box", "object detection", "coco"], search_aliases=["bbox", "bounding box", "object detection", "coco"],
inputs=[ inputs=[
io.Model.Input("model", display_name="model"), io.Model.Input("model", display_name="model"),
@ -71,7 +71,7 @@ class DrawBBoxes(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="DrawBBoxes", node_id="DrawBBoxes",
display_name="Draw BBoxes", display_name="Draw BBoxes",
category="detection/", category="detection",
search_aliases=["bbox", "bounding box", "object detection", "rt_detr", "visualize detections", "coco"], search_aliases=["bbox", "bounding box", "object detection", "rt_detr", "visualize detections", "coco"],
inputs=[ inputs=[
io.Image.Input("image", optional=True), io.Image.Input("image", optional=True),

View File

@ -113,7 +113,7 @@ class SelfAttentionGuidance(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="SelfAttentionGuidance", node_id="SelfAttentionGuidance",
display_name="Self-Attention Guidance", display_name="Self-Attention Guidance",
category="_for_testing", category="experimental",
inputs=[ inputs=[
io.Model.Input("model"), io.Model.Input("model"),
io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01), io.Float.Input("scale", default=0.5, min=-2.0, max=5.0, step=0.01),

View File

@ -93,7 +93,7 @@ class SAM3_Detect(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="SAM3_Detect", node_id="SAM3_Detect",
display_name="SAM3 Detect", display_name="SAM3 Detect",
category="detection/", category="detection",
search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"], search_aliases=["sam3", "segment anything", "open vocabulary", "text detection", "segment"],
inputs=[ inputs=[
io.Model.Input("model", display_name="model"), io.Model.Input("model", display_name="model"),
@ -265,7 +265,7 @@ class SAM3_VideoTrack(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="SAM3_VideoTrack", node_id="SAM3_VideoTrack",
display_name="SAM3 Video Track", display_name="SAM3 Video Track",
category="detection/", category="detection",
search_aliases=["sam3", "video", "track", "propagate"], search_aliases=["sam3", "video", "track", "propagate"],
inputs=[ inputs=[
io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"), io.Image.Input("images", display_name="images", tooltip="Video frames as batched images"),
@ -320,7 +320,7 @@ class SAM3_TrackPreview(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="SAM3_TrackPreview", node_id="SAM3_TrackPreview",
display_name="SAM3 Track Preview", display_name="SAM3 Track Preview",
category="detection/", category="detection",
inputs=[ inputs=[
SAM3TrackData.Input("track_data", display_name="track_data"), SAM3TrackData.Input("track_data", display_name="track_data"),
io.Image.Input("images", display_name="images", optional=True), io.Image.Input("images", display_name="images", optional=True),
@ -478,7 +478,7 @@ class SAM3_TrackToMask(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="SAM3_TrackToMask", node_id="SAM3_TrackToMask",
display_name="SAM3 Track to Mask", display_name="SAM3 Track to Mask",
category="detection/", category="detection",
inputs=[ inputs=[
SAM3TrackData.Input("track_data", display_name="track_data"), SAM3TrackData.Input("track_data", display_name="track_data"),
io.String.Input("object_indices", display_name="object_indices", default="", io.String.Input("object_indices", display_name="object_indices", default="",

View File

@ -119,7 +119,7 @@ class StableCascade_SuperResolutionControlnet(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="StableCascade_SuperResolutionControlnet", node_id="StableCascade_SuperResolutionControlnet",
category="_for_testing/stable_cascade", category="experimental/stable_cascade",
is_experimental=True, is_experimental=True,
inputs=[ inputs=[
io.Image.Input("image"), io.Image.Input("image"),

View File

@ -26,7 +26,8 @@ class TextGenerate(io.ComfyNode):
return io.Schema( return io.Schema(
node_id="TextGenerate", node_id="TextGenerate",
category="textgen", display_name="Generate Text",
category="text",
search_aliases=["LLM", "gemma"], search_aliases=["LLM", "gemma"],
inputs=[ inputs=[
io.Clip.Input("clip"), io.Clip.Input("clip"),
@ -157,6 +158,7 @@ class TextGenerateLTX2Prompt(TextGenerate):
parent_schema = super().define_schema() parent_schema = super().define_schema()
return io.Schema( return io.Schema(
node_id="TextGenerateLTX2Prompt", node_id="TextGenerateLTX2Prompt",
display_name="Generate LTX2 Prompt",
category=parent_schema.category, category=parent_schema.category,
inputs=parent_schema.inputs, inputs=parent_schema.inputs,
outputs=parent_schema.outputs, outputs=parent_schema.outputs,

View File

@ -10,7 +10,7 @@ class TorchCompileModel(io.ComfyNode):
def define_schema(cls) -> io.Schema: def define_schema(cls) -> io.Schema:
return io.Schema( return io.Schema(
node_id="TorchCompileModel", node_id="TorchCompileModel",
category="_for_testing", category="experimental",
inputs=[ inputs=[
io.Model.Input("model"), io.Model.Input("model"),
io.Combo.Input( io.Combo.Input(

View File

@ -1361,7 +1361,7 @@ class SaveLoRA(io.ComfyNode):
node_id="SaveLoRA", node_id="SaveLoRA",
search_aliases=["export lora"], search_aliases=["export lora"],
display_name="Save LoRA Weights", display_name="Save LoRA Weights",
category="loaders", category="advanced/model_merging",
is_experimental=True, is_experimental=True,
is_output_node=True, is_output_node=True,
inputs=[ inputs=[

View File

@ -15,7 +15,7 @@ class ImageOnlyCheckpointLoader:
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE") RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
FUNCTION = "load_checkpoint" FUNCTION = "load_checkpoint"
CATEGORY = "loaders/video_models" CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name)

View File

@ -22,7 +22,7 @@ class SaveImageWebsocket:
OUTPUT_NODE = True OUTPUT_NODE = True
CATEGORY = "api/image" CATEGORY = "image"
def save_images(self, images): def save_images(self, images):
pbar = comfy.utils.ProgressBar(images.shape[0]) pbar = comfy.utils.ProgressBar(images.shape[0])
@ -42,3 +42,7 @@ class SaveImageWebsocket:
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"SaveImageWebsocket": SaveImageWebsocket, "SaveImageWebsocket": SaveImageWebsocket,
} }
NODE_DISPLAY_NAME_MAPPINGS = {
"SaveImageWebsocket": "Save Image (Websocket)",
}

View File

@ -52,6 +52,8 @@ folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patc
folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions) folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions)
folder_names_and_paths["background_removal"] = ([os.path.join(models_dir, "background_removal")], supported_pt_extensions)
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions) folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions) folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)

View File

@ -330,7 +330,7 @@ class VAEDecodeTiled:
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode" FUNCTION = "decode"
CATEGORY = "_for_testing" CATEGORY = "experimental"
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8): def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
if tile_size < overlap * 4: if tile_size < overlap * 4:
@ -377,7 +377,7 @@ class VAEEncodeTiled:
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "encode" FUNCTION = "encode"
CATEGORY = "_for_testing" CATEGORY = "experimental"
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
@ -493,7 +493,7 @@ class SaveLatent:
OUTPUT_NODE = True OUTPUT_NODE = True
CATEGORY = "_for_testing" CATEGORY = "experimental"
def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): def save(self, samples, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
@ -538,7 +538,7 @@ class LoadLatent:
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")] files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
return {"required": {"latent": [sorted(files), ]}, } return {"required": {"latent": [sorted(files), ]}, }
CATEGORY = "_for_testing" CATEGORY = "experimental"
RETURN_TYPES = ("LATENT", ) RETURN_TYPES = ("LATENT", )
FUNCTION = "load" FUNCTION = "load"
@ -1443,7 +1443,7 @@ class LatentBlend:
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "blend" FUNCTION = "blend"
CATEGORY = "_for_testing" CATEGORY = "experimental"
def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"): def blend(self, samples1, samples2, blend_factor:float, blend_mode: str="normal"):
@ -2092,6 +2092,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"StyleModelLoader": "Load Style Model", "StyleModelLoader": "Load Style Model",
"CLIPVisionLoader": "Load CLIP Vision", "CLIPVisionLoader": "Load CLIP Vision",
"UNETLoader": "Load Diffusion Model", "UNETLoader": "Load Diffusion Model",
"unCLIPCheckpointLoader": "Load unCLIP Checkpoint",
"GLIGENLoader": "Load GLIGEN Model",
# Conditioning # Conditioning
"CLIPVisionEncode": "CLIP Vision Encode", "CLIPVisionEncode": "CLIP Vision Encode",
"StyleModelApply": "Apply Style Model", "StyleModelApply": "Apply Style Model",
@ -2140,7 +2142,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImageSharpen": "Sharpen Image", "ImageSharpen": "Sharpen Image",
"ImageScaleToTotalPixels": "Scale Image to Total Pixels", "ImageScaleToTotalPixels": "Scale Image to Total Pixels",
"GetImageSize": "Get Image Size", "GetImageSize": "Get Image Size",
# _for_testing # experimental
"VAEDecodeTiled": "VAE Decode (Tiled)", "VAEDecodeTiled": "VAE Decode (Tiled)",
"VAEEncodeTiled": "VAE Encode (Tiled)", "VAEEncodeTiled": "VAE Encode (Tiled)",
} }
@ -2427,6 +2429,7 @@ async def init_builtin_extra_nodes():
"nodes_number_convert.py", "nodes_number_convert.py",
"nodes_painter.py", "nodes_painter.py",
"nodes_curve.py", "nodes_curve.py",
"nodes_bg_removal.py",
"nodes_rtdetr.py", "nodes_rtdetr.py",
"nodes_frame_interpolation.py", "nodes_frame_interpolation.py",
"nodes_sam3.py", "nodes_sam3.py",

View File

@ -0,0 +1,90 @@
"""Tests for NodeReplaceManager registration behavior."""
import importlib
import sys
import types
import pytest
@pytest.fixture
def NodeReplaceManager(monkeypatch):
"""Provide NodeReplaceManager with `nodes` stubbed.
`app.node_replace_manager` does `import nodes` at module level, which pulls in
torch + the full ComfyUI graph. register() doesn't actually need it, so we
stub `nodes` per-test (via monkeypatch so it's torn down) and reload the
module so it picks up the stub instead of any cached real import.
"""
fake_nodes = types.ModuleType("nodes")
fake_nodes.NODE_CLASS_MAPPINGS = {}
monkeypatch.setitem(sys.modules, "nodes", fake_nodes)
monkeypatch.delitem(sys.modules, "app.node_replace_manager", raising=False)
module = importlib.import_module("app.node_replace_manager")
yield module.NodeReplaceManager
# Drop the freshly-imported module so the next test (or a later real import
# of `nodes`) starts from a clean slate.
sys.modules.pop("app.node_replace_manager", None)
class FakeNodeReplace:
"""Lightweight stand-in for comfy_api.latest._io.NodeReplace."""
def __init__(self, new_node_id, old_node_id, old_widget_ids=None,
input_mapping=None, output_mapping=None):
self.new_node_id = new_node_id
self.old_node_id = old_node_id
self.old_widget_ids = old_widget_ids
self.input_mapping = input_mapping
self.output_mapping = output_mapping
def test_register_adds_replacement(NodeReplaceManager):
manager = NodeReplaceManager()
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
assert manager.has_replacement("OldNode")
assert len(manager.get_replacement("OldNode")) == 1
def test_register_allows_multiple_alternatives_for_same_old_node(NodeReplaceManager):
"""Different new_node_ids for the same old_node_id should all be kept."""
manager = NodeReplaceManager()
manager.register(FakeNodeReplace(new_node_id="AltA", old_node_id="OldNode"))
manager.register(FakeNodeReplace(new_node_id="AltB", old_node_id="OldNode"))
replacements = manager.get_replacement("OldNode")
assert len(replacements) == 2
assert {r.new_node_id for r in replacements} == {"AltA", "AltB"}
def test_register_is_idempotent_for_duplicate_pair(NodeReplaceManager):
"""Re-registering the same (old_node_id, new_node_id) should be a no-op."""
manager = NodeReplaceManager()
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
manager.register(FakeNodeReplace(new_node_id="NewNode", old_node_id="OldNode"))
assert len(manager.get_replacement("OldNode")) == 1
def test_register_idempotent_preserves_first_registration(NodeReplaceManager):
"""First registration wins; later duplicates with different mappings are ignored."""
manager = NodeReplaceManager()
first = FakeNodeReplace(
new_node_id="NewNode", old_node_id="OldNode",
input_mapping=[{"new_id": "a", "old_id": "x"}],
)
second = FakeNodeReplace(
new_node_id="NewNode", old_node_id="OldNode",
input_mapping=[{"new_id": "b", "old_id": "y"}],
)
manager.register(first)
manager.register(second)
replacements = manager.get_replacement("OldNode")
assert len(replacements) == 1
assert replacements[0] is first
def test_register_dedupe_does_not_affect_other_old_nodes(NodeReplaceManager):
manager = NodeReplaceManager()
manager.register(FakeNodeReplace(new_node_id="NewA", old_node_id="OldA"))
manager.register(FakeNodeReplace(new_node_id="NewA", old_node_id="OldA"))
manager.register(FakeNodeReplace(new_node_id="NewB", old_node_id="OldB"))
assert len(manager.get_replacement("OldA")) == 1
assert len(manager.get_replacement("OldB")) == 1

View File

@ -21,7 +21,7 @@ class TestAsyncProgressUpdate(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,) RETURN_TYPES = (IO.ANY,)
FUNCTION = "execute" FUNCTION = "execute"
CATEGORY = "_for_testing/async" CATEGORY = "experimental/async"
async def execute(self, value, sleep_seconds): async def execute(self, value, sleep_seconds):
start = time.time() start = time.time()
@ -51,7 +51,7 @@ class TestSyncProgressUpdate(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,) RETURN_TYPES = (IO.ANY,)
FUNCTION = "execute" FUNCTION = "execute"
CATEGORY = "_for_testing/async" CATEGORY = "experimental/async"
def execute(self, value, sleep_seconds): def execute(self, value, sleep_seconds):
start = time.time() start = time.time()

View File

@ -21,7 +21,7 @@ class TestAsyncValidation(ComfyNodeABC):
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "process" FUNCTION = "process"
CATEGORY = "_for_testing/async" CATEGORY = "experimental/async"
@classmethod @classmethod
async def VALIDATE_INPUTS(cls, value, threshold): async def VALIDATE_INPUTS(cls, value, threshold):
@ -53,7 +53,7 @@ class TestAsyncError(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,) RETURN_TYPES = (IO.ANY,)
FUNCTION = "error_execution" FUNCTION = "error_execution"
CATEGORY = "_for_testing/async" CATEGORY = "experimental/async"
async def error_execution(self, value, error_after): async def error_execution(self, value, error_after):
await asyncio.sleep(error_after) await asyncio.sleep(error_after)
@ -74,7 +74,7 @@ class TestAsyncValidationError(ComfyNodeABC):
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "process" FUNCTION = "process"
CATEGORY = "_for_testing/async" CATEGORY = "experimental/async"
@classmethod @classmethod
async def VALIDATE_INPUTS(cls, value, max_value): async def VALIDATE_INPUTS(cls, value, max_value):
@ -105,7 +105,7 @@ class TestAsyncTimeout(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,) RETURN_TYPES = (IO.ANY,)
FUNCTION = "timeout_execution" FUNCTION = "timeout_execution"
CATEGORY = "_for_testing/async" CATEGORY = "experimental/async"
async def timeout_execution(self, value, timeout, operation_time): async def timeout_execution(self, value, timeout, operation_time):
try: try:
@ -129,7 +129,7 @@ class TestSyncError(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,) RETURN_TYPES = (IO.ANY,)
FUNCTION = "sync_error" FUNCTION = "sync_error"
CATEGORY = "_for_testing/async" CATEGORY = "experimental/async"
def sync_error(self, value): def sync_error(self, value):
raise RuntimeError("Intentional sync execution error for testing") raise RuntimeError("Intentional sync execution error for testing")
@ -150,7 +150,7 @@ class TestAsyncLazyCheck(ComfyNodeABC):
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "process" FUNCTION = "process"
CATEGORY = "_for_testing/async" CATEGORY = "experimental/async"
async def check_lazy_status(self, condition, input1, input2): async def check_lazy_status(self, condition, input1, input2):
# Simulate async checking (e.g., querying remote service) # Simulate async checking (e.g., querying remote service)
@ -184,7 +184,7 @@ class TestDynamicAsyncGeneration(ComfyNodeABC):
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "generate_async_workflow" FUNCTION = "generate_async_workflow"
CATEGORY = "_for_testing/async" CATEGORY = "experimental/async"
def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration): def generate_async_workflow(self, image1, image2, num_async_nodes, sleep_duration):
g = GraphBuilder() g = GraphBuilder()
@ -229,7 +229,7 @@ class TestAsyncResourceUser(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,) RETURN_TYPES = (IO.ANY,)
FUNCTION = "use_resource" FUNCTION = "use_resource"
CATEGORY = "_for_testing/async" CATEGORY = "experimental/async"
async def use_resource(self, value, resource_id, duration): async def use_resource(self, value, resource_id, duration):
# Check if resource is already in use # Check if resource is already in use
@ -265,7 +265,7 @@ class TestAsyncBatchProcessing(ComfyNodeABC):
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "process_batch" FUNCTION = "process_batch"
CATEGORY = "_for_testing/async" CATEGORY = "experimental/async"
async def process_batch(self, images, process_time_per_item, unique_id): async def process_batch(self, images, process_time_per_item, unique_id):
batch_size = images.shape[0] batch_size = images.shape[0]
@ -305,7 +305,7 @@ class TestAsyncConcurrentLimit(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,) RETURN_TYPES = (IO.ANY,)
FUNCTION = "limited_execution" FUNCTION = "limited_execution"
CATEGORY = "_for_testing/async" CATEGORY = "experimental/async"
async def limited_execution(self, value, duration, node_id): async def limited_execution(self, value, duration, node_id):
async with self._semaphore: async with self._semaphore:

View File

@ -409,7 +409,7 @@ class TestSleep(ComfyNodeABC):
RETURN_TYPES = (IO.ANY,) RETURN_TYPES = (IO.ANY,)
FUNCTION = "sleep" FUNCTION = "sleep"
CATEGORY = "_for_testing" CATEGORY = "experimental"
async def sleep(self, value, seconds, unique_id): async def sleep(self, value, seconds, unique_id):
pbar = ProgressBar(seconds, node_id=unique_id) pbar = ProgressBar(seconds, node_id=unique_id)
@ -440,7 +440,7 @@ class TestParallelSleep(ComfyNodeABC):
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "parallel_sleep" FUNCTION = "parallel_sleep"
CATEGORY = "_for_testing" CATEGORY = "experimental"
OUTPUT_NODE = True OUTPUT_NODE = True
def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id): def parallel_sleep(self, image1, image2, image3, sleep1, sleep2, sleep3, unique_id):
@ -474,7 +474,7 @@ class TestOutputNodeWithSocketOutput:
} }
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "process" FUNCTION = "process"
CATEGORY = "_for_testing" CATEGORY = "experimental"
OUTPUT_NODE = True OUTPUT_NODE = True
def process(self, image, value): def process(self, image, value):