Compare commits

...

12 Commits

Author SHA1 Message Date
Alexander Piskun
412762b03c
Merge 28d538ddf9 into 87878f354f 2026-05-09 04:49:24 +09:00
Matt Miller
87878f354f
Add cloud-runtime FE-facing operations to spec (#13734)
* Add cloud-runtime FE-facing operations to openapi.yaml

Add ~67 cloud-runtime FE-facing path operations to the core OpenAPI spec,
each tagged with x-runtime: [cloud] at the operation level. These operations
are served by the cloud runtime; the local runtime returns 404 for all of
these paths.

Domain groups added:
- Jobs / prompts: /api/job/*, /api/jobs/*/cancel, /api/prompt/*, etc.
- History v2: /api/history_v2, /api/history_v2/{prompt_id}
- Cloud logs: /api/logs
- Asset extensions: /api/assets/download, export, import, etc.
- Custom nodes: /api/experiment/nodes (cloud install/uninstall)
- Hub: /api/hub/profiles, /api/hub/workflows, /api/hub/labels, etc.
- Workflows: /api/workflows CRUD, versioning, fork, publish
- Auth/session: /api/auth/session, /api/auth/token, /.well-known/jwks.json
- Billing: /api/billing/balance, plans, subscribe, topup, etc.
- Workspace: /api/workspace/*, /api/workspaces/*
- User/settings/misc: /api/user, /api/secrets, /api/feedback, etc.

Also adds corresponding cloud-only component schemas (CloudJob, CloudWorkflow,
BillingPlan, Workspace, HubProfile, AuthSession, etc.), all tagged with
x-runtime: [cloud].

Spectral lint passes under the existing ruleset with zero new warnings.

* Add job_id field to Asset schema and deprecate prompt_id (#13736)

- Add job_id as a nullable UUID field to the Asset schema
- Mark prompt_id as deprecated with note pointing to job_id
- No x-runtime tag needed as both runtimes populate the field

* Add hash field to Asset schemas and deprecate asset_hash (#13738)

- Add 'hash' as a nullable string field to Asset and AssetUpdated schemas
- Mark 'asset_hash' as deprecated with a note pointing to 'hash'
- AssetCreated inherits 'hash' via allOf from Asset
- Spectral lint clean (no new warnings)

* Fix method drift on cloud-runtime endpoints

Three PUT operations were added that should be PATCH (cloud serves
PATCH for partial updates):

- /api/workflows/{workflow_id}
- /api/workspaces/{id}
- /api/workspace/members/{userId}

Two POST operations were added that should be GET (cloud serves GET
with query params):

- /api/assets/remote-metadata (url moves to query param)
- /api/files/mask-layers (response shape replaced — operation queries
  related mask layer filenames, not file uploads)

* Add missing cloud-runtime operations and schemas

PR review surfaced operations the cloud runtime serves that weren't
covered by the initial spec push, plus one path family missed entirely.

New methods on existing paths:

- /api/auth/session: add POST (create session cookie) and DELETE (logout)
- /api/secrets/{id}: add GET (read metadata) and PATCH (update)
- /api/hub/profiles: add POST (create profile)
- /api/hub/workflows: add POST (publish to hub)
- /api/hub/workflows/{share_id}: add DELETE (unpublish)
- /api/workspaces/{id}: add DELETE (soft-delete workspace)
- /api/workspace/members/{user_id}/api-keys: add DELETE (bulk revoke)
- /api/workflows/{workflow_id}/versions: add POST (create new version)
- /api/userdata/{file}/publish: add GET (read publish info)

New path family:

- /api/tasks (GET list) and /api/tasks/{task_id} (GET detail) for the
  background task framework

New component schemas (all tagged x-runtime: [cloud]):

CreateSessionResponse, DeleteSessionResponse, UpdateSecretRequest,
BulkRevokeAPIKeysResponse, CreateHubProfileRequest, PublishHubWorkflowRequest,
HubWorkflowDetail, AssetInfo, CreateWorkflowVersionRequest,
WorkflowVersionResponse, WorkflowPublishInfo, TaskEntry, TaskResponse,
TasksListResponse. Existing SecretMeta extended with provider and
last_used_at fields the cloud runtime actually returns.

New tag: task. Spectral lint passes with zero errors.

* Add job_id and prompt_id to AssetUpdated schema

Mirrors the Asset schema's deprecation pattern: prompt_id is marked
deprecated with a description pointing to job_id; job_id is the new
preferred field. PUT /api/assets/{id} responses can now carry both fields
consistent with the other Asset-returning endpoints.

* feat: add width and height fields to Asset schema (#13745)

Add nullable integer fields 'width' and 'height' to the Asset schema
in openapi.yaml. These expose original image dimensions in pixels for
clients that need pre-thumbnail size info. Both fields are null for
non-image assets or assets ingested before dimension extraction.

Co-authored-by: Matt Miller <MillerMedia@users.noreply.github.com>

* Remove /api/job/{job_id} and /api/job/{job_id}/outputs

These two paths are not actually served by the cloud runtime — they
return 404 with a redirect message pointing callers to the canonical
`/api/jobs/{job_id}` (plural). Declaring them with `x-runtime: [cloud]`
and a 200 response schema is incorrect.

`/api/job/{job_id}/status` stays — it is a real cloud-served endpoint.

Also drops the now-orphaned `CloudJob` and `CloudJobOutputs` component
schemas. `CloudJobStatus` is retained.
2026-05-08 12:39:16 -07:00
Alexis Rolland
c5ecd231a2
fix: Fix bug when mask not on same device (CORE-181) (#13801)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
2026-05-08 23:06:29 +08:00
drozbay
9864f5ac86
fix: Stop LTXVImgToVideoInplace from mutating input latents and dropping noise_mask (#13793) 2026-05-08 23:02:17 +08:00
drozbay
05cd076bc1
fix: Make LTXVAddGuide center-crop guide images to match other LTXV nodes (#13794) 2026-05-08 22:48:59 +08:00
Yousef R. Gamaleldin
d3c18c1636
Add support for BiRefNet background remove model (CORE-46) (#12747) 2026-05-08 17:59:24 +08:00
omahs
bac6fc35fb
Fix typos (#10986) 2026-05-08 17:14:45 +08:00
Alexander Piskun
28d538ddf9
Merge branch 'master' into feat/string-min-max-length
Some checks failed
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
2026-03-25 08:30:29 +02:00
Alexander Piskun
e326b41d62
Merge branch 'master' into feat/string-min-max-length
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-03-14 07:37:05 +02:00
Alexander Piskun
aa9e7a84bc
Merge branch 'master' into feat/string-min-max-length
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
2026-03-13 19:21:36 +02:00
Alexander Piskun
9b7a2a3248
Merge branch 'master' into feat/string-min-max-length
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
2026-03-13 06:51:41 +02:00
bigcat88
5b913f0377
feat: add minLength/maxLength validation for String inputs
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-03-09 09:00:33 +02:00
19 changed files with 5762 additions and 26 deletions

View File

@ -0,0 +1,7 @@
{
"model_type": "birefnet",
"image_std": [1.0, 1.0, 1.0],
"image_mean": [0.0, 0.0, 0.0],
"image_size": 1024,
"resize_to_original": true
}

View File

@ -0,0 +1,689 @@
import torch
import comfy.ops
import numpy as np
import torch.nn as nn
from functools import partial
import torch.nn.functional as F
from torchvision.ops import deform_conv2d
from comfy.ldm.modules.attention import optimized_attention_for_device
CXT = [3072, 1536, 768, 384][1:][::-1][-3:]
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.q = operations.Linear(dim, dim, bias=qkv_bias, device=device, dtype=dtype)
self.kv = operations.Linear(dim, dim * 2, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
def forward(self, x):
B, N, C = x.shape
optimized_attention = optimized_attention_for_device(x.device, mask=False, small_input=True)
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
x = optimized_attention(
q, k, v, heads=self.num_heads, skip_output_reshape=True, skip_reshape=True
).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, device=None, dtype=None, operations=None):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = operations.Linear(in_features, hidden_features, device=device, dtype=dtype)
self.act = nn.GELU()
self.fc2 = operations.Linear(hidden_features, out_features, device=device, dtype=dtype)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads, device=device, dtype=dtype))
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, device=device, dtype=dtype)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
return x
class SwinTransformerBlock(nn.Module):
def __init__(self, dim, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None,
norm_layer=nn.LayerNorm, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.norm1 = norm_layer(dim, device=device, dtype=dtype)
self.attn = WindowAttention(
dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, device=device, dtype=dtype, operations=operations)
self.norm2 = norm_layer(dim, device=device, dtype=dtype)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, device=device, dtype=dtype, operations=operations)
self.H = None
self.W = None
def forward(self, x, mask_matrix):
B, L, C = x.shape
H, W = self.H, self.W
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
attn_mask = mask_matrix
else:
shifted_x = x
attn_mask = None
x_windows = window_partition(shifted_x, self.window_size)
x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
attn_windows = self.attn(x_windows, mask=attn_mask)
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class PatchMerging(nn.Module):
def __init__(self, dim, device=None, dtype=None, operations=None):
super().__init__()
self.dim = dim
self.reduction = operations.Linear(4 * dim, 2 * dim, bias=False, device=device, dtype=dtype)
self.norm = operations.LayerNorm(4 * dim, device=device, dtype=dtype)
def forward(self, x, H, W):
B, L, C = x.shape
x = x.view(B, H, W, C)
# padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
class BasicLayer(nn.Module):
def __init__(self,
dim,
depth,
num_heads,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
norm_layer=nn.LayerNorm,
downsample=None,
device=None, dtype=None, operations=None):
super().__init__()
self.window_size = window_size
self.shift_size = window_size // 2
self.depth = depth
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
norm_layer=norm_layer,
device=device, dtype=dtype, operations=operations)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, device=device, dtype=dtype, operations=operations)
else:
self.downsample = None
def forward(self, x, H, W):
Hp = int(np.ceil(H / self.window_size)) * self.window_size
Wp = int(np.ceil(W / self.window_size)) * self.window_size
img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
for blk in self.blocks:
blk.H, blk.W = H, W
x = blk(x, attn_mask)
if self.downsample is not None:
x_down = self.downsample(x, H, W)
Wh, Ww = (H + 1) // 2, (W + 1) // 2
return x, H, W, x_down, Wh, Ww
else:
return x, H, W, x, H, W
class PatchEmbed(nn.Module):
def __init__(self, patch_size=4, in_channels=3, embed_dim=96, norm_layer=None, device=None, dtype=None, operations=None):
super().__init__()
patch_size = (patch_size, patch_size)
self.patch_size = patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
self.proj = operations.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype)
if norm_layer is not None:
self.norm = norm_layer(embed_dim, device=device, dtype=dtype)
else:
self.norm = None
def forward(self, x):
_, _, H, W = x.size()
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x) # B C Wh Ww
if self.norm is not None:
Wh, Ww = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
return x
class SwinTransformer(nn.Module):
def __init__(self,
pretrain_img_size=224,
patch_size=4,
in_channels=3,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
patch_norm=True,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
device=None, dtype=None, operations=None):
super().__init__()
norm_layer = partial(operations.LayerNorm, device=device, dtype=dtype)
self.pretrain_img_size = pretrain_img_size
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.patch_embed = PatchEmbed(
patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim,
device=device, dtype=dtype, operations=operations,
norm_layer=norm_layer if self.patch_norm else None)
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
device=device, dtype=dtype, operations=operations)
self.layers.append(layer)
num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
self.num_features = num_features
for i_layer in out_indices:
layer = norm_layer(num_features[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
def forward(self, x):
x = self.patch_embed(x)
Wh, Ww = x.size(2), x.size(3)
outs = []
x = x.flatten(2).transpose(1, 2)
for i in range(self.num_layers):
layer = self.layers[i]
x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
if i in self.out_indices:
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(x_out)
out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)
return tuple(outs)
class DeformableConv2d(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False, device=None, dtype=None, operations=None):
super(DeformableConv2d, self).__init__()
kernel_size = kernel_size if type(kernel_size) is tuple else (kernel_size, kernel_size)
self.stride = stride if type(stride) is tuple else (stride, stride)
self.padding = padding
self.offset_conv = operations.Conv2d(in_channels,
2 * kernel_size[0] * kernel_size[1],
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=True, device=device, dtype=dtype)
self.modulator_conv = operations.Conv2d(in_channels,
1 * kernel_size[0] * kernel_size[1],
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=True, device=device, dtype=dtype)
self.regular_conv = operations.Conv2d(in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=self.padding,
bias=bias, device=device, dtype=dtype)
def forward(self, x):
offset = self.offset_conv(x)
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
weight, bias, offload_info = comfy.ops.cast_bias_weight(self.regular_conv, x, offloadable=True)
x = deform_conv2d(
input=x,
offset=offset,
weight=weight,
bias=None,
padding=self.padding,
mask=modulator,
stride=self.stride,
)
comfy.ops.uncast_bias_weight(self.regular_conv, weight, bias, offload_info)
return x
class BasicDecBlk(nn.Module):
def __init__(self, in_channels=64, out_channels=64, inter_channels=64, device=None, dtype=None, operations=None):
super(BasicDecBlk, self).__init__()
inter_channels = 64
self.conv_in = operations.Conv2d(in_channels, inter_channels, 3, 1, padding=1, device=device, dtype=dtype)
self.relu_in = nn.ReLU(inplace=True)
self.dec_att = ASPPDeformable(in_channels=inter_channels, device=device, dtype=dtype, operations=operations)
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, padding=1, device=device, dtype=dtype)
self.bn_in = operations.BatchNorm2d(inter_channels, device=device, dtype=dtype)
self.bn_out = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
def forward(self, x):
x = self.conv_in(x)
x = self.bn_in(x)
x = self.relu_in(x)
x = self.dec_att(x)
x = self.conv_out(x)
x = self.bn_out(x)
return x
class BasicLatBlk(nn.Module):
def __init__(self, in_channels=64, out_channels=64, device=None, dtype=None, operations=None):
super(BasicLatBlk, self).__init__()
self.conv = operations.Conv2d(in_channels, out_channels, 1, 1, 0, device=device, dtype=dtype)
def forward(self, x):
x = self.conv(x)
return x
class _ASPPModuleDeformable(nn.Module):
def __init__(self, in_channels, planes, kernel_size, padding, device, dtype, operations):
super(_ASPPModuleDeformable, self).__init__()
self.atrous_conv = DeformableConv2d(in_channels, planes, kernel_size=kernel_size,
stride=1, padding=padding, bias=False, device=device, dtype=dtype, operations=operations)
self.bn = operations.BatchNorm2d(planes, device=device, dtype=dtype)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.atrous_conv(x)
x = self.bn(x)
return self.relu(x)
class ASPPDeformable(nn.Module):
def __init__(self, in_channels, out_channels=None, parallel_block_sizes=[1, 3, 7], device=None, dtype=None, operations=None):
super(ASPPDeformable, self).__init__()
self.down_scale = 1
if out_channels is None:
out_channels = in_channels
self.in_channelster = 256 // self.down_scale
self.aspp1 = _ASPPModuleDeformable(in_channels, self.in_channelster, 1, padding=0, device=device, dtype=dtype, operations=operations)
self.aspp_deforms = nn.ModuleList([
_ASPPModuleDeformable(in_channels, self.in_channelster, conv_size, padding=int(conv_size//2), device=device, dtype=dtype, operations=operations)
for conv_size in parallel_block_sizes
])
self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)),
operations.Conv2d(in_channels, self.in_channelster, 1, stride=1, bias=False, device=device, dtype=dtype),
operations.BatchNorm2d(self.in_channelster, device=device, dtype=dtype),
nn.ReLU(inplace=True))
self.conv1 = operations.Conv2d(self.in_channelster * (2 + len(self.aspp_deforms)), out_channels, 1, bias=False, device=device, dtype=dtype)
self.bn1 = operations.BatchNorm2d(out_channels, device=device, dtype=dtype)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x1 = self.aspp1(x)
x_aspp_deforms = [aspp_deform(x) for aspp_deform in self.aspp_deforms]
x5 = self.global_avg_pool(x)
x5 = F.interpolate(x5, size=x1.size()[2:], mode='bilinear', align_corners=True)
x = torch.cat((x1, *x_aspp_deforms, x5), dim=1)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
return x
class BiRefNet(nn.Module):
def __init__(self, config=None, dtype=None, device=None, operations=None):
super(BiRefNet, self).__init__()
self.bb = SwinTransformer(embed_dim=192, depths=[2, 2, 18, 2], num_heads=[6, 12, 24, 48], window_size=12, device=device, dtype=dtype, operations=operations)
channels = [1536, 768, 384, 192]
channels = [c * 2 for c in channels]
self.cxt = channels[1:][::-1][-3:]
self.squeeze_module = nn.Sequential(*[
BasicDecBlk(channels[0]+sum(self.cxt), channels[0], device=device, dtype=dtype, operations=operations)
for _ in range(1)
])
self.decoder = Decoder(channels, device=device, dtype=dtype, operations=operations)
def forward_enc(self, x):
x1, x2, x3, x4 = self.bb(x)
B, C, H, W = x.shape
x1_, x2_, x3_, x4_ = self.bb(F.interpolate(x, size=(H//2, W//2), mode='bilinear', align_corners=True))
x1 = torch.cat([x1, F.interpolate(x1_, size=x1.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x2 = torch.cat([x2, F.interpolate(x2_, size=x2.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x3 = torch.cat([x3, F.interpolate(x3_, size=x3.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x4 = torch.cat([x4, F.interpolate(x4_, size=x4.shape[2:], mode='bilinear', align_corners=True)], dim=1)
x4 = torch.cat(
(
*[
F.interpolate(x1, size=x4.shape[2:], mode='bilinear', align_corners=True),
F.interpolate(x2, size=x4.shape[2:], mode='bilinear', align_corners=True),
F.interpolate(x3, size=x4.shape[2:], mode='bilinear', align_corners=True),
][-len(CXT):],
x4
),
dim=1
)
return (x1, x2, x3, x4)
def forward_ori(self, x):
(x1, x2, x3, x4) = self.forward_enc(x)
x4 = self.squeeze_module(x4)
features = [x, x1, x2, x3, x4]
scaled_preds = self.decoder(features)
return scaled_preds
def forward(self, pixel_values, intermediate_output=None):
scaled_preds = self.forward_ori(pixel_values)
return scaled_preds
class Decoder(nn.Module):
def __init__(self, channels, device, dtype, operations):
super(Decoder, self).__init__()
# factory kwargs
fk = {"device":device, "dtype":dtype, "operations":operations}
DecoderBlock = partial(BasicDecBlk, **fk)
LateralBlock = partial(BasicLatBlk, **fk)
DBlock = partial(SimpleConvs, **fk)
self.split = True
N_dec_ipt = 64
ic = 64
ipt_cha_opt = 1
self.ipt_blk5 = DBlock(2**10*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk4 = DBlock(2**8*3 if self.split else 3, [N_dec_ipt, channels[0]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk3 = DBlock(2**6*3 if self.split else 3, [N_dec_ipt, channels[1]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk2 = DBlock(2**4*3 if self.split else 3, [N_dec_ipt, channels[2]//8][ipt_cha_opt], inter_channels=ic)
self.ipt_blk1 = DBlock(2**0*3 if self.split else 3, [N_dec_ipt, channels[3]//8][ipt_cha_opt], inter_channels=ic)
self.decoder_block4 = DecoderBlock(channels[0]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[1])
self.decoder_block3 = DecoderBlock(channels[1]+([N_dec_ipt, channels[0]//8][ipt_cha_opt]), channels[2])
self.decoder_block2 = DecoderBlock(channels[2]+([N_dec_ipt, channels[1]//8][ipt_cha_opt]), channels[3])
self.decoder_block1 = DecoderBlock(channels[3]+([N_dec_ipt, channels[2]//8][ipt_cha_opt]), channels[3]//2)
fk = {"device":device, "dtype":dtype}
self.conv_out1 = nn.Sequential(operations.Conv2d(channels[3]//2+([N_dec_ipt, channels[3]//8][ipt_cha_opt]), 1, 1, 1, 0, **fk))
self.lateral_block4 = LateralBlock(channels[1], channels[1])
self.lateral_block3 = LateralBlock(channels[2], channels[2])
self.lateral_block2 = LateralBlock(channels[3], channels[3])
self.conv_ms_spvn_4 = operations.Conv2d(channels[1], 1, 1, 1, 0, **fk)
self.conv_ms_spvn_3 = operations.Conv2d(channels[2], 1, 1, 1, 0, **fk)
self.conv_ms_spvn_2 = operations.Conv2d(channels[3], 1, 1, 1, 0, **fk)
_N = 16
self.gdt_convs_4 = nn.Sequential(operations.Conv2d(channels[0] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
self.gdt_convs_3 = nn.Sequential(operations.Conv2d(channels[1] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
self.gdt_convs_2 = nn.Sequential(operations.Conv2d(channels[2] // 2, _N, 3, 1, 1, **fk), operations.BatchNorm2d(_N, **fk), nn.ReLU(inplace=True))
[setattr(self, f"gdt_convs_pred_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
[setattr(self, f"gdt_convs_attn_{i}", nn.Sequential(operations.Conv2d(_N, 1, 1, 1, 0, **fk))) for i in range(2, 5)]
def get_patches_batch(self, x, p):
_size_h, _size_w = p.shape[2:]
patches_batch = []
for idx in range(x.shape[0]):
columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
patches_x = []
for column_x in columns_x:
patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
patch_sample = torch.cat(patches_x, dim=1)
patches_batch.append(patch_sample)
return torch.cat(patches_batch, dim=0)
def forward(self, features):
x, x1, x2, x3, x4 = features
patches_batch = self.get_patches_batch(x, x4) if self.split else x
x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
p4 = self.decoder_block4(x4)
p4_gdt = self.gdt_convs_4(p4)
gdt_attn_4 = self.gdt_convs_attn_4(p4_gdt).sigmoid()
p4 = p4 * gdt_attn_4
_p4 = F.interpolate(p4, size=x3.shape[2:], mode='bilinear', align_corners=True)
_p3 = _p4 + self.lateral_block4(x3)
patches_batch = self.get_patches_batch(x, _p3) if self.split else x
_p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
p3 = self.decoder_block3(_p3)
p3_gdt = self.gdt_convs_3(p3)
gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
p3 = p3 * gdt_attn_3
_p3 = F.interpolate(p3, size=x2.shape[2:], mode='bilinear', align_corners=True)
_p2 = _p3 + self.lateral_block3(x2)
patches_batch = self.get_patches_batch(x, _p2) if self.split else x
_p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
p2 = self.decoder_block2(_p2)
p2_gdt = self.gdt_convs_2(p2)
gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
p2 = p2 * gdt_attn_2
_p2 = F.interpolate(p2, size=x1.shape[2:], mode='bilinear', align_corners=True)
_p1 = _p2 + self.lateral_block2(x1)
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
_p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
_p1 = self.decoder_block1(_p1)
_p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
patches_batch = self.get_patches_batch(x, _p1) if self.split else x
_p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
p1_out = self.conv_out1(_p1)
return p1_out
class SimpleConvs(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, inter_channels=64, device=None, dtype=None, operations=None
) -> None:
super().__init__()
self.conv1 = operations.Conv2d(in_channels, inter_channels, 3, 1, 1, device=device, dtype=dtype)
self.conv_out = operations.Conv2d(inter_channels, out_channels, 3, 1, 1, device=device, dtype=dtype)
def forward(self, x):
return self.conv_out(self.conv1(x))

78
comfy/bg_removal_model.py Normal file
View File

@ -0,0 +1,78 @@
from .utils import load_torch_file
import os
import json
import torch
import logging
import comfy.ops
import comfy.model_patcher
import comfy.model_management
import comfy.clip_model
import comfy.background_removal.birefnet
BG_REMOVAL_MODELS = {
"birefnet": comfy.background_removal.birefnet.BiRefNet
}
class BackgroundRemovalModel():
def __init__(self, json_config):
with open(json_config) as f:
config = json.load(f)
self.image_size = config.get("image_size", 1024)
self.image_mean = config.get("image_mean", [0.0, 0.0, 0.0])
self.image_std = config.get("image_std", [1.0, 1.0, 1.0])
self.model_type = config.get("model_type", "birefnet")
self.config = config.copy()
model_class = BG_REMOVAL_MODELS.get(self.model_type)
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
def get_sd(self):
return self.model.state_dict()
def encode_image(self, image):
comfy.model_management.load_model_gpu(self.patcher)
H, W = image.shape[1], image.shape[2]
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=False)
out = self.model(pixel_values=pixel_values)
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
if mask.ndim == 3:
mask = mask.unsqueeze(0)
if mask.shape[1] != 1:
mask = mask.movedim(-1, 1)
return mask
def load_background_removal_model(sd):
if "bb.layers.1.blocks.0.attn.relative_position_index" in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "background_removal"), "birefnet.json")
else:
return None
bg_model = BackgroundRemovalModel(json_config)
m, u = bg_model.load_sd(sd)
if len(m) > 0:
logging.warning("missing background removal: {}".format(m))
u = set(u)
keys = list(sd.keys())
for k in keys:
if k not in u:
sd.pop(k)
return bg_model
def load(ckpt_path):
sd = load_torch_file(ckpt_path)
return load_background_removal_model(sd)

View File

@ -93,7 +93,7 @@ class Hook:
self.hook_scope = hook_scope
'''Scope of where this hook should apply in terms of the conds used in sampling run.'''
self.custom_should_register = default_should_register
'''Can be overriden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
'''Can be overridden with a compatible function to decide if this hook should be registered without the need to override .should_register'''
@property
def strength(self):

View File

@ -140,7 +140,7 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502
# according to the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose:
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')

View File

@ -562,6 +562,25 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class BatchNorm2d(torch.nn.BatchNorm2d, CastWeightBiasOp):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
running_mean = self.running_mean.to(device=input.device, dtype=weight.dtype) if self.running_mean is not None else None
running_var = self.running_var.to(device=input.device, dtype=weight.dtype) if self.running_var is not None else None
x = torch.nn.functional.batch_norm(input, running_mean, running_var, weight, bias, self.training, self.momentum, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
def reset_parameters(self):
return None
@ -749,6 +768,9 @@ class manual_cast(disable_weight_init):
class Conv3d(disable_weight_init.Conv3d):
comfy_cast_weights = True
class BatchNorm2d(disable_weight_init.BatchNorm2d):
comfy_cast_weights = True
class GroupNorm(disable_weight_init.GroupNorm):
comfy_cast_weights = True

View File

@ -17,6 +17,7 @@ if TYPE_CHECKING:
from spandrel import ImageModelDescriptor
from comfy.clip_vision import ClipVisionModel
from comfy.clip_vision import Output as ClipVisionOutput_
from comfy.bg_removal_model import BackgroundRemovalModel
from comfy.controlnet import ControlNet
from comfy.hooks import HookGroup, HookKeyframeGroup
from comfy.model_patcher import ModelPatcher
@ -326,11 +327,14 @@ class String(ComfyTypeIO):
'''String input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
multiline=False, placeholder: str=None, default: str=None, dynamic_prompts: bool=None,
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None,
min_length: int=None, max_length: int=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.multiline = multiline
self.placeholder = placeholder
self.dynamic_prompts = dynamic_prompts
self.min_length = min_length
self.max_length = max_length
self.default: str
def as_dict(self):
@ -338,6 +342,8 @@ class String(ComfyTypeIO):
"multiline": self.multiline,
"placeholder": self.placeholder,
"dynamicPrompts": self.dynamic_prompts,
"minLength": self.min_length,
"maxLength": self.max_length,
})
@comfytype(io_type="COMBO")
@ -614,6 +620,11 @@ class Model(ComfyTypeIO):
if TYPE_CHECKING:
Type = ModelPatcher
@comfytype(io_type="BACKGROUND_REMOVAL")
class BackgroundRemoval(ComfyTypeIO):
if TYPE_CHECKING:
Type = BackgroundRemovalModel
@comfytype(io_type="CLIP_VISION")
class ClipVision(ComfyTypeIO):
if TYPE_CHECKING:
@ -2257,6 +2268,7 @@ __all__ = [
"ModelPatch",
"ClipVision",
"ClipVisionOutput",
"BackgroundRemoval",
"AudioEncoder",
"AudioEncoderOutput",
"StyleModel",

View File

@ -0,0 +1,60 @@
import folder_paths
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO
from comfy.bg_removal_model import load
class LoadBackgroundRemovalModel(IO.ComfyNode):
@classmethod
def define_schema(cls):
files = folder_paths.get_filename_list("background_removal")
return IO.Schema(
node_id="LoadBackgroundRemovalModel",
display_name="Load Background Removal Model",
category="loaders",
inputs=[
IO.Combo.Input("bg_removal_name", options=sorted(files), tooltip="The model used to remove backgrounds from images"),
],
outputs=[
IO.BackgroundRemoval.Output("bg_model")
]
)
@classmethod
def execute(cls, bg_removal_name):
path = folder_paths.get_full_path_or_raise("background_removal", bg_removal_name)
bg = load(path)
if bg is None:
raise RuntimeError("ERROR: background model file is invalid and does not contain a valid background removal model.")
return IO.NodeOutput(bg)
class RemoveBackground(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RemoveBackground",
display_name="Remove Background",
category="image/background removal",
inputs=[
IO.Image.Input("image", tooltip="Input image to remove the background from"),
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask")
],
outputs=[
IO.Mask.Output("mask", tooltip="Generated foreground mask")
]
)
@classmethod
def execute(cls, image, bg_removal_model):
mask = bg_removal_model.encode_image(image)
return IO.NodeOutput(mask)
class BackgroundRemovalExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
LoadBackgroundRemovalModel,
RemoveBackground
]
async def comfy_entrypoint() -> BackgroundRemovalExtension:
return BackgroundRemovalExtension()

View File

@ -203,7 +203,7 @@ class JoinImageWithAlpha(io.ComfyNode):
@classmethod
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
batch_size = max(len(image), len(alpha))
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
alpha = 1.0 - resize_mask(alpha.to(image), image.shape[1:])
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
image = comfy.utils.repeat_to_batch_size(image, batch_size)
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))

View File

@ -102,7 +102,7 @@ class FluxDisableGuidance(io.ComfyNode):
append = execute # TODO: remove
PREFERED_KONTEXT_RESOLUTIONS = [
PREFERRED_KONTEXT_RESOLUTIONS = [
(672, 1568),
(688, 1504),
(720, 1456),
@ -143,7 +143,7 @@ class FluxKontextImageScale(io.ComfyNode):
width = image.shape[2]
height = image.shape[1]
aspect_ratio = width / height
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERED_KONTEXT_RESOLUTIONS)
_, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS)
image = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "lanczos", "center").movedim(1, -1)
return io.NodeOutput(image)

View File

@ -106,12 +106,12 @@ class LTXVImgToVideoInplace(io.ComfyNode):
if bypass:
return (latent,)
samples = latent["samples"]
samples = latent["samples"].clone()
_, height_scale_factor, width_scale_factor = (
vae.downscale_index_formula
)
batch, _, latent_frames, latent_height, latent_width = samples.shape
_, _, _, latent_height, latent_width = samples.shape
width = latent_width * width_scale_factor
height = latent_height * height_scale_factor
@ -124,11 +124,7 @@ class LTXVImgToVideoInplace(io.ComfyNode):
samples[:, :, :t.shape[2]] = t
conditioning_latent_frames_mask = torch.ones(
(batch, 1, latent_frames, 1, 1),
dtype=torch.float32,
device=samples.device,
)
conditioning_latent_frames_mask = get_noise_mask(latent)
conditioning_latent_frames_mask[:, :, :t.shape[2]] = 1.0 - strength
return io.NodeOutput({"samples": samples, "noise_mask": conditioning_latent_frames_mask})
@ -236,7 +232,7 @@ class LTXVAddGuide(io.ComfyNode):
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="disabled").movedim(1, -1)
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
return encode_pixels, t

View File

@ -40,10 +40,21 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou
inverse_mask = torch.ones_like(mask) - mask
source_portion = mask * source[..., :visible_height, :visible_width]
destination_portion = inverse_mask * destination[..., top:bottom, left:right]
source_rgb = source[:, :3, :visible_height, :visible_width]
dest_slice = destination[..., top:bottom, left:right]
if destination.shape[1] == 4:
if torch.max(dest_slice) == 0:
destination[:, :3, top:bottom, left:right] = source_rgb
destination[:, 3:4, top:bottom, left:right] = mask
else:
destination[:, :3, top:bottom, left:right] = (mask * source_rgb) + (inverse_mask * dest_slice[:, :3])
destination[:, 3:4, top:bottom, left:right] = torch.max(mask, dest_slice[:, 3:4])
else:
source_portion = mask * source_rgb
destination_portion = inverse_mask * dest_slice
destination[..., top:bottom, left:right] = source_portion + destination_portion
destination[..., top:bottom, left:right] = source_portion + destination_portion
return destination
class LatentCompositeMasked(IO.ComfyNode):
@ -84,18 +95,23 @@ class ImageCompositeMasked(IO.ComfyNode):
display_name="Image Composite Masked",
category="image",
inputs=[
IO.Image.Input("destination"),
IO.Image.Input("source"),
IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1),
IO.Boolean.Input("resize_source", default=False),
IO.Image.Input("destination", optional=True),
IO.Mask.Input("mask", optional=True),
],
outputs=[IO.Image.Output()],
)
@classmethod
def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput:
def execute(cls, source, x, y, resize_source, destination = None, mask = None) -> IO.NodeOutput:
if destination is None: # transparent rgba
B, H, W, C = source.shape
destination = torch.zeros((B, H, W, 4), dtype=source.dtype, device=source.device)
if C == 3:
source = torch.nn.functional.pad(source, (0, 1), value=1.0)
destination, source = node_helpers.image_alpha_fix(destination, source)
destination = destination.clone().movedim(-1, 1)
output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1)
@ -381,7 +397,6 @@ class GrowMask(IO.ComfyNode):
expand_mask = execute # TODO: remove
class ThresholdMask(IO.ComfyNode):
@classmethod
def define_schema(cls):

View File

@ -83,7 +83,7 @@ class IsChangedCache:
return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
input_data_all, _, v3_data, _ = get_input_data(node["inputs"], class_def, node_id, None)
try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name, v3_data=v3_data)
is_changed = await resolve_map_node_over_list_results(is_changed)
@ -215,7 +215,35 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
v3_data["hidden_inputs"] = hidden_inputs_v3
return input_data_all, missing_keys, v3_data
return input_data_all, missing_keys, v3_data, valid_inputs
def validate_resolved_inputs(input_data_all, class_def, valid_inputs):
"""Validate resolved input values against schema constraints.
This is needed because validate_inputs() only sees direct widget values.
Linked inputs aren't resolved during validate_inputs(), so this runs after resolution to catch any violations.
"""
for x, values in input_data_all.items():
input_type, input_category, extra_info = get_input_info(class_def, x, valid_inputs)
if input_type != "STRING":
continue
min_length = extra_info.get("minLength")
max_length = extra_info.get("maxLength")
if min_length is None and max_length is None:
continue
for val in values:
if val is None or not isinstance(val, str):
continue
if min_length is not None and len(val) < min_length:
raise ValueError(
f"Input '{x}': value length {len(val)} is shorter than "
f"minimum length of {min_length}"
)
if max_length is not None and len(val) > max_length:
raise ValueError(
f"Input '{x}': value length {len(val)} is longer than "
f"maximum length of {max_length}"
)
map_node_over_list = None #Don't hook this please
@ -480,7 +508,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
has_subgraph = False
else:
get_progress_state().start_progress(unique_id)
input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
input_data_all, missing_keys, v3_data, valid_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@ -509,6 +537,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
execution_list.make_input_strong_link(unique_id, i)
return (ExecutionResult.PENDING, None, None)
validate_resolved_inputs(input_data_all, class_def, valid_inputs)
def execution_block_cb(block):
if block.message is not None:
mes = {
@ -1014,6 +1044,34 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
errors.append(error)
continue
if input_type == "STRING":
if "minLength" in extra_info and len(val) < extra_info["minLength"]:
error = {
"type": "value_shorter_than_min_length",
"message": "Value length {} shorter than min length of {}".format(len(val), extra_info["minLength"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if "maxLength" in extra_info and len(val) > extra_info["maxLength"]:
error = {
"type": "value_longer_than_max_length",
"message": "Value length {} longer than max length of {}".format(len(val), extra_info["maxLength"]),
"details": f"{x}",
"extra_info": {
"input_name": x,
"input_config": info,
"received_value": val,
}
}
errors.append(error)
continue
if isinstance(input_type, list) or input_type == io.Combo.io_type:
if input_type == io.Combo.io_type:
combo_options = extra_info.get("options", [])
@ -1050,7 +1108,7 @@ async def validate_inputs(prompt_id, prompt, item, validated, visiting=None):
continue
if len(validate_function_inputs) > 0 or validate_has_kwargs:
input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id)
input_data_all, _, v3_data, _ = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs or validate_has_kwargs:

View File

@ -52,6 +52,8 @@ folder_names_and_paths["model_patches"] = ([os.path.join(models_dir, "model_patc
folder_names_and_paths["audio_encoders"] = ([os.path.join(models_dir, "audio_encoders")], supported_pt_extensions)
folder_names_and_paths["background_removal"] = ([os.path.join(models_dir, "background_removal")], supported_pt_extensions)
folder_names_and_paths["frame_interpolation"] = ([os.path.join(models_dir, "frame_interpolation")], supported_pt_extensions)
folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions)

View File

@ -2429,6 +2429,7 @@ async def init_builtin_extra_nodes():
"nodes_number_convert.py",
"nodes_painter.py",
"nodes_curve.py",
"nodes_bg_removal.py",
"nodes_rtdetr.py",
"nodes_frame_interpolation.py",
"nodes_sam3.py",

File diff suppressed because it is too large Load Diff

View File

@ -1011,3 +1011,49 @@ class TestExecution:
"""Test getting a non-existent job returns 404"""
job = client.get_job("nonexistent-job-id")
assert job is None, "Non-existent job should return None"
@pytest.mark.parametrize("text, expect_error", [
("hello", False), # 5 chars, within [3, 10]
("abc", False), # 3 chars, exact min boundary
("abcdefghij", False), # 10 chars, exact max boundary
("ab", True), # 2 chars, below min
("abcdefghijk", True), # 11 chars, above max
("", True), # 0 chars, below min
])
def test_string_length_widget_validation(self, text, expect_error, client: ComfyClient, builder: GraphBuilder):
"""Test minLength/maxLength validation for direct widget values (validate_inputs path)."""
g = builder
node = g.node("StubStringWithLength", text=text)
g.node("SaveImage", images=node.out(0))
if expect_error:
with pytest.raises(urllib.error.HTTPError) as exc_info:
client.run(g)
assert exc_info.value.code == 400
else:
client.run(g)
@pytest.mark.parametrize("text, expect_error", [
("hello", False), # 5 chars, within [3, 10]
("abc", False), # 3 chars, exact min boundary
("abcdefghij", False), # 10 chars, exact max boundary
("ab", True), # 2 chars, below min
("abcdefghijk", True), # 11 chars, above max
("", True), # 0 chars, below min
])
def test_string_length_linked_validation(self, text, expect_error, client: ComfyClient, builder: GraphBuilder):
"""Test minLength/maxLength validation for linked inputs (validate_resolved_inputs path)."""
g = builder
str_node = g.node("StubStringOutput", value=text)
node = g.node("StubStringWithLength", text=str_node.out(0))
g.node("SaveImage", images=node.out(0))
if expect_error:
try:
client.run(g)
assert False, "Should have raised an error"
except Exception as e:
assert 'prompt_id' in e.args[0], f"Did not get proper error message: {e}"
else:
client.run(g)

View File

@ -113,12 +113,48 @@ class StubFloat:
def stub_float(self, value):
return (value,)
class StubStringOutput:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("STRING", {"default": ""}),
},
}
RETURN_TYPES = ("STRING",)
FUNCTION = "stub_string"
CATEGORY = "Testing/Stub Nodes"
def stub_string(self, value):
return (value,)
class StubStringWithLength:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"text": ("STRING", {"default": "hello", "minLength": 3, "maxLength": 10}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_string_with_length"
CATEGORY = "Testing/Stub Nodes"
def stub_string_with_length(self, text):
return (torch.zeros(1, 64, 64, 3),)
TEST_STUB_NODE_CLASS_MAPPINGS = {
"StubImage": StubImage,
"StubConstantImage": StubConstantImage,
"StubMask": StubMask,
"StubInt": StubInt,
"StubFloat": StubFloat,
"StubStringOutput": StubStringOutput,
"StubStringWithLength": StubStringWithLength,
}
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
"StubImage": "Stub Image",
@ -126,4 +162,6 @@ TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
"StubMask": "Stub Mask",
"StubInt": "Stub Int",
"StubFloat": "Stub Float",
"StubStringOutput": "Stub String Output",
"StubStringWithLength": "Stub String With Length",
}