Merge branch 'Comfy-Org:master' into fix/jobs-preview-fallback-priority

This commit is contained in:
Deniz Şafak 2026-04-02 17:03:41 +03:00 committed by GitHub
commit a0c84262f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
28 changed files with 1090 additions and 94 deletions

View File

@ -61,6 +61,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
## Features ## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- NOTE: There are many more models supported than the list below, if you want to see what is supported see our templates list inside ComfyUI.
- Image Models - Image Models
- SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/)) - SD1.x, SD2.x ([unCLIP](https://comfyanonymous.github.io/ComfyUI_examples/unclip/))
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/) - [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
@ -232,7 +233,7 @@ Put your VAE in: models/vae
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1``` ```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.2```
This is the command to install the nightly with ROCm 7.2 which might have some performance improvements: This is the command to install the nightly with ROCm 7.2 which might have some performance improvements:

View File

@ -38,9 +38,12 @@ void main() {
// GIMP order: per-channel curves first, then RGB master curve. // GIMP order: per-channel curves first, then RGB master curve.
// See gimp_curve_map_pixels() default case in gimpcurve-map.c: // See gimp_curve_map_pixels() default case in gimpcurve-map.c:
// dest = colors_curve( channel_curve( src ) ) // dest = colors_curve( channel_curve( src ) )
color.r = applyCurve(u_curve0, applyCurve(u_curve1, color.r)); float tmp_r = applyCurve(u_curve1, color.r);
color.g = applyCurve(u_curve0, applyCurve(u_curve2, color.g)); float tmp_g = applyCurve(u_curve2, color.g);
color.b = applyCurve(u_curve0, applyCurve(u_curve3, color.b)); float tmp_b = applyCurve(u_curve3, color.b);
color.r = applyCurve(u_curve0, tmp_r);
color.g = applyCurve(u_curve0, tmp_g);
color.b = applyCurve(u_curve0, tmp_b);
fragColor0 = vec4(color.rgb, color.a); fragColor0 = vec4(color.rgb, color.a);
} }

File diff suppressed because one or more lines are too long

View File

@ -110,11 +110,13 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.") parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
CACHE_RAM_AUTO_GB = -1.0
cache_group = parser.add_mutually_exclusive_group() cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB") cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).")
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")

View File

@ -3,12 +3,9 @@ from ..diffusionmodules.openaimodel import Timestep
import torch import torch
class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation): class CLIPEmbeddingNoiseAugmentation(ImageConcatWithNoiseAugmentation):
def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs): def __init__(self, *args, timestep_dim=256, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if clip_stats_path is None: clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
else:
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
self.register_buffer("data_mean", clip_mean[None, :], persistent=False) self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
self.register_buffer("data_std", clip_std[None, :], persistent=False) self.register_buffer("data_std", clip_std[None, :], persistent=False)
self.time_embed = Timestep(timestep_dim) self.time_embed = Timestep(timestep_dim)

View File

@ -0,0 +1,725 @@
from collections import OrderedDict
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
COCO_CLASSES = [
'person','bicycle','car','motorcycle','airplane','bus','train','truck','boat',
'traffic light','fire hydrant','stop sign','parking meter','bench','bird','cat',
'dog','horse','sheep','cow','elephant','bear','zebra','giraffe','backpack',
'umbrella','handbag','tie','suitcase','frisbee','skis','snowboard','sports ball',
'kite','baseball bat','baseball glove','skateboard','surfboard','tennis racket',
'bottle','wine glass','cup','fork','knife','spoon','bowl','banana','apple',
'sandwich','orange','broccoli','carrot','hot dog','pizza','donut','cake','chair',
'couch','potted plant','bed','dining table','toilet','tv','laptop','mouse',
'remote','keyboard','cell phone','microwave','oven','toaster','sink',
'refrigerator','book','clock','vase','scissors','teddy bear','hair drier','toothbrush',
]
# ---------------------------------------------------------------------------
# HGNetv2 backbone
# ---------------------------------------------------------------------------
class ConvBNAct(nn.Module):
"""Conv→BN→ReLU. padding='same' adds asymmetric zero-pad (stem)."""
def __init__(self, ic, oc, k=3, s=1, groups=1, use_act=True, device=None, dtype=None, operations=None):
super().__init__()
self.conv = operations.Conv2d(ic, oc, k, s, (k - 1) // 2, groups=groups, bias=False, device=device, dtype=dtype)
self.bn = nn.BatchNorm2d(oc, device=device, dtype=dtype)
self.act = nn.ReLU() if use_act else nn.Identity()
def forward(self, x):
return self.act(self.bn(self.conv(x)))
class LightConvBNAct(nn.Module):
def __init__(self, ic, oc, k, device=None, dtype=None, operations=None):
super().__init__()
self.conv1 = ConvBNAct(ic, oc, 1, use_act=False, device=device, dtype=dtype, operations=operations)
self.conv2 = ConvBNAct(oc, oc, k, groups=oc, use_act=True, device=device, dtype=dtype, operations=operations)
def forward(self, x):
return self.conv2(self.conv1(x))
class _StemBlock(nn.Module):
def __init__(self, ic, mc, oc, device=None, dtype=None, operations=None):
super().__init__()
self.stem1 = ConvBNAct(ic, mc, 3, 2, device=device, dtype=dtype, operations=operations)
# stem2a/stem2b use kernel=2, stride=1, no internal padding;
# padding is applied manually in forward (matching PaddlePaddle original)
self.stem2a = ConvBNAct(mc, mc//2, 2, 1, device=device, dtype=dtype, operations=operations)
self.stem2b = ConvBNAct(mc//2, mc, 2, 1, device=device, dtype=dtype, operations=operations)
self.stem3 = ConvBNAct(mc*2, mc, 3, 2, device=device, dtype=dtype, operations=operations)
self.stem4 = ConvBNAct(mc, oc, 1, device=device, dtype=dtype, operations=operations)
self.pool = nn.MaxPool2d(2, 1, ceil_mode=True)
def forward(self, x):
x = self.stem1(x)
x = F.pad(x, (0, 1, 0, 1)) # pad before pool and stem2a
x2 = self.stem2a(x)
x2 = F.pad(x2, (0, 1, 0, 1)) # pad before stem2b
x2 = self.stem2b(x2)
x1 = self.pool(x)
return self.stem4(self.stem3(torch.cat([x1, x2], 1)))
class _HG_Block(nn.Module):
def __init__(self, ic, mc, oc, layer_num, k=3, residual=False, light=False, device=None, dtype=None, operations=None):
super().__init__()
self.residual = residual
if light:
self.layers = nn.ModuleList(
[LightConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)])
else:
self.layers = nn.ModuleList(
[ConvBNAct(ic if i == 0 else mc, mc, k, device=device, dtype=dtype, operations=operations) for i in range(layer_num)])
total = ic + layer_num * mc
self.aggregation = nn.Sequential(
ConvBNAct(total, oc // 2, 1, device=device, dtype=dtype, operations=operations),
ConvBNAct(oc // 2, oc, 1, device=device, dtype=dtype, operations=operations))
def forward(self, x):
identity = x
outs = [x]
for layer in self.layers:
x = layer(x)
outs.append(x)
x = self.aggregation(torch.cat(outs, 1))
return x + identity if self.residual else x
class _HG_Stage(nn.Module):
# config order: ic, mc, oc, num_blocks, downsample, light, k, layer_num
def __init__(self, ic, mc, oc, num_blocks, downsample=True, light=False, k=3, layer_num=6, device=None, dtype=None, operations=None):
super().__init__()
if downsample:
self.downsample = ConvBNAct(ic, ic, 3, 2, groups=ic, use_act=False, device=device, dtype=dtype, operations=operations)
else:
self.downsample = nn.Identity()
self.blocks = nn.Sequential(*[
_HG_Block(ic if i == 0 else oc, mc, oc, layer_num,
k=k, residual=(i != 0), light=light, device=device, dtype=dtype, operations=operations)
for i in range(num_blocks)
])
def forward(self, x):
return self.blocks(self.downsample(x))
class HGNetv2(nn.Module):
# B5 config: stem=[3,32,64], stages=[ic, mc, oc, blocks, down, light, k, layers]
_STAGE_CFGS = [[64, 64, 128, 1, False, False, 3, 6],
[128, 128, 512, 2, True, False, 3, 6],
[512, 256, 1024, 5, True, True, 5, 6],
[1024,512, 2048, 2, True, True, 5, 6]]
def __init__(self, return_idx=(1, 2, 3), device=None, dtype=None, operations=None):
super().__init__()
self.stem = _StemBlock(3, 32, 64, device=device, dtype=dtype, operations=operations)
self.stages = nn.ModuleList([_HG_Stage(*cfg, device=device, dtype=dtype, operations=operations) for cfg in self._STAGE_CFGS])
self.return_idx = list(return_idx)
self.out_channels = [self._STAGE_CFGS[i][2] for i in return_idx]
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
x = self.stem(x)
outs = []
for i, stage in enumerate(self.stages):
x = stage(x)
if i in self.return_idx:
outs.append(x)
return outs
# ---------------------------------------------------------------------------
# Encoder — HybridEncoder (dfine version: RepNCSPELAN4 + SCDown PAN)
# ---------------------------------------------------------------------------
class ConvNormLayer(nn.Module):
"""Conv→act (expects pre-fused BN weights)."""
def __init__(self, ic, oc, k, s, g=1, padding=None, act=None, device=None, dtype=None, operations=None):
super().__init__()
p = (k - 1) // 2 if padding is None else padding
self.conv = operations.Conv2d(ic, oc, k, s, p, groups=g, bias=True, device=device, dtype=dtype)
self.act = nn.SiLU() if act == 'silu' else nn.Identity()
def forward(self, x):
return self.act(self.conv(x))
class VGGBlock(nn.Module):
"""Rep-VGG block (expects pre-fused weights)."""
def __init__(self, ic, oc, device=None, dtype=None, operations=None):
super().__init__()
self.conv = operations.Conv2d(ic, oc, 3, 1, padding=1, bias=True, device=device, dtype=dtype)
self.act = nn.SiLU()
def forward(self, x):
return self.act(self.conv(x))
class CSPLayer(nn.Module):
def __init__(self, ic, oc, num_blocks=3, expansion=1.0, act='silu', device=None, dtype=None, operations=None):
super().__init__()
h = int(oc * expansion)
self.conv1 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
self.conv2 = ConvNormLayer(ic, h, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
self.bottlenecks = nn.Sequential(*[VGGBlock(h, h, device=device, dtype=dtype, operations=operations) for _ in range(num_blocks)])
self.conv3 = ConvNormLayer(h, oc, 1, 1, act=act, device=device, dtype=dtype, operations=operations) if h != oc else nn.Identity()
def forward(self, x):
return self.conv3(self.bottlenecks(self.conv1(x)) + self.conv2(x))
class RepNCSPELAN4(nn.Module):
"""CSP-ELAN block — the FPN/PAN block in RTv4's HybridEncoder."""
def __init__(self, c1, c2, c3, c4, n=3, act='silu', device=None, dtype=None, operations=None):
super().__init__()
self.c = c3 // 2
self.cv1 = ConvNormLayer(c1, c3, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
self.cv2 = nn.Sequential(CSPLayer(c3 // 2, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations))
self.cv3 = nn.Sequential(CSPLayer(c4, c4, n, 1.0, act=act, device=device, dtype=dtype, operations=operations), ConvNormLayer(c4, c4, 3, 1, act=act, device=device, dtype=dtype, operations=operations))
self.cv4 = ConvNormLayer(c3 + 2 * c4, c2, 1, 1, act=act, device=device, dtype=dtype, operations=operations)
def forward(self, x):
y = list(self.cv1(x).split((self.c, self.c), 1))
y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
return self.cv4(torch.cat(y, 1))
class SCDown(nn.Module):
"""Separable conv downsampling used in HybridEncoder PAN bottom-up path."""
def __init__(self, ic, oc, k, s, device=None, dtype=None, operations=None):
super().__init__()
self.cv1 = ConvNormLayer(ic, oc, 1, 1, device=device, dtype=dtype, operations=operations)
self.cv2 = ConvNormLayer(oc, oc, k, s, g=oc, device=device, dtype=dtype, operations=operations)
def forward(self, x):
return self.cv2(self.cv1(x))
class SelfAttention(nn.Module):
def __init__(self, embed_dim, num_heads, device=None, dtype=None, operations=None):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
self.k_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
self.v_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
self.out_proj = operations.Linear(embed_dim, embed_dim, device=device, dtype=dtype)
def forward(self, query, key, value, attn_mask=None):
optimized_attention = optimized_attention_for_device(query.device, False, small_input=True)
q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)
out = optimized_attention(q, k, v, heads=self.num_heads, mask=attn_mask)
return self.out_proj(out)
class _TransformerEncoderLayer(nn.Module):
"""Single AIFI encoder layer (pre- or post-norm, GELU by default)."""
def __init__(self, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations)
self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.norm2 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.activation = nn.GELU()
def forward(self, src, src_mask=None, pos_embed=None):
q = k = src if pos_embed is None else src + pos_embed
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask)
src = self.norm1(src + src2)
src2 = self.linear2(self.activation(self.linear1(src)))
return self.norm2(src + src2)
class _TransformerEncoder(nn.Module):
"""Thin wrapper so state-dict keys are encoder.0.layers.N.*"""
def __init__(self, num_layers, d_model, nhead, dim_feedforward, device=None, dtype=None, operations=None):
super().__init__()
self.layers = nn.ModuleList([
_TransformerEncoderLayer(d_model, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations)
for _ in range(num_layers)
])
def forward(self, src, src_mask=None, pos_embed=None):
for layer in self.layers:
src = layer(src, src_mask=src_mask, pos_embed=pos_embed)
return src
class HybridEncoder(nn.Module):
def __init__(self, in_channels=(512, 1024, 2048), feat_strides=(8, 16, 32), hidden_dim=256, nhead=8, dim_feedforward=2048, use_encoder_idx=(2,), num_encoder_layers=1,
pe_temperature=10000, expansion=1.0, depth_mult=1.0, act='silu', eval_spatial_size=(640, 640), device=None, dtype=None, operations=None):
super().__init__()
self.in_channels = list(in_channels)
self.feat_strides = list(feat_strides)
self.hidden_dim = hidden_dim
self.use_encoder_idx = list(use_encoder_idx)
self.pe_temperature = pe_temperature
self.eval_spatial_size = eval_spatial_size
self.out_channels = [hidden_dim] * len(in_channels)
self.out_strides = list(feat_strides)
# channel projection (expects pre-fused weights)
self.input_proj = nn.ModuleList([
nn.Sequential(OrderedDict([('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))]))
for ch in in_channels
])
# AIFI transformer — use _TransformerEncoder so keys are encoder.0.layers.N.*
self.encoder = nn.ModuleList([
_TransformerEncoder(num_encoder_layers, hidden_dim, nhead, dim_feedforward, device=device, dtype=dtype, operations=operations)
for _ in range(len(use_encoder_idx))
])
nb = round(3 * depth_mult)
exp = expansion
# top-down FPN (dfine: lateral conv has no act)
self.lateral_convs = nn.ModuleList(
[ConvNormLayer(hidden_dim, hidden_dim, 1, 1, device=device, dtype=dtype, operations=operations)
for _ in range(len(in_channels) - 1)])
self.fpn_blocks = nn.ModuleList(
[RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations)
for _ in range(len(in_channels) - 1)])
# bottom-up PAN (dfine: nn.Sequential(SCDown) — keeps checkpoint key .0.cv1/.0.cv2)
self.downsample_convs = nn.ModuleList(
[nn.Sequential(SCDown(hidden_dim, hidden_dim, 3, 2, device=device, dtype=dtype, operations=operations))
for _ in range(len(in_channels) - 1)])
self.pan_blocks = nn.ModuleList(
[RepNCSPELAN4(hidden_dim * 2, hidden_dim, hidden_dim * 2, round(exp * hidden_dim // 2), nb, act=act, device=device, dtype=dtype, operations=operations)
for _ in range(len(in_channels) - 1)])
# cache positional embeddings for fixed spatial size
if eval_spatial_size:
for idx in self.use_encoder_idx:
stride = self.feat_strides[idx]
pe = self._build_pe(eval_spatial_size[1] // stride,
eval_spatial_size[0] // stride,
hidden_dim, pe_temperature)
setattr(self, f'pos_embed{idx}', pe)
@staticmethod
def _build_pe(w, h, dim=256, temp=10000.):
assert dim % 4 == 0
gw = torch.arange(w, dtype=torch.float32)
gh = torch.arange(h, dtype=torch.float32)
gw, gh = torch.meshgrid(gw, gh, indexing='ij')
pdim = dim // 4
omega = 1. / (temp ** (torch.arange(pdim, dtype=torch.float32) / pdim))
ow = gw.flatten()[:, None] @ omega[None]
oh = gh.flatten()[:, None] @ omega[None]
return torch.cat([ow.sin(), ow.cos(), oh.sin(), oh.cos()], 1)[None]
def forward(self, feats: List[torch.Tensor]) -> List[torch.Tensor]:
proj = [self.input_proj[i](f) for i, f in enumerate(feats)]
for i, enc_idx in enumerate(self.use_encoder_idx):
h, w = proj[enc_idx].shape[2:]
src = proj[enc_idx].flatten(2).permute(0, 2, 1)
pe = getattr(self, f'pos_embed{enc_idx}').to(device=src.device, dtype=src.dtype)
for layer in self.encoder[i].layers:
src = layer(src, pos_embed=pe)
proj[enc_idx] = src.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous()
n = len(self.in_channels)
inner = [proj[-1]]
for k in range(n - 1, 0, -1):
j = n - 1 - k
top = self.lateral_convs[j](inner[0])
inner[0] = top
up = F.interpolate(top, scale_factor=2., mode='nearest')
inner.insert(0, self.fpn_blocks[j](torch.cat([up, proj[k - 1]], 1)))
outs = [inner[0]]
for k in range(n - 1):
outs.append(self.pan_blocks[k](
torch.cat([self.downsample_convs[k](outs[-1]), inner[k + 1]], 1)))
return outs
# ---------------------------------------------------------------------------
# Decoder — DFINETransformer
# ---------------------------------------------------------------------------
def _deformable_attn_v2(value: list, spatial_shapes, sampling_locations: torch.Tensor, attention_weights: torch.Tensor, num_points_list: List[int]) -> torch.Tensor:
"""
value : list of per-level tensors [bs*n_head, c, h_l, w_l]
sampling_locations: [bs, Lq, n_head, sum(pts), 2] in [0,1]
attention_weights : [bs, Lq, n_head, sum(pts)]
"""
_, c = value[0].shape[:2] # bs*n_head, c
_, Lq, n_head, _, _ = sampling_locations.shape
bs = sampling_locations.shape[0]
n_h = n_head
grids = (2 * sampling_locations - 1) # [bs, Lq, n_head, sum_pts, 2]
grids = grids.permute(0, 2, 1, 3, 4).flatten(0, 1) # [bs*n_head, Lq, sum_pts, 2]
grids_per_lvl = grids.split(num_points_list, dim=2) # list of [bs*n_head, Lq, pts_l, 2]
sampled = []
for lvl, (h, w) in enumerate(spatial_shapes):
val_l = value[lvl].reshape(bs * n_h, c, h, w)
sv = F.grid_sample(val_l, grids_per_lvl[lvl], mode='bilinear', padding_mode='zeros', align_corners=False)
sampled.append(sv) # sv: [bs*n_head, c, Lq, pts_l]
attn = attention_weights.permute(0, 2, 1, 3) # [bs, n_head, Lq, sum_pts]
attn = attn.flatten(0, 1).unsqueeze(1) # [bs*n_head, 1, Lq, sum_pts]
out = (torch.cat(sampled, -1) * attn).sum(-1) # [bs*n_head, c, Lq]
out = out.reshape(bs, n_h * c, Lq)
return out.permute(0, 2, 1) # [bs, Lq, hidden]
class MSDeformableAttention(nn.Module):
def __init__(self, embed_dim=256, num_heads=8, num_levels=3, num_points=4, offset_scale=0.5, device=None, dtype=None, operations=None):
super().__init__()
self.embed_dim, self.num_heads = embed_dim, num_heads
self.head_dim = embed_dim // num_heads
pts = num_points if isinstance(num_points, list) else [num_points] * num_levels
self.num_points_list = pts
self.offset_scale = offset_scale
total = num_heads * sum(pts)
self.register_buffer('num_points_scale', torch.tensor([1. / n for n in pts for _ in range(n)], dtype=torch.float32))
self.sampling_offsets = operations.Linear(embed_dim, total * 2, device=device, dtype=dtype)
self.attention_weights = operations.Linear(embed_dim, total, device=device, dtype=dtype)
def forward(self, query, ref_pts, value, spatial_shapes):
bs, Lq = query.shape[:2]
offsets = self.sampling_offsets(query).reshape(
bs, Lq, self.num_heads, sum(self.num_points_list), 2)
attn_w = F.softmax(
self.attention_weights(query).reshape(
bs, Lq, self.num_heads, sum(self.num_points_list)), -1)
scale = self.num_points_scale.to(query).unsqueeze(-1)
offset = offsets * scale * ref_pts[:, :, None, :, 2:] * self.offset_scale
locs = ref_pts[:, :, None, :, :2] + offset # [bs, Lq, n_head, sum_pts, 2]
return _deformable_attn_v2(value, spatial_shapes, locs, attn_w, self.num_points_list)
class Gate(nn.Module):
def __init__(self, d_model, device=None, dtype=None, operations=None):
super().__init__()
self.gate = operations.Linear(2 * d_model, 2 * d_model, device=device, dtype=dtype)
self.norm = operations.LayerNorm(d_model, device=device, dtype=dtype)
def forward(self, x1, x2):
g1, g2 = torch.sigmoid(self.gate(torch.cat([x1, x2], -1))).chunk(2, -1)
return self.norm(g1 * x1 + g2 * x2)
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, num_layers, device=None, dtype=None, operations=None):
super().__init__()
dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim]
self.layers = nn.ModuleList(operations.Linear(dims[i], dims[i + 1], device=device, dtype=dtype) for i in range(num_layers))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = nn.SiLU()(layer(x)) if i < len(self.layers) - 1 else layer(x)
return x
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model=256, nhead=8, dim_feedforward=1024, num_levels=3, num_points=4, device=None, dtype=None, operations=None):
super().__init__()
self.self_attn = SelfAttention(d_model, nhead, device=device, dtype=dtype, operations=operations)
self.norm1 = operations.LayerNorm(d_model, device=device, dtype=dtype)
self.cross_attn = MSDeformableAttention(d_model, nhead, num_levels, num_points, device=device, dtype=dtype, operations=operations)
self.gateway = Gate(d_model, device=device, dtype=dtype, operations=operations)
self.linear1 = operations.Linear(d_model, dim_feedforward, device=device, dtype=dtype)
self.activation = nn.ReLU()
self.linear2 = operations.Linear(dim_feedforward, d_model, device=device, dtype=dtype)
self.norm3 = operations.LayerNorm(d_model, device=device, dtype=dtype)
def forward(self, target, ref_pts, value, spatial_shapes, attn_mask=None, query_pos=None):
q = k = target if query_pos is None else target + query_pos
t2 = self.self_attn(q, k, value=target, attn_mask=attn_mask)
target = self.norm1(target + t2)
t2 = self.cross_attn(
target if query_pos is None else target + query_pos,
ref_pts, value, spatial_shapes)
target = self.gateway(target, t2)
t2 = self.linear2(self.activation(self.linear1(target)))
target = self.norm3((target + t2).clamp(-65504, 65504))
return target
# ---------------------------------------------------------------------------
# FDR utilities
# ---------------------------------------------------------------------------
def weighting_function(reg_max, up, reg_scale):
"""Non-uniform weighting function W(n) for FDR box regression."""
ub1 = (abs(up[0]) * abs(reg_scale)).item()
ub2 = ub1 * 2
step = (ub1 + 1) ** (2 / (reg_max - 2))
left = [-(step ** i) + 1 for i in range(reg_max // 2 - 1, 0, -1)]
right = [ (step ** i) - 1 for i in range(1, reg_max // 2)]
vals = [-ub2] + left + [0] + right + [ub2]
return torch.tensor(vals, dtype=up.dtype, device=up.device)
def distance2bbox(points, distance, reg_scale):
"""Decode edge-distances → cxcywh boxes."""
rs = abs(reg_scale).to(dtype=points.dtype)
x1 = points[..., 0] - (0.5 * rs + distance[..., 0]) * (points[..., 2] / rs)
y1 = points[..., 1] - (0.5 * rs + distance[..., 1]) * (points[..., 3] / rs)
x2 = points[..., 0] + (0.5 * rs + distance[..., 2]) * (points[..., 2] / rs)
y2 = points[..., 1] + (0.5 * rs + distance[..., 3]) * (points[..., 3] / rs)
x0, y0, x1_, y1_ = (x1 + x2) / 2, (y1 + y2) / 2, x2 - x1, y2 - y1
return torch.stack([x0, y0, x1_, y1_], -1)
class Integral(nn.Module):
"""Sum Pr(n)·W(n) over the distribution bins."""
def __init__(self, reg_max=32):
super().__init__()
self.reg_max = reg_max
def forward(self, x, project):
shape = x.shape
x = F.softmax(x.reshape(-1, self.reg_max + 1), 1)
x = F.linear(x, project.to(device=x.device, dtype=x.dtype)).reshape(-1, 4)
return x.reshape(list(shape[:-1]) + [-1])
class LQE(nn.Module):
"""Location Quality Estimator — refines class scores using corner distribution."""
def __init__(self, k=4, hidden_dim=64, num_layers=2, reg_max=32, device=None, dtype=None, operations=None):
super().__init__()
self.k, self.reg_max = k, reg_max
self.reg_conf = MLP(4 * (k + 1), hidden_dim, 1, num_layers, device=device, dtype=dtype, operations=operations)
def forward(self, scores, pred_corners):
B, L, _ = pred_corners.shape
prob = F.softmax(pred_corners.reshape(B, L, 4, self.reg_max + 1), -1)
topk, _ = prob.topk(self.k, -1)
stat = torch.cat([topk, topk.mean(-1, keepdim=True)], -1)
return scores + self.reg_conf(stat.reshape(B, L, -1))
class TransformerDecoder(nn.Module):
def __init__(self, hidden_dim, nhead, dim_feedforward, num_levels, num_points, num_layers, reg_max, reg_scale, up, eval_idx=-1, device=None, dtype=None, operations=None):
super().__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.nhead = nhead
self.eval_idx = eval_idx if eval_idx >= 0 else num_layers + eval_idx
self.up, self.reg_scale, self.reg_max = up, reg_scale, reg_max
self.layers = nn.ModuleList([
TransformerDecoderLayer(hidden_dim, nhead, dim_feedforward, num_levels, num_points, device=device, dtype=dtype, operations=operations)
for _ in range(self.eval_idx + 1)
])
self.lqe_layers = nn.ModuleList([LQE(4, 64, 2, reg_max, device=device, dtype=dtype, operations=operations) for _ in range(self.eval_idx + 1)])
self.register_buffer('project', weighting_function(reg_max, up, reg_scale))
def _value_op(self, memory, spatial_shapes):
"""Reshape memory to per-level value tensors for deformable attention."""
c = self.hidden_dim // self.nhead
split = [h * w for h, w in spatial_shapes]
val = memory.reshape(memory.shape[0], memory.shape[1], self.nhead, c) # memory: [bs, sum(h*w), hidden_dim]
# → [bs, n_head, c, sum_hw]
val = val.permute(0, 2, 3, 1).flatten(0, 1) # [bs*n_head, c, sum_hw]
return val.split(split, dim=-1) # list of [bs*n_head, c, h_l*w_l]
def forward(self, target, ref_pts_unact, memory, spatial_shapes, bbox_head, score_head, query_pos_head, pre_bbox_head, integral):
val_split_flat = self._value_op(memory, spatial_shapes) # pre-split value for deformable attention
# reshape to [bs*n_head, c, h_l, w_l]
value = []
for lvl, (h, w) in enumerate(spatial_shapes):
v = val_split_flat[lvl] # [bs*n_head, c, h*w]
value.append(v.reshape(v.shape[0], v.shape[1], h, w))
ref_pts = F.sigmoid(ref_pts_unact)
output = target
output_detach = pred_corners_undetach = 0
dec_bboxes, dec_logits = [], []
for i, layer in enumerate(self.layers):
ref_input = ref_pts.unsqueeze(2) # [bs, Lq, 1, 4]
query_pos = query_pos_head(ref_pts).clamp(-10, 10)
output = layer(output, ref_input, value, spatial_shapes, query_pos=query_pos)
if i == 0:
ref_unact = ref_pts.clamp(1e-5, 1 - 1e-5)
ref_unact = torch.log(ref_unact / (1 - ref_unact))
pre_bboxes = F.sigmoid(pre_bbox_head(output) + ref_unact)
ref_pts_initial = pre_bboxes.detach()
pred_corners = bbox_head[i](output + output_detach) + pred_corners_undetach
inter_ref_bbox = distance2bbox(ref_pts_initial, integral(pred_corners, self.project), self.reg_scale)
if i == self.eval_idx:
scores = score_head[i](output)
scores = self.lqe_layers[i](scores, pred_corners)
dec_bboxes.append(inter_ref_bbox)
dec_logits.append(scores)
break
pred_corners_undetach = pred_corners
ref_pts = inter_ref_bbox.detach()
output_detach = output.detach()
return torch.stack(dec_bboxes), torch.stack(dec_logits)
class DFINETransformer(nn.Module):
def __init__(self, num_classes=80, hidden_dim=256, num_queries=300, feat_channels=[256, 256, 256], feat_strides=[8, 16, 32],
num_levels=3, num_points=[3, 6, 3], nhead=8, num_layers=6, dim_feedforward=1024, eval_idx=-1, eps=1e-2, reg_max=32,
reg_scale=8.0, eval_spatial_size=(640, 640), device=None, dtype=None, operations=None):
super().__init__()
assert len(feat_strides) == len(feat_channels)
self.hidden_dim = hidden_dim
self.num_queries = num_queries
self.num_levels = num_levels
self.eps = eps
self.eval_spatial_size = eval_spatial_size
self.feat_strides = list(feat_strides)
for i in range(num_levels - len(feat_strides)):
self.feat_strides.append(feat_strides[-1] * 2 ** (i + 1))
# input projection (expects pre-fused weights)
self.input_proj = nn.ModuleList()
for ch in feat_channels:
if ch == hidden_dim:
self.input_proj.append(nn.Identity())
else:
self.input_proj.append(nn.Sequential(OrderedDict([
('conv', operations.Conv2d(ch, hidden_dim, 1, bias=True, device=device, dtype=dtype))])))
in_ch = feat_channels[-1]
for i in range(num_levels - len(feat_channels)):
self.input_proj.append(nn.Sequential(OrderedDict([
('conv', operations.Conv2d(in_ch if i == 0 else hidden_dim,
hidden_dim, 3, 2, 1, bias=True, device=device, dtype=dtype))])))
in_ch = hidden_dim
# FDR parameters (non-trainable placeholders, set from config)
self.up = nn.Parameter(torch.tensor([0.5]), requires_grad=False)
self.reg_scale = nn.Parameter(torch.tensor([reg_scale]), requires_grad=False)
pts = num_points if isinstance(num_points, (list, tuple)) else [num_points] * num_levels
self.decoder = TransformerDecoder(hidden_dim, nhead, dim_feedforward, num_levels, pts,
num_layers, reg_max, self.reg_scale, self.up, eval_idx, device=device, dtype=dtype, operations=operations)
self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, 2, device=device, dtype=dtype, operations=operations)
self.enc_output = nn.Sequential(OrderedDict([
('proj', operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype)),
('norm', operations.LayerNorm(hidden_dim, device=device, dtype=dtype))]))
self.enc_score_head = operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype)
self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations)
self.eval_idx_ = eval_idx if eval_idx >= 0 else num_layers + eval_idx
self.dec_score_head = nn.ModuleList(
[operations.Linear(hidden_dim, num_classes, device=device, dtype=dtype) for _ in range(self.eval_idx_ + 1)])
self.pre_bbox_head = MLP(hidden_dim, hidden_dim, 4, 3, device=device, dtype=dtype, operations=operations)
self.dec_bbox_head = nn.ModuleList(
[MLP(hidden_dim, hidden_dim, 4 * (reg_max + 1), 3, device=device, dtype=dtype, operations=operations)
for _ in range(self.eval_idx_ + 1)])
self.integral = Integral(reg_max)
if eval_spatial_size:
# Register as buffers so checkpoint values override the freshly-computed defaults
anchors, valid_mask = self._gen_anchors()
self.register_buffer('anchors', anchors)
self.register_buffer('valid_mask', valid_mask)
def _gen_anchors(self, spatial_shapes=None, grid_size=0.05, dtype=torch.float32, device='cpu'):
if spatial_shapes is None:
h0, w0 = self.eval_spatial_size
spatial_shapes = [[int(h0 / s), int(w0 / s)] for s in self.feat_strides]
anchors = []
for lvl, (h, w) in enumerate(spatial_shapes):
gy, gx = torch.meshgrid(torch.arange(h), torch.arange(w), indexing='ij')
gxy = (torch.stack([gx, gy], -1).float() + 0.5) / torch.tensor([w, h], dtype=dtype)
wh = torch.ones_like(gxy) * grid_size * (2. ** lvl)
anchors.append(torch.cat([gxy, wh], -1).reshape(-1, h * w, 4))
anchors = torch.cat(anchors, 1).to(device)
valid_mask = ((anchors > self.eps) & (anchors < 1 - self.eps)).all(-1, keepdim=True)
anchors = torch.log(anchors / (1 - anchors))
anchors = torch.where(valid_mask, anchors, torch.full_like(anchors, float('inf')))
return anchors, valid_mask
def _encoder_input(self, feats: List[torch.Tensor]):
proj = [self.input_proj[i](f) for i, f in enumerate(feats)]
for i in range(len(feats), self.num_levels):
proj.append(self.input_proj[i](feats[-1] if i == len(feats) else proj[-1]))
flat, shapes = [], []
for f in proj:
_, _, h, w = f.shape
flat.append(f.flatten(2).permute(0, 2, 1))
shapes.append([h, w])
return torch.cat(flat, 1), shapes
def _decoder_input(self, memory: torch.Tensor):
anchors, valid_mask = self.anchors.to(memory), self.valid_mask
if memory.shape[0] > 1:
anchors = anchors.repeat(memory.shape[0], 1, 1)
mem = valid_mask.to(memory) * memory
out_mem = self.enc_output(mem)
logits = self.enc_score_head(out_mem)
_, idx = torch.topk(logits.max(-1).values, self.num_queries, dim=-1)
idx_e = idx.unsqueeze(-1)
topk_mem = out_mem.gather(1, idx_e.expand(-1, -1, out_mem.shape[-1]))
topk_anc = anchors.gather(1, idx_e.expand(-1, -1, anchors.shape[-1]))
topk_ref = self.enc_bbox_head(topk_mem) + topk_anc
return topk_mem.detach(), topk_ref.detach()
def forward(self, feats: List[torch.Tensor]):
memory, shapes = self._encoder_input(feats)
content, ref = self._decoder_input(memory)
out_bboxes, out_logits = self.decoder(
content, ref, memory, shapes,
self.dec_bbox_head, self.dec_score_head,
self.query_pos_head, self.pre_bbox_head, self.integral)
return {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]}
# ---------------------------------------------------------------------------
# Main model
# ---------------------------------------------------------------------------
class RTv4(nn.Module):
def __init__(self, num_classes=80, num_queries=300, enc_h=256, dec_h=256, enc_ff=2048, dec_ff=1024, feat_strides=[8, 16, 32], device=None, dtype=None, operations=None, **kwargs):
super().__init__()
self.device = device
self.dtype = dtype
self.operations = operations
self.backbone = HGNetv2(device=device, dtype=dtype, operations=operations)
self.encoder = HybridEncoder(hidden_dim=enc_h, dim_feedforward=enc_ff, device=device, dtype=dtype, operations=operations)
self.decoder = DFINETransformer(num_classes=num_classes, hidden_dim=dec_h, num_queries=num_queries,
feat_channels=[enc_h] * len(feat_strides), feat_strides=feat_strides, dim_feedforward=dec_ff, device=device, dtype=dtype, operations=operations)
self.num_classes = num_classes
self.num_queries = num_queries
self.load_device = comfy.model_management.get_torch_device()
def _forward(self, x: torch.Tensor):
return self.decoder(self.encoder(self.backbone(x)))
def postprocess(self, outputs, orig_size: tuple = (640, 640)) -> List[dict]:
logits = outputs['pred_logits']
boxes = torchvision.ops.box_convert(outputs['pred_boxes'], 'cxcywh', 'xyxy')
boxes = boxes * torch.tensor(orig_size, device=boxes.device, dtype=boxes.dtype).repeat(1, 2).unsqueeze(1)
scores = F.sigmoid(logits)
scores, idx = torch.topk(scores.flatten(1), self.num_queries, dim=-1)
labels = idx % self.num_classes
boxes = boxes.gather(1, (idx // self.num_classes).unsqueeze(-1).expand(-1, -1, 4))
return [{'labels': lbl, 'boxes': b, 'scores': s} for lbl, b, s in zip(labels, boxes, scores)]
def forward(self, x: torch.Tensor, orig_size: tuple = (640, 640), **kwargs):
outputs = self._forward(x.to(device=self.load_device, dtype=self.dtype))
return self.postprocess(outputs, orig_size)

View File

@ -141,3 +141,17 @@ def interpret_gathered_like(tensors, gathered):
return dest_views return dest_views
aimdo_enabled = False aimdo_enabled = False
extra_ram_release_callback = None
RAM_CACHE_HEADROOM = 0
def set_ram_cache_release_state(callback, headroom):
global extra_ram_release_callback
global RAM_CACHE_HEADROOM
extra_ram_release_callback = callback
RAM_CACHE_HEADROOM = max(0, int(headroom))
def extra_ram_release(target):
if extra_ram_release_callback is None:
return 0
return extra_ram_release_callback(target)

View File

@ -52,6 +52,7 @@ import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model import comfy.ldm.anima.model
import comfy.ldm.ace.ace_step15 import comfy.ldm.ace.ace_step15
import comfy.ldm.rt_detr.rtdetr_v4
import comfy.model_management import comfy.model_management
import comfy.patcher_extension import comfy.patcher_extension
@ -890,7 +891,7 @@ class Flux(BaseModel):
return torch.cat((image, mask), dim=1) return torch.cat((image, mask), dim=1)
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
return kwargs["pooled_output"] return kwargs.get("pooled_output", None)
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs) out = super().extra_conds(**kwargs)
@ -1957,3 +1958,7 @@ class Kandinsky5Image(Kandinsky5):
def concat_cond(self, **kwargs): def concat_cond(self, **kwargs):
return None return None
class RT_DETR_v4(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)

View File

@ -698,6 +698,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["audio_model"] = "ace1.5" dit_config["audio_model"] = "ace1.5"
return dit_config return dit_config
if '{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix) in state_dict_keys: # RT-DETR_v4
dit_config = {}
dit_config["image_model"] = "RT_DETR_v4"
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None return None

View File

@ -669,7 +669,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for i in range(len(current_loaded_models) -1, -1, -1): for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i] shift_model = current_loaded_models[i]
if shift_model.device == device: if device is None or shift_model.device == device:
if shift_model not in keep_loaded and not shift_model.is_dead(): if shift_model not in keep_loaded and not shift_model.is_dead():
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False shift_model.currently_used = False
@ -679,8 +679,8 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
i = x[-1] i = x[-1]
memory_to_free = 1e32 memory_to_free = 1e32
pins_to_free = 1e32 pins_to_free = 1e32
if not DISABLE_SMART_MEMORY: if not DISABLE_SMART_MEMORY or device is None:
memory_to_free = memory_required - get_free_memory(device) memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
pins_to_free = pins_required - get_free_ram() pins_to_free = pins_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic: if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models #don't actually unload dynamic models for the sake of other dynamic models
@ -708,7 +708,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
if len(unloaded_model) > 0: if len(unloaded_model) > 0:
soft_empty_cache() soft_empty_cache()
else: elif device is not None:
if vram_state != VRAMState.HIGH_VRAM: if vram_state != VRAMState.HIGH_VRAM:
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True) mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
if mem_free_torch > mem_free_total * 0.25: if mem_free_torch > mem_free_total * 0.25:
@ -1326,9 +1326,9 @@ MAX_PINNED_MEMORY = -1
if not args.disable_pinned_memory: if not args.disable_pinned_memory:
if is_nvidia() or is_amd(): if is_nvidia() or is_amd():
if WINDOWS: if WINDOWS:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.40 # Windows limit is apparently 50%
else: else:
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.90
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"]) PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
@ -1403,8 +1403,6 @@ def unpin_memory(tensor):
if torch.cuda.cudart().cudaHostUnregister(ptr) == 0: if torch.cuda.cudart().cudaHostUnregister(ptr) == 0:
TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr) TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr)
if len(PINNED_MEMORY) == 0:
TOTAL_PINNED_MEMORY = 0
return True return True
else: else:
logging.warning("Unpin error.") logging.warning("Unpin error.")

View File

@ -300,9 +300,6 @@ class ModelPatcher:
def model_mmap_residency(self, free=False): def model_mmap_residency(self, free=False):
return comfy.model_management.module_mmap_residency(self.model, free=free) return comfy.model_management.module_mmap_residency(self.model, free=free)
def get_ram_usage(self):
return self.model_size()
def loaded_size(self): def loaded_size(self):
return self.model.model_loaded_weight_memory return self.model.model_loaded_weight_memory

View File

@ -2,6 +2,7 @@ import comfy.model_management
import comfy.memory_management import comfy.memory_management
import comfy_aimdo.host_buffer import comfy_aimdo.host_buffer
import comfy_aimdo.torch import comfy_aimdo.torch
import psutil
from comfy.cli_args import args from comfy.cli_args import args
@ -12,6 +13,11 @@ def pin_memory(module):
if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None:
return return
#FIXME: This is a RAM cache trigger event #FIXME: This is a RAM cache trigger event
ram_headroom = comfy.memory_management.RAM_CACHE_HEADROOM
#we split the difference and assume half the RAM cache headroom is for us
if ram_headroom > 0 and psutil.virtual_memory().available < (ram_headroom * 0.5):
comfy.memory_management.extra_ram_release(ram_headroom)
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY: if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:

View File

@ -280,9 +280,6 @@ class CLIP:
n.apply_hooks_to_conds = self.apply_hooks_to_conds n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n return n
def get_ram_usage(self):
return self.patcher.get_ram_usage()
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model) return self.patcher.add_patches(patches, strength_patch, strength_model)
@ -840,9 +837,6 @@ class VAE:
self.size = comfy.model_management.module_size(self.first_stage_model) self.size = comfy.model_management.module_size(self.first_stage_model)
return self.size return self.size
def get_ram_usage(self):
return self.model_size()
def throw_exception_if_invalid(self): def throw_exception_if_invalid(self):
if self.first_stage_model is None: if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
@ -1742,15 +1736,18 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
""" """
dtype = model_options.get("dtype", None) dtype = model_options.get("dtype", None)
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
#Allow loading unets from checkpoint files #Allow loading unets from checkpoint files
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True) temp_sd = comfy.utils.state_dict_prefix_replace(sd, {diffusion_model_prefix: ""}, filter_keys=True)
if len(temp_sd) > 0: if len(temp_sd) > 0:
sd = temp_sd sd = temp_sd
if custom_operations is None:
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
custom_operations = model_options.get("custom_operations", None)
if custom_operations is None:
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
parameters = comfy.utils.calculate_parameters(sd) parameters = comfy.utils.calculate_parameters(sd)
weight_dtype = comfy.utils.weight_dtype(sd) weight_dtype = comfy.utils.weight_dtype(sd)

View File

@ -1734,6 +1734,21 @@ class LongCatImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
class RT_DETR_v4(supported_models_base.BASE):
unet_config = {
"image_model": "RT_DETR_v4",
}
supported_inference_dtypes = [torch.float16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.RT_DETR_v4(self, device=device)
return out
def clip_target(self, state_dict={}):
return None
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@ -1373,6 +1373,7 @@ class NodeInfoV1:
price_badge: dict | None = None price_badge: dict | None = None
search_aliases: list[str]=None search_aliases: list[str]=None
essentials_category: str=None essentials_category: str=None
has_intermediate_output: bool=None
@dataclass @dataclass
@ -1496,6 +1497,16 @@ class Schema:
"""When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema.""" """When True, all inputs from the prompt will be passed to the node as kwargs, even if not defined in the schema."""
essentials_category: str | None = None essentials_category: str | None = None
"""Optional category for the Essentials tab. Path-based like category field (e.g., 'Basic', 'Image Tools/Editing').""" """Optional category for the Essentials tab. Path-based like category field (e.g., 'Basic', 'Image Tools/Editing')."""
has_intermediate_output: bool=False
"""Flags this node as having intermediate output that should persist across page refreshes.
Nodes with this flag behave like output nodes (their UI results are cached and resent
to the frontend) but do NOT automatically get added to the execution list. This means
they will only execute if they are on the dependency path of a real output node.
Use this for nodes with interactive/operable UI regions that produce intermediate outputs
(e.g., Image Crop, Painter) rather than final outputs (e.g., Save Image).
"""
def validate(self): def validate(self):
'''Validate the schema: '''Validate the schema:
@ -1595,6 +1606,7 @@ class Schema:
category=self.category, category=self.category,
description=self.description, description=self.description,
output_node=self.is_output_node, output_node=self.is_output_node,
has_intermediate_output=self.has_intermediate_output,
deprecated=self.is_deprecated, deprecated=self.is_deprecated,
experimental=self.is_experimental, experimental=self.is_experimental,
dev_only=self.is_dev_only, dev_only=self.is_dev_only,
@ -1886,6 +1898,14 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls.GET_SCHEMA() cls.GET_SCHEMA()
return cls._OUTPUT_NODE return cls._OUTPUT_NODE
_HAS_INTERMEDIATE_OUTPUT = None
@final
@classproperty
def HAS_INTERMEDIATE_OUTPUT(cls): # noqa
if cls._HAS_INTERMEDIATE_OUTPUT is None:
cls.GET_SCHEMA()
return cls._HAS_INTERMEDIATE_OUTPUT
_INPUT_IS_LIST = None _INPUT_IS_LIST = None
@final @final
@classproperty @classproperty
@ -1978,6 +1998,8 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
cls._API_NODE = schema.is_api_node cls._API_NODE = schema.is_api_node
if cls._OUTPUT_NODE is None: if cls._OUTPUT_NODE is None:
cls._OUTPUT_NODE = schema.is_output_node cls._OUTPUT_NODE = schema.is_output_node
if cls._HAS_INTERMEDIATE_OUTPUT is None:
cls._HAS_INTERMEDIATE_OUTPUT = schema.has_intermediate_output
if cls._INPUT_IS_LIST is None: if cls._INPUT_IS_LIST is None:
cls._INPUT_IS_LIST = schema.is_input_list cls._INPUT_IS_LIST = schema.is_input_list
if cls._NOT_IDEMPOTENT is None: if cls._NOT_IDEMPOTENT is None:

View File

@ -132,7 +132,7 @@ class TencentTextToModelNode(IO.ComfyNode):
tooltip="The LowPoly option is unavailable for the `3.1` model.", tooltip="The LowPoly option is unavailable for the `3.1` model.",
), ),
IO.String.Input("prompt", multiline=True, default="", tooltip="Supports up to 1024 characters."), IO.String.Input("prompt", multiline=True, default="", tooltip="Supports up to 1024 characters."),
IO.Int.Input("face_count", default=500000, min=40000, max=1500000), IO.Int.Input("face_count", default=500000, min=3000, max=1500000),
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
"generate_type", "generate_type",
options=[ options=[
@ -251,7 +251,7 @@ class TencentImageToModelNode(IO.ComfyNode):
IO.Image.Input("image_left", optional=True), IO.Image.Input("image_left", optional=True),
IO.Image.Input("image_right", optional=True), IO.Image.Input("image_right", optional=True),
IO.Image.Input("image_back", optional=True), IO.Image.Input("image_back", optional=True),
IO.Int.Input("face_count", default=500000, min=40000, max=1500000), IO.Int.Input("face_count", default=500000, min=3000, max=1500000),
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
"generate_type", "generate_type",
options=[ options=[
@ -422,6 +422,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
outputs=[ outputs=[
IO.File3DOBJ.Output(display_name="OBJ"), IO.File3DOBJ.Output(display_name="OBJ"),
IO.File3DFBX.Output(display_name="FBX"), IO.File3DFBX.Output(display_name="FBX"),
IO.Image.Output(display_name="uv_image"),
], ],
hidden=[ hidden=[
IO.Hidden.auth_token_comfy_org, IO.Hidden.auth_token_comfy_org,
@ -468,9 +469,16 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse, response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status, status_extractor=lambda r: r.Status,
) )
uv_image_file = get_file_from_response(result.ResultFile3Ds, "uv_image", raise_if_not_found=False)
uv_image = (
await download_url_to_image_tensor(uv_image_file.Url)
if uv_image_file is not None
else torch.zeros(1, 1, 1, 3)
)
return IO.NodeOutput( return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"), await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"), await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
uv_image,
) )

View File

@ -1,6 +1,5 @@
import asyncio import asyncio
import bisect import bisect
import gc
import itertools import itertools
import psutil import psutil
import time import time
@ -475,6 +474,10 @@ class LRUCache(BasicCache):
self._mark_used(node_id) self._mark_used(node_id)
return await self._set_immediate(node_id, value) return await self._set_immediate(node_id, value)
def set_local(self, node_id, value):
self._mark_used(node_id)
BasicCache.set_local(self, node_id, value)
async def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes # Just uses subcaches for tracking 'live' nodes
await super()._ensure_subcache(node_id, children_ids) await super()._ensure_subcache(node_id, children_ids)
@ -489,15 +492,10 @@ class LRUCache(BasicCache):
return self return self
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure #Small baseline weight used when a cache entry has no measurable CPU tensors.
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows. #Keeps unknown-sized entries in eviction scoring without dominating tensor-backed entries.
RAM_CACHE_HYSTERESIS = 1.1 RAM_CACHE_DEFAULT_RAM_USAGE = 0.05
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
#Exponential bias towards evicting older workflows so garbage will be taken out #Exponential bias towards evicting older workflows so garbage will be taken out
#in constantly changing setups. #in constantly changing setups.
@ -521,19 +519,17 @@ class RAMPressureCache(LRUCache):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return await super().get(node_id) return await super().get(node_id)
def poll(self, ram_headroom): def set_local(self, node_id, value):
def _ram_gb(): self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return psutil.virtual_memory().available / (1024**3) super().set_local(node_id, value)
if _ram_gb() > ram_headroom: def ram_release(self, target):
return if psutil.virtual_memory().available >= target:
gc.collect()
if _ram_gb() > ram_headroom:
return return
clean_list = [] clean_list = []
for key, (outputs, _), in self.cache.items(): for key, cache_entry in self.cache.items():
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key]) oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
@ -542,22 +538,20 @@ class RAMPressureCache(LRUCache):
if outputs is None: if outputs is None:
return return
for output in outputs: for output in outputs:
if isinstance(output, list): if isinstance(output, (list, tuple)):
scan_list_for_ram_usage(output) scan_list_for_ram_usage(output)
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu': elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
#score Tensors at a 50% discount for RAM usage as they are likely to ram_usage += output.numel() * output.element_size()
#be high value intermediates scan_list_for_ram_usage(cache_entry.outputs)
ram_usage += (output.numel() * output.element_size()) * 0.5
elif hasattr(output, "get_ram_usage"):
ram_usage += output.get_ram_usage()
scan_list_for_ram_usage(outputs)
oom_score *= ram_usage oom_score *= ram_usage
#In the case where we have no information on the node ram usage at all, #In the case where we have no information on the node ram usage at all,
#break OOM score ties on the last touch timestamp (pure LRU) #break OOM score ties on the last touch timestamp (pure LRU)
bisect.insort(clean_list, (oom_score, self.timestamps[key], key)) bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list: while psutil.virtual_memory().available < target and clean_list:
_, _, key = clean_list.pop() _, _, key = clean_list.pop()
del self.cache[key] del self.cache[key]
gc.collect() self.used_generation.pop(key, None)
self.timestamps.pop(key, None)
self.children.pop(key, None)

View File

@ -813,6 +813,7 @@ class GLSLShader(io.ComfyNode):
"u_resolution (vec2) is always available." "u_resolution (vec2) is always available."
), ),
is_experimental=True, is_experimental=True,
has_intermediate_output=True,
inputs=[ inputs=[
io.String.Input( io.String.Input(
"fragment_shader", "fragment_shader",

View File

@ -59,6 +59,7 @@ class ImageCropV2(IO.ComfyNode):
display_name="Image Crop", display_name="Image Crop",
category="image/transform", category="image/transform",
essentials_category="Image Tools", essentials_category="Image Tools",
has_intermediate_output=True,
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.BoundingBox.Input("crop_region", component="ImageCrop"), IO.BoundingBox.Input("crop_region", component="ImageCrop"),

View File

@ -30,6 +30,7 @@ class PainterNode(io.ComfyNode):
node_id="Painter", node_id="Painter",
display_name="Painter", display_name="Painter",
category="image", category="image",
has_intermediate_output=True,
inputs=[ inputs=[
io.Image.Input( io.Image.Input(
"image", "image",

View File

@ -0,0 +1,154 @@
from typing_extensions import override
import torch
from comfy.ldm.rt_detr.rtdetr_v4 import COCO_CLASSES
import comfy.model_management
import comfy.utils
from comfy_api.latest import ComfyExtension, io
from torchvision.transforms import ToPILImage, ToTensor
from PIL import ImageDraw, ImageFont
class RTDETR_detect(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="RTDETR_detect",
display_name="RT-DETR Detect",
category="detection/",
search_aliases=["bbox", "bounding box", "object detection", "coco"],
inputs=[
io.Model.Input("model", display_name="model"),
io.Image.Input("image", display_name="image"),
io.Float.Input("threshold", display_name="threshold", default=0.5),
io.Combo.Input("class_name", options=["all"] + COCO_CLASSES, default="all", tooltip="Filter detections by class. Set to 'all' to disable filtering."),
io.Int.Input("max_detections", display_name="max_detections", default=100, tooltip="Maximum number of detections to return per image. In order of descending confidence score."),
],
outputs=[
io.BoundingBox.Output("bboxes")],
)
@classmethod
def execute(cls, model, image, threshold, class_name, max_detections) -> io.NodeOutput:
B, H, W, C = image.shape
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
comfy.model_management.load_model_gpu(model)
results = model.model.diffusion_model(image_in, (W, H)) # list of B dicts
all_bbox_dicts = []
for det in results:
keep = det['scores'] > threshold
boxes = det['boxes'][keep].cpu()
labels = det['labels'][keep].cpu()
scores = det['scores'][keep].cpu()
bbox_dicts = [
{
"x": float(box[0]),
"y": float(box[1]),
"width": float(box[2] - box[0]),
"height": float(box[3] - box[1]),
"label": COCO_CLASSES[int(label)],
"score": float(score)
}
for box, label, score in zip(boxes, labels, scores)
if class_name == "all" or COCO_CLASSES[int(label)] == class_name
]
bbox_dicts.sort(key=lambda d: d["score"], reverse=True)
all_bbox_dicts.append(bbox_dicts[:max_detections])
return io.NodeOutput(all_bbox_dicts)
class DrawBBoxes(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DrawBBoxes",
display_name="Draw BBoxes",
category="detection/",
search_aliases=["bbox", "bounding box", "object detection", "rt_detr", "visualize detections", "coco"],
inputs=[
io.Image.Input("image", optional=True),
io.BoundingBox.Input("bboxes", force_input=True),
],
outputs=[
io.Image.Output("out_image"),
],
)
@classmethod
def execute(cls, bboxes, image=None) -> io.NodeOutput:
# Normalise to list[list[dict]], then fit to batch size B.
B = image.shape[0] if image is not None else 1
if isinstance(bboxes, dict):
bboxes = [[bboxes]]
elif not isinstance(bboxes, list) or not bboxes:
bboxes = [[]]
elif isinstance(bboxes[0], dict):
bboxes = [bboxes] # flat list → same detections for every image
if len(bboxes) == 1:
bboxes = bboxes * B
bboxes = (bboxes + [[]] * B)[:B]
if image is None:
B = len(bboxes)
max_w = max((int(d["x"] + d["width"]) for frame in bboxes for d in frame), default=640)
max_h = max((int(d["y"] + d["height"]) for frame in bboxes for d in frame), default=640)
image = torch.zeros((B, max_h, max_w, 3), dtype=torch.float32)
all_out_images = []
for i in range(B):
detections = bboxes[i]
if detections:
boxes = torch.tensor([[d["x"], d["y"], d["x"] + d["width"], d["y"] + d["height"]] for d in detections])
labels = [d.get("label") if d.get("label") in COCO_CLASSES else None for d in detections]
scores = torch.tensor([d.get("score", 1.0) for d in detections])
else:
boxes = torch.zeros((0, 4))
labels = []
scores = torch.zeros((0,))
pil_image = image[i].movedim(-1, 0)
img = ToPILImage()(pil_image)
if detections:
img = cls.draw_detections(img, boxes, labels, scores)
all_out_images.append(ToTensor()(img).unsqueeze(0).movedim(1, -1))
out_images = torch.cat(all_out_images, dim=0).to(comfy.model_management.intermediate_device())
return io.NodeOutput(out_images)
@classmethod
def draw_detections(cls, img, boxes, labels, scores):
draw = ImageDraw.Draw(img)
try:
font = ImageFont.truetype('arial.ttf', 16)
except Exception:
font = ImageFont.load_default()
colors = [(255,0,0),(0,200,0),(0,0,255),(255,165,0),(128,0,128),
(0,255,255),(255,20,147),(100,149,237)]
for box, label, score in sorted(zip(boxes, labels, scores), key=lambda x: x[2].item()):
x1, y1, x2, y2 = box.tolist()
color_idx = COCO_CLASSES.index(label) if label is not None else 0
c = colors[color_idx % len(colors)]
draw.rectangle([x1, y1, x2, y2], outline=c, width=3)
if label is not None:
draw.text((x1 + 2, y1 + 2), f'{label} {score:.2f}', fill=c, font=font)
return img
class RTDETRExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
RTDETR_detect,
DrawBBoxes,
]
async def comfy_entrypoint() -> RTDETRExtension:
return RTDETRExtension()

View File

@ -661,6 +661,7 @@ class CropByBBoxes(io.ComfyNode):
io.Int.Input("output_width", default=512, min=64, max=4096, step=8, tooltip="Width each crop is resized to."), io.Int.Input("output_width", default=512, min=64, max=4096, step=8, tooltip="Width each crop is resized to."),
io.Int.Input("output_height", default=512, min=64, max=4096, step=8, tooltip="Height each crop is resized to."), io.Int.Input("output_height", default=512, min=64, max=4096, step=8, tooltip="Height each crop is resized to."),
io.Int.Input("padding", default=0, min=0, max=1024, step=1, tooltip="Extra padding in pixels added on each side of the bbox before cropping."), io.Int.Input("padding", default=0, min=0, max=1024, step=1, tooltip="Extra padding in pixels added on each side of the bbox before cropping."),
io.Combo.Input("keep_aspect", options=["stretch", "pad"], default="stretch", tooltip="Whether to stretch the crop to fit the output size, or pad with black pixels to preserve aspect ratio."),
], ],
outputs=[ outputs=[
io.Image.Output(tooltip="All crops stacked into a single image batch."), io.Image.Output(tooltip="All crops stacked into a single image batch."),
@ -668,7 +669,7 @@ class CropByBBoxes(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, image, bboxes, output_width, output_height, padding) -> io.NodeOutput: def execute(cls, image, bboxes, output_width, output_height, padding, keep_aspect="stretch") -> io.NodeOutput:
total_frames = image.shape[0] total_frames = image.shape[0]
img_h = image.shape[1] img_h = image.shape[1]
img_w = image.shape[2] img_w = image.shape[2]
@ -716,7 +717,19 @@ class CropByBBoxes(io.ComfyNode):
x1, y1, x2, y2 = fb_x1, fb_y1, fb_x2, fb_y2 x1, y1, x2, y2 = fb_x1, fb_y1, fb_x2, fb_y2
crop_chw = frame_chw[:, :, y1:y2, x1:x2] # (1, C, crop_h, crop_w) crop_chw = frame_chw[:, :, y1:y2, x1:x2] # (1, C, crop_h, crop_w)
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
if keep_aspect == "pad":
crop_h, crop_w = y2 - y1, x2 - x1
scale = min(output_width / crop_w, output_height / crop_h)
scaled_w = int(round(crop_w * scale))
scaled_h = int(round(crop_h * scale))
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
pad_left = (output_width - scaled_w) // 2
pad_top = (output_height - scaled_h) // 2
resized = torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device)
resized[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
else: # "stretch"
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
crops.append(resized) crops.append(resized)
if not crops: if not crops:

View File

@ -9,9 +9,9 @@ class StringConcatenate(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="StringConcatenate", node_id="StringConcatenate",
display_name="Concatenate", display_name="Text Concatenate",
category="utils/string", category="utils/string",
search_aliases=["text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"], search_aliases=["Concatenate", "text concat", "join text", "merge text", "combine strings", "concat", "concatenate", "append text", "combine text", "string"],
inputs=[ inputs=[
io.String.Input("string_a", multiline=True), io.String.Input("string_a", multiline=True),
io.String.Input("string_b", multiline=True), io.String.Input("string_b", multiline=True),
@ -32,8 +32,8 @@ class StringSubstring(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="StringSubstring", node_id="StringSubstring",
search_aliases=["extract text", "text portion"], search_aliases=["Substring", "extract text", "text portion"],
display_name="Substring", display_name="Text Substring",
category="utils/string", category="utils/string",
inputs=[ inputs=[
io.String.Input("string", multiline=True), io.String.Input("string", multiline=True),
@ -55,8 +55,8 @@ class StringLength(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="StringLength", node_id="StringLength",
search_aliases=["character count", "text size"], search_aliases=["character count", "text size", "string length"],
display_name="Length", display_name="Text Length",
category="utils/string", category="utils/string",
inputs=[ inputs=[
io.String.Input("string", multiline=True), io.String.Input("string", multiline=True),
@ -76,8 +76,8 @@ class CaseConverter(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="CaseConverter", node_id="CaseConverter",
search_aliases=["text case", "uppercase", "lowercase", "capitalize"], search_aliases=["Case Converter", "text case", "uppercase", "lowercase", "capitalize"],
display_name="Case Converter", display_name="Text Case Converter",
category="utils/string", category="utils/string",
inputs=[ inputs=[
io.String.Input("string", multiline=True), io.String.Input("string", multiline=True),
@ -109,8 +109,8 @@ class StringTrim(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="StringTrim", node_id="StringTrim",
search_aliases=["clean whitespace", "remove whitespace"], search_aliases=["Trim", "clean whitespace", "remove whitespace", "strip"],
display_name="Trim", display_name="Text Trim",
category="utils/string", category="utils/string",
inputs=[ inputs=[
io.String.Input("string", multiline=True), io.String.Input("string", multiline=True),
@ -140,8 +140,8 @@ class StringReplace(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="StringReplace", node_id="StringReplace",
search_aliases=["find and replace", "substitute", "swap text"], search_aliases=["Replace", "find and replace", "substitute", "swap text"],
display_name="Replace", display_name="Text Replace",
category="utils/string", category="utils/string",
inputs=[ inputs=[
io.String.Input("string", multiline=True), io.String.Input("string", multiline=True),
@ -163,8 +163,8 @@ class StringContains(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="StringContains", node_id="StringContains",
search_aliases=["text includes", "string includes"], search_aliases=["Contains", "text includes", "string includes"],
display_name="Contains", display_name="Text Contains",
category="utils/string", category="utils/string",
inputs=[ inputs=[
io.String.Input("string", multiline=True), io.String.Input("string", multiline=True),
@ -191,8 +191,8 @@ class StringCompare(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="StringCompare", node_id="StringCompare",
search_aliases=["text match", "string equals", "starts with", "ends with"], search_aliases=["Compare", "text match", "string equals", "starts with", "ends with"],
display_name="Compare", display_name="Text Compare",
category="utils/string", category="utils/string",
inputs=[ inputs=[
io.String.Input("string_a", multiline=True), io.String.Input("string_a", multiline=True),
@ -227,8 +227,8 @@ class RegexMatch(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="RegexMatch", node_id="RegexMatch",
search_aliases=["pattern match", "text contains", "string match"], search_aliases=["Regex Match", "regex", "pattern match", "text contains", "string match"],
display_name="Regex Match", display_name="Text Match",
category="utils/string", category="utils/string",
inputs=[ inputs=[
io.String.Input("string", multiline=True), io.String.Input("string", multiline=True),
@ -268,8 +268,8 @@ class RegexExtract(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="RegexExtract", node_id="RegexExtract",
search_aliases=["pattern extract", "text parser", "parse text"], search_aliases=["Regex Extract", "regex", "pattern extract", "text parser", "parse text"],
display_name="Regex Extract", display_name="Text Extract Substring",
category="utils/string", category="utils/string",
inputs=[ inputs=[
io.String.Input("string", multiline=True), io.String.Input("string", multiline=True),
@ -343,8 +343,8 @@ class RegexReplace(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="RegexReplace", node_id="RegexReplace",
search_aliases=["pattern replace", "find and replace", "substitution"], search_aliases=["Regex Replace", "regex", "pattern replace", "regex replace", "substitution"],
display_name="Regex Replace", display_name="Text Replace (Regex)",
category="utils/string", category="utils/string",
description="Find and replace text using regex patterns.", description="Find and replace text using regex patterns.",
inputs=[ inputs=[

View File

@ -411,6 +411,19 @@ def format_value(x):
else: else:
return str(x) return str(x)
def _is_intermediate_output(dynprompt, node_id):
class_type = dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
if server.client_id is None:
return
cached_ui = cached.ui or {}
server.send_sync("executed", { "node": node_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id }, server.client_id)
if cached.ui is not None:
ui_outputs[node_id] = cached.ui
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs): async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
unique_id = current_item unique_id = current_item
real_node_id = dynprompt.get_real_node_id(unique_id) real_node_id = dynprompt.get_real_node_id(unique_id)
@ -421,11 +434,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
cached = await caches.outputs.get(unique_id) cached = await caches.outputs.get(unique_id)
if cached is not None: if cached is not None:
if server.client_id is not None: _send_cached_ui(server, unique_id, display_node_id, cached, prompt_id, ui_outputs)
cached_ui = cached.ui or {}
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
if cached.ui is not None:
ui_outputs[unique_id] = cached.ui
get_progress_state().finish_progress(unique_id) get_progress_state().finish_progress(unique_id)
execution_list.cache_update(unique_id, cached) execution_list.cache_update(unique_id, cached)
return (ExecutionResult.SUCCESS, None, None) return (ExecutionResult.SUCCESS, None, None)
@ -715,6 +724,9 @@ class PromptExecutor:
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
self._notify_prompt_lifecycle("start", prompt_id) self._notify_prompt_lifecycle("start", prompt_id)
ram_headroom = int(self.cache_args["ram"] * (1024 ** 3))
ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None
comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom)
try: try:
with torch.inference_mode(): with torch.inference_mode():
@ -764,9 +776,22 @@ class PromptExecutor:
execution_list.unstage_node_execution() execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS: else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution() execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
if self.cache_type == CacheType.RAM_PRESSURE:
comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom)
comfy.memory_management.extra_ram_release(ram_headroom)
else: else:
# Only execute when the while-loop ends without break # Only execute when the while-loop ends without break
# Send cached UI for intermediate output nodes that weren't executed
for node_id in dynamic_prompt.all_node_ids():
if node_id in executed:
continue
if not _is_intermediate_output(dynamic_prompt, node_id):
continue
cached = await self.caches.outputs.get(node_id)
if cached is not None:
display_node_id = dynamic_prompt.get_display_node_id(node_id)
_send_cached_ui(self.server, node_id, display_node_id, cached, prompt_id, ui_node_outputs)
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
ui_outputs = {} ui_outputs = {}
@ -782,6 +807,7 @@ class PromptExecutor:
if comfy.model_management.DISABLE_SMART_MEMORY: if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models() comfy.model_management.unload_all_models()
finally: finally:
comfy.memory_management.set_ram_cache_release_state(None, 0)
self._notify_prompt_lifecycle("end", prompt_id) self._notify_prompt_lifecycle("end", prompt_id)

View File

@ -275,15 +275,19 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]:
def prompt_worker(q, server_instance): def prompt_worker(q, server_instance):
current_time: float = 0.0 current_time: float = 0.0
cache_ram = args.cache_ram
if cache_ram < 0:
cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0))
cache_type = execution.CacheType.CLASSIC cache_type = execution.CacheType.CLASSIC
if args.cache_lru > 0: if args.cache_lru > 0:
cache_type = execution.CacheType.LRU cache_type = execution.CacheType.LRU
elif args.cache_ram > 0: elif cache_ram > 0:
cache_type = execution.CacheType.RAM_PRESSURE cache_type = execution.CacheType.RAM_PRESSURE
elif args.cache_none: elif args.cache_none:
cache_type = execution.CacheType.NONE cache_type = execution.CacheType.NONE
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } ) e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram } )
last_gc_collect = 0 last_gc_collect = 0
need_gc = False need_gc = False
gc_collect_interval = 10.0 gc_collect_interval = 10.0

View File

@ -2457,6 +2457,7 @@ async def init_builtin_extra_nodes():
"nodes_number_convert.py", "nodes_number_convert.py",
"nodes_painter.py", "nodes_painter.py",
"nodes_curve.py", "nodes_curve.py",
"nodes_rtdetr.py"
] ]
import_failed = [] import_failed = []

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.42.8 comfyui-frontend-package==1.42.8
comfyui-workflow-templates==0.9.39 comfyui-workflow-templates==0.9.41
comfyui-embedded-docs==0.4.3 comfyui-embedded-docs==0.4.3
torch torch
torchsde torchsde

View File

@ -709,6 +709,11 @@ class PromptServer():
else: else:
info['output_node'] = False info['output_node'] = False
if hasattr(obj_class, 'HAS_INTERMEDIATE_OUTPUT') and obj_class.HAS_INTERMEDIATE_OUTPUT == True:
info['has_intermediate_output'] = True
else:
info['has_intermediate_output'] = False
if hasattr(obj_class, 'CATEGORY'): if hasattr(obj_class, 'CATEGORY'):
info['category'] = obj_class.CATEGORY info['category'] = obj_class.CATEGORY