From 6887165a9d657ced4f0122c0ca5368dc74125d80 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Tue, 19 May 2026 16:55:04 -0700 Subject: [PATCH 01/45] docs(openapi): tighten workspace API key description field (BE-1004) (#13996) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Aligns the OSS spec with the cloud-side BE-1004 contract: - createWorkspaceApiKey request body: add maxLength: 5000 to the description property (matches cloud's hub_profile.description MaxLen(5000) convention; enforced cloud-side via handler check). - WorkspaceApiKey + WorkspaceApiKeyCreated response schemas: mark description as required (cloud's handler always populates the field, defaulting to empty string when not supplied on create), drop nullable: true, add maxLength: 5000 for symmetry, and clarify the doc string ("Always present in responses; empty string when no description was supplied on create"). Both schemas are tagged x-runtime: [cloud] at the schema level so the tightening is correctly scoped — OSS-only implementations are not required to honor the workspace API keys endpoints at all. Related cloud PR: Comfy-Org/cloud#3747 --- openapi.yaml | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index bc1ae16fa..2658b9b86 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -4162,7 +4162,8 @@ paths: description: Display name for the API key description: type: string - description: User-provided description for the key + description: User-provided description of the key's purpose + maxLength: 5000 responses: "201": description: API key created @@ -7680,6 +7681,7 @@ components: required: - id - name + - description properties: id: type: string @@ -7687,8 +7689,8 @@ components: type: string description: type: string - nullable: true - description: User-provided description + maxLength: 5000 + description: User-provided description of the key's purpose. Always present in responses; empty string when no description was supplied on create. prefix: type: string description: First few characters of the key for identification @@ -7709,6 +7711,7 @@ components: required: - id - name + - description - key properties: id: @@ -7717,8 +7720,8 @@ components: type: string description: type: string - nullable: true - description: User-provided description + maxLength: 5000 + description: User-provided description of the key's purpose. Always present in responses; empty string when no description was supplied on create. key: type: string description: Full API key value (only returned on creation) From 7ec7b6ffe93bb47d70c5fa1b702e387e4d545dae Mon Sep 17 00:00:00 2001 From: Pauan Date: Tue, 19 May 2026 19:25:49 -0700 Subject: [PATCH 02/45] Adding new StringFormat node (#13997) --- comfy_extras/nodes_string.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/comfy_extras/nodes_string.py b/comfy_extras/nodes_string.py index 925a40da8..97485c8c5 100644 --- a/comfy_extras/nodes_string.py +++ b/comfy_extras/nodes_string.py @@ -1,10 +1,41 @@ import re import json +import string from typing_extensions import override from comfy_api.latest import ComfyExtension, io +class StringFormat(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + autogrow = io.Autogrow.TemplateNames( + input=io.AnyType.Input("value"), + names=list(string.ascii_lowercase), + min=0, + ) + return io.Schema( + node_id="StringFormat", + display_name="Format Text", + category="text", + search_aliases=["string", "format"], + description="Same as Python's string format method. Supports all of Python's format options and features.", + inputs=[ + io.Autogrow.Input("values", template=autogrow), + io.String.Input("f_string", default="{a}", multiline=True), + ], + outputs=[ + io.String.Output(), + ], + ) + + @classmethod + def execute( + cls, values: io.Autogrow.Type, f_string: str + ) -> io.NodeOutput: + return io.NodeOutput(f_string.format(**values)) + + class StringConcatenate(io.ComfyNode): @classmethod def define_schema(cls): @@ -413,6 +444,7 @@ class StringExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ + StringFormat, StringConcatenate, StringSubstring, StringLength, From 72e3f6081ccf8853baede1308f16e0e9ebcc09dc Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 19 May 2026 20:28:06 -0700 Subject: [PATCH 03/45] Add downscale ratio to empty ltxv latent. (#13999) --- comfy_extras/nodes_lt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 675de4f81..51cf7951f 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -77,7 +77,7 @@ class EmptyLTXVLatentVideo(io.ComfyNode): @classmethod def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device()) - return io.NodeOutput({"samples": latent}) + return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 32}) generate = execute # TODO: remove From 78b5dec6b6beefb9fb40f917d33d2f10a40d9e53 Mon Sep 17 00:00:00 2001 From: Cezarijus Kivylius Date: Wed, 20 May 2026 12:58:49 +0100 Subject: [PATCH 04/45] fix: Hunyuan3D 2.1 batch size crashes in attention and forward pass (#13699) --- comfy/ldm/hunyuan3dv2_1/hunyuandit.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py index f67ba84e9..bc36b8998 100644 --- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -328,7 +328,7 @@ class CrossAttention(nn.Module): kv = torch.cat((k, v), dim=-1) split_size = kv.shape[-1] // self.num_heads // 2 - kv = kv.view(1, -1, self.num_heads, split_size * 2) + kv = kv.view(b, -1, self.num_heads, split_size * 2) k, v = torch.split(kv, split_size, dim=-1) q = q.view(b, s1, self.num_heads, self.head_dim) @@ -398,7 +398,7 @@ class Attention(nn.Module): qkv_combined = torch.cat((query, key, value), dim=-1) split_size = qkv_combined.shape[-1] // self.num_heads // 3 - qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3) + qkv = qkv_combined.view(B, -1, self.num_heads, split_size * 3) query, key, value = torch.split(qkv, split_size, dim=-1) query = query.reshape(B, N, self.num_heads, self.head_dim) @@ -607,9 +607,9 @@ class HunYuanDiTPlain(nn.Module): def forward(self, x, t, context, transformer_options = {}, **kwargs): x = x.movedim(-1, -2) - uncond_emb, cond_emb = context.chunk(2, dim = 0) - - context = torch.cat([cond_emb, uncond_emb], dim = 0) + if context.shape[0] >= 2: + uncond_emb, cond_emb = context.chunk(2, dim = 0) + context = torch.cat([cond_emb, uncond_emb], dim = 0) main_condition = context t = 1.0 - t @@ -657,5 +657,8 @@ class HunYuanDiTPlain(nn.Module): output = self.final_layer(combined) output = output.movedim(-2, -1) * (-1.0) - cond_emb, uncond_emb = output.chunk(2, dim = 0) - return torch.cat([uncond_emb, cond_emb]) + if output.shape[0] >= 2: + cond_emb, uncond_emb = output.chunk(2, dim = 0) + return torch.cat([uncond_emb, cond_emb]) + else: + return output From f9c84c94b43855ec46e26afc944c06fb121cb9a8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 20 May 2026 08:34:22 -0700 Subject: [PATCH 05/45] Support Stable Audio 3 model. (#14010) --- comfy/latent_formats.py | 5 + comfy/ldm/audio/dit.py | 250 +++++++++++++--- comfy/ldm/audio/embedders.py | 31 +- comfy/ldm/audio/vae_sa3.py | 533 +++++++++++++++++++++++++++++++++++ comfy/model_base.py | 79 ++++++ comfy/model_detection.py | 39 +++ comfy/sd.py | 37 +++ comfy/supported_models.py | 25 ++ comfy/text_encoders/sa3.py | 207 ++++++++++++++ 9 files changed, 1161 insertions(+), 45 deletions(-) create mode 100644 comfy/ldm/audio/vae_sa3.py create mode 100644 comfy/text_encoders/sa3.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 6e37080bb..75d459b59 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -152,6 +152,11 @@ class StableAudio1(LatentFormat): latent_dimensions = 1 temporal_downscale_ratio = 2048 +class StableAudio3(LatentFormat): + latent_channels = 256 + latent_dimensions = 1 + temporal_downscale_ratio = 4096 + class Flux(SD3): latent_channels = 16 def __init__(self): diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py index ca865189e..a6258b755 100644 --- a/comfy/ldm/audio/dit.py +++ b/comfy/ldm/audio/dit.py @@ -10,6 +10,17 @@ from torch import nn from torch.nn import functional as F import math import comfy.ops +from .embedders import ExpoFourierFeatures + + +def _left_pad_to_match(emb, target_len): + emb_len = emb.shape[-2] + if emb_len < target_len: + return F.pad(emb, (0, 0, target_len - emb_len, 0), value=0.) + elif emb_len > target_len: + return emb[:, -target_len:, :] + return emb + class FourierFeatures(nn.Module): def __init__(self, in_features, out_features, std=1., dtype=None, device=None): @@ -22,6 +33,7 @@ class FourierFeatures(nn.Module): f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input) return torch.cat([f.cos(), f.sin()], dim=-1) + # norms class LayerNorm(nn.Module): def __init__(self, dim, bias=False, fix_scale=False, dtype=None, device=None): @@ -43,6 +55,16 @@ class LayerNorm(nn.Module): beta = comfy.ops.cast_to_input(beta, x) return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta) + +class RMSNorm(nn.Module): + def __init__(self, dim, dtype=None, device=None): + super().__init__() + self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) + + def forward(self, x): + return F.rms_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x)) + + class GLU(nn.Module): def __init__( self, @@ -236,13 +258,6 @@ class FeedForward(nn.Module): linear_out = operations.Linear(inner_dim, dim_out, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device) - # # init last linear layer to 0 - # if zero_init_output: - # nn.init.zeros_(linear_out.weight) - # if not no_bias: - # nn.init.zeros_(linear_out.bias) - - self.ff = nn.Sequential( linear_in, rearrange('b d n -> b n d') if use_conv else nn.Identity(), @@ -261,8 +276,10 @@ class Attention(nn.Module): dim_context = None, causal = False, zero_init_output=True, - qk_norm = False, + qk_norm = "none", + differential = False, natten_kernel_size = None, + feat_scale = False, dtype=None, device=None, operations=None, @@ -271,6 +288,7 @@ class Attention(nn.Module): self.dim = dim self.dim_heads = dim_heads self.causal = causal + self.differential = differential dim_kv = dim_context if dim_context is not None else dim @@ -278,18 +296,37 @@ class Attention(nn.Module): self.kv_heads = dim_kv // dim_heads if dim_context is not None: - self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) - self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device) + if differential: + self.to_q = operations.Linear(dim, dim * 2, bias=False, dtype=dtype, device=device) + self.to_kv = operations.Linear(dim_kv, dim_kv * 3, bias=False, dtype=dtype, device=device) + else: + self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device) else: - self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device) + if differential: + self.to_qkv = operations.Linear(dim, dim * 5, bias=False, dtype=dtype, device=device) + else: + self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device) self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) - # if zero_init_output: - # nn.init.zeros_(self.to_out.weight) - + # Accept bool for backward compat + if isinstance(qk_norm, bool): + qk_norm = "l2" if qk_norm else "none" self.qk_norm = qk_norm + if self.qk_norm == "ln": + self.q_norm = operations.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device) + self.k_norm = operations.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6, dtype=dtype, device=device) + elif self.qk_norm == "rms": + self.q_norm = RMSNorm(dim_heads, dtype=dtype, device=device) + self.k_norm = RMSNorm(dim_heads, dtype=dtype, device=device) + + self.feat_scale = feat_scale + + if self.feat_scale: + self.lambda_dc = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) + self.lambda_hf = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) def forward( self, @@ -306,22 +343,51 @@ class Attention(nn.Module): kv_input = context if has_context else x if hasattr(self, 'to_q'): - # Use separate linear projections for q and k/v - q = self.to_q(x) - q = rearrange(q, 'b n (h d) -> b h n d', h = h) + if self.differential: + # cross-attention differential: to_q → (q, q_diff), to_kv → (k, k_diff, v) + q, q_diff = self.to_q(x).chunk(2, dim=-1) + q = rearrange(q, 'b n (h d) -> b h n d', h=h) + q_diff = rearrange(q_diff, 'b n (h d) -> b h n d', h=h) + q = torch.stack([q, q_diff], dim=1) # (B, 2, H, N, D) + k, k_diff, v = self.to_kv(kv_input).chunk(3, dim=-1) + k = rearrange(k, 'b n (h d) -> b h n d', h=kv_h) + k_diff = rearrange(k_diff, 'b n (h d) -> b h n d', h=kv_h) + v = rearrange(v, 'b n (h d) -> b h n d', h=kv_h) + k = torch.stack([k, k_diff], dim=1) # (B, 2, H, M, D) + else: + # Use separate linear projections for q and k/v + q = self.to_q(x) + q = rearrange(q, 'b n (h d) -> b h n d', h = h) - k, v = self.to_kv(kv_input).chunk(2, dim=-1) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) - k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) + k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v)) else: - # Use fused linear projection - q, k, v = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) + if self.differential: + # self-attention differential: to_qkv → (q, k, v, q_diff, k_diff) + q, k, v, q_diff, k_diff = self.to_qkv(x).chunk(5, dim=-1) + q, k, v, q_diff, k_diff = map( + lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), + (q, k, v, q_diff, k_diff) + ) + q = torch.stack([q, q_diff], dim=1) # (B, 2, H, N, D) + k = torch.stack([k, k_diff], dim=1) + else: + # Use fused linear projection + q, k, v = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) # Normalize q and k for cosine sim attention - if self.qk_norm: + if self.qk_norm == "l2": q = F.normalize(q, dim=-1) k = F.normalize(k, dim=-1) + elif self.qk_norm == "rms": + q_type, k_type = q.dtype, k.dtype + q = self.q_norm(q).to(q_type) + k = self.k_norm(k).to(k_type) + elif self.qk_norm != 'none': + q = self.q_norm(q) + k = self.k_norm(k) if rotary_pos_emb is not None and not has_context: freqs, _ = rotary_pos_emb @@ -364,9 +430,24 @@ class Attention(nn.Module): heads_per_kv_head = h // kv_h k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v)) - out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) + if self.differential: + q, q_diff = q.unbind(dim=1) + k, k_diff = k.unbind(dim=1) + out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) + out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, transformer_options=transformer_options) + out = out - out_diff + else: + out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) + out = self.to_out(out) + if self.feat_scale: + out_dc = out.mean(dim=-2, keepdim=True) + out_hf = out - out_dc + + # Selectively modulate DC and high frequency components + out = out + comfy.ops.cast_to_input(self.lambda_dc, out) * out_dc + comfy.ops.cast_to_input(self.lambda_hf, out) * out_hf + if mask is not None: mask = rearrange(mask, 'b n -> b n 1') out = out.masked_fill(~mask, 0.) @@ -417,11 +498,14 @@ class TransformerBlock(nn.Module): cross_attend = False, dim_context = None, global_cond_dim = None, + global_cond_shared_embed = False, + local_add_cond_dim = None, causal = False, zero_init_branch_outputs = True, conformer = False, layer_ix = -1, remove_norms = False, + norm_type = "layer_norm", attn_kwargs = {}, ff_kwargs = {}, norm_kwargs = {}, @@ -436,8 +520,20 @@ class TransformerBlock(nn.Module): self.cross_attend = cross_attend self.dim_context = dim_context self.causal = causal + self.global_cond_shared_embed = global_cond_shared_embed - self.pre_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity() + norm_layer_map = { + "layer_norm": LayerNorm, + "rms_norm": RMSNorm, + } + norm_cls = norm_layer_map.get(norm_type, LayerNorm) + + def make_norm(): + if remove_norms: + return nn.Identity() + return norm_cls(dim, dtype=dtype, device=device, **norm_kwargs) + + self.pre_norm = make_norm() self.self_attn = Attention( dim, @@ -451,7 +547,7 @@ class TransformerBlock(nn.Module): ) if cross_attend: - self.cross_attend_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity() + self.cross_attend_norm = make_norm() self.cross_attn = Attention( dim, dim_heads = dim_heads, @@ -464,37 +560,56 @@ class TransformerBlock(nn.Module): **attn_kwargs ) - self.ff_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity() - self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations,**ff_kwargs) + self.ff_norm = make_norm() + self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations, **ff_kwargs) self.layer_ix = layer_ix self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None - self.global_cond_dim = global_cond_dim + # Global conditioning + self.has_global_cond = (global_cond_dim is not None) or global_cond_shared_embed - if global_cond_dim is not None: + if global_cond_shared_embed: + # SA3 style: learnable per-block additive bias; global_cond is pre-projected to (B, dim*6) + self.to_scale_shift_gate = nn.Parameter(torch.empty(dim * 6, device=device, dtype=dtype)) + elif global_cond_dim is not None: + # SA1 style: per-block MLP projects global_cond → (B, dim*6) self.to_scale_shift_gate = nn.Sequential( nn.SiLU(), - nn.Linear(global_cond_dim, dim * 6, bias=False) + operations.Linear(global_cond_dim, dim * 6, bias=False, device=device, dtype=dtype) ) - nn.init.zeros_(self.to_scale_shift_gate[1].weight) - #nn.init.zeros_(self.to_scale_shift_gate_self[1].bias) + # Local additive conditioning (e.g. inpaint mask + masked latent) + self.local_add_cond_dim = local_add_cond_dim + if local_add_cond_dim is not None: + self.to_local_embed = nn.Sequential( + operations.Linear(local_add_cond_dim, dim, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(dim, dim, bias=True, dtype=dtype, device=device), + ) + else: + self.to_local_embed = None def forward( self, x, context = None, global_cond=None, + local_add_cond=None, mask = None, context_mask = None, rotary_pos_emb = None, transformer_options={} ): - if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None: + if self.has_global_cond and global_cond is not None: + if self.global_cond_shared_embed: + # global_cond already has shape (B, dim*6) + ssg = (comfy.ops.cast_to_input(self.to_scale_shift_gate, global_cond) + global_cond).unsqueeze(1) + else: + ssg = self.to_scale_shift_gate(global_cond).unsqueeze(1) - scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1) + scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = ssg.chunk(6, dim = -1) # self-attention with adaLN residual = x @@ -510,6 +625,9 @@ class TransformerBlock(nn.Module): if self.conformer is not None: x = x + self.conformer(x) + if local_add_cond is not None and self.to_local_embed is not None: + x = x + _left_pad_to_match(self.to_local_embed(local_add_cond), x.shape[-2]) + # feedforward with adaLN residual = x x = self.ff_norm(x) @@ -527,6 +645,9 @@ class TransformerBlock(nn.Module): if self.conformer is not None: x = x + self.conformer(x) + if local_add_cond is not None and self.to_local_embed is not None: + x = x + _left_pad_to_match(self.to_local_embed(local_add_cond), x.shape[-2]) + x = x + self.ff(self.ff_norm(x)) return x @@ -543,6 +664,8 @@ class ContinuousTransformer(nn.Module): cross_attend=False, cond_token_dim=None, global_cond_dim=None, + global_cond_shared_embed=False, + local_add_cond_dim=None, causal=False, rotary_pos_emb=True, zero_init_branch_outputs=True, @@ -550,6 +673,7 @@ class ContinuousTransformer(nn.Module): use_sinusoidal_emb=False, use_abs_pos_emb=False, abs_pos_emb_max_length=10000, + num_memory_tokens=0, dtype=None, device=None, operations=None, @@ -562,6 +686,8 @@ class ContinuousTransformer(nn.Module): self.depth = depth self.causal = causal self.layers = nn.ModuleList([]) + self.num_memory_tokens = num_memory_tokens + self.global_cond_shared_embed = global_cond_shared_embed self.project_in = operations.Linear(dim_in, dim, bias=False, dtype=dtype, device=device) if dim_in is not None else nn.Identity() self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity() @@ -577,7 +703,22 @@ class ContinuousTransformer(nn.Module): self.use_abs_pos_emb = use_abs_pos_emb if use_abs_pos_emb: - self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length) + self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length + num_memory_tokens) + + if num_memory_tokens > 0: + self.memory_tokens = nn.Parameter(torch.empty(num_memory_tokens, dim, device=device, dtype=dtype)) + + # Shared global-cond embedder (SA3 style): projects (B, global_cond_dim) → (B, dim*6) + self.global_cond_embedder = None + if global_cond_shared_embed and global_cond_dim is not None: + self.global_cond_embedder = nn.Sequential( + operations.Linear(global_cond_dim, dim, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(dim, dim * 6, bias=True, dtype=dtype, device=device), + ) + + # When using shared embed, TransformerBlocks use per-block Parameter (not per-block MLP) + block_global_cond_dim = None if global_cond_shared_embed else global_cond_dim for i in range(depth): self.layers.append( @@ -586,7 +727,9 @@ class ContinuousTransformer(nn.Module): dim_heads = dim_heads, cross_attend = cross_attend, dim_context = cond_token_dim, - global_cond_dim = global_cond_dim, + global_cond_dim = block_global_cond_dim, + global_cond_shared_embed = global_cond_shared_embed, + local_add_cond_dim = local_add_cond_dim, causal = causal, zero_init_branch_outputs = zero_init_branch_outputs, conformer=conformer, @@ -605,6 +748,7 @@ class ContinuousTransformer(nn.Module): prepend_embeds = None, prepend_mask = None, global_cond = None, + local_add_cond = None, return_info = False, **kwargs ): @@ -632,7 +776,9 @@ class ContinuousTransformer(nn.Module): mask = torch.cat((prepend_mask, mask), dim = -1) - # Attention layers + if self.num_memory_tokens > 0: + memory_tokens = comfy.ops.cast_to_input(self.memory_tokens, x).expand(batch, -1, -1) + x = torch.cat((memory_tokens, x), dim=1) if self.rotary_pos_emb is not None: rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=torch.float, device=x.device) @@ -642,6 +788,10 @@ class ContinuousTransformer(nn.Module): if self.use_sinusoidal_emb or self.use_abs_pos_emb: x = x + self.pos_emb(x) + # Project global_cond once (SA3 shared-embed path) + if global_cond is not None and self.global_cond_embedder is not None: + global_cond = self.global_cond_embedder(global_cond) + blocks_replace = patches_replace.get("dit", {}) # Iterate over the transformer layers for i, layer in enumerate(self.layers): @@ -654,12 +804,17 @@ class ContinuousTransformer(nn.Module): out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": global_cond, "pe": rotary_pos_emb, "transformer_options": transformer_options}, {"original_block": block_wrap}) x = out["img"] else: - x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, context=context, transformer_options=transformer_options) - # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs) + x = layer(x, rotary_pos_emb=rotary_pos_emb, global_cond=global_cond, + local_add_cond=local_add_cond, context=context, + transformer_options=transformer_options) if return_info: info["hidden_states"].append(x) + # Strip memory tokens before projecting out + if self.num_memory_tokens > 0: + x = x[:, self.num_memory_tokens:, :] + x = self.project_out(x) if return_info: @@ -682,6 +837,7 @@ class AudioDiffusionTransformer(nn.Module): num_heads=24, transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer", global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend", + timestep_features_type: str = "learned", audio_model="", dtype=None, device=None, @@ -696,7 +852,10 @@ class AudioDiffusionTransformer(nn.Module): # Timestep embeddings timestep_features_dim = 256 - self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device) + if timestep_features_type == "expo": + self.timestep_features = ExpoFourierFeatures(timestep_features_dim, 0.5, 10000.0) + else: + self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device) self.to_timestep_embed = nn.Sequential( operations.Linear(timestep_features_dim, embed_dim, bias=True, dtype=dtype, device=device), @@ -781,6 +940,7 @@ class AudioDiffusionTransformer(nn.Module): cross_attn_cond=None, cross_attn_cond_mask=None, input_concat_cond=None, + local_add_cond=None, global_embed=None, prepend_cond=None, prepend_cond_mask=None, @@ -802,9 +962,13 @@ class AudioDiffusionTransformer(nn.Module): prepend_cond = self.to_prepend_embed(prepend_cond) prepend_inputs = prepend_cond + prepend_length = prepend_cond.shape[1] if prepend_cond_mask is not None: prepend_mask = prepend_cond_mask + if local_add_cond is not None and local_add_cond.dim() == 3: + local_add_cond = local_add_cond.permute(0, 2, 1) + if input_concat_cond is not None: # Interpolate input_concat_cond to the same length as x @@ -850,7 +1014,7 @@ class AudioDiffusionTransformer(nn.Module): if self.transformer_type == "x-transformers": output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs) elif self.transformer_type == "continuous_transformer": - output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs) + output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, local_add_cond=local_add_cond, **extra_args, **kwargs) if return_info: output, info = output @@ -876,6 +1040,7 @@ class AudioDiffusionTransformer(nn.Module): context=None, context_mask=None, input_concat_cond=None, + local_add_cond=None, global_embed=None, negative_global_embed=None, prepend_cond=None, @@ -890,6 +1055,7 @@ class AudioDiffusionTransformer(nn.Module): cross_attn_cond=context, cross_attn_cond_mask=context_mask, input_concat_cond=input_concat_cond, + local_add_cond=local_add_cond, global_embed=global_embed, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, diff --git a/comfy/ldm/audio/embedders.py b/comfy/ldm/audio/embedders.py index 20edb365a..ba9a62837 100644 --- a/comfy/ldm/audio/embedders.py +++ b/comfy/ldm/audio/embedders.py @@ -31,15 +31,39 @@ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module: ) +class ExpoFourierFeatures(nn.Module): + """Exponentially-spaced Fourier features (no learnable parameters).""" + def __init__(self, dim, min_freq=0.5, max_freq=10000.0): + super().__init__() + self.dim = dim + self.min_freq = min_freq + self.max_freq = max_freq + + def forward(self, t): + in_dtype = t.dtype + t = t.float() + if t.dim() == 1: + t = t.unsqueeze(-1) + half_dim = self.dim // 2 + ramp = torch.linspace(0, 1, half_dim, device=t.device, dtype=torch.float32) + freqs = torch.exp(ramp * (math.log(self.max_freq) - math.log(self.min_freq)) + math.log(self.min_freq)) + args = t * freqs * 2 * math.pi + return torch.cat([args.cos(), args.sin()], dim=-1).to(in_dtype) + + class NumberEmbedder(nn.Module): def __init__( self, features: int, dim: int = 256, + fourier_features_type="learned", ): super().__init__() self.features = features - self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) + if fourier_features_type == "expo": + self.embedding = nn.Sequential(ExpoFourierFeatures(dim=dim), comfy.ops.manual_cast.Linear(in_features=dim, out_features=features)) + else: + self.embedding = TimePositionalEmbedding(dim=dim, out_features=features) def forward(self, x: Union[List[float], Tensor]) -> Tensor: if not torch.is_tensor(x): @@ -77,14 +101,15 @@ class NumberConditioner(Conditioner): def __init__(self, output_dim: int, min_val: float=0, - max_val: float=1 + max_val: float=1, + fourier_features_type: str = "learned", ): super().__init__(output_dim, output_dim) self.min_val = min_val self.max_val = max_val - self.embedder = NumberEmbedder(features=output_dim) + self.embedder = NumberEmbedder(features=output_dim, fourier_features_type=fourier_features_type) def forward(self, floats, device=None): # Cast the inputs to floats diff --git a/comfy/ldm/audio/vae_sa3.py b/comfy/ldm/audio/vae_sa3.py new file mode 100644 index 000000000..276846444 --- /dev/null +++ b/comfy/ldm/audio/vae_sa3.py @@ -0,0 +1,533 @@ +import torch +import torch.nn as nn + +import comfy.ops +import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.audio.autoencoder import WNConv1d + +ops = comfy.ops.disable_weight_init + +class Transpose(nn.Module): + def forward(self, x, **kwargs): + return x.transpose(-2, -1) + + +def _zero_pad_modulo_sequence(x, size, dim=-2): + input_len = x.shape[dim] + pad_len = (size - input_len % size) % size + if pad_len > 0: + pad_shape = list(x.shape) + pad_shape[dim] = pad_len + x = torch.cat([x, torch.zeros(pad_shape, device=x.device, dtype=x.dtype)], dim=dim) + return x + + +def _sliding_window_mask(seq_len, window, device, dtype): + """Additive attention mask enforcing a ±window local window (matches flash_attn window_size).""" + i = torch.arange(seq_len, device=device).unsqueeze(1) + j = torch.arange(seq_len, device=device).unsqueeze(0) + out_of_window = (j - i).abs() > window + return torch.where( + out_of_window, + torch.full((1,), torch.finfo(dtype).min / 4, device=device, dtype=dtype), + torch.zeros(1, device=device, dtype=dtype), + ) + + +class DynamicTanh(nn.Module): + def __init__(self, dim, init_alpha=4.0, dtype=None, device=None, **kwargs): + super().__init__() + self.alpha = nn.Parameter(torch.empty(1, dtype=dtype, device=device)) + self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) + self.beta = nn.Parameter(torch.empty(dim, dtype=dtype, device=device)) + + def forward(self, x): + alpha = comfy.ops.cast_to_input(self.alpha, x) + gamma = comfy.ops.cast_to_input(self.gamma, x) + beta = comfy.ops.cast_to_input(self.beta, x) + return gamma * torch.tanh(alpha * x) + beta + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, base=10000, base_rescale_factor=1., dtype=None, device=None): + super().__init__() + base = base * base_rescale_factor ** (dim / (dim - 2)) + self.register_buffer("inv_freq", torch.empty(dim // 2, dtype=dtype, device=device)) + + def forward_from_seq_len(self, seq_len, device, dtype=None): + t = torch.arange(seq_len, device=device, dtype=torch.float32) + return self.forward(t) + + def forward(self, t): + freqs = torch.outer(t.float(), comfy.model_management.cast_to(self.inv_freq, dtype=torch.float32, device=t.device)) + freqs = torch.cat((freqs, freqs), dim=-1) + return freqs, 1. + + +def _rotate_half(x): + d = x.shape[-1] // 2 + return torch.cat((-x[..., d:], x[..., :d]), dim=-1) + + +def _apply_rotary_pos_emb(t, freqs): + out_dtype = t.dtype + rot_dim = freqs.shape[-1] + seq_len = t.shape[-2] + freqs = freqs[-seq_len:] + t_rot, t_pass = t[..., :rot_dim], t[..., rot_dim:] + t_rot = t_rot * freqs.cos() + _rotate_half(t_rot) * freqs.sin() + return torch.cat((t_rot.to(out_dtype), t_pass.to(out_dtype)), dim=-1) + + +class Attention(nn.Module): + def __init__(self, dim, dim_heads=64, qk_norm="none", qk_norm_eps=1e-6, + differential=False, zero_init_output=True, + dtype=None, device=None, operations=None, **kwargs): + super().__init__() + self.num_heads = dim // dim_heads + self.differential = differential + self.qk_norm = qk_norm + + self.to_qkv = operations.Linear( + dim, dim * (5 if differential else 3), bias=False, dtype=dtype, device=device) + self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device) + + if qk_norm == "dyt": + self.q_norm = DynamicTanh(dim_heads, dtype=dtype, device=device) + self.k_norm = DynamicTanh(dim_heads, dtype=dtype, device=device) + elif qk_norm == "rms": + self.q_norm = operations.RMSNorm(dim_heads, eps=qk_norm_eps, dtype=dtype, device=device) + self.k_norm = operations.RMSNorm(dim_heads, eps=qk_norm_eps, dtype=dtype, device=device) + + def forward(self, x, rotary_pos_emb=None, mask=None, **kwargs): + B, N, _ = x.shape + h = self.num_heads + + qkv = self.to_qkv(x) + if self.differential: + q, k, v, q_diff, k_diff = qkv.chunk(5, dim=-1) + del qkv + q = q.view(B, N, h, -1).transpose(1, 2) + k = k.view(B, N, h, -1).transpose(1, 2) + v = v.view(B, N, h, -1).transpose(1, 2) + q_diff = q_diff.view(B, N, h, -1).transpose(1, 2) + k_diff = k_diff.view(B, N, h, -1).transpose(1, 2) + else: + q, k, v = qkv.chunk(3, dim=-1) + del qkv + q = q.view(B, N, h, -1).transpose(1, 2) + k = k.view(B, N, h, -1).transpose(1, 2) + v = v.view(B, N, h, -1).transpose(1, 2) + + if self.qk_norm != "none": + q_dtype, k_dtype = q.dtype, k.dtype + q = self.q_norm(q).to(q_dtype) + k = self.k_norm(k).to(k_dtype) + if self.differential: + q_diff = self.q_norm(q_diff).to(q_dtype) + k_diff = self.k_norm(k_diff).to(k_dtype) + + if rotary_pos_emb is not None: + freqs, _ = rotary_pos_emb + q_dtype, k_dtype = q.dtype, k.dtype + q = _apply_rotary_pos_emb(q.float(), freqs).to(q_dtype) + k = _apply_rotary_pos_emb(k.float(), freqs).to(k_dtype) + if self.differential: + q_diff = _apply_rotary_pos_emb(q_diff.float(), freqs).to(q_dtype) + k_diff = _apply_rotary_pos_emb(k_diff.float(), freqs).to(k_dtype) + + if self.differential: + out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True) + - optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True)) + del q, k, v, q_diff, k_diff + else: + out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True) + del q, k, v + + return self.to_out(out) + + +class _Sin(nn.Module): + def forward(self, x): + return torch.sin(3.14159265359 * x) + + +class _GLU(nn.Module): + def __init__(self, dim_in, dim_out, activation, dtype=None, device=None, operations=None): + super().__init__() + self.act = activation + self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) + + def forward(self, x): + x = self.proj(x) + x, gate = x.chunk(2, dim=-1) + return x * self.act(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, mult=4, no_bias=False, zero_init_output=True, + sinusoidal=False, dtype=None, device=None, operations=None, **kwargs): + super().__init__() + inner_dim = int(dim * mult) + act = _Sin() if sinusoidal else nn.SiLU() + self.ff = nn.Sequential( + _GLU(dim, inner_dim, act, dtype=dtype, device=device, operations=operations), + nn.Identity(), + operations.Linear(inner_dim, dim, bias=not no_bias, dtype=dtype, device=device), + nn.Identity(), + ) + + def forward(self, x, **kwargs): + return self.ff(x) + + +class TransformerBlock(nn.Module): + def __init__(self, dim, dim_heads=64, causal=False, zero_init_branch_outputs=True, + norm_type="dyt", add_rope=False, attn_kwargs=None, ff_kwargs=None, + norm_kwargs=None, dtype=None, device=None, operations=None, **kwargs): + super().__init__() + if attn_kwargs is None: + attn_kwargs = {} + if ff_kwargs is None: + ff_kwargs = {} + if norm_kwargs is None: + norm_kwargs = {} + dim_heads = min(dim_heads, dim) + + Norm = DynamicTanh if norm_type == "dyt" else operations.RMSNorm + norm_kw = {**norm_kwargs, "dtype": dtype, "device": device} + + self.pre_norm = Norm(dim, **norm_kw) + self.self_attn = Attention(dim, dim_heads=dim_heads, + zero_init_output=zero_init_branch_outputs, + dtype=dtype, device=device, operations=operations, + **attn_kwargs) + self.ff_norm = Norm(dim, **norm_kw) + self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, + dtype=dtype, device=device, operations=operations, **ff_kwargs) + self.rope = RotaryEmbedding(dim_heads // 2, dtype=dtype, device=device) if add_rope else None + + def forward(self, x, mask=None, **kwargs): + rope = self.rope.forward_from_seq_len(x.shape[-2], device=x.device) \ + if self.rope is not None else None + x = x + self.self_attn(self.pre_norm(x), rotary_pos_emb=rope, mask=mask) + x = x + self.ff(self.ff_norm(x)) + return x + + +class TransformerResamplingBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, type="encoder", + transformer_depth=3, dim_heads=128, differential=True, + sliding_window=None, chunk_size=128, chunk_midpoint_shift=False, + dyt=True, ff_mult=3, mapping_bias=True, variable_stride=False, + sinusoidal_blocks=0, conv_mapping=False, dtype=None, device=None, operations=None, **kwargs): + super().__init__() + if type not in ("encoder", "decoder"): + raise ValueError(f"type must be 'encoder' or 'decoder', got {type!r}") + + self.type = type + self.stride = stride + self.chunk_size = chunk_size + self.chunk_midpoint_shift = chunk_midpoint_shift + self.variable_stride = variable_stride + self.transformer_depth = transformer_depth + + transformer_dim = out_channels if type == "encoder" else in_channels + + self.mapping = (WNConv1d(in_channels, out_channels, 3 if conv_mapping else 1, padding="same", bias=mapping_bias) + if in_channels != out_channels else nn.Identity()) + + self.sliding_window_latents = sliding_window + self.sliding_window_seq = self._get_sliding_window_size(sliding_window, stride) + self.input_seg_size, self.output_seg_size, self.sub_chunk_size = self._get_seg_sizes(stride) + + token_seq = 1 if variable_stride else self.output_seg_size + self.new_tokens = nn.Parameter(torch.empty(1, token_seq, transformer_dim, dtype=dtype, device=device)) + + norm_type = "dyt" if dyt else "rms_norm" + attn_kwargs = {"qk_norm": "dyt" if dyt else "rms", "qk_norm_eps": 1e-3, + "differential": differential} + norm_kwargs = {"eps": 1e-3} + transformers = [] + for i in range(transformer_depth): + sinusoidal = (transformer_depth - i) < sinusoidal_blocks + transformers.append(TransformerBlock( + transformer_dim, + dim_heads=dim_heads, + causal=False, + zero_init_branch_outputs=True, + norm_type=norm_type, + add_rope=True, + attn_kwargs=attn_kwargs, + ff_kwargs={"mult": ff_mult, "no_bias": False, "sinusoidal": sinusoidal}, + norm_kwargs=norm_kwargs, + dtype=dtype, device=device, operations=operations, + )) + self.transformers = nn.ModuleList(transformers) + + def _get_sliding_window_size(self, window, stride, prepend_cond_length=0): + if window is None: + return None + return [w * (stride + 1 + prepend_cond_length) for w in window] + + def _get_seg_sizes(self, stride, prepend_cond_length=0): + sub_chunk_size = stride + 1 + prepend_cond_length + input_seg_size = stride if self.type == "encoder" else 1 + output_seg_size = 1 if self.type == "encoder" else stride + return input_seg_size, output_seg_size, sub_chunk_size + + def forward(self, x, stride=None, **kwargs): + B = x.shape[0] + + if stride is None: + input_seg = self.input_seg_size + output_seg = self.output_seg_size + sub_chunk = self.sub_chunk_size + sliding_window = self.sliding_window_seq + else: + input_seg, output_seg, sub_chunk = self._get_seg_sizes(stride) + sliding_window = self._get_sliding_window_size(self.sliding_window_latents, stride) + + if self.type == "encoder": + if self.transformer_depth > 0: + pad_mod = self.chunk_size if sliding_window is None else input_seg + x = _zero_pad_modulo_sequence(x, pad_mod, dim=-1) + x = self.mapping(x) + + if self.transformer_depth > 0: + x = x.permute(0, 2, 1) + + if self.type != "encoder": + pad_mod = 1 if sliding_window is not None else ( + self.chunk_size // (stride if stride is not None else self.stride)) + x = _zero_pad_modulo_sequence(x, pad_mod) + + C = x.shape[2] + x = x.reshape(-1, input_seg, C) + + new_tokens = self.new_tokens.expand(x.shape[0], output_seg, -1) + x = torch.cat([x, comfy.ops.cast_to_input(new_tokens, x)], dim=-2) + del new_tokens + + x = x.reshape(B, -1, C) + + if sliding_window is None: + eff_chunk = self.chunk_size + self.chunk_size // (stride if stride is not None else self.stride) + + if sliding_window is None and self.chunk_midpoint_shift: + split = self.transformer_depth // 2 + shift = eff_chunk // 2 + + x = x.reshape(-1, eff_chunk, C) + for layer in self.transformers[:split]: + x = layer(x) + x = x.reshape(B, -1, C) + + shifted = torch.cat([x[:, :shift, :], x, x[:, -shift:, :]], dim=1) + del x + x = shifted.reshape(-1, eff_chunk, C) + del shifted + for layer in self.transformers[split:]: + x = layer(x) + x = x.reshape(B, -1, C) + x = x[:, shift:-shift, :] + elif sliding_window is None: + x = x.reshape(-1, eff_chunk, C) + for layer in self.transformers: + x = layer(x) + x = x.reshape(B, -1, C) + else: + attn_mask = _sliding_window_mask(x.shape[1], sliding_window[0], x.device, x.dtype) + for layer in self.transformers: + x = layer(x, mask=attn_mask) + + x = x.reshape(-1, sub_chunk, C) + x = x[:, -output_seg:, :] + x = x.reshape(B, -1, C).transpose(1, 2) + + if self.type == "decoder": + x = self.mapping(x) + + return x + + +class SAMEEncoder(nn.Module): + def __init__(self, in_channels=2, channels=128, latent_dim=32, + c_mults=(1, 2, 4, 8), strides=(2, 4, 8, 8), + transformer_depths=(3, 3, 3, 3), + dtype=None, device=None, operations=None, **kwargs): + super().__init__() + channel_dims = [in_channels] + [channels * c for c in c_mults] + layers = [] + for i in range(len(c_mults)): + layers.append(TransformerResamplingBlock( + in_channels=channel_dims[i], out_channels=channel_dims[i + 1], + stride=strides[i], type="encoder", + transformer_depth=transformer_depths[i], + dtype=dtype, device=device, operations=operations, **kwargs)) + layers += [ + Transpose(), + operations.Linear(channel_dims[-1], latent_dim, dtype=dtype, device=device), + Transpose(), + ] + self.layers = nn.ModuleList(layers) + + def forward(self, x, **kwargs): + for layer in self.layers: + x = layer(x) + return x + + +class SAMEDecoder(nn.Module): + def __init__(self, out_channels=2, channels=128, latent_dim=32, + c_mults=(1, 2, 4, 8), strides=(2, 4, 8, 8), + transformer_depths=(3, 3, 3, 3), sinusoidal_blocks=None, + dtype=None, device=None, operations=None, **kwargs): + super().__init__() + if sinusoidal_blocks is None: + sinusoidal_blocks = [0] * len(c_mults) + channel_dims = [out_channels] + [channels * c for c in c_mults] + layers = [ + Transpose(), + operations.Linear(latent_dim, channel_dims[-1], dtype=dtype, device=device), + Transpose(), + ] + for i in range(len(c_mults) - 1, -1, -1): + layers.append(TransformerResamplingBlock( + in_channels=channel_dims[i + 1], out_channels=channel_dims[i], + stride=strides[i], type="decoder", + transformer_depth=transformer_depths[i], + sinusoidal_blocks=sinusoidal_blocks[i], + dtype=dtype, device=device, operations=operations, **kwargs)) + self.layers = nn.ModuleList(layers) + + def forward(self, x, **kwargs): + for layer in self.layers: + x = layer(x) + return x + + +class SoftNormBottleneck(nn.Module): + def __init__(self, dim=32, noise_augment_dim=0, noise_regularize=False, + auto_scale=False, freeze=False, dtype=None, device=None, **kwargs): + super().__init__() + self.noise_augment_dim = noise_augment_dim + self.noise_regularize = noise_regularize + self.scaling_factor = nn.Parameter(torch.empty(1, dim, 1, dtype=dtype, device=device)) + self.bias = nn.Parameter(torch.empty(1, dim, 1, dtype=dtype, device=device)) + self.noise_scaling_factor = nn.Parameter(torch.empty(1, noise_augment_dim, 1, dtype=dtype, device=device)) + if auto_scale: + self.register_parameter("running_std", nn.Parameter( + torch.empty(1, dtype=dtype, device=device), requires_grad=False)) + if freeze: + for p in self.parameters(): + p.requires_grad = False + + def encode(self, x, return_info=False, **kwargs): + x = x * comfy.ops.cast_to_input(self.scaling_factor, x) \ + + comfy.ops.cast_to_input(self.bias, x) + if hasattr(self, "running_std"): + x = x / comfy.ops.cast_to_input(self.running_std, x) + if return_info: + return x, {} + return x + + def decode(self, x, **kwargs): + if hasattr(self, "running_std"): + x = x * comfy.ops.cast_to_input(self.running_std, x) + if self.noise_regularize: + scaling = self.running_std if hasattr(self, "running_std") \ + else x.std(dim=-1, keepdim=True) + noise = torch.randn_like(x) * comfy.ops.cast_to_input(scaling, x) * 1e-3 + x = x + noise + if self.noise_augment_dim > 0: + noise = comfy.ops.cast_to_input(self.noise_scaling_factor, x) * torch.randn( + x.shape[0], self.noise_augment_dim, x.shape[-1], device=x.device, dtype=x.dtype) + x = torch.cat([x, noise], dim=1) + return x + + +class PatchedPretransform(nn.Module): + def __init__(self, channels, patch_size, **kwargs): + super().__init__() + self.channels = channels + self.patch_size = patch_size + self.enable_grad = False + + def _pad(self, x): + pad_len = (self.patch_size - x.shape[-1] % self.patch_size) % self.patch_size + if pad_len > 0: + x = torch.cat([x, torch.zeros_like(x[:, :, :pad_len])], dim=-1) + return x + + def encode(self, x): + x = self._pad(x) + B, C, T = x.shape + h = self.patch_size + L = T // h + # b c (l h) -> b (c h) l + return x.reshape(B, C, L, h).permute(0, 1, 3, 2).reshape(B, C * h, L) + + def decode(self, x): + B, Ch, L = x.shape + h = self.patch_size + C = Ch // h + # b (c h) l -> b c (l h) + return x.reshape(B, C, h, L).permute(0, 1, 3, 2).reshape(B, C, L * h) + + +class SA3AudioVAE(nn.Module): + """SA3 VAE. State dict keys match checkpoint after stripping 'pretransform.model.'""" + + def __init__(self, channels=256, transformer_depths=12, sinusoidal_blocks=8, + sliding_window=None, decoder_conv_mapping=False, + chunk_size=128, chunk_midpoint_shift=False, + dtype=None, device=None, operations=None): + super().__init__() + if operations is None: + operations = ops + + self.pretransform = PatchedPretransform(channels=2, patch_size=256) + + common_kwargs = dict( + differential=True, dyt=True, dim_heads=64, + sliding_window=sliding_window, variable_stride=True, + chunk_size=chunk_size, chunk_midpoint_shift=chunk_midpoint_shift, + dtype=dtype, device=device, operations=operations, + ) + self.encoder = SAMEEncoder( + in_channels=512, channels=channels, c_mults=[6], strides=[16], + latent_dim=256, transformer_depths=[transformer_depths], + conv_mapping=False, **common_kwargs, + ) + self.decoder = SAMEDecoder( + out_channels=512, channels=channels, c_mults=[6], strides=[16], + latent_dim=256, transformer_depths=[transformer_depths], sinusoidal_blocks=[sinusoidal_blocks], + conv_mapping=decoder_conv_mapping, **common_kwargs, + ) + self.bottleneck = SoftNormBottleneck( + dim=256, noise_augment_dim=0, noise_regularize=True, + auto_scale=True, freeze=True, + dtype=dtype, device=device, + ) + + @torch.no_grad() + def _pretransform_encode(self, x): + return self.pretransform.encode(x) + + @torch.no_grad() + def _pretransform_decode(self, x): + return self.pretransform.decode(x) + + def encode(self, x): + x = self._pretransform_encode(x) + x = self.encoder(x) + x = self.bottleneck.encode(x) + return x + + def decode(self, x): + x = self.bottleneck.decode(x) + x = self.decoder(x) + x = self._pretransform_decode(x) + return x diff --git a/comfy/model_base.py b/comfy/model_base.py index c22705655..d81f13c69 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -813,6 +813,85 @@ class StableAudio1(BaseModel): sd["{}{}".format(k, l)] = s[l] return sd +class StableAudio3(BaseModel): + def __init__(self, model_config, seconds_total_embedder_weights, padding_embedding=None, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.audio.dit.AudioDiffusionTransformer) + self.seconds_total_embedder = comfy.ldm.audio.embedders.NumberConditioner(768, min_val=0, max_val=384, fourier_features_type=model_config.unet_config["timestep_features_type"]) + self.seconds_total_embedder.load_state_dict(seconds_total_embedder_weights) + if padding_embedding is not None: + self.padding_embedding = torch.nn.Parameter(padding_embedding, requires_grad=False) + else: + self.padding_embedding = None + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + image = kwargs.get("concat_latent_image", None) + + if image is None: + shape_image = list(noise.shape) + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + image = self.process_latent_in(image) + # TODO: scale if not match + image = utils.resize_to_batch_size(image, noise.shape[0]) + + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + mask = torch.zeros_like(noise)[:, :1] + else: + if mask.shape[1] != 1: + mask = torch.mean(mask, dim=1, keepdim=True) + mask = 1.0 - mask + # TODO: scale if not match + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + + return torch.cat((mask, image), dim=1) + + def extra_conds(self, **kwargs): + out = {} + + concat_cond = self.concat_cond(**kwargs) + if concat_cond is not None: + out['local_add_cond'] = comfy.conds.CONDNoiseShape(concat_cond) + + noise = kwargs.get("noise", None) + device = kwargs["device"] + + seconds_total = kwargs.get("seconds_total", int(noise.shape[-1] / 10.7666)) + seconds_total_embed = self.seconds_total_embedder([seconds_total])[0].to(device) + + global_embed = seconds_total_embed.reshape((1, -1)) + out['global_embed'] = comfy.conds.CONDRegular(global_embed) + + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + cross_attn = cross_attn.to(device) + if self.padding_embedding is not None: + pe = self.padding_embedding.to(device=device, dtype=cross_attn.dtype) + max_text_tokens = self.model_config.unet_config.get("max_text_tokens", 256) + n_text = cross_attn.shape[1] + if n_text < max_text_tokens: + pad = pe.view(1, 1, -1).expand(cross_attn.shape[0], max_text_tokens - n_text, -1) + cross_attn = torch.cat([cross_attn, pad], dim=1) + cross_attn = torch.cat([cross_attn, seconds_total_embed.repeat((cross_attn.shape[0], 1, 1))], dim=1) + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + return out + + def state_dict_for_saving(self, unet_state_dict, clip_state_dict=None, vae_state_dict=None, clip_vision_state_dict=None): + sd = super().state_dict_for_saving(unet_state_dict, clip_state_dict=clip_state_dict, vae_state_dict=vae_state_dict, clip_vision_state_dict=clip_vision_state_dict) + + d = {"conditioner.conditioners.seconds_total.": self.seconds_total_embedder.state_dict()} + + for k in d: + s = d[k] + for l in s: + sd["{}{}".format(k, l)] = s[l] + + if self.padding_embedding is not None: + sd["conditioner.conditioners.prompt.padding_embedding"] = self.padding_embedding.data + return sd + class HunyuanDiT(BaseModel): def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index bc0b933bc..70b4df8b3 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -116,6 +116,45 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit unet_config = {} unet_config["audio_model"] = "dit1.0" + unet_config["global_cond_dim"] = state_dict['{}to_global_embed.0.weight'.format(key_prefix)].shape[1] + cond_embed = state_dict['{}to_cond_embed.0.weight'.format(key_prefix)] + unet_config["project_cond_tokens"] = cond_embed.shape[0] != cond_embed.shape[1] + unet_config["embed_dim"] = state_dict['{}to_timestep_embed.0.weight'.format(key_prefix)].shape[0] + mem_tokens = state_dict.get('{}transformer.memory_tokens'.format(key_prefix), None) + to_qkv = state_dict.get('{}transformer.layers.0.self_attn.to_qkv.weight'.format(key_prefix), None) + differential = False + if to_qkv is not None: + if to_qkv.shape[0] == to_qkv.shape[1] * 5: + differential = True + if mem_tokens is not None: + unet_config["num_memory_tokens"] = mem_tokens.shape[0] + if '{}transformer.layers.0.self_attn.q_norm.weight'.format(key_prefix) in state_dict: + unet_config["attn_kwargs"] = {"qk_norm": "ln", "feat_scale": True} + rms_norm = state_dict.get('{}transformer.layers.0.self_attn.q_norm.gamma'.format(key_prefix), None) + if rms_norm is not None: + unet_config["attn_kwargs"] = {"qk_norm": "rms", "differential": differential} + unet_config["norm_type"] = "rms_norm" + unet_config["num_heads"] = unet_config["embed_dim"] // rms_norm.shape[0] + + if '{}timestep_features.weight'.format(key_prefix) in state_dict: + unet_config["timestep_features_type"] = "learned" + else: + unet_config["timestep_features_type"] = "expo" + + io_channels = state_dict['{}postprocess_conv.weight'.format(key_prefix)].shape[0] + unet_config["io_channels"] = io_channels + unet_config["input_concat_dim"] = state_dict['{}transformer.project_in.weight'.format(key_prefix)].shape[1] - io_channels + + local_add_cond = state_dict.get('{}transformer.layers.0.to_local_embed.0.weight'.format(key_prefix), None) + if local_add_cond is not None: + unet_config["local_add_cond_dim"] = local_add_cond.shape[1] + + global_cond_embed = state_dict.get('{}transformer.global_cond_embedder.0.weight'.format(key_prefix), None) + if global_cond_embed is not None: + unet_config["global_cond_shared_embed"] = True + unet_config["global_cond_type"] = "adaLN" + + unet_config["depth"] = count_blocks(state_dict_keys, '{}transformer.layers.'.format(key_prefix) + '{}.') return unet_config if '{}double_layers.0.attn.w1q.weight'.format(key_prefix) in state_dict_keys: #aura flow dit diff --git a/comfy/sd.py b/comfy/sd.py index 2443353a4..7bd07ed3a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -21,6 +21,7 @@ import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.cogvideo.vae import comfy.ldm.hunyuan_video.vae import comfy.ldm.mmaudio.vae.autoencoder +import comfy.ldm.audio.vae_sa3 import comfy.pixel_space_convert import comfy.weight_adapter import yaml @@ -67,6 +68,7 @@ import comfy.text_encoders.qwen35 import comfy.text_encoders.ernie import comfy.text_encoders.gemma4 import comfy.text_encoders.cogvideo +import comfy.text_encoders.sa3 import comfy.model_patcher import comfy.lora @@ -854,6 +856,34 @@ class VAE: self.working_dtypes = [torch.float32] self.disable_offload = True self.extra_1d_channel = 16 + elif "decoder.layers.3.transformers.0.pre_norm.alpha" in sd: # Stable Audio 3 VAE + if "decoder.layers.3.transformers.11.self_attn.to_out.weight" in sd: + config = {"channels": 256, "transformer_depths": 12, "sinusoidal_blocks": 8, + "sliding_window": [1, 1], "decoder_conv_mapping": False, + "chunk_size": 128, "chunk_midpoint_shift": False} + self.memory_used_encode = lambda shape, dtype: (1500 * shape[2]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * 4096) * model_management.dtype_size(dtype) + else: + config = {"channels": 128, "transformer_depths": 6, "sinusoidal_blocks": 0, + "sliding_window": None, "decoder_conv_mapping": True, + "chunk_size": 32, "chunk_midpoint_shift": True} + self.memory_used_encode = lambda shape, dtype: (72 * shape[2]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (72 * shape[2] * 4096) * model_management.dtype_size(dtype) + + self.first_stage_model = comfy.ldm.audio.vae_sa3.SA3AudioVAE(**config) + self.latent_channels = 256 + self.output_channels = 2 + self.upscale_ratio = 4096 + self.downscale_ratio = 4096 + self.latent_dim = 1 + self.audio_sample_rate = 44100 + self.process_output = lambda audio: audio + self.process_input = lambda audio: audio + self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] + #This VAE has Parameters and Buffers the non-dynamic caster cannot handle + #Force cast it for --disable-dynamic-vram users until there is a true core fix. + if not comfy.memory_management.aimdo_enabled: + self.disable_offload = True else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -1290,6 +1320,7 @@ class TEModel(Enum): GEMMA_4_E4B = 29 GEMMA_4_E2B = 30 GEMMA_4_31B = 31 + T5_GEMMA = 32 def detect_te_model(sd): @@ -1314,6 +1345,8 @@ def detect_te_model(sd): if weight.shape[0] == 384: return TEModel.BYT5_SMALL_GLYPH return TEModel.T5_BASE + if "model.encoder.layers.0.pre_self_attn_layernorm.weight" in sd: + return TEModel.T5_GEMMA if 'model.layers.0.post_feedforward_layernorm.weight' in sd: if 'model.layers.59.self_attn.q_norm.weight' in sd: return TEModel.GEMMA_4_31B @@ -1463,6 +1496,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip else: clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer + elif te_model == TEModel.T5_GEMMA: + clip_target.clip = comfy.text_encoders.sa3.SAT5GemmaModel + clip_target.tokenizer = comfy.text_encoders.sa3.SAT5GemmaTokenizer + tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B): variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B, TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B, diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1e4434fd5..617db4f28 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -7,6 +7,7 @@ from . import sdxl_clip import comfy.text_encoders.sd2_clip import comfy.text_encoders.sd3_clip import comfy.text_encoders.sa_t5 +import comfy.text_encoders.sa3 import comfy.text_encoders.aura_t5 import comfy.text_encoders.pixart_t5 import comfy.text_encoders.hydit @@ -603,6 +604,29 @@ class StableAudio(supported_models_base.BASE): def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.sa_t5.SAT5Tokenizer, comfy.text_encoders.sa_t5.SAT5Model) +class StableAudio3(StableAudio): + unet_config = { + "audio_model": "dit1.0", + "global_cond_shared_embed": True, + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 2.0, + } + + latent_format = latent_formats.StableAudio3 + + memory_usage_factor = 7 + + def get_model(self, state_dict, prefix="", device=None): + seconds_total_sd = utils.state_dict_prefix_replace(state_dict, {"conditioner.conditioners.seconds_total.": ""}, filter_keys=True) + padding_embedding = state_dict.get("conditioner.conditioners.prompt.padding_embedding", None) + return model_base.StableAudio3(self, seconds_total_embedder_weights=seconds_total_sd, padding_embedding=padding_embedding, device=device) + + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget(comfy.text_encoders.sa3.SAT5GemmaTokenizer, comfy.text_encoders.sa3.SAT5GemmaModel) + class AuraFlow(supported_models_base.BASE): unet_config = { "cond_seq_dim": 2048, @@ -2018,6 +2042,7 @@ models = [ SV3D_u, SV3D_p, SD3, + StableAudio3, StableAudio, AuraFlow, PixArtAlpha, diff --git a/comfy/text_encoders/sa3.py b/comfy/text_encoders/sa3.py new file mode 100644 index 000000000..0a1c73ec1 --- /dev/null +++ b/comfy/text_encoders/sa3.py @@ -0,0 +1,207 @@ +import torch +import torch.nn as nn +from comfy import sd1_clip +from comfy.text_encoders.llama import Attention as LlamaAttention, RMSNorm, MLP, precompute_freqs_cis, apply_rope, _make_scaled_embedding +from comfy.text_encoders.spiece_tokenizer import SPieceTokenizer + + +class T5GemmaEncoderConfig: + def __init__(self): + self.vocab_size = 256000 + self.hidden_size = 768 + self.intermediate_size = 2048 + self.num_hidden_layers = 12 + self.num_attention_heads = 12 + self.num_key_value_heads = 12 + self.head_dim = 64 + self.rms_norm_eps = 1e-6 + self.rms_norm_add = False + self.rope_theta = 10000.0 + self.attn_logit_softcapping = 50.0 + self.query_pre_attn_scalar = 64 + self.sliding_window = 4096 + self.mlp_activation = "gelu_pytorch_tanh" + self.layer_types = ["sliding_attention", "full_attention"] * 6 + self.qkv_bias = False + self.q_norm = None + self.k_norm = None + self.rms_norm_add = True + + +class T5GemmaAttention(LlamaAttention): + """Reuses LlamaAttention projection setup; overrides forward for softcap attention. + + T5Gemma applies tanh(QK^T * scale / cap) * cap between the matmul and softmax. + This nonlinearity is incompatible with fused SDPA kernels, so attention is + computed manually. Everything else (projections, RoPE, GQA expansion) is identical + to LlamaAttention so __init__ is inherited unchanged. + """ + + def __init__(self, config, device=None, dtype=None, ops=None): + super().__init__(config, device=device, dtype=dtype, ops=ops) + self.scale = config.query_pre_attn_scalar ** -0.5 + self.softcap = config.attn_logit_softcapping + + def forward(self, hidden_states, attention_mask=None, freqs_cis=None, **kwargs): + B, S, _ = hidden_states.shape + xq = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) + xk = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) + xv = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) + xq, xk = apply_rope(xq, xk, freqs_cis) + xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) + attn = torch.matmul(xq * self.scale, xk.transpose(-2, -1)) + attn = torch.tanh(attn / self.softcap) * self.softcap + if attention_mask is not None: + attn = attn + attention_mask + attn = torch.nn.functional.softmax(attn.float(), dim=-1).to(xq.dtype) + out = torch.matmul(attn, xv).transpose(1, 2).reshape(B, S, self.inner_size) + return self.o_proj(out), None + + +class T5GemmaBlock(nn.Module): + def __init__(self, config, layer_type, device=None, dtype=None, ops=None): + super().__init__() + self.self_attn = T5GemmaAttention(config, device=device, dtype=dtype, ops=ops) + self.mlp = MLP(config, device=device, dtype=dtype, ops=ops) + # Names match checkpoint keys: model.encoder.layers.X..weight + self.pre_self_attn_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype) + self.post_self_attn_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype) + self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype) + self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype) + self.is_sliding = (layer_type == "sliding_attention") + self.sliding_window = config.sliding_window + + def forward(self, x, attention_mask=None, freqs_cis=None): + attn_mask = attention_mask + if self.is_sliding and x.shape[1] > self.sliding_window: + S = x.shape[1] + pos = torch.arange(S, device=x.device) + dist = (pos.unsqueeze(0) - pos.unsqueeze(1)).abs() + sw_mask = torch.zeros(S, S, dtype=x.dtype, device=x.device) + sw_mask.masked_fill_(dist > self.sliding_window, -torch.finfo(x.dtype).max) + sw_mask = sw_mask.unsqueeze(0).unsqueeze(0) + attn_mask = (attention_mask + sw_mask) if attention_mask is not None else sw_mask + residual = x + x = self.pre_self_attn_layernorm(x) + x, _ = self.self_attn(x, attention_mask=attn_mask, freqs_cis=freqs_cis) + x = self.post_self_attn_layernorm(x) + x = residual + x + residual = x + x = self.pre_feedforward_layernorm(x) + x = self.mlp(x) + x = self.post_feedforward_layernorm(x) + x = residual + x + return x + + +class T5GemmaEncoder(nn.Module): + """Encoder stack: embed_tokens, layers, norm. + Keys: embed_tokens.*, layers.X.*, norm.*""" + + def __init__(self, config, device, dtype, ops): + super().__init__() + self.config = config + # Gemma-style scaled embedding: output *= sqrt(hidden_size) + self.embed_tokens = _make_scaled_embedding( + ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype) + self.layers = nn.ModuleList([ + T5GemmaBlock(config, config.layer_types[i], device=device, dtype=dtype, ops=ops) + for i in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=True, device=device, dtype=dtype) + + def forward(self, input_ids, attention_mask=None, embeds=None, intermediate_output=None, + final_layer_norm_intermediate=True, dtype=None, num_layers=None): + x = embeds if embeds is not None else self.embed_tokens(input_ids, out_dtype=dtype or torch.float32) + seq_len = x.shape[1] + position_ids = torch.arange(seq_len, device=x.device).unsqueeze(0) + freqs_cis = precompute_freqs_cis(self.config.head_dim, position_ids, self.config.rope_theta, device=x.device) + mask = None + if attention_mask is not None: + mask = 1.0 - attention_mask.to(x.dtype).reshape( + (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) + ).expand(attention_mask.shape[0], 1, seq_len, attention_mask.shape[-1]) + mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max) + intermediate = None + for i, layer in enumerate(self.layers): + x = layer(x, attention_mask=mask, freqs_cis=freqs_cis) + if i == intermediate_output: + intermediate = x.clone() + x = self.norm(x) + if intermediate is not None and final_layer_norm_intermediate: + intermediate = self.norm(intermediate) + return x, intermediate + + +class T5GemmaBody(nn.Module): + """Provides the 'encoder' sub-module. + Keys: encoder.*""" + + def __init__(self, config, device, dtype, ops): + super().__init__() + self.encoder = T5GemmaEncoder(config, device, dtype, ops) + + +class T5GemmaModel(nn.Module): + """Top-level model class passed to SDClipModel as model_class. + Module layout: self.model.encoder.* → matches checkpoint keys model.encoder.*""" + + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = T5GemmaEncoderConfig() + self.num_layers = config.num_hidden_layers + self.dtype = dtype + self.model = T5GemmaBody(config, device, dtype, operations) + + def get_input_embeddings(self): + return self.model.encoder.embed_tokens + + def set_input_embeddings(self, embeddings): + self.model.encoder.embed_tokens = embeddings + + def forward(self, input_ids, attention_mask=None, embeds=None, num_tokens=None, + intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, **kwargs): + if intermediate_output is not None and intermediate_output < 0: + intermediate_output = self.num_layers + intermediate_output + return self.model.encoder( + input_ids, attention_mask=attention_mask, embeds=embeds, + intermediate_output=intermediate_output, + final_layer_norm_intermediate=final_layer_norm_intermediate, + dtype=dtype, num_layers=self.num_layers) + + +class T5GemmaSDClipModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, + textmodel_json_config={}, dtype=dtype, + special_tokens={"pad": 0}, + model_class=T5GemmaModel, + enable_attention_masks=True, zero_out_masked=True, + model_options=model_options) + + +class T5GemmaSDTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_model = tokenizer_data.get("spiece_model", None) + super().__init__(tokenizer_model, pad_with_end=False, embedding_size=768, + embedding_key="t5gemma", tokenizer_class=SPieceTokenizer, + has_start_token=False, has_end_token=False, pad_to_max_length=False, + max_length=99999999, min_length=1, pad_token=0, + tokenizer_data=tokenizer_data, + tokenizer_args={"add_bos": False, "add_eos": False}) + + def state_dict(self): + return {"spiece_model": self.tokenizer.serialize_model()} + + +class SAT5GemmaTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, + tokenizer_data=tokenizer_data, clip_name="t5gemma", tokenizer=T5GemmaSDTokenizer) + + +class SAT5GemmaModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): + super().__init__(device=device, dtype=dtype, model_options=model_options, + name="t5gemma", clip_model=T5GemmaSDClipModel, **kwargs) From 4efe1ddb5c7933e9db28d5ecb0be4fa77e0edcde Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Wed, 20 May 2026 23:46:20 +0800 Subject: [PATCH 06/45] chore: update workflow templates to v0.9.79 (#14011) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index f499a10ae..1c87690da 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.43.18 -comfyui-workflow-templates==0.9.77 +comfyui-workflow-templates==0.9.79 comfyui-embedded-docs==0.5.0 torch torchsde From a8d2519058ea766ca3b14916bcc01ecef5efd235 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 20 May 2026 13:49:36 -0400 Subject: [PATCH 07/45] ComfyUI v0.22.0 --- comfyui_version.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfyui_version.py b/comfyui_version.py index 4c6f5eb2a..0bb0f780c 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.21.1" +__version__ = "0.22.0" diff --git a/pyproject.toml b/pyproject.toml index 0a1554428..1e449b4a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.21.1" +version = "0.22.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" From 4d6a058bf1dd18fb6d4594081c3f9a7575c97256 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Thu, 21 May 2026 02:07:48 +0300 Subject: [PATCH 08/45] feat: MediaPipe face detection (CORE-235) (#14009) * Initial mediapipe face detection support * Update face_geometry.py * Account for diff sized batch input * Model folder placeholder --- comfy_extras/mediapipe/face_geometry.py | 111 ++++ comfy_extras/mediapipe/face_landmarker.py | 682 +++++++++++++++++++++ comfy_extras/nodes_mediapipe.py | 502 +++++++++++++++ folder_paths.py | 2 + models/mediapipe/put_mediapipe_models_here | 0 nodes.py | 1 + 6 files changed, 1298 insertions(+) create mode 100644 comfy_extras/mediapipe/face_geometry.py create mode 100644 comfy_extras/mediapipe/face_landmarker.py create mode 100644 comfy_extras/nodes_mediapipe.py create mode 100644 models/mediapipe/put_mediapipe_models_here diff --git a/comfy_extras/mediapipe/face_geometry.py b/comfy_extras/mediapipe/face_geometry.py new file mode 100644 index 000000000..04b2b0557 --- /dev/null +++ b/comfy_extras/mediapipe/face_geometry.py @@ -0,0 +1,111 @@ +"""Pure-numpy port of MediaPipe's face_geometry (FACE_LANDMARK_PIPELINE mode) ++ weighted Procrustes solver. Computes the 4x4 facial transformation matrix. +""" + +from __future__ import annotations + +import math +import numpy as np + + +def _solve_weighted_orthogonal_problem(src: np.ndarray, tgt: np.ndarray, weights: np.ndarray) -> np.ndarray: + """Weighted orthogonal Procrustes (similarity). Returns 4x4 M with + `target ≈ M @ homogeneous(source)` in the weighted LS sense. fp64 for + SVD stability. Port of procrustes_solver.cc.""" + sqrt_w = np.sqrt(weights.astype(np.float64)) + w_total = float((sqrt_w ** 2).sum()) + ws = src.astype(np.float64) * sqrt_w + wt = tgt.astype(np.float64) * sqrt_w + + c_w = (ws @ sqrt_w) / w_total + centered = ws - np.outer(c_w, sqrt_w) + U, _S, Vt = np.linalg.svd(wt @ centered.T, full_matrices=True) + # Disallow reflection: flip the least-significant axis when det(U)·det(V)<0. + post, pre = U.copy(), Vt.T.copy() + if np.linalg.det(post) * np.linalg.det(pre) < 0: + post[:, 2] *= -1.0 + R = post @ pre.T + + denom = float((centered * ws).sum()) + if denom < 1e-12: + raise ValueError("Procrustes denominator collapsed (degenerate source).") + scale = float((R @ centered * wt).sum()) / denom + translation = ((wt - scale * (R @ ws)) @ sqrt_w) / w_total + + M = np.eye(4, dtype=np.float64) + M[:3, :3] = scale * R + M[:3, 3] = translation + return M + + +def _estimate_scale(canonical: np.ndarray, runtime: np.ndarray, weights: np.ndarray) -> float: + """scale = ‖first column of M[:3]‖ per geometry_pipeline.cc::EstimateScale.""" + return float(np.linalg.norm(_solve_weighted_orthogonal_problem(canonical, runtime, weights)[:3, 0])) + + +def solve_facial_transformation_matrix( + landmarks_normalized: np.ndarray, + canonical_vertices: np.ndarray, + procrustes_indices: np.ndarray, + procrustes_weights: np.ndarray, + image_width: int, + image_height: int, + # face_geometry_calculator_options.pbtxt defaults + vertical_fov_degrees: float = 63.0, + near: float = 1.0, +) -> np.ndarray: + """4x4 facial transformation matrix via two-pass scale recovery + `landmarks_normalized` is (N, 3) in MediaPipe normalized convention: x, y + in [0,1] with TOP-LEFT origin, z in width-scaled units. + """ + + h_near = 2.0 * near * math.tan(0.5 * math.radians(vertical_fov_degrees)) + w_near = image_width * h_near / image_height + + sub = procrustes_indices.astype(np.int64) + screen = landmarks_normalized[sub].T.astype(np.float64).copy() + canon = canonical_vertices[sub].T.astype(np.float64).copy() + weights = procrustes_weights.astype(np.float64) + + # ProjectXY (TOP_LEFT y-flip, then scale all 3 axes; z uses x-scale). + screen[1] = 1.0 - screen[1] + screen[0] = screen[0] * w_near - 0.5 * w_near + screen[1] = screen[1] * h_near - 0.5 * h_near + screen[2] = screen[2] * w_near + depth_offset = float(screen[2].mean()) + + def _unproject(s: np.ndarray, scale: float) -> np.ndarray: + s = s.copy() + s[2] = (s[2] - depth_offset + near) / scale + s[0] *= s[2] / near + s[1] *= s[2] / near + s[2] *= -1.0 + return s + + first = screen.copy() + first[2] *= -1.0 + s1 = _estimate_scale(canon, first, weights) # 1st pass: Procrustes on projected XY + s2 = _estimate_scale(canon, _unproject(screen, s1), weights) # 2nd pass: rescale z by s1, un-project XY + return _solve_weighted_orthogonal_problem(canon, _unproject(screen, s1 * s2), weights).astype(np.float32) + + +def transformation_matrix_from_detection(face_dict: dict, image_width: int, image_height: int, canonical_data: dict) -> np.ndarray: + """Adapt a FaceLandmarker face dict to MP's normalized convention and solve. + FaceMesh emits (x, y, z) in 192-canonical units; MP's geometry expects + z_norm = z_canonical * scale_x / image_width""" + + lmks_xy, lmks_3d = face_dict["landmarks_xy"], face_dict["landmarks_3d"] + aug = np.concatenate([lmks_3d[:, :2].astype(np.float64), np.ones((lmks_xy.shape[0], 1))], axis=1) + M, *_ = np.linalg.lstsq(aug, lmks_xy.astype(np.float64), rcond=None) + scale_x = float(np.linalg.norm(M[0])) + z_scale = scale_x / image_width if scale_x > 1e-6 else 1.0 / image_width + + normalized = np.empty((lmks_xy.shape[0], 3), dtype=np.float32) + normalized[:, 0] = lmks_xy[:, 0] / image_width + normalized[:, 1] = lmks_xy[:, 1] / image_height + normalized[:, 2] = lmks_3d[:, 2] * z_scale + return solve_facial_transformation_matrix( + normalized, canonical_data["canonical_vertices"], + canonical_data["procrustes_indices"], canonical_data["procrustes_weights"], + image_width=image_width, image_height=image_height, + ) diff --git a/comfy_extras/mediapipe/face_landmarker.py b/comfy_extras/mediapipe/face_landmarker.py new file mode 100644 index 000000000..a792b6046 --- /dev/null +++ b/comfy_extras/mediapipe/face_landmarker.py @@ -0,0 +1,682 @@ +"""Pure-PyTorch port of MediaPipe's face_landmarker_v2_with_blendshapes.task: +BlazeFace detector → FaceMesh v2 → ARKit-52 blendshapes.""" + +from __future__ import annotations + +import math +from functools import lru_cache +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from scipy.special import expit +from torch import Tensor, nn + + +# Values below must stay verbatim with the published face_landmarker_v2 graph + +# face_blendshapes_graph.cc::kLandmarksSubsetIdxs +_BS_INPUT_INDICES: Tuple[int, ...] = ( + 0, 1, 4, 5, 6, 7, 8, 10, 13, 14, 17, 21, 33, 37, 39, 40, 46, 52, 53, 54, + 55, 58, 61, 63, 65, 66, 67, 70, 78, 80, 81, 82, 84, 87, 88, 91, 93, 95, + 103, 105, 107, 109, 127, 132, 133, 136, 144, 145, 146, 148, 149, 150, 152, + 153, 154, 155, 157, 158, 159, 160, 161, 162, 163, 168, 172, 173, 176, 178, + 181, 185, 191, 195, 197, 234, 246, 249, 251, 263, 267, 269, 270, 276, 282, + 283, 284, 285, 288, 291, 293, 295, 296, 297, 300, 308, 310, 311, 312, 314, + 317, 318, 321, 323, 324, 332, 334, 336, 338, 356, 361, 362, 365, 373, 374, + 375, 377, 378, 379, 380, 381, 382, 384, 385, 386, 387, 388, 389, 390, 397, + 398, 400, 402, 405, 409, 415, 454, 466, 468, 469, 470, 471, 472, 473, 474, + 475, 476, 477, +) + +# face_blendshapes_graph.cc::kCategoryNames +BLENDSHAPE_NAMES: Tuple[str, ...] = ( + "_neutral", "browDownLeft", "browDownRight", "browInnerUp", "browOuterUpLeft", + "browOuterUpRight", "cheekPuff", "cheekSquintLeft", "cheekSquintRight", + "eyeBlinkLeft", "eyeBlinkRight", "eyeLookDownLeft", "eyeLookDownRight", + "eyeLookInLeft", "eyeLookInRight", "eyeLookOutLeft", "eyeLookOutRight", + "eyeLookUpLeft", "eyeLookUpRight", "eyeSquintLeft", "eyeSquintRight", + "eyeWideLeft", "eyeWideRight", "jawForward", "jawLeft", "jawOpen", + "jawRight", "mouthClose", "mouthDimpleLeft", "mouthDimpleRight", + "mouthFrownLeft", "mouthFrownRight", "mouthFunnel", "mouthLeft", + "mouthLowerDownLeft", "mouthLowerDownRight", "mouthPressLeft", + "mouthPressRight", "mouthPucker", "mouthRight", "mouthRollLower", + "mouthRollUpper", "mouthShrugLower", "mouthShrugUpper", "mouthSmileLeft", + "mouthSmileRight", "mouthStretchLeft", "mouthStretchRight", + "mouthUpperUpLeft", "mouthUpperUpRight", "noseSneerLeft", "noseSneerRight", +) + +# face_detection.pbtxt — short-range BlazeFace. +_BF_NUM_LAYERS = 4 +_BF_INPUT_SIZE = 128 +_BF_STRIDES = (8, 16, 16, 16) +_BF_ANCHOR_OFFSET_X = 0.5 +_BF_ANCHOR_OFFSET_Y = 0.5 +_BF_ASPECT_RATIOS = (1.0,) +_BF_INTERP_SCALE_AR = 1.0 +_BF_BOX_SCALE = 128.0 +_BF_KP_OFFSET = 4 +_BF_SCORE_CLIP = 100.0 +_BF_MIN_SCORE = 0.5 + +# face_detection_full_range.pbtxt — 48x48 grid at stride 4, 1 anchor/cell. +_BF_FR_INPUT_SIZE = 192 +_BF_FR_GRID = 48 +_BF_FR_NUM_ANCHORS = _BF_FR_GRID * _BF_FR_GRID +_BF_FR_BOX_SCALE = 192.0 +_BF_FR_SCORE_CLIP = 100.0 + +_FM_INPUT_SIZE = 192 + +# Face ROI: 1.5xbbox rect warped anisotropically into 192x192. +_FACE_LEFT_EYE_KP = 0 +_FACE_RIGHT_EYE_KP = 1 +_FACE_ROI_SCALE_X = 1.5 +_FACE_ROI_SCALE_Y = 1.5 +_FACE_ROI_TARGET_ANGLE = 0.0 + + +def _tf_same_pad(x: Tensor, kernel: int, stride: int) -> Tensor: + """TF SAME pad (asymmetric on stride-2; PyTorch's symmetric pad undershoots by 1 px).""" + H, W = x.shape[-2], x.shape[-1] + pad_h = max(((H + stride - 1) // stride - 1) * stride + kernel - H, 0) + pad_w = max(((W + stride - 1) // stride - 1) * stride + kernel - W, 0) + if pad_h == 0 and pad_w == 0: + return x + return F.pad(x, (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) + + +# BlazeFace short-range: stem 5x5/s2 → 16 BlazeBlocks → parallel heads at +# 16²x88 (2 anchors/cell) and 8²x96 (6/cell) = 896 anchors. (in, out, stride): +_BLAZEFACE_BLOCKS = [ + (24, 24, 1), (24, 28, 1), (28, 32, 2), (32, 36, 1), + (36, 42, 1), (42, 48, 2), (48, 56, 1), (56, 64, 1), + (64, 72, 1), (72, 80, 1), (80, 88, 1), (88, 96, 2), + (96, 96, 1), (96, 96, 1), (96, 96, 1), (96, 96, 1), +] + + +class BlazeFaceBlock(nn.Module): + """DW 3x3 + PW + residual. Residual max-pools on stride>1, channel-pads on out_ch>in_ch.""" + + def __init__(self, in_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + self.in_ch, self.out_ch, self.stride = in_ch, out_ch, stride + self.depthwise = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, device=device, dtype=dtype) + self.pointwise = ops.Conv2d(in_ch, out_ch, 1, padding=0, bias=True, device=device, dtype=dtype) + + def forward(self, x: Tensor) -> Tensor: + residual = F.max_pool2d(x, 2, 2) if self.stride > 1 else x + if self.out_ch > self.in_ch: + residual = F.pad(residual, (0, 0, 0, 0, 0, self.out_ch - self.in_ch)) + x = _tf_same_pad(x, 3, self.stride) if self.stride > 1 else F.pad(x, (1, 1, 1, 1)) + return F.relu(self.pointwise(self.depthwise(x)) + residual) + + +class BlazeFace(nn.Module): + """Short-range BlazeFace: (B, 3, 128, 128) in [-1, 1] → 896 anchors x 17.""" + + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + self.stem = ops.Conv2d(3, 24, 5, stride=2, padding=0, bias=True, **kw) + self.blocks = nn.ModuleList(BlazeFaceBlock(i, o, s, device=device, dtype=dtype, operations=operations) + for (i, o, s) in _BLAZEFACE_BLOCKS) + # 16²x2 + 8²x6 = 512 + 384 = 896 anchors. + self.cls_16 = ops.Conv2d(88, 2, 1, padding=0, bias=True, **kw) + self.cls_8 = ops.Conv2d(96, 6, 1, padding=0, bias=True, **kw) + self.reg_16 = ops.Conv2d(88, 32, 1, padding=0, bias=True, **kw) + self.reg_8 = ops.Conv2d(96, 96, 1, padding=0, bias=True, **kw) + + def forward(self, image_chw_normalized: Tensor) -> tuple[Tensor, Tensor]: + x = F.relu(self.stem(_tf_same_pad(image_chw_normalized, 5, 2))) + # 16x16 tap is block-10 output (before the 88→96 stride-2 in block 11). + for i in range(11): + x = self.blocks[i](x) + feat_16 = x + for i in range(11, 16): + x = self.blocks[i](x) + feat_8 = x + + def flat(t, a, k): # NHWC flatten → (B, H*W*A, K) + B, _, H, W = t.shape + return t.permute(0, 2, 3, 1).reshape(B, H * W * a, k) + + cls = torch.cat([flat(self.cls_16(feat_16), 2, 1), flat(self.cls_8(feat_8), 6, 1)], dim=1) + reg = torch.cat([flat(self.reg_16(feat_16), 2, 16), flat(self.reg_8(feat_8), 6, 16)], dim=1) + return reg, cls + + +# BlazeFace full-range (face_detection_full_range_sparse.tflite): MobileNetV2-ish +# backbone + top-down FPN, 192² input → 2304 anchors at the 48x48 grid. +class FRBlock(nn.Module): + """Double inverted residual: DW → PW(mid) → DW → PW(out) [+ residual]. + + Per source tflite: dw* have no fused activation, pw1 is always ReLU, pw2 + is ReLU only when no residual (else ReLU fuses into the ADD). + """ + + def __init__(self, in_ch: int, mid_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + self.has_residual = (in_ch == out_ch and stride == 1) + self.dw1 = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, **kw) + self.pw1 = ops.Conv2d(in_ch, mid_ch, 1, padding=0, bias=True, **kw) + self.dw2 = ops.Conv2d(mid_ch, mid_ch, 3, stride=1, padding=0, groups=mid_ch, bias=True, **kw) + self.pw2 = ops.Conv2d(mid_ch, out_ch, 1, padding=0, bias=True, **kw) + + def forward(self, x: Tensor) -> Tensor: + residual = x if self.has_residual else None + x = F.relu(self.pw1(self.dw1(F.pad(x, (1, 1, 1, 1))))) + x = self.pw2(self.dw2(F.pad(x, (1, 1, 1, 1)))) + return F.relu(x + residual) if residual is not None else F.relu(x) + + +# (in_ch, mid_ch, out_ch, stride). Stages downsample 96²x32 → 48²x64 → 24²x128 +# → 12²x192 → 6²x384. Lateral taps at indices 4, 7, 10 (see _FR_LATERAL_*). +_FR_BACKBONE_BLOCKS = [ + (32, 8, 32, 1), (32, 8, 32, 1), # 96²x32 + (32, 16, 64, 2), (64, 16, 64, 1), (64, 16, 64, 1), # 48²x64 — tap[0] + (64, 32, 128, 2), (128, 32, 128, 1), (128, 32, 128, 1), # 24²x128 — tap[1] + (128, 48, 192, 2), (192, 48, 192, 1), (192, 48, 192, 1), # 12²x192 — tap[2] + (192, 96, 384, 2), (384, 96, 384, 1), (384, 96, 384, 1), (384, 96, 384, 1), # 6²x384 +] +_FR_LATERAL_TAP_INDICES = (4, 7, 10) +_FR_LATERAL_CHANNELS = ((64, 48), (128, 64), (192, 96)) # (in, out) per side-conv + +# Decoder blocks per FPN level (after upsample-and-merge with the lateral). +_FR_DECODER_BLOCKS = [ + [(96, 48, 96, 1), (96, 48, 96, 1)], # 12²x96 + [(64, 32, 64, 1), (64, 32, 64, 1)], # 24²x64 + [(48, 24, 48, 1)], # 48²x48 — feeds the heads +] + + +def _dcr_depth_to_space(t: Tensor, r: int, c_out: int) -> Tensor: + """TF DEPTH_TO_SPACE in DCR layout (input channels = (i, j, c_out)). + pixel_shuffle uses CRD which permutes output channels for c_out > 1.""" + B_, _, H_, W_ = t.shape + t = t.reshape(B_, r, r, c_out, H_, W_) + t = t.permute(0, 3, 4, 1, 5, 2).contiguous() + return t.reshape(B_, c_out, H_ * r, W_ * r) + + +class BlazeFaceFullRange(nn.Module): + """Full-range face detector: (B, 3, 192, 192) in [-1, 1] → 2304 anchors x 17 values.""" + + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + mk_block = lambda i, m, o, s: FRBlock(i, m, o, s, device=device, dtype=dtype, operations=operations) + self.stem = ops.Conv2d(3, 32, 3, stride=2, padding=0, bias=True, **kw) + self.backbone = nn.ModuleList(mk_block(i, m, o, s) for (i, m, o, s) in _FR_BACKBONE_BLOCKS) + self.lateral_convs = nn.ModuleList(ops.Conv2d(i, o, 1, padding=0, bias=True, **kw) for (i, o) in _FR_LATERAL_CHANNELS) + self.top_conv = ops.Conv2d(384, 96, 1, padding=0, bias=True, **kw) + self.decoder_levels = nn.ModuleList( + nn.ModuleList(mk_block(i, m, o, s) for (i, m, o, s) in lvl) for lvl in _FR_DECODER_BLOCKS + ) + # 96→64 before 12→24, 64→48 before 24→48. + self.decoder_reduce_convs = nn.ModuleList([ + ops.Conv2d(96, 64, 1, padding=0, bias=True, **kw), + ops.Conv2d(64, 48, 1, padding=0, bias=True, **kw), + ]) + # Heads mix 2x2-cell info via DW-stride-2 + depth_to_space block_size=2. + self.cls_conv = ops.Conv2d(48, 4, 1, padding=0, bias=True, **kw) + self.cls_dw = ops.Conv2d(4, 4, 3, stride=2, padding=0, groups=4, bias=True, **kw) + self.reg_conv = ops.Conv2d(48, 64, 1, padding=0, bias=True, **kw) + self.reg_dw = ops.Conv2d(64, 64, 3, stride=2, padding=0, groups=64, bias=True, **kw) + + def forward(self, image_chw_normalized: Tensor) -> tuple[Tensor, Tensor]: + # Symmetric pad-1 throughout (full-range tflite uses explicit TF PAD, not SAME). + x = F.relu(self.stem(F.pad(image_chw_normalized, (1, 1, 1, 1)))) + tap_set = set(_FR_LATERAL_TAP_INDICES) + laterals: list[Tensor] = [] + for i, blk in enumerate(self.backbone): + x = blk(x) + if i in tap_set: + laterals.append(x) + + # top_conv / lateral_convs / decoder_reduce_convs all have fused ReLU in the tflite. + p = F.relu(self.top_conv(x)) + laterals_rev = list(reversed(laterals)) + lateral_convs_rev = list(reversed(self.lateral_convs)) + for level in range(len(self.decoder_levels)): + lateral = laterals_rev[level] + p = F.interpolate(p, size=lateral.shape[-2:], mode="bilinear", align_corners=False) + p = p + F.relu(lateral_convs_rev[level](lateral)) + for blk in self.decoder_levels[level]: + p = blk(p) + if level < len(self.decoder_reduce_convs): + p = F.relu(self.decoder_reduce_convs[level](p)) + + c = self.cls_dw(F.pad(self.cls_conv(p), (1, 1, 1, 1))) + c = _dcr_depth_to_space(c, r=2, c_out=1) + r = self.reg_dw(F.pad(self.reg_conv(p), (1, 1, 1, 1))) + r = _dcr_depth_to_space(r, r=2, c_out=16) + B = c.shape[0] + cls_out = c.permute(0, 2, 3, 1).reshape(B, _BF_FR_NUM_ANCHORS, 1) + reg_out = r.permute(0, 2, 3, 1).reshape(B, _BF_FR_NUM_ANCHORS, 16) + return reg_out, cls_out + + +@lru_cache(maxsize=1) +def _blazeface_full_range_anchors() -> np.ndarray: + """2304 anchors over 48x48; anchor_w=anchor_h=1 (fixed_anchor_size).""" + feat = _BF_FR_GRID + yy, xx = np.meshgrid(np.arange(feat, dtype=np.float32), np.arange(feat, dtype=np.float32), indexing="ij") + cx, cy, ones = (xx + 0.5) / feat, (yy + 0.5) / feat, np.ones_like(xx) + return np.stack([cx, cy, ones, ones], axis=-1).reshape(_BF_FR_NUM_ANCHORS, 4) + + +def _decode_blazeface_full_range(regressors: np.ndarray, classificators: np.ndarray, + score_thresh: float = _BF_MIN_SCORE) -> np.ndarray: + """Same decode as short-range with 2304-anchor grid and box_scale=192.""" + scores = expit(np.clip(classificators[:, 0], -_BF_FR_SCORE_CLIP, _BF_FR_SCORE_CLIP)) + keep = scores >= score_thresh + if not keep.any(): + return np.empty((0, 17), dtype=np.float32) + r = regressors[keep] / _BF_FR_BOX_SCALE + a = _blazeface_full_range_anchors()[keep] + cxs, cys, aws, ahs = a[:, 0:1], a[:, 1:2], a[:, 2:3], a[:, 3:4] + xc, yc = r[:, 0:1] * aws + cxs, r[:, 1:2] * ahs + cys + w, h = r[:, 2:3] * aws, r[:, 3:4] * ahs + out = np.empty((r.shape[0], 17), dtype=np.float32) + out[:, 0:1], out[:, 1:2], out[:, 2:3], out[:, 3:4] = xc - w / 2, yc - h / 2, xc + w / 2, yc + h / 2 + out[:, 4:16:2] = r[:, _BF_KP_OFFSET::2] * aws + cxs + out[:, 5:16:2] = r[:, _BF_KP_OFFSET + 1::2] * ahs + cys + out[:, 16] = scores[keep] + return out + + +# FaceMesh (face_landmarks_detector.tflite): PReLU variant of BlazeBlock, +# 17 blocks, heads for 478x3 landmarks + presence. +_FACEMESH_BLOCKS = [ # (in_ch, out_ch, stride) + (16, 16, 1), (16, 16, 1), (16, 32, 2), (32, 32, 1), (32, 32, 1), (32, 64, 2), + (64, 64, 1), (64, 64, 1), (64, 128, 2), (128, 128, 1), (128, 128, 1), (128, 128, 2), + (128, 128, 1), (128, 128, 1), (128, 128, 2), (128, 128, 1), (128, 128, 1), +] + + +class FaceMeshBlock(nn.Module): + """PReLU BlazeBlock: PReLU between DW and PW, and after the residual add.""" + + def __init__(self, in_ch: int, out_ch: int, stride: int, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + self.in_ch, self.out_ch, self.stride = in_ch, out_ch, stride + self.depthwise = ops.Conv2d(in_ch, in_ch, 3, stride=stride, padding=0, groups=in_ch, bias=True, **kw) + self.prelu_dwise = nn.PReLU(num_parameters=in_ch, **kw) + self.pointwise = ops.Conv2d(in_ch, out_ch, 1, padding=0, bias=True, **kw) + self.prelu_out = nn.PReLU(num_parameters=out_ch, **kw) + + def forward(self, x: Tensor) -> Tensor: + residual = F.max_pool2d(x, 2, 2) if self.stride > 1 else x + if self.out_ch > self.in_ch: + residual = F.pad(residual, (0, 0, 0, 0, 0, self.out_ch - self.in_ch)) + x = _tf_same_pad(x, 3, self.stride) if self.stride > 1 else F.pad(x, (1, 1, 1, 1)) + return self.prelu_out(self.pointwise(self.prelu_dwise(self.depthwise(x))) + residual) + + +class FaceMesh(nn.Module): + NUM_LANDMARKS = 478 + + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + self.stem = ops.Conv2d(3, 16, 3, stride=2, padding=0, bias=True, **kw) + self.prelu_stem = nn.PReLU(num_parameters=16, **kw) + self.blocks = nn.ModuleList(FaceMeshBlock(i, o, s, device=device, dtype=dtype, operations=operations) + for (i, o, s) in _FACEMESH_BLOCKS) + self.head_reduce = ops.Conv2d(128, 8, 1, padding=0, bias=True, **kw) + self.prelu_head_reduce = nn.PReLU(num_parameters=8, **kw) + self.head_block = FaceMeshBlock(8, 8, 1, device=device, dtype=dtype, operations=operations) + self.head_presence = ops.Conv2d(8, 1, 3, padding=0, bias=True, **kw) + self.head_landmarks = ops.Conv2d(8, self.NUM_LANDMARKS * 3, 3, padding=0, bias=True, **kw) + + def forward(self, face_chw_normalized: Tensor) -> tuple[Tensor, Tensor]: + """(B, 3, 192, 192) in [0, 1] → ((B, 478, 3) landmarks in 192-canonical, (B,) presence).""" + x = self.prelu_stem(self.stem(_tf_same_pad(face_chw_normalized, 3, 2))) + for blk in self.blocks: + x = blk(x) + x = self.prelu_head_reduce(self.head_reduce(x)) + x = self.head_block(x) + B = x.shape[0] + presence = self.head_presence(x).reshape(B) + lmks = self.head_landmarks(x).reshape(B, self.NUM_LANDMARKS, 3) + return lmks, presence + + +# FaceBlendshapes (MLP-Mixer "GhumMarkerPoserMlpMixerGeneral"): +# 146x2 → token-reduce 146→96 → embed 2→64 → +cls token → 4x mixer → cls→52. +_BS_NUM_INPUT_LANDMARKS = 146 +_BS_NUM_TOKENS_REDUCED = 96 +_BS_NUM_TOKENS = 97 # +1 cls +_BS_TOKEN_DIM = 64 +_BS_TOKEN_MIX_HIDDEN = 384 +_BS_CHANNEL_MIX_HIDDEN = 256 +_BS_NUM_BLENDSHAPES = 52 +_BS_LN_EPS = 1e-6 + + +class MlpMixerBlock(nn.Module): + """MLP-Mixer block: token-mixing MLP (over tokens) → channel-mixing MLP (over dim). + Both pre-LN, both residual. LN has no beta (bias=False) to match MP.""" + + def __init__(self, num_tokens: int, token_dim: int, token_hidden: int, channel_hidden: int, + device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + # bias=False → no LN beta (matches MP). + self.ln1 = ops.LayerNorm(token_dim, eps=_BS_LN_EPS, bias=False, **kw) + self.ln2 = ops.LayerNorm(token_dim, eps=_BS_LN_EPS, bias=False, **kw) + self.token_mlp1 = ops.Linear(num_tokens, token_hidden, bias=True, **kw) + self.token_mlp2 = ops.Linear(token_hidden, num_tokens, bias=True, **kw) + self.channel_mlp1 = ops.Linear(token_dim, channel_hidden, bias=True, **kw) + self.channel_mlp2 = ops.Linear(channel_hidden, token_dim, bias=True, **kw) + + def forward(self, x: Tensor) -> Tensor: + y = self.ln1(x).transpose(1, 2) + x = x + self.token_mlp2(F.relu(self.token_mlp1(y))).transpose(1, 2) + return x + self.channel_mlp2(F.relu(self.channel_mlp1(self.ln2(x)))) + + +class FaceBlendshapes(nn.Module): + def __init__(self, device=None, dtype=None, operations=None): + super().__init__() + ops = operations if operations is not None else nn + kw = dict(device=device, dtype=dtype) + self.token_reduce = ops.Linear(_BS_NUM_INPUT_LANDMARKS, _BS_NUM_TOKENS_REDUCED, bias=True, **kw) + self.token_embed = ops.Linear(2, _BS_TOKEN_DIM, bias=True, **kw) + self.cls_token = nn.Parameter(torch.zeros(1, 1, _BS_TOKEN_DIM, **kw)) + self.blocks = nn.ModuleList( + MlpMixerBlock(_BS_NUM_TOKENS, _BS_TOKEN_DIM, _BS_TOKEN_MIX_HIDDEN, _BS_CHANNEL_MIX_HIDDEN, + device=device, dtype=dtype, operations=operations) for _ in range(4) + ) + self.head = ops.Linear(_BS_TOKEN_DIM, _BS_NUM_BLENDSHAPES, bias=True, **kw) + + @staticmethod + def _input_normalize(landmarks_2d: Tensor) -> Tensor: + # Centroid-subtract → L2 scale → x0.5. The 0.5 is baked into training. + centroid = landmarks_2d.mean(dim=1, keepdim=True) + x = landmarks_2d - centroid + mag = torch.sqrt((x * x).sum(dim=-1, keepdim=True)) + scale = mag.mean(dim=1, keepdim=True) + return (x / scale.clamp(min=1e-12)) * 0.5 + + def forward(self, landmarks_2d: Tensor) -> Tensor: + """(B, 146, 2) → (B, 52) in [0, 1]. Input units don't matter (centroid + L2 normalize).""" + x = self._input_normalize(landmarks_2d) + x = self.token_reduce(x.transpose(1, 2)).transpose(1, 2) + x = self.token_embed(x) + cls = self.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat([cls, x], dim=1) + for blk in self.blocks: + x = blk(x) + return torch.sigmoid(self.head(x[:, 0])) + + +@lru_cache(maxsize=1) +def _blazeface_anchors() -> np.ndarray: + """896 anchors per SsdAnchorsCalculator (fixed_anchor_size → anchor_w=anchor_h=1).""" + per_ar = len(_BF_ASPECT_RATIOS) + (1 if _BF_INTERP_SCALE_AR > 0 else 0) + layer_anchors: List[np.ndarray] = [] + layer = 0 + while layer < _BF_NUM_LAYERS: + stride = _BF_STRIDES[layer] + last = layer + while last < _BF_NUM_LAYERS and _BF_STRIDES[last] == stride: + last += 1 + per_cell = per_ar * (last - layer) + feat = (_BF_INPUT_SIZE + stride - 1) // stride + yy, xx = np.meshgrid(np.arange(feat, dtype=np.float32), np.arange(feat, dtype=np.float32), indexing="ij") + cx, cy, ones = (xx + _BF_ANCHOR_OFFSET_X) / feat, (yy + _BF_ANCHOR_OFFSET_Y) / feat, np.ones_like(xx) + cell = np.stack([cx, cy, ones, ones], axis=-1).reshape(-1, 4) + layer_anchors.append(np.repeat(cell, per_cell, axis=0)) + layer = last + out = np.concatenate(layer_anchors, axis=0) + assert out.shape == (896, 4), out.shape + return out + + +def _decode_blazeface(regressors: np.ndarray, classificators: np.ndarray, + score_thresh: float = _BF_MIN_SCORE) -> np.ndarray: + """Decode (regs (896,16), cls (896,1)) → (N, 17) = [xyxy, kp0x..kp5y, score] in [0, 1].""" + scores = expit(np.clip(classificators[:, 0], -_BF_SCORE_CLIP, _BF_SCORE_CLIP)) + keep = scores >= score_thresh + if not keep.any(): + return np.empty((0, 17), dtype=np.float32) + r = regressors[keep] / _BF_BOX_SCALE + a = _blazeface_anchors()[keep] # (N, 4) cx, cy, 1, 1 + cxs, cys, aws, ahs = a[:, 0:1], a[:, 1:2], a[:, 2:3], a[:, 3:4] + xc, yc = r[:, 0:1] * aws + cxs, r[:, 1:2] * ahs + cys + w, h = r[:, 2:3] * aws, r[:, 3:4] * ahs + out = np.empty((r.shape[0], 17), dtype=np.float32) + out[:, 0:1], out[:, 1:2], out[:, 2:3], out[:, 3:4] = xc - w / 2, yc - h / 2, xc + w / 2, yc + h / 2 + out[:, 4:16:2] = r[:, _BF_KP_OFFSET::2] * aws + cxs + out[:, 5:16:2] = r[:, _BF_KP_OFFSET + 1::2] * ahs + cys + out[:, 16] = scores[keep] + return out + + +def _weighted_nms(detections: np.ndarray, iou_thresh: float = 0.5) -> np.ndarray: + """MP weighted NMS — kept boxes are score-weighted averages of overlapping detections.""" + if detections.shape[0] == 0: + return detections + dets = detections[np.argsort(-detections[:, 16])] + N = dets.shape[0] + areas = np.clip(dets[:, 2] - dets[:, 0], 0, None) * np.clip(dets[:, 3] - dets[:, 1], 0, None) + kept: List[np.ndarray] = [] + used = np.zeros(N, dtype=bool) + for i in range(N): + if used[i]: + continue + ax1, ay1, ax2, ay2 = dets[i, 0:4] + merge_idx = [i] + for j in range(i + 1, N): + if used[j]: + continue + bx1, by1, bx2, by2 = dets[j, 0:4] + iw = max(0.0, min(ax2, bx2) - max(ax1, bx1)) + ih = max(0.0, min(ay2, by2) - max(ay1, by1)) + inter = iw * ih + union = areas[i] + areas[j] - inter + if union > 0 and inter / union > iou_thresh: # strict > matches MP + merge_idx.append(j) + used[j] = True + used[i] = True + cluster = dets[merge_idx] + ws = cluster[:, 16:17] + ws_sum = ws.sum() + merged = np.copy(cluster[0]) + if ws_sum > 0: + merged[:16] = (cluster[:, :16] * ws).sum(axis=0) / ws_sum + kept.append(merged) + return np.stack(kept, axis=0) if kept else np.empty((0, 17), dtype=np.float32) + + +def _detection_to_face_rect(detection: np.ndarray, image_w: int, image_h: int) -> Tuple[float, float, float, float, float]: + """Detection (normalized) → rotated 1.5xbbox ROI in image pixels (anisotropic).""" + xmin, ymin, xmax, ymax = detection[0:4] + lx = detection[4 + _FACE_LEFT_EYE_KP * 2 + 0] * image_w + ly = detection[4 + _FACE_LEFT_EYE_KP * 2 + 1] * image_h + rx = detection[4 + _FACE_RIGHT_EYE_KP * 2 + 0] * image_w + ry = detection[4 + _FACE_RIGHT_EYE_KP * 2 + 1] * image_h + # Image-y-down convention: angle = target - atan2(-dy, dx). + angle = _FACE_ROI_TARGET_ANGLE - math.atan2(ly - ry, rx - lx) + return (float((xmin + xmax) * 0.5 * image_w), + float((ymin + ymax) * 0.5 * image_h), + float((xmax - xmin) * image_w * _FACE_ROI_SCALE_X), + float((ymax - ymin) * image_h * _FACE_ROI_SCALE_Y), + float(angle)) + + +def _sample_warp(image_chw: Tensor, src_x: Tensor, src_y: Tensor, padding_mode: str) -> Tensor: + """Bilinear-sample image_chw at corner-aligned (src_x, src_y).""" + H, W = int(image_chw.shape[-2]), int(image_chw.shape[-1]) + grid = torch.stack([(2.0 * src_x + 1.0) / W - 1.0, + (2.0 * src_y + 1.0) / H - 1.0], dim=-1).unsqueeze(0) + return F.grid_sample(image_chw.unsqueeze(0), grid, mode="bilinear", + align_corners=False, padding_mode=padding_mode).squeeze(0) + + +def _warp_face_crop(image_chw: Tensor, cx: float, cy: float, width: float, height: float, + angle: float, output_size: int = _FM_INPUT_SIZE) -> Tensor: + """Rotated rect → output_size² with BORDER_REPLICATE. image_chw must be in [0, 1].""" + s_x, s_y = width / output_size, height / output_size + cos_a, sin_a = math.cos(angle), math.sin(angle) + arange = torch.arange(output_size, dtype=image_chw.dtype, device=image_chw.device) - output_size * 0.5 + v_grid, u_grid = torch.meshgrid(arange, arange, indexing="ij") + src_x = cx + u_grid * s_x * cos_a - v_grid * s_y * sin_a + src_y = cy + u_grid * s_x * sin_a + v_grid * s_y * cos_a + return _sample_warp(image_chw, src_x, src_y, "border") + + +def _blazeface_input_warp(image_chw_raw: Tensor, target: int = _BF_INPUT_SIZE) -> Tuple[Tensor, float, float, float]: + """Centered max(W,H) square → target² with BORDER_ZERO + [-1, 1] norm. + + Sub-pixel grid_sample matters; integer-pad-then-resize drifts the bbox ~5%. + Returns (warped, sub_rect_cx, sub_rect_cy, sub_rect_size) — the triplet maps + tensor-normalized [0,1] detections back to image pixels. + """ + H, W = int(image_chw_raw.shape[1]), int(image_chw_raw.shape[2]) + sub_rect_size = float(max(W, H)) + sub_rect_cx, sub_rect_cy = W * 0.5, H * 0.5 + s = sub_rect_size / target + arange = torch.arange(target, dtype=image_chw_raw.dtype, device=image_chw_raw.device) - target * 0.5 + v_grid, u_grid = torch.meshgrid(arange, arange, indexing="ij") + out = _sample_warp(image_chw_raw, sub_rect_cx + u_grid * s, sub_rect_cy + v_grid * s, "zeros") + return (out / 127.5) - 1.0, sub_rect_cx, sub_rect_cy, sub_rect_size + + +class FaceLandmarker(nn.Module): + """BlazeFace → FaceMesh v2 → blendshapes. `detector_variant` selects 'short' + (128², ≤2m) or 'full' (192² FPN, ≤5m). State dict uses inner-module prefixes + `detector.*` / `mesh.*` / `blendshapes.*`; the outer FaceLandmarkerModel + wrapper rewrites `detector_{variant}.*` keys to `detector.*` before loading. + """ + + def __init__(self, device=None, dtype=None, operations=None, detector_variant: str = "short"): + super().__init__() + det_cls = {"short": BlazeFace, "full": BlazeFaceFullRange}.get(detector_variant) + + self.detector_variant = detector_variant + self.detector = det_cls(device=device, dtype=dtype, operations=operations) + self.mesh = FaceMesh(device=device, dtype=dtype, operations=operations) + self.blendshapes = FaceBlendshapes(device=device, dtype=dtype, operations=operations) + self.register_buffer("_bs_idx", torch.tensor(_BS_INPUT_INDICES, dtype=torch.long), persistent=False) + + def run_detector_batch(self, images_rgb_uint8: List[np.ndarray], + score_thresh: float = _BF_MIN_SCORE, + iou_thresh: float = 0.5): + """Batched detector pass. Returns (img_raws, sub_rects, sizes, per_frame_decoded) + where per_frame_decoded[b] is (N, 17) in tensor-normalized [0,1] coords.""" + if not images_rgb_uint8: + return [], [], [], [] + device, dtype = self.detector.stem.weight.device, self.detector.stem.weight.dtype + det_input_size, decode_fn = ((_BF_FR_INPUT_SIZE, _decode_blazeface_full_range) + if self.detector_variant == "full" + else (_BF_INPUT_SIZE, _decode_blazeface)) + + # Same-size frames: stack once and transfer once. Variable size falls back + # to per-image (only triggers for SAM3DBody's head crops). + sizes = [tuple(img.shape[:2]) for img in images_rgb_uint8] + if len(set(sizes)) == 1: + batch_chw = torch.from_numpy(np.stack(images_rgb_uint8, axis=0)).to(device, dtype).movedim(-1, -3).contiguous() + img_raws = [batch_chw[bi] for bi in range(batch_chw.shape[0])] + else: + img_raws = [torch.from_numpy(img).to(device, dtype).movedim(-1, -3).contiguous() for img in images_rgb_uint8] + + warps = [_blazeface_input_warp(img_raw, det_input_size) for img_raw in img_raws] + det_crops = [w[0] for w in warps] + sub_rects = [(w[1], w[2], w[3]) for w in warps] + + regs_b, cls_b = self.detector(torch.stack(det_crops, dim=0)) + regs_np, cls_np = regs_b.float().cpu().numpy(), cls_b.float().cpu().numpy() + per_frame = [] + for b in range(len(images_rgb_uint8)): + decoded = decode_fn(regs_np[b], cls_np[b], score_thresh=score_thresh) + per_frame.append(_weighted_nms(decoded, iou_thresh=iou_thresh) if decoded.shape[0] > 0 else decoded) + return img_raws, sub_rects, sizes, per_frame + + def detect_batch(self, images_rgb_uint8: List[np.ndarray], num_faces: int = 1, + score_thresh: float = _BF_MIN_SCORE) -> List[List[dict]]: + """Full pipeline batched across `images_rgb_uint8`. Returns one face-dict + list per image (empty if nothing detected). Face dict: + bbox_xyxy (4,) image pixels, blendshapes {52} ∈ [0,1], + landmarks_xy (478, 2) image pixels, landmarks_3d (478, 3) in + 192-canonical (pre-transformation) units, presence float (raw logit). + """ + img_raws, sub_rects, sizes, per_frame_dets = self.run_detector_batch( + images_rgb_uint8, score_thresh=score_thresh, + ) + # tensor-normalized → image-normalized [0,1] for _detection_to_face_rect. + for b, decoded in enumerate(per_frame_dets): + if decoded.shape[0] == 0: + continue + cx, cy, size = sub_rects[b] + H, W = sizes[b] + sx0, sy0 = cx - size * 0.5, cy - size * 0.5 + decoded[:, 0:16:2] = (sx0 + size * decoded[:, 0:16:2]) / W + decoded[:, 1:16:2] = (sy0 + size * decoded[:, 1:16:2]) / H + if num_faces > 0: + per_frame_dets[b] = decoded[: int(num_faces)] + + # Collect every detected face across all frames into one mesh input. + face_params: List[Tuple[int, float, float, float, float, float, float]] = [] + mesh_crops: List[Tensor] = [] + for b, dets in enumerate(per_frame_dets): + if dets.shape[0] == 0: + continue + H, W = sizes[b] + img_for_mesh = img_raws[b] / 255.0 + for det in dets: + cx, cy, w, h, angle = _detection_to_face_rect(det, W, H) + mesh_crops.append(_warp_face_crop(img_for_mesh, cx, cy, w, h, angle, _FM_INPUT_SIZE)) + face_params.append((b, float(det[16]), cx, cy, w, h, angle)) + + results: List[List[dict]] = [[] for _ in range(len(images_rgb_uint8))] + if not mesh_crops: + return results + + lmks_canon_b, presence_b = self.mesh(torch.stack(mesh_crops, dim=0)) + bs_out_b = self.blendshapes(lmks_canon_b[:, self._bs_idx, :2]) + + # Batched canonical→image affine + params_t = torch.tensor( + [(cx, cy, w, h, math.cos(a), math.sin(a)) for (_b, _s, cx, cy, w, h, a) in face_params], + device=lmks_canon_b.device, dtype=lmks_canon_b.dtype, + ) + cxs, cys, ws, hs, cos_a, sin_a = params_t.unbind(dim=1) + inv = 1.0 / _FM_INPUT_SIZE + u = lmks_canon_b[..., 0] - _FM_INPUT_SIZE * 0.5 + v = lmks_canon_b[..., 1] - _FM_INPUT_SIZE * 0.5 + lmks_xy_t = torch.stack([ + cxs[:, None] + u * (ws * inv * cos_a)[:, None] - v * (hs * inv * sin_a)[:, None], + cys[:, None] + u * (ws * inv * sin_a)[:, None] + v * (hs * inv * cos_a)[:, None], + ], dim=-1) + + lmks_xy_np = lmks_xy_t.float().cpu().numpy() + lmks_canon_np = lmks_canon_b.float().cpu().numpy() + presence_np = presence_b.float().cpu().numpy() + bs_np = bs_out_b.float().cpu().numpy() + + for i, (b, score, *_) in enumerate(face_params): + lmks_xy = lmks_xy_np[i] + mn, mx = lmks_xy.min(0), lmks_xy.max(0) + results[b].append({ + "bbox_xyxy": np.array([mn[0], mn[1], mx[0], mx[1]], dtype=np.float32), + "blendshapes": dict(zip(BLENDSHAPE_NAMES, bs_np[i].tolist())), + "landmarks_xy": lmks_xy, + "landmarks_3d": lmks_canon_np[i], + "presence": float(presence_np[i]), + "score": score, + }) + return results diff --git a/comfy_extras/nodes_mediapipe.py b/comfy_extras/nodes_mediapipe.py new file mode 100644 index 000000000..2e67ae83f --- /dev/null +++ b/comfy_extras/nodes_mediapipe.py @@ -0,0 +1,502 @@ +"""ComfyUI nodes for the pure-PyTorch MediaPipe Face Landmarker port. + +Custom IO types: + FACE_LANDMARKER — FaceLandmarkerModel wrapper (ModelPatcher inside) + FACE_LANDMARKS — {"frames": List[List[face_dict]], "image_size": (H, W), + "connection_sets": dict[str, frozenset[(int, int)]]} + face_dict: bbox_xyxy, blendshapes, landmarks_xy, + landmarks_3d, presence, score, transformation_matrix + +MediaPipeFaceLandmarker also emits the core BOUNDING_BOX type — pair with DrawBBoxes. +""" + +from __future__ import annotations + +import numpy as np +import torch +from PIL import Image, ImageColor, ImageDraw +from tqdm.auto import tqdm +from typing_extensions import override + +import comfy.model_management +import comfy.model_patcher +import comfy.utils +import folder_paths +from comfy_api.latest import ComfyExtension, io + +from comfy_extras.mediapipe.face_landmarker import FaceLandmarker +from comfy_extras.mediapipe.face_geometry import transformation_matrix_from_detection + + +FaceLandmarkerType = io.Custom("FACE_LANDMARKER") +FaceLandmarksType = io.Custom("FACE_LANDMARKS") + +_CANONICAL_KEYS = ("canonical_vertices", "procrustes_indices", "procrustes_weights") +_CONTOUR_PARTS = ("face_oval", "left_eye", "right_eye", "left_eyebrow", "right_eyebrow", "lips") + + +class FaceLandmarkerModel: + """Loaded FaceLandmarker variants + ModelPatcher per variant. + + Safetensors layout: `detector_short.*` / `detector_full.*` plus shared + `mesh.*`, `blendshapes.*`, `canonical_*`, and `topology.*`. + PReLU forces plain-nn / fp32 (manual_cast strands buffers across devices). + """ + + def __init__(self, state_dict: dict): + self.load_device = comfy.model_management.text_encoder_device() + offload_device = comfy.model_management.text_encoder_offload_device() + self.dtype = torch.float32 + + # FACEMESH_* connection sets, embedded as int32 (N, 2) under topology.*. + base: dict[str, frozenset] = {} + for k in [k for k in state_dict if k.startswith("topology.")]: + base[k[len("topology."):]] = frozenset(map(tuple, state_dict.pop(k).tolist())) + base["contours"] = frozenset().union(*(base[p] for p in _CONTOUR_PARTS)) + base["all"] = base["contours"] | base["irises"] | base["nose"] + + self.connection_sets: dict[str, frozenset] = base + self.canonical_data: dict[str, np.ndarray] = {k: state_dict.pop(k).numpy() for k in _CANONICAL_KEYS} + + shared = {k: v for k, v in state_dict.items() if k.startswith(("mesh.", "blendshapes."))} + + self.models: dict[str, FaceLandmarker] = {} + self.patchers: dict[str, comfy.model_patcher.ModelPatcher] = {} + for variant in ("short", "full"): + prefix = f"detector_{variant}." + sub = dict(shared) + sub.update({f"detector.{k[len(prefix):]}": v for k, v in state_dict.items() if k.startswith(prefix)}) + fl = FaceLandmarker(device=offload_device, dtype=self.dtype, operations=None, detector_variant=variant).eval() + fl.load_state_dict(sub, strict=False) + + self.models[variant] = fl + self.patchers[variant] = comfy.model_patcher.CoreModelPatcher( + fl, load_device=self.load_device, offload_device=offload_device, + size=comfy.model_management.module_size(fl), + ) + + def detect_batch(self, images, num_faces: int, score_thresh: float, variant: str): + comfy.model_management.load_model_gpu(self.patchers[variant]) + return self.models[variant].detect_batch(images, num_faces=num_faces, score_thresh=score_thresh) + + +def _image_to_uint8(image: torch.Tensor) -> np.ndarray: + return image[..., :3].mul(255.0).add_(0.5).clamp_(0, 255).to(torch.uint8).cpu().numpy() + + +def _parse_color(color: str) -> tuple[int, int, int]: + try: + return ImageColor.getrgb(color)[:3] + except ValueError: + return (0, 255, 0) + + +def _copy_face(face: dict) -> dict: + """Shallow copy of a face_dict with array-fields cloned so callers can mutate.""" + return { + "bbox_xyxy": face["bbox_xyxy"].copy(), + "blendshapes": dict(face["blendshapes"]), + "landmarks_xy": face["landmarks_xy"].copy(), + "landmarks_3d": face["landmarks_3d"].copy(), + "presence": face["presence"], + "score": face["score"], + } + + +def _lerp_face(a: dict, b: dict, t: float) -> dict: + return { + "bbox_xyxy": (1 - t) * a["bbox_xyxy"] + t * b["bbox_xyxy"], + "blendshapes": {k: (1 - t) * a["blendshapes"][k] + t * b["blendshapes"][k] for k in a["blendshapes"]}, + "landmarks_xy": (1 - t) * a["landmarks_xy"] + t * b["landmarks_xy"], + "landmarks_3d": (1 - t) * a["landmarks_3d"] + t * b["landmarks_3d"], + "presence": (1 - t) * a["presence"] + t * b["presence"], + "score": (1 - t) * a["score"] + t * b["score"], + } + + +def _match_faces(a: list[dict], b: list[dict]) -> list[tuple[int, int]]: + """Greedy nearest-neighbour pairing of faces between two frames by bbox + centre distance. Unmatched (when counts differ) are dropped.""" + if not a or not b: + return [] + centers_a = np.array([(0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]), + 0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3])) for f in a]) + centers_b = np.array([(0.5 * (f["bbox_xyxy"][0] + f["bbox_xyxy"][2]), + 0.5 * (f["bbox_xyxy"][1] + f["bbox_xyxy"][3])) for f in b]) + dists = np.linalg.norm(centers_a[:, None] - centers_b[None], axis=-1) + pairs: list[tuple[int, int]] = [] + used_a: set[int] = set() + used_b: set[int] = set() + candidates = sorted((dists[ia, ib], ia, ib) for ia in range(len(a)) for ib in range(len(b))) + for _, ia, ib in candidates: + if ia in used_a or ib in used_b: + continue + pairs.append((ia, ib)) + used_a.add(ia) + used_b.add(ib) + return pairs + + +def _fill_missing_frames(frames: list[list[dict]], mode: str) -> None: + """In-place fill empty frame slots from neighbouring detections. Multi-face + aware: pairs faces across bracketing frames by greedy bbox-centre NN. + When counts differ, unmatched faces are dropped from the synthesised frame.""" + if mode == "empty": + return + valid = [i for i, fr in enumerate(frames) if fr] + if not valid: + return # nothing to fill from + if mode == "previous": + last: list[dict] = [] + for i, fr in enumerate(frames): + if fr: + last = fr + elif last: + frames[i] = [_copy_face(f) for f in last] + return + # interpolate: lerp between bracketing valid frames; clamp at ends. + for i in range(len(frames)): + if frames[i]: + continue + prev_i = max((v for v in valid if v < i), default=None) + next_i = min((v for v in valid if v > i), default=None) + if prev_i is None: + frames[i] = [_copy_face(f) for f in frames[next_i]] + elif next_i is None: + frames[i] = [_copy_face(f) for f in frames[prev_i]] + else: + t = (i - prev_i) / (next_i - prev_i) + pairs = _match_faces(frames[prev_i], frames[next_i]) + frames[i] = [_lerp_face(frames[prev_i][a], frames[next_i][b], t) for a, b in pairs] + + +def _ordered_rings(edges: frozenset[tuple[int, int]]) -> list[list[int]]: + """Walk an unordered edge set into one or more closed-loop vertex rings + (handles multi-loop sets like FACEMESH_LIPS: outer + inner).""" + adj: dict[int, set[int]] = {} + for a, b in edges: + adj.setdefault(a, set()).add(b) + adj.setdefault(b, set()).add(a) + visited: set[int] = set() + rings: list[list[int]] = [] + for start in adj: + if start in visited: + continue + ring = [start] + visited.add(start) + prev, cur = -1, start + while True: + nxt = next((v for v in adj[cur] if v != prev), None) + if nxt is None or nxt == start: + break + ring.append(nxt) + visited.add(nxt) + prev, cur = cur, nxt + rings.append(ring) + return rings + + +class LoadMediaPipeFaceLandmarker(io.ComfyNode): + """Load MediaPipe Face Landmarker v2 weights. Contains both detector variants + (short / full), shared mesh, blendshapes, and canonical geometry.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadMediaPipeFaceLandmarker", + display_name="Load MediaPipe Face Landmarker", + category="loaders", + inputs=[ + io.Combo.Input("model_name", options=folder_paths.get_filename_list("mediapipe"), + tooltip="Face Landmarker safetensors from models/mediapipe/."), + ], + outputs=[FaceLandmarkerType.Output()], + ) + + @classmethod + def execute(cls, model_name) -> io.NodeOutput: + sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("mediapipe", model_name), safe_load=True) + wrapper = FaceLandmarkerModel(sd) + return io.NodeOutput(wrapper) + + +# Per-frame fallback modes for detection failures in a batch. +_FALLBACK_MODES = ("empty", "previous", "interpolate") + + +class MediaPipeFaceLandmarker(io.ComfyNode): + """BlazeFace → FaceMesh v2 → ARKit-52 blendshapes, batched across the + input. Also emits a BOUNDING_BOX list (landmark-extent bbox per face) — + pair with DrawBBoxes for detector-only viz or MediaPipeFaceMeshVisualize + for the mesh overlay.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MediaPipeFaceLandmarker", + display_name="MediaPipe Face Landmarker", + category="image/detection", + inputs=[ + FaceLandmarkerType.Input("face_landmarker"), + io.Image.Input("image"), + io.Combo.Input("detector_variant", options=["short", "full", "both"], default="short", + tooltip="Face detector range. 'short' is tuned for close-up faces " + "(within ~2 m of the camera); 'full' covers farther / smaller " + "faces (up to ~5 m) but is slower. 'both' runs both detectors and " + "keeps whichever found more faces per frame (~2× detection cost)."), + io.Int.Input("num_faces", default=1, min=0, max=16, step=1, + tooltip="Maximum faces to return per frame. 0 = no cap (return all detected)."), + io.Float.Input("min_confidence", default=0.5, min=0.0, max=1.0, step=0.01, advanced=True, + tooltip="BlazeFace score threshold. Lower to catch small/occluded faces."), + io.Combo.Input("missing_frame_fallback", options=list(_FALLBACK_MODES), default="empty", advanced=True, + tooltip="Per-frame behaviour when detection fails in a batch. " + "'empty' leaves the frame faceless. 'previous' copies the most recent successful " + "detection. 'interpolate' lerps landmarks/bbox/blendshapes between bracketing " + "successful frames. Multi-face: pairs faces across frames by greedy bbox-centre NN."), + ], + outputs=[ + FaceLandmarksType.Output(display_name="face_landmarks"), + io.BoundingBox.Output("bboxes"), + ], + ) + + @classmethod + def execute(cls, face_landmarker, image, detector_variant, num_faces, min_confidence, + missing_frame_fallback) -> io.NodeOutput: + canonical = face_landmarker.canonical_data + img_np = _image_to_uint8(image) + B, H, W = img_np.shape[:3] + chunk = 16 + is_both = detector_variant == "both" + total_work = 2 * B if is_both else B + pbar = comfy.utils.ProgressBar(total_work) + + def _run(variant: str) -> list[list[dict]]: + res: list[list[dict]] = [] + with tqdm(total=B, desc=f"MediaPipe Face Landmarker ({variant})") as tq: + for i in range(0, B, chunk): + end = min(i + chunk, B) + res.extend(face_landmarker.detect_batch( + [img_np[bi] for bi in range(i, end)], + num_faces=int(num_faces), + score_thresh=float(min_confidence), + variant=variant, + )) + pbar.update_absolute(min(pbar.current + (end - i), total_work)) + tq.update(end - i) + return res + + if is_both: + short_res = _run("short") + full_res = _run("full") + # Per-frame keep whichever found more faces (tie → short). + frames: list[list[dict]] = [ + short_res[bi] if len(short_res[bi]) >= len(full_res[bi]) else full_res[bi] + for bi in range(B) + ] + else: + frames = _run(detector_variant) + _fill_missing_frames(frames, missing_frame_fallback) + bboxes = [] + for per_frame in frames: + per_bb = [] + for f in per_frame: + f["transformation_matrix"] = transformation_matrix_from_detection(f, W, H, canonical) + x1, y1, x2, y2 = (float(v) for v in f["bbox_xyxy"]) + per_bb.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1, "label": "face", "score": float(f["score"])}) + bboxes.append(per_bb) + return io.NodeOutput({"frames": frames, "image_size": (H, W), + "connection_sets": face_landmarker.connection_sets}, bboxes) + + +# Topology keys unioned by the 'all' connections preset (contour parts + irises + nose). +_ALL_CONNECTION_PARTS: tuple[str, ...] = (*_CONTOUR_PARTS, "irises", "nose") +_CUSTOM_FEATURES: tuple[tuple[str, bool], ...] = ( + ("face_oval", True), + ("lips", True), + ("left_eye", True), + ("right_eye", True), + ("left_eyebrow", True), + ("right_eyebrow", True), + ("irises", True), + ("nose", True), + ("tesselation", False), +) + + +class MediaPipeFaceMeshVisualize(io.ComfyNode): + """Draw a FACEMESH_* subset over an image. Topology travels with the + FACE_LANDMARKS payload (set at detection time).""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MediaPipeFaceMeshVisualize", + display_name="MediaPipe Face Mesh Visualize", + category="image/detection", + inputs=[ + FaceLandmarksType.Input("face_landmarks"), + io.Image.Input("image", optional=True, tooltip="If not connected, a black canvas will be used."), + io.DynamicCombo.Input( + "connections", + tooltip="'all' = oval+eyes+brows+lips+irises+nose. 'fill' = solid face_oval polygon (silhouette mask). 'custom' = toggle each feature individually (including 'tesselation', the full 2547-edge wireframe).", + options=[ + io.DynamicCombo.Option("all", []), + io.DynamicCombo.Option("fill", []), + io.DynamicCombo.Option("custom", [ + io.Boolean.Input(feat, default=default, + tooltip=f"Draw the '{feat}' connection set.") + for feat, default in _CUSTOM_FEATURES + ]), + ], + ), + io.Color.Input("color", default="#00ff00"), + io.Int.Input("thickness", default=1, min=0, max=8, step=1, + tooltip="Edge line thickness in pixels. 0 disables edge drawing."), + io.Int.Input("point_size", default=2, min=0, max=16, step=1, + tooltip="Landmark dot radius in pixels. 0 disables point drawing."), + ], + outputs=[io.Image.Output()], + ) + + @classmethod + def execute(cls, face_landmarks, connections, color, thickness, point_size, image=None) -> io.NodeOutput: + sets = face_landmarks["connection_sets"] + sel = connections["connections"] + fill_rings: list[list[int]] | None = None + if sel == "fill": + fill_rings = _ordered_rings(sets["face_oval"]) + edges = frozenset() + elif sel == "custom": + parts = [feat for feat, _ in _CUSTOM_FEATURES if connections.get(feat, False)] + edges = frozenset().union(*(sets[p] for p in parts)) + else: # "all" + edges = frozenset().union(*(sets[p] for p in _ALL_CONNECTION_PARTS)) + rgb, thick, psize = _parse_color(color), int(thickness), int(point_size) + frames = face_landmarks["frames"] + if image is None: + H, W = face_landmarks["image_size"] + img_np = np.zeros((len(frames), H, W, 3), dtype=np.uint8) + else: + img_np = _image_to_uint8(image) + B = img_np.shape[0] + n_frames = len(frames) + pbar = comfy.utils.ProgressBar(B) + out = np.empty_like(img_np) + for bi in range(B): + faces = frames[bi] if bi < n_frames else [] + out[bi] = _draw_mesh(img_np[bi], faces, edges, rgb, thick, psize, fill_rings) + pbar.update_absolute(bi + 1) + return io.NodeOutput(torch.from_numpy(out).to( + device=comfy.model_management.intermediate_device(), + dtype=comfy.model_management.intermediate_dtype(), + ).div_(255.0)) + + +def _draw_mesh(image_rgb: np.ndarray, faces: list, edges, + rgb: tuple[int, int, int], thickness: int, + point_size: int, fill_rings: list[list[int]] | None = None) -> np.ndarray: + draw_edges = thickness > 0 and edges + if not faces or (fill_rings is None and not draw_edges and point_size <= 0): + return image_rgb.copy() + pil = Image.fromarray(image_rgb) + draw = ImageDraw.Draw(pil) + r = point_size * 0.5 + if fill_rings is not None: + for f in faces: + lmks = f["landmarks_xy"] + for ring in fill_rings: + draw.polygon([(float(lmks[i, 0]), float(lmks[i, 1])) for i in ring], fill=rgb) + return np.asarray(pil) + for f in faces: + lmks = f["landmarks_xy"] + n = lmks.shape[0] + if draw_edges: + for a, b in edges: + if a < n and b < n: + draw.line([(float(lmks[a, 0]), float(lmks[a, 1])), + (float(lmks[b, 0]), float(lmks[b, 1]))], fill=rgb, width=thickness) + if point_size == 1: + draw.point(lmks.flatten().tolist(), fill=rgb) + elif point_size > 1: + for x, y in lmks: + draw.ellipse((float(x) - r, float(y) - r, float(x) + r, float(y) + r), fill=rgb) + return np.asarray(pil) + + +# Mask region presets — closed-loop topologies only. +_MASK_REGIONS: tuple[str, ...] = ("face_oval", "lips", "left_eye", "right_eye", "irises") +_MASK_CUSTOM_FEATURES: tuple[tuple[str, bool], ...] = ( + ("face_oval", True), + ("lips", False), + ("left_eye", False), + ("right_eye", False), + ("irises", False), +) + + +class MediaPipeFaceMask(io.ComfyNode): + """Binary mask from face landmarks, filled polygon per face. One mask per + frame in the batch; faces in the same frame composite (union).""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MediaPipeFaceMask", + display_name="MediaPipe Face Mask", + category="image/detection", + inputs=[ + FaceLandmarksType.Input("face_landmarks"), + io.DynamicCombo.Input( + "regions", + tooltip="'all' = union of face_oval+lips+eyes+irises (which collapses to face_oval since it encloses the rest). 'custom' = toggle each region individually for combos like lips+eyes.", + options=[ + io.DynamicCombo.Option("all", []), + io.DynamicCombo.Option("custom", [ + io.Boolean.Input(reg, default=default, + tooltip=f"Include the '{reg}' region in the mask.") + for reg, default in _MASK_CUSTOM_FEATURES + ]), + ], + ), + ], + outputs=[io.Mask.Output()], + ) + + @classmethod + def execute(cls, face_landmarks, regions) -> io.NodeOutput: + sets = face_landmarks["connection_sets"] + sel = regions["regions"] + if sel == "custom": + picked = [reg for reg, _ in _MASK_CUSTOM_FEATURES if regions.get(reg, False)] + else: + picked = list(_MASK_REGIONS) + rings = [r for reg in picked for r in _ordered_rings(sets[reg])] + frames = face_landmarks["frames"] + H, W = face_landmarks["image_size"] + masks = np.zeros((len(frames), H, W), dtype=np.uint8) + pbar = comfy.utils.ProgressBar(len(frames)) + for bi, per_frame in enumerate(frames): + if per_frame: + pil = Image.new("L", (W, H), 0) + draw = ImageDraw.Draw(pil) + for f in per_frame: + lmks = f["landmarks_xy"] + for ring in rings: + draw.polygon([(float(lmks[i, 0]), float(lmks[i, 1])) for i in ring], fill=255) + masks[bi] = np.asarray(pil) + pbar.update_absolute(bi + 1) + return io.NodeOutput(torch.from_numpy(masks).to( + device=comfy.model_management.intermediate_device(), + dtype=comfy.model_management.intermediate_dtype(), + ).div_(255.0)) + + +class MediaPipeFaceExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [LoadMediaPipeFaceLandmarker, MediaPipeFaceLandmarker, MediaPipeFaceMeshVisualize, MediaPipeFaceMask] + + +async def comfy_entrypoint() -> MediaPipeFaceExtension: + return MediaPipeFaceExtension() diff --git a/folder_paths.py b/folder_paths.py index ad7f0f4fc..ce152eb37 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -60,6 +60,8 @@ folder_names_and_paths["geometry_estimation"] = ([os.path.join(models_dir, "geom folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions) +folder_names_and_paths["mediapipe"] = ([os.path.join(models_dir, "mediapipe")], supported_pt_extensions) + output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") input_directory = os.path.join(base_path, "input") diff --git a/models/mediapipe/put_mediapipe_models_here b/models/mediapipe/put_mediapipe_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/nodes.py b/nodes.py index fdd6eeb5f..13e46ac8a 100644 --- a/nodes.py +++ b/nodes.py @@ -2444,6 +2444,7 @@ async def init_builtin_extra_nodes(): "nodes_hidream_o1.py", "nodes_save_3d.py", "nodes_moge.py", + "nodes_mediapipe.py", ] import_failed = [] From 5aa5ccc9e02aec94cf43e0f71d4b2f62b204b5b6 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 21 May 2026 10:03:58 +1000 Subject: [PATCH 09/45] Multi-threaded load of models from disk (big load time speedups & Offload to disk) (CORE-43,CORE-152,CORE-164,CORE-165,CORE-117) (#13802) * model_management: disable non-dynamic smart memory Disable smart memory outright for non dynamic models. This is a minor step towards deprecation of --disable-dynamic-vram and the legacy ModelPatcher. This is needed for estimate-free model development, where new models can opt-out of supplying a memory estimate and not have to worry about hard VRAM allocations due to legacy non-dynamic model patchers This is also a general stability increase for a lot of stray use cases where estimates may still be off and going forward we are not going to accurately maintain such estimates. * pinned_memory: implement with aimdo growable buffer Use a single growable buffer so we can do threaded pre-warming on pinned memory. * mm: use aimdo to do transfer from disk to pin Aimdo implements a faster threaded loader. * Add stream host pin buffer for AIMDO casts Introduce per-offload-stream HostBuffer reuse for pinned staging, include it in cast buffer reset synchronization. Defer actual casts that go via this pin path to a separate pass such that the buffer can be allocated monolithically (to avoid cudaHostRegister thrash). * remove old pin path * Implement JIT pinned memory pressure Replace the predictive pin pressure mechanism with JIT PIN memory pressure. * LowVRAMPatch: change to two-phase visit * lora: re-implement as inplace swiss-army-knife operation * prepare for multiple pin sets * implement pinned loras * requirements: comfy-aimdo 0.4.0 * ops: remove unused arg This was defeatured in aimdo iteration * ops: sync the CPU with only the offload stream activity This was syncing with the offload stream which itself is synced with the compute stream, so this was syncing CPU with compute transitively. Define the event to sync it more gently. * pins: implement freeing intermediate for pinned memory Pinning is more important than inactive intermediates and the stream pin buffer is more important than even active intermediates. * execution: implement pin eviction on RAM presure Add back proper pin freeing on RAM pressure * implement pin registration swaps Uncap the windows pins from 50% by extending the pool and have a pressure mechanism to move the pin reservations om demand. This unfortunately implies a GPU sync to do the freeing so significant hysterisis needs to be added to consolidate these pressure events. * cli_args/execution: Implement lower background cache-ram threshold Limit the amount of RAM background intermediates can use, so that switching workflows doesn't degrade performance too much. * make default * bump aimdo * model-patcher: force-cast tiny weights Flux 2 gets crazy stalls due to a mix of tiny and giant weights creating lopsided steam buffer rotations which creates stalls. * ops: refactor in prep for chunking * mm: delegate pin-on-the-way to aimdo Aimdo is able to chunk and slice this on the way for better CPU->GPU overlap. The main advantage is the ability to shorten the bus contention window between previous weight transfer and the next weights vbar fault. * bump aimdo * pinning updates * specify hostbuf max allocation size There a signs of virtual memory exhaustion on some linux systems when throwing 128GB for every little piece. Pass the actual to save aimdo from over-estimates * tests: update execution tests for caching The default caching changed to ram-cache so update these tests accordingly. Remove the LRU 0 test as this also falls through to RAM cache. --- comfy/cli_args.py | 7 +- comfy/lora.py | 19 ++- comfy/memory_management.py | 24 +++- comfy/model_management.py | 189 +++++++++++++++++----------- comfy/model_patcher.py | 138 +++++++++++++++----- comfy/ops.py | 88 +++++++++++-- comfy/pinned_memory.py | 68 ++++++---- comfy/utils.py | 2 - comfy/windows.py | 52 -------- execution.py | 12 +- main.py | 20 +-- requirements.txt | 2 +- tests/execution/test_async_nodes.py | 3 +- tests/execution/test_execution.py | 3 +- 14 files changed, 408 insertions(+), 219 deletions(-) delete mode 100644 comfy/windows.py diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 76faed3ad..9d88c8517 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -110,13 +110,11 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.") -CACHE_RAM_AUTO_GB = -1.0 - cache_group = parser.add_mutually_exclusive_group() +cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metavar="GB", help="Use RAM pressure caching with the specified headroom thresholds. This is the default caching mode. The first value sets the active-cache threshold; the optional second value sets the inactive-cache/pin threshold. Defaults when no values are provided: active 25%% of system RAM (min 4GB, max 32GB), inactive 75%% of system RAM (min 12GB, max 96GB).") cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") -cache_group.add_argument("--cache-ram", nargs='?', const=CACHE_RAM_AUTO_GB, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threshold the cache removes large items to free RAM. Default (when no value is provided): 25%% of system RAM (min 4GB, max 32GB).") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") @@ -245,6 +243,9 @@ if comfy.options.args_parsing: else: args = parser.parse_args([]) +if args.cache_ram is not None and len(args.cache_ram) > 2: + parser.error("--cache-ram accepts at most two values: active GB and inactive GB") + if args.windows_standalone_build: args.auto_launch = True diff --git a/comfy/lora.py b/comfy/lora.py index f11e26ec9..c0e8b865c 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -484,16 +484,23 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori return weight -def prefetch_prepared_value(value, allocate_buffer, stream): +def prefetch_prepared_value(value, counter, destination, stream, copy): if isinstance(value, torch.Tensor): - dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value)) - comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) + size = comfy.memory_management.vram_aligned_size(value) + offset = counter[0] + counter[0] += size + if destination is None: + return value + + dest = destination[offset:offset + size] + if copy: + comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream) return comfy.memory_management.interpret_gathered_like([value], dest)[0] elif isinstance(value, weight_adapter.WeightAdapterBase): - return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream)) + return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, counter, destination, stream, copy)) elif isinstance(value, tuple): - return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value) + return tuple(prefetch_prepared_value(item, counter, destination, stream, copy) for item in value) elif isinstance(value, list): - return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value] + return [prefetch_prepared_value(item, counter, destination, stream, copy) for item in value] return value diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 48e3c11da..c43f0c4a2 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -15,7 +15,7 @@ class TensorFileSlice(NamedTuple): size: int -def read_tensor_file_slice_into(tensor, destination): +def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None): if isinstance(tensor, QuantizedTensor): if not isinstance(destination, QuantizedTensor): @@ -23,12 +23,17 @@ def read_tensor_file_slice_into(tensor, destination): if tensor._layout_cls != destination._layout_cls: return False - if not read_tensor_file_slice_into(tensor._qdata, destination._qdata): + if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream, + destination2=(destination2._qdata if destination2 is not None else None)): return False dst_orig_dtype = destination._params.orig_dtype destination._params.copy_from(tensor._params, non_blocking=False) destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype) + if destination2 is not None: + dst_orig_dtype = destination2._params.orig_dtype + destination2._params.copy_from(destination._params, non_blocking=True) + destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype) return True info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None) @@ -48,6 +53,17 @@ def read_tensor_file_slice_into(tensor, destination): if info.size == 0: return True + hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None) + if hostbuf is not None: + stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0 + device_ptr = destination2.data_ptr() if destination2 is not None else 0 + hostbuf.read_file_slice(file_obj, info.offset, info.size, + offset=destination.data_ptr() - hostbuf.get_raw_address(), + stream=stream_ptr, + device_ptr=device_ptr, + device=None if destination2 is None else destination2.device.index) + return True + buf_type = ctypes.c_ubyte * info.size view = memoryview(buf_type.from_address(destination.data_ptr())) @@ -151,7 +167,7 @@ def set_ram_cache_release_state(callback, headroom): extra_ram_release_callback = callback RAM_CACHE_HEADROOM = max(0, int(headroom)) -def extra_ram_release(target): +def extra_ram_release(target, free_active=False): if extra_ram_release_callback is None: return 0 - return extra_ram_release_callback(target) + return extra_ram_release_callback(target, free_active=free_active) diff --git a/comfy/model_management.py b/comfy/model_management.py index 21738a4c7..3894dfa9c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -31,6 +31,7 @@ from contextlib import nullcontext import comfy.memory_management import comfy.utils import comfy.quant_ops +import comfy_aimdo.host_buffer import comfy_aimdo.vram_buffer class VRAMState(Enum): @@ -495,6 +496,14 @@ except: current_loaded_models = [] +DIRTY_MMAPS = set() + +PIN_PRESSURE_HYSTERESIS = 256 * 1024 * 1024 + +#Freeing registerables on pressure does imply a GPU sync, so go big on +#the hysteresis so each expensive sync gives us back a good chunk. +REGISTERABLE_PIN_HYSTERESIS = 2048 * 1024 * 1024 + def module_size(module): module_mem = 0 sd = module.state_dict() @@ -503,27 +512,46 @@ def module_size(module): module_mem += t.nbytes return module_mem -def module_mmap_residency(module, free=False): - mmap_touched_mem = 0 - module_mem = 0 - bounced_mmaps = set() - sd = module.state_dict() - for k in sd: - t = sd[k] - module_mem += t.nbytes - storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage() - if not getattr(storage, "_comfy_tensor_mmap_touched", False): - continue - mmap_touched_mem += t.nbytes - if not free: - continue - storage._comfy_tensor_mmap_touched = False - mmap_obj = storage._comfy_tensor_mmap_refs[0] - if mmap_obj in bounced_mmaps: - continue - mmap_obj.bounce() - bounced_mmaps.add(mmap_obj) - return mmap_touched_mem, module_mem +def mark_mmap_dirty(storage): + mmap_refs = getattr(storage, "_comfy_tensor_mmap_refs", None) + if mmap_refs is not None: + DIRTY_MMAPS.add(mmap_refs[0]) + +def free_pins(size, evict_active=False): + freed_total = 0 + for loaded_model in reversed(current_loaded_models): + if size <= 0: + return freed_total + model = loaded_model.model + if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]): + freed = model.partially_unload_ram(size) + freed_total += freed + size -= freed + return freed_total + +def ensure_pin_budget(size, evict_active=False): + shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available + if shortfall <= 0: + return True + + to_free = shortfall + PIN_PRESSURE_HYSTERESIS + return free_pins(to_free, evict_active=evict_active) >= shortfall + +def ensure_pin_registerable(size, evict_active=False): + shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY + if MAX_PINNED_MEMORY <= 0: + return False + if shortfall <= 0: + return True + + shortfall += REGISTERABLE_PIN_HYSTERESIS + for loaded_model in reversed(current_loaded_models): + model = loaded_model.model + if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]): + shortfall -= model.unregister_inactive_pins(shortfall) + if shortfall <= 0: + return True + return shortfall <= REGISTERABLE_PIN_HYSTERESIS class LoadedModel: def __init__(self, model): @@ -553,9 +581,6 @@ class LoadedModel: def model_memory(self): return self.model.model_size() - def model_mmap_residency(self, free=False): - return self.model.model_mmap_residency(free=free) - def model_loaded_memory(self): return self.model.loaded_size() @@ -635,15 +660,9 @@ WINDOWS = any(platform.win32_ver()) EXTRA_RESERVED_VRAM = 400 * 1024 * 1024 if WINDOWS: - import comfy.windows EXTRA_RESERVED_VRAM = 600 * 1024 * 1024 #Windows is higher because of the shared vram issue if total_vram > (15 * 1024): # more extra reserved vram on 16GB+ cards EXTRA_RESERVED_VRAM += 100 * 1024 * 1024 - def get_free_ram(): - return comfy.windows.get_free_ram() -else: - def get_free_ram(): - return psutil.virtual_memory().available if args.reserve_vram is not None: EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024 @@ -657,7 +676,6 @@ def minimum_inference_memory(): def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0): cleanup_models_gc() - comfy.memory_management.extra_ram_release(max(pins_required, ram_required)) unloaded_model = [] can_unload = [] unloaded_models = [] @@ -673,11 +691,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins for x in can_unload_sorted: i = x[-1] memory_to_free = 1e32 - pins_to_free = 1e32 - if not DISABLE_SMART_MEMORY or device is None: + if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None): memory_to_free = 0 if device is None else memory_required - get_free_memory(device) - pins_to_free = pins_required - get_free_ram() - if current_loaded_models[i].model.is_dynamic() and for_dynamic: + if for_dynamic: #don't actually unload dynamic models for the sake of other dynamic models #as that works on-demand. memory_required -= current_loaded_models[i].model.loaded_size() @@ -685,18 +701,6 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") unloaded_model.append(i) - if pins_to_free > 0: - logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}") - current_loaded_models[i].model.partially_unload_ram(pins_to_free) - - for x in can_unload_sorted: - i = x[-1] - ram_to_free = ram_required - psutil.virtual_memory().available - if ram_to_free <= 0 and i not in unloaded_model: - continue - resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True) - if resident_memory > 0: - logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}") for i in sorted(unloaded_model, reverse=True): unloaded_models.append(current_loaded_models.pop(i)) @@ -762,29 +766,16 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu model_to_unload.model.detach(unpatch_all=False) model_to_unload.model_finalizer.detach() - total_memory_required = {} - total_pins_required = {} - total_ram_required = {} for loaded_model in models_to_load: device = loaded_model.device total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device) - resident_memory, model_memory = loaded_model.model_mmap_residency() - pinned_memory = loaded_model.model.pinned_memory_size() - #FIXME: This can over-free the pins as it budgets to pin the entire model. We should - #make this JIT to keep as much pinned as possible. - pins_required = model_memory - pinned_memory - ram_required = model_memory - resident_memory - total_pins_required[device] = total_pins_required.get(device, 0) + pins_required - total_ram_required[device] = total_ram_required.get(device, 0) + ram_required for device in total_memory_required: if device != torch.device("cpu"): free_memory(total_memory_required[device] * 1.1 + extra_mem, device, - for_dynamic=free_for_dynamic, - pins_required=total_pins_required[device], - ram_required=total_ram_required[device]) + for_dynamic=free_for_dynamic) for device in total_memory_required: if device != torch.device("cpu"): @@ -1180,6 +1171,7 @@ STREAM_CAST_BUFFERS = {} LARGEST_CASTED_WEIGHT = (None, 0) STREAM_AIMDO_CAST_BUFFERS = {} LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) +STREAM_PIN_BUFFERS = {} DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 @@ -1220,21 +1212,66 @@ def get_aimdo_cast_buffer(offload_stream, device): if cast_buffer is None: cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index) STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer - return cast_buffer + +def get_pin_buffer(offload_stream): + pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None) + if pin_buffer is None: + pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3)) + STREAM_PIN_BUFFERS[offload_stream] = pin_buffer + elif offload_stream is not None: + event = getattr(pin_buffer, "_comfy_event", None) + if event is not None: + event.synchronize() + delattr(pin_buffer, "_comfy_event") + return pin_buffer + +def resize_pin_buffer(pin_buffer, size): + global TOTAL_PINNED_MEMORY + old_size = pin_buffer.size + if size <= old_size: + return True + growth = size - old_size + comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM) + ensure_pin_budget(growth, evict_active=True) + ensure_pin_registerable(growth, evict_active=True) + try: + pin_buffer.extend(size=size, reallocate=True) + except RuntimeError: + return False + TOTAL_PINNED_MEMORY += pin_buffer.size - old_size + return True + def reset_cast_buffers(): + global TOTAL_PINNED_MEMORY global LARGEST_CASTED_WEIGHT global LARGEST_AIMDO_CASTED_WEIGHT LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) - for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS): + for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS): if offload_stream is not None: offload_stream.synchronize() synchronize() + for mmap_obj in DIRTY_MMAPS: + mmap_obj.bounce() + DIRTY_MMAPS.clear() + + for pin_buffer in STREAM_PIN_BUFFERS.values(): + TOTAL_PINNED_MEMORY -= pin_buffer.size + TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY) + + for loaded_model in current_loaded_models: + model = loaded_model.model + if model is not None and model.is_dynamic(): + model.model.dynamic_pins[model.load_device]["active"] = False + model.partially_unload_ram(1e30, subsets=[ "patches" ]) + model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0]) + STREAM_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear() + STREAM_PIN_BUFFERS.clear() soft_empty_cache() def get_offload_stream(device): @@ -1280,7 +1317,7 @@ def sync_stream(device, stream): current_stream(device).wait_stream(stream) -def cast_to_gathered(tensors, r, non_blocking=False, stream=None): +def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None): wf_context = nullcontext() if stream is not None: wf_context = stream @@ -1288,17 +1325,20 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None): wf_context = wf_context.as_context(stream) dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) + dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None with wf_context: for tensor in tensors: dest_view = dest_views.pop(0) + dest2_view = dest2_views.pop(0) if dest2_views is not None else None if tensor is None: continue - if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view): + if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view, stream=stream, destination2=dest2_view): continue storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage() - if hasattr(storage, "_comfy_tensor_mmap_touched"): - storage._comfy_tensor_mmap_touched = True + mark_mmap_dirty(storage) dest_view.copy_(tensor, non_blocking=non_blocking) + if dest2_view is not None: + dest2_view.copy_(dest_view, non_blocking=non_blocking) def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): @@ -1339,14 +1379,18 @@ TOTAL_PINNED_MEMORY = 0 MAX_PINNED_MEMORY = -1 if not args.disable_pinned_memory: if is_nvidia() or is_amd(): + ram = get_total_memory(torch.device("cpu")) if WINDOWS: - MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.40 # Windows limit is apparently 50% + MAX_PINNED_MEMORY = ram * 0.40 # Windows limit is apparently 50% else: - MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.90 + MAX_PINNED_MEMORY = ram * 0.90 logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"]) +def pinned_hostbuf_size(size): + return max(0, int(min(size, MAX_PINNED_MEMORY) * 2)) + def discard_cuda_async_error(): try: a = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) @@ -1378,8 +1422,8 @@ def pin_memory(tensor): return False size = tensor.nbytes - if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: - return False + comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM) + ensure_pin_registerable(size) ptr = tensor.data_ptr() if ptr == 0: @@ -1416,7 +1460,8 @@ def unpin_memory(tensor): return False if torch.cuda.cudart().cudaHostUnregister(ptr) == 0: - TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr) + size = PINNED_MEMORY.pop(ptr) + TOTAL_PINNED_MEMORY -= size return True else: logging.warning("Unpin error.") diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4f9d8403e..c8ed02e70 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -35,6 +35,7 @@ import comfy.model_management import comfy.ops import comfy.patcher_extension import comfy.utils +import comfy_aimdo.host_buffer from comfy.comfy_types import UnetWrapperFunction from comfy.quant_ops import QuantizedTensor from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP @@ -117,6 +118,8 @@ def string_to_seed(data): return comfy.utils.string_to_seed(data) class LowVramPatch: + is_lowvram_patch = True + def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key self.patches = patches @@ -124,11 +127,21 @@ class LowVramPatch: self.set_func = set_func self.prepared_patches = None - def prepare(self, allocate_buffer, stream): - self.prepared_patches = [ - (patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4]) + def memory_required(self): + counter = [0] + for patch in self.patches[self.key]: + comfy.lora.prefetch_prepared_value(patch[1], counter, None, None, False) + return counter[0] + + def prepare(self, destination, stream, copy=True, commit=True): + counter = [0] + prepared_patches = [ + (patch[0], comfy.lora.prefetch_prepared_value(patch[1], counter, destination, stream, copy), patch[2], patch[3], patch[4]) for patch in self.patches[self.key] ] + if commit: + self.prepared_patches = prepared_patches + return prepared_patches def clear_prepared(self): self.prepared_patches = None @@ -341,9 +354,6 @@ class ModelPatcher: self.size = comfy.model_management.module_size(self.model) return self.size - def model_mmap_residency(self, free=False): - return comfy.model_management.module_mmap_residency(self.model, free=free) - def loaded_size(self): return self.model.model_loaded_weight_memory @@ -1118,8 +1128,12 @@ class ModelPatcher: # Pinned memory pressure tracking is only implemented for DynamicVram loading return 0 + def loaded_ram_size(self): + # Loaded RAM pressure tracking is only implemented for DynamicVram loading + return 0 + def partially_unload_ram(self, ram_to_unload): - pass + return 0 def detach(self, unpatch_all=True): self.eject_model() @@ -1550,6 +1564,16 @@ class ModelPatcherDynamic(ModelPatcher): super().__init__(model, load_device, offload_device, size, weight_inplace_update) if not hasattr(self.model, "dynamic_vbars"): self.model.dynamic_vbars = {} + if not hasattr(self.model, "dynamic_pins"): + self.model.dynamic_pins = {} + if self.load_device not in self.model.dynamic_pins: + self.model.dynamic_pins[self.load_device] = { + "weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), + "patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), + "hostbufs_initialized": False, + "failed": False, + "active": False, + } self.non_dynamic_delegate_model = None assert load_device is not None @@ -1611,6 +1635,14 @@ class ModelPatcherDynamic(ModelPatcher): self.unpatch_hooks() vbar = self._vbar_get(create=True) + pin_state = self.model.dynamic_pins[self.load_device] + if not pin_state["hostbufs_initialized"]: + hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size()) + pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0]) + pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0]) + pin_state["hostbufs_initialized"] = True + pin_state["failed"] = False + pin_state["active"] = True if vbar is not None: vbar.prioritize() @@ -1636,7 +1668,9 @@ class ModelPatcherDynamic(ModelPatcher): if key in self.patches: if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape: return (True, 0) - setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) + lowvram_patch = LowVramPatch(key, self.patches) + lowvram_patch._pin_state = pin_state + setattr(m, param_key + "_lowvram_function", lowvram_patch) num_patches += 1 else: setattr(m, param_key + "_lowvram_function", None) @@ -1653,6 +1687,9 @@ class ModelPatcherDynamic(ModelPatcher): def force_load_param(self, param_key, device_to): key = key_param_name_to_key(n, param_key) + weight, _, _ = get_key_weight(self.model, key) + if weight is None: + return if key in self.backup: comfy.utils.set_attr_param(self.model, key, self.backup[key].weight) self.patch_weight_to_device(key, device_to=device_to, force_cast=True) @@ -1662,17 +1699,23 @@ class ModelPatcherDynamic(ModelPatcher): if hasattr(m, "comfy_cast_weights"): m.comfy_cast_weights = True - m.pin_failed = False m.seed_key = n + m._pin_state = pin_state set_dirty(m, dirty) - force_load, v_weight_size = setup_param(self, m, n, "weight") - force_load_bias, v_weight_bias = setup_param(self, m, n, "bias") - force_load = force_load or force_load_bias - v_weight_size += v_weight_bias + #Models that mix tiny and giant weights can causing lopsided stream buffer + #rotations and stall. force the tinys over. + if module_mem > 16 * 1024: + force_load, v_weight_size = setup_param(self, m, n, "weight") + force_load_bias, v_weight_bias = setup_param(self, m, n, "bias") + force_load = force_load or force_load_bias + v_weight_size += v_weight_bias + if force_load: + logging.info(f"Module {n} has resizing Lora - force loading") + else: + force_load=True if force_load: - logging.info(f"Module {n} has resizing Lora - force loading") force_load_param(self, "weight", device_to) force_load_param(self, "bias", device_to) else: @@ -1740,23 +1783,58 @@ class ModelPatcherDynamic(ModelPatcher): return freed - def pinned_memory_size(self): - total = 0 - loading = self._load_list(for_dynamic=True) - for x in loading: - _, _, _, _, m, _ = x - pin = comfy.pinned_memory.get_pin(m) - if pin is not None: - total += pin.numel() * pin.element_size() - return total + def loaded_ram_size(self): + return (self.model.dynamic_pins[self.load_device]["weights"][0].size + + self.model.dynamic_pins[self.load_device]["patches"][0].size) - def partially_unload_ram(self, ram_to_unload): - loading = self._load_list(for_dynamic=True, default_device=self.offload_device) - for x in loading: - *_, m, _ = x - ram_to_unload -= comfy.pinned_memory.unpin_memory(m) - if ram_to_unload <= 0: - return + def pinned_memory_size(self): + return (self.model.dynamic_pins[self.load_device]["weights"][3][0] + + self.model.dynamic_pins[self.load_device]["patches"][3][0]) + + def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]): + freed = 0 + pin_state = self.model.dynamic_pins[self.load_device] + for subset in subsets: + hostbuf, stack, stack_split, pinned_size = pin_state[subset] + split = stack_split[0] + while split >= 0: + module, offset = stack[split] + split -= 1 + stack_split[0] = split + if not module._pin_registered: + continue + size = module._pin.numel() * module._pin.element_size() + if torch.cuda.cudart().cudaHostUnregister(module._pin.data_ptr()) != 0: + comfy.model_management.discard_cuda_async_error() + continue + module._pin_registered = False + comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size) + pinned_size[0] = max(0, pinned_size[0] - size) + freed += size + ram_to_unload -= size + if ram_to_unload <= 0: + return freed + return freed + + def partially_unload_ram(self, ram_to_unload, subsets=[ "weights", "patches" ]): + freed = 0 + pin_state = self.model.dynamic_pins[self.load_device] + for subset in subsets: + hostbuf, stack, stack_split, pinned_size = pin_state[subset] + while len(stack) > 0: + module, offset = stack.pop() + size = module._pin.numel() * module._pin.element_size() + del module._pin + hostbuf.truncate(offset, do_unregister=module._pin_registered) + stack_split[0] = min(stack_split[0], len(stack) - 1) + if module._pin_registered: + comfy.model_management.TOTAL_PINNED_MEMORY = max(0, comfy.model_management.TOTAL_PINNED_MEMORY - size) + pinned_size[0] = max(0, pinned_size[0] - size) + freed += size + ram_to_unload -= size + if ram_to_unload <= 0: + return freed + return freed def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False): #This isn't used by the core at all and can only be to load a model out of diff --git a/comfy/ops.py b/comfy/ops.py index eae3bd873..9bcd6c900 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -75,6 +75,8 @@ except: cast_to = comfy.model_management.cast_to #TODO: remove once no more references +STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024 + def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) @@ -91,6 +93,9 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin offload_stream = None cast_buffer = None cast_buffer_offset = 0 + stream_pin_hostbuf = None + stream_pin_offset = 0 + stream_pin_queue = [] def ensure_offload_stream(module, required_size, check_largest): nonlocal offload_stream @@ -124,6 +129,22 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin cast_buffer_offset += buffer_size return buffer + def get_stream_pin_buffer_offset(buffer_size): + nonlocal stream_pin_hostbuf + nonlocal stream_pin_offset + + if buffer_size == 0 or offload_stream is None: + return None + + if stream_pin_hostbuf is None: + stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream) + if stream_pin_hostbuf is None: + return None + + offset = stream_pin_offset + stream_pin_offset += buffer_size + return offset + for s in comfy_modules: signature = comfy_aimdo.model_vbar.vbar_fault(s._v) resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) @@ -162,23 +183,47 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin if xfer_dest is None: xfer_dest = get_cast_buffer(dest_size) - if signature is None and pin is None: - comfy.pinned_memory.pin_memory(s) - pin = comfy.pinned_memory.get_pin(s) - else: - pin = None + def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream): + if xfer_source is not None: + if getattr(xfer_source, "is_lowvram_patch", False): + xfer_source.prepare(xfer_dest, stream, copy=True, commit=False) + else: + comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream) - if pin is not None: - comfy.model_management.cast_to_gathered(xfer_source, pin) - xfer_source = [ pin ] - #send it over - comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) + def handle_pin(m, pin, source, dest, subset="weights", size=None): + if pin is not None: + cast_maybe_lowvram_patch([pin], dest, offload_stream) + return + if signature is None: + comfy.pinned_memory.pin_memory(m, subset=subset, size=size) + pin = comfy.pinned_memory.get_pin(m, subset=subset) + if pin is not None: + if isinstance(source, list): + comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest) + else: + cast_maybe_lowvram_patch(source, pin, None) + cast_maybe_lowvram_patch([ pin ], dest, offload_stream) + return + if pin is None: + pin_offset = get_stream_pin_buffer_offset(size) + if pin_offset is not None: + stream_pin_queue.append((source, pin_offset, size, dest)) + return + cast_maybe_lowvram_patch(source, dest, offload_stream) + + handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size) for param_key in ("weight", "bias"): - lowvram_fn = getattr(s, param_key + "_lowvram_function", None) - if lowvram_fn is not None: + lowvram_source = getattr(s, param_key + "_lowvram_function", None) + if lowvram_source is not None: ensure_offload_stream(s, cast_buffer_offset, False) - lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream) + lowvram_size = lowvram_source.memory_required() + lowvram_dest = get_cast_buffer(lowvram_size) + lowvram_source.prepare(lowvram_dest, None, copy=False, commit=True) + + pin = comfy.pinned_memory.get_pin(lowvram_source, subset="patches") + handle_pin(lowvram_source, pin, lowvram_source, lowvram_dest, subset="patches", size=lowvram_size) + prefetch["xfer_dest"] = xfer_dest prefetch["cast_dest"] = cast_dest @@ -186,6 +231,23 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin prefetch["needs_cast"] = needs_cast s._prefetch = prefetch + if stream_pin_offset > 0: + if stream_pin_hostbuf.size < stream_pin_offset: + if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM): + for xfer_source, _, _, xfer_dest in stream_pin_queue: + cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream) + return offload_stream + stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf) + stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf + for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue: + pin = stream_pin_tensor[pin_offset:pin_offset + pin_size] + if isinstance(xfer_source, list): + comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest) + else: + cast_maybe_lowvram_patch(xfer_source, pin, None) + comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream) + stream_pin_hostbuf._comfy_event = offload_stream.record_event() + return offload_stream diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 6d3ba367a..0e8f573ba 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -2,42 +2,62 @@ import comfy.model_management import comfy.memory_management import comfy_aimdo.host_buffer import comfy_aimdo.torch +import torch from comfy.cli_args import args -def get_pin(module): - return getattr(module, "_pin", None) +def get_pin(module, subset="weights"): + pin = getattr(module, "_pin", None) + if pin is None or module._pin_registered or args.disable_pinned_memory: + return pin -def pin_memory(module): - if module.pin_failed or args.disable_pinned_memory or get_pin(module) is not None: + _, _, stack_split, pinned_size = module._pin_state[subset] + size = pin.nbytes + comfy.model_management.ensure_pin_registerable(size) + + if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0: + comfy.model_management.discard_cuda_async_error() + return pin + + module._pin_registered = True + stack_split[0] = max(stack_split[0], module._pin_stack_index) + comfy.model_management.TOTAL_PINNED_MEMORY += size + pinned_size[0] += size + return pin + +def pin_memory(module, subset="weights", size=None): + pin_state = module._pin_state + if args.disable_pinned_memory: return - size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) + pin = get_pin(module, subset) + if pin is not None or pin_state["failed"]: + return - if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY: - module.pin_failed = True + hostbuf, stack, stack_split, pinned_size = pin_state[subset] + if size is None: + size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) + offset = hostbuf.size + registerable_size = size + max(0, hostbuf.size - pinned_size[0]) + + comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM) + if (not comfy.model_management.ensure_pin_budget(size) or + not comfy.model_management.ensure_pin_registerable(registerable_size)): + pin_state["failed"] = True return False try: - hostbuf = comfy_aimdo.host_buffer.HostBuffer(size) + hostbuf.extend(size=size) except RuntimeError: - module.pin_failed = True + pin_state["failed"] = True return False - module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf) - module._pin_hostbuf = hostbuf + module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size] + module._pin.untyped_storage()._comfy_hostbuf = hostbuf + stack.append((module, offset)) + module._pin_registered = True + module._pin_stack_index = len(stack) - 1 + stack_split[0] = max(stack_split[0], module._pin_stack_index) comfy.model_management.TOTAL_PINNED_MEMORY += size + pinned_size[0] += size return True - -def unpin_memory(module): - if get_pin(module) is None: - return 0 - size = module._pin.numel() * module._pin.element_size() - - comfy.model_management.TOTAL_PINNED_MEMORY -= size - if comfy.model_management.TOTAL_PINNED_MEMORY < 0: - comfy.model_management.TOTAL_PINNED_MEMORY = 0 - - del module._pin - del module._pin_hostbuf - return size diff --git a/comfy/utils.py b/comfy/utils.py index 66682690a..00e382fac 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -113,7 +113,6 @@ def load_safetensors(ckpt): "_comfy_tensor_file_slice", comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start)) setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv)) - setattr(storage, "_comfy_tensor_mmap_touched", False) sd[name] = tensor return sd, header.get("__metadata__", {}), @@ -1451,4 +1450,3 @@ def deepcopy_list_dict(obj, memo=None): memo[obj_id] = res return res - diff --git a/comfy/windows.py b/comfy/windows.py deleted file mode 100644 index 213dc481d..000000000 --- a/comfy/windows.py +++ /dev/null @@ -1,52 +0,0 @@ -import ctypes -import logging -import psutil -from ctypes import wintypes - -import comfy_aimdo.control - -psapi = ctypes.WinDLL("psapi") -kernel32 = ctypes.WinDLL("kernel32") - -class PERFORMANCE_INFORMATION(ctypes.Structure): - _fields_ = [ - ("cb", wintypes.DWORD), - ("CommitTotal", ctypes.c_size_t), - ("CommitLimit", ctypes.c_size_t), - ("CommitPeak", ctypes.c_size_t), - ("PhysicalTotal", ctypes.c_size_t), - ("PhysicalAvailable", ctypes.c_size_t), - ("SystemCache", ctypes.c_size_t), - ("KernelTotal", ctypes.c_size_t), - ("KernelPaged", ctypes.c_size_t), - ("KernelNonpaged", ctypes.c_size_t), - ("PageSize", ctypes.c_size_t), - ("HandleCount", wintypes.DWORD), - ("ProcessCount", wintypes.DWORD), - ("ThreadCount", wintypes.DWORD), - ] - -def get_free_ram(): - #Windows is way too conservative and chalks recently used uncommitted model RAM - #as "in-use". So, calculate free RAM for the sake of general use as the greater of: - # - #1: What psutil says - #2: Total Memory - (Committed Memory - VRAM in use) - # - #We have to subtract VRAM in use from the comitted memory as WDDM creates a naked - #commit charge for all VRAM used just incase it wants to page it all out. This just - #isn't realistic so "overcommit" on our calculations by just subtracting it off. - - pi = PERFORMANCE_INFORMATION() - pi.cb = ctypes.sizeof(pi) - - if not psapi.GetPerformanceInfo(ctypes.byref(pi), pi.cb): - logging.warning("WARNING: Failed to query windows performance info. RAM usage may be sub optimal") - return psutil.virtual_memory().available - - committed = pi.CommitTotal * pi.PageSize - total = pi.PhysicalTotal * pi.PageSize - - return max(psutil.virtual_memory().available, - total - (committed - comfy_aimdo.control.get_total_vram_usage())) - diff --git a/execution.py b/execution.py index 4c7de2e84..5246d651c 100644 --- a/execution.py +++ b/execution.py @@ -2,6 +2,7 @@ import copy import heapq import inspect import logging +import psutil import sys import threading import time @@ -727,6 +728,7 @@ class PromptExecutor: self._notify_prompt_lifecycle("start", prompt_id) ram_headroom = int(self.cache_args["ram"] * (1024 ** 3)) + ram_inactive_headroom = int(self.cache_args["ram_inactive"] * (1024 ** 3)) ram_release_callback = self.caches.outputs.ram_release if self.cache_type == CacheType.RAM_PRESSURE else None comfy.memory_management.set_ram_cache_release_state(ram_release_callback, ram_headroom) @@ -780,8 +782,14 @@ class PromptExecutor: execution_list.complete_node_execution() if self.cache_type == CacheType.RAM_PRESSURE: - comfy.model_management.free_memory(0, None, pins_required=ram_headroom, ram_required=ram_headroom) - ram_release_callback(ram_headroom, free_active=True) + ram_release_callback(ram_inactive_headroom) + ram_shortfall = ram_headroom - psutil.virtual_memory().available + freed = comfy.model_management.free_pins(ram_shortfall + 512 * (1024 ** 2)) + if freed < ram_shortfall: + if freed > 64 * (1024 ** 2): + # AIMDO MEM_DECOMMIT can outrun psutil.available catching up. + time.sleep(0.05) + ram_release_callback(ram_headroom, free_active=True) else: # Only execute when the while-loop ends without break # Send cached UI for intermediate output nodes that weren't executed diff --git a/main.py b/main.py index a6fdaf43c..1e47cab84 100644 --- a/main.py +++ b/main.py @@ -283,19 +283,25 @@ def _collect_output_absolute_paths(history_result: dict) -> list[str]: def prompt_worker(q, server_instance): current_time: float = 0.0 - cache_ram = args.cache_ram - if cache_ram < 0: + cache_ram = 0 + cache_ram_inactive = 0 + if not args.cache_classic and not args.cache_none and args.cache_lru <= 0: cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0)) + cache_ram_inactive = min(96.0, max(12.0, comfy.model_management.total_ram * 0.75 / 1024.0)) + if len(args.cache_ram) > 0: + cache_ram = args.cache_ram[0] + if len(args.cache_ram) > 1: + cache_ram_inactive = args.cache_ram[1] - cache_type = execution.CacheType.CLASSIC - if args.cache_lru > 0: + cache_type = execution.CacheType.RAM_PRESSURE + if args.cache_classic: + cache_type = execution.CacheType.CLASSIC + elif args.cache_lru > 0: cache_type = execution.CacheType.LRU - elif cache_ram > 0: - cache_type = execution.CacheType.RAM_PRESSURE elif args.cache_none: cache_type = execution.CacheType.NONE - e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram } ) + e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : cache_ram, "ram_inactive" : cache_ram_inactive } ) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 diff --git a/requirements.txt b/requirements.txt index 1c87690da..d2986eda8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0 filelock av>=14.2.0 comfy-kitchen>=0.2.8 -comfy-aimdo==0.3.0 +comfy-aimdo==0.4.3 requests simpleeval>=1.0.0 blake3 diff --git a/tests/execution/test_async_nodes.py b/tests/execution/test_async_nodes.py index c771b4b36..54660c112 100644 --- a/tests/execution/test_async_nodes.py +++ b/tests/execution/test_async_nodes.py @@ -14,7 +14,6 @@ from tests.execution.test_execution import ComfyClient, run_warmup class TestAsyncNodes: @fixture(scope="class", autouse=True, params=[ (False, 0), - (True, 0), (True, 100), ]) def _server(self, args_pytest, request): @@ -29,6 +28,8 @@ class TestAsyncNodes: use_lru, lru_size = request.param if use_lru: pargs += ['--cache-lru', str(lru_size)] + else: + pargs += ['--cache-classic'] # Running server with args: pargs p = subprocess.Popen(pargs) yield diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index f73ca7e3c..15e2304fc 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -183,8 +183,7 @@ class TestExecution: # Initialize server and client # @fixture(scope="class", autouse=True, params=[ - { "extra_args" : [], "should_cache_results" : True }, - { "extra_args" : ["--cache-lru", 0], "should_cache_results" : True }, + { "extra_args" : ["--cache-classic"], "should_cache_results" : True }, { "extra_args" : ["--cache-lru", 100], "should_cache_results" : True }, { "extra_args" : ["--cache-none"], "should_cache_results" : False }, ]) From 95fdc6cf910f809e39edc3254470e619ffa9dbf8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Wed, 20 May 2026 17:17:55 -0700 Subject: [PATCH 10/45] Repo security stuff. (#14019) --- CODEOWNERS | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CODEOWNERS b/CODEOWNERS index 946dbf946..043c0ec75 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,2 +1,5 @@ -# Admins * @comfyanonymous @kosinkadink @guill @alexisrolland @rattus128 @kijai + +/CODEOWNERS @comfyanonymous +/.ci/ @comfyanonymous +/.github/ @comfyanonymous From 9f9b32ed978045262b71e6b27093e4ae80c29804 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Wed, 20 May 2026 21:22:12 -0700 Subject: [PATCH 11/45] feat: add OAuth 2.1 + RFC 7591 DCR endpoints to openapi.yaml (#14026) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add the OAuth 2.1 authorization flow and RFC 7591 Dynamic Client Registration endpoints to the shared spec, alongside the existing auth-tagged operations (/api/auth/session, /api/auth/token, /.well-known/jwks.json). All tagged x-runtime: [cloud] with a [cloud-only] description prefix, following the established convention for cloud-runtime-only operations. Endpoints: - GET /.well-known/oauth-authorization-server (RFC 8414 metadata) - GET /.well-known/oauth-protected-resource (RFC 9728 metadata) - GET /oauth/authorize (consent challenge) - POST /oauth/authorize (consent submission) - POST /oauth/token (RFC 6749 §3.2) - POST /oauth/register (RFC 7591 §3.1 DCR) Component schemas added: - OAuthAuthorizationServerMetadata - OAuthProtectedResourceMetadata - OAuthConsentChallenge, OAuthConsentChallengeWorkspace - OAuthAuthorizeRedirectResponse - OAuthTokenResponse, OAuthTokenError - OAuthRegisterRequest, OAuthRegisterResponse, OAuthRegisterError These endpoints are implemented in the cloud runtime today and are called by browser frontends rendering the consent UI and by MCP-spec-compliant clients (Claude Desktop, Cursor, etc.) doing auto-discovery + self-registration. Documenting them in the shared spec lets the cloud frontend generate types directly from this spec instead of maintaining a parallel definition. Spectral lints clean (0 errors). The hint-level findings on OAuthTokenError / OAuthRegisterError ("standard error schema") match the same hint on CloudError — these are protocol-specific RFC-shaped errors, not generic application errors. --- openapi.yaml | 608 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 608 insertions(+) diff --git a/openapi.yaml b/openapi.yaml index 2658b9b86..92f7eaccc 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -3790,6 +3790,295 @@ paths: schema: $ref: "#/components/schemas/JwksResponse" + # --------------------------------------------------------------------------- + # OAuth 2.1 / RFC 7591 Dynamic Client Registration (cloud) + # --------------------------------------------------------------------------- + /.well-known/oauth-authorization-server: + get: + operationId: getOAuthAuthorizationServer + tags: [auth] + summary: "[cloud-only] OAuth 2.1 authorization-server metadata (RFC 8414)" + description: "[cloud-only] Public metadata document for OAuth 2.1 clients. Cached 5 minutes." + x-runtime: [cloud] + security: [] + responses: + "200": + description: Authorization-server metadata + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthAuthorizationServerMetadata" + "404": + description: OAuth disabled + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /.well-known/oauth-protected-resource: + get: + operationId: getOAuthProtectedResource + tags: [auth] + summary: "[cloud-only] OAuth 2.1 protected-resource metadata (RFC 9728)" + description: "[cloud-only] Public metadata describing the currently advertised protected resource. Cached 5 minutes." + x-runtime: [cloud] + security: [] + responses: + "200": + description: Protected-resource metadata + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthProtectedResourceMetadata" + "404": + description: OAuth disabled or no active resource configured + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /oauth/authorize: + get: + operationId: getOAuthAuthorize + tags: [auth] + summary: "[cloud-only] Begin or resume an OAuth 2.1 authorization request" + description: | + [cloud-only] Two modes: + - **Initial entry** (OAuth params present): validates client/redirect/resource/scopes, persists a server-side authorization-request row, and either redirects (no session / unverified email) to the configured frontend login URL carrying only the opaque `oauth_request_id`, or returns the JSON consent challenge for the frontend to render. + - **Resume** (`oauth_request_id` present): loads the server-side row, fails closed if expired/consumed/unknown, returns the JSON consent challenge. Browser-replayed OAuth params are intentionally ignored. + + The frontend renders the consent UI from the JSON payload and POSTs the user's decision back to this endpoint. + x-runtime: [cloud] + security: [] + parameters: + - { name: response_type, in: query, required: false, schema: { type: string } } + - { name: client_id, in: query, required: false, schema: { type: string } } + - { name: redirect_uri, in: query, required: false, schema: { type: string } } + - { name: scope, in: query, required: false, schema: { type: string } } + - name: state + in: query + required: false + schema: { type: string } + description: | + RFC 6749 §10.12 marks `state` as RECOMMENDED. Cloud hardening makes it REQUIRED on the initial-entry path (omitted only on the resume path where `oauth_request_id` is supplied instead). This parameter is `required: false` at the spec level only because the operation is dual-mode (initial entry vs. resume); the runtime rejects empty `state` on the initial-entry path with a stable `invalid_request` 400. + - { name: code_challenge, in: query, required: false, schema: { type: string } } + - { name: code_challenge_method, in: query, required: false, schema: { type: string } } + - { name: resource, in: query, required: false, schema: { type: string } } + - { name: oauth_request_id, in: query, required: false, schema: { type: string } } + responses: + "200": + description: Consent challenge payload (session present, email verified). Frontend renders the consent UI from this payload and POSTs back to /oauth/authorize. + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthConsentChallenge" + "302": + description: Redirect to login (no session / unverified email) or to registered redirect_uri (pre-validated client error) + headers: + Location: + schema: + type: string + "400": + description: Invalid authorize request (pre-redirect failure — unknown client, redirect mismatch, malformed params) + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: OAuth disabled + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + post: + operationId: postOAuthAuthorize + tags: [auth] + summary: "[cloud-only] Submit OAuth consent decision" + description: | + [cloud-only] JSON-only consent submission. The handler verifies the per-row CSRF token, atomically marks the authorization request consumed (single-use covers both allow and deny paths), then returns the redirect URL the browser must navigate to. The URL contains either `code` + original `state` for allow, or the RFC 6749 §5.2 error and `state` for deny. + + Workspace membership is re-checked at submission time. Consent is persisted keyed by `(user_id, client_id, resource_id, workspace_id)`; broadening the previously approved scope set requires a fresh consent flow. + x-runtime: [cloud] + security: [] + requestBody: + required: true + content: + application/json: + schema: + type: object + required: [oauth_request_id, csrf_token, decision, workspace_id] + properties: + oauth_request_id: { type: string, format: uuid } + csrf_token: { type: string } + decision: { type: string, enum: [allow, deny] } + workspace_id: { type: string } + responses: + "200": + description: Redirect URL for the frontend to navigate to (allow → with code+state; deny → with error+state) + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthAuthorizeRedirectResponse" + "400": + description: Bad request (CSRF mismatch, expired/consumed request, inaccessible workspace) + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "403": + description: Scope broadening on consent re-grant — fresh consent flow required + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "404": + description: OAuth disabled + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /oauth/token: + post: + operationId: postOAuthToken + tags: [auth] + summary: "[cloud-only] Exchange authorization code or refresh token for a resource-bound access token" + description: | + [cloud-only] OAuth 2.1 token endpoint (RFC 6749 §3.2). Public clients only — `client_secret` is rejected. + + Two grant types are supported: + - `authorization_code` — exchanges the code minted by `/oauth/authorize` (with PKCE verifier) for an access token + first refresh token. Single-use; reuse fails closed. + - `refresh_token` — rotates the refresh token. Old token immediately invalid; presenting an already-rotated token revokes the entire token family and emits a security metric. + + Both grant types re-validate canonical user state, current workspace membership, and the resource's active flag at every mint. A code or refresh token bound to a deactivated resource fails closed. + + Errors follow RFC 6749 §5.2. Logs never contain raw codes, refresh tokens, or minted tokens. + + Per RFC 6749 §5.1, every 200 and 400 response carries `Cache-Control: no-store` and `Pragma: no-cache` so intermediaries cannot cache token-bearing or state-change-reason responses. + x-runtime: [cloud] + security: [] + requestBody: + required: true + content: + application/x-www-form-urlencoded: + schema: + type: object + required: [grant_type, client_id] + properties: + grant_type: { type: string, enum: [authorization_code, refresh_token] } + client_id: { type: string } + code: { type: string } + redirect_uri: { type: string } + code_verifier: { type: string } + refresh_token: { type: string } + scope: { type: string } + client_secret: { type: string } + responses: + "200": + description: New token pair + headers: + Cache-Control: + schema: + type: string + description: 'Always "no-store" per RFC 6749 §5.1' + Pragma: + schema: + type: string + description: 'Always "no-cache" per RFC 6749 §5.1' + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthTokenResponse" + "400": + description: RFC 6749 §5.2 error + headers: + Cache-Control: + schema: + type: string + description: 'Always "no-store" per RFC 6749 §5.1' + Pragma: + schema: + type: string + description: 'Always "no-cache" per RFC 6749 §5.1' + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthTokenError" + "404": + description: OAuth disabled + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + + /oauth/register: + post: + operationId: postOAuthRegister + tags: [auth] + summary: "[cloud-only] Dynamic Client Registration (RFC 7591)" + description: | + [cloud-only] Public, unauthenticated, insert-only RFC 7591 §3.1 client registration. Used by MCP-spec-compliant clients to self-register a public OAuth client without operator involvement. + + Policy: + + - Public clients only — `token_endpoint_auth_method` is forced to `none`. Confidential-client registration is out of scope this phase. + - Server-owned `resource_grants`. Caller-supplied `scope` or `resource_grants` is rejected as `invalid_client_metadata` (would be a privilege-escalation surface). Dynamic clients receive the same scopes the active resource publishes. + - Application-type-aware redirect URI policy. `application_type=native` accepts loopback (`127.0.0.1`, `::1`, `localhost`) and reverse-DNS-shaped custom schemes; `application_type=web` accepts HTTPS to hosts in an operator-controlled allowlist only. `application_type` is REQUIRED on the request — missing or empty rejects with `invalid_client_metadata`. + - Anti-impersonation: reserved client names are rejected from third parties via NFKC-folded compare. + - Generated `client_id` carries a stable prefix to distinguish dynamic from seeded clients in audit logs. + - Cache-Control: `no-store` on every 201 and 400 response (the response carries fresh credentials and rejection reasons). + x-runtime: [cloud] + security: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthRegisterRequest" + responses: + "201": + description: Registered. Body echoes the metadata RFC 7591 §3.2.1 requires. + headers: + Cache-Control: + schema: + type: string + description: 'Always "no-store"' + Pragma: + schema: + type: string + description: 'Always "no-cache"' + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthRegisterResponse" + "400": + description: RFC 7591 §3.2.2 invalid client metadata + headers: + Cache-Control: + schema: + type: string + description: 'Always "no-store"' + Pragma: + schema: + type: string + description: 'Always "no-cache"' + content: + application/json: + schema: + $ref: "#/components/schemas/OAuthRegisterError" + "404": + description: OAuth disabled + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + "503": + description: No active resource is configured — DCR cannot mint a usable client until an active resource row is seeded. + content: + application/json: + schema: + $ref: "#/components/schemas/CloudError" + # --------------------------------------------------------------------------- # Billing (cloud) # --------------------------------------------------------------------------- @@ -7424,6 +7713,325 @@ components: description: RSA exponent (base64url) additionalProperties: true + OAuthAuthorizationServerMetadata: + type: object + x-runtime: [cloud] + description: "[cloud-only] OAuth 2.1 authorization-server metadata (RFC 8414)." + required: + - issuer + - authorization_endpoint + - token_endpoint + - jwks_uri + - response_types_supported + - grant_types_supported + - code_challenge_methods_supported + - token_endpoint_auth_methods_supported + properties: + issuer: + type: string + format: uri + authorization_endpoint: + type: string + format: uri + token_endpoint: + type: string + format: uri + jwks_uri: + type: string + format: uri + registration_endpoint: + type: string + format: uri + description: "[cloud-only] RFC 7591 §3.1 Dynamic Client Registration endpoint. Advertised so MCP-spec-compliant clients can auto-discover and self-register without operator involvement. Present only when DCR is enabled." + response_types_supported: + type: array + items: + type: string + grant_types_supported: + type: array + items: + type: string + code_challenge_methods_supported: + type: array + items: + type: string + token_endpoint_auth_methods_supported: + type: array + items: + type: string + scopes_supported: + type: array + items: + type: string + + OAuthProtectedResourceMetadata: + type: object + x-runtime: [cloud] + description: "[cloud-only] OAuth 2.1 protected-resource metadata (RFC 9728)." + required: + - resource + - authorization_servers + - scopes_supported + properties: + resource: + type: string + format: uri + authorization_servers: + type: array + items: + type: string + format: uri + scopes_supported: + type: array + items: + type: string + bearer_methods_supported: + type: array + items: + type: string + + OAuthConsentChallenge: + type: object + x-runtime: [cloud] + description: "[cloud-only] Server-side state describing the OAuth consent decision the user is being asked to make. Returned by GET /oauth/authorize when a valid session exists; the frontend renders the consent UI from this payload and POSTs the decision back. Browser never sees the original OAuth params on resume." + required: + - oauth_request_id + - csrf_token + - client_display_name + - resource_display_name + - scopes + - workspaces + properties: + oauth_request_id: + type: string + format: uuid + description: Opaque server-side identifier for the authorization-request row. Carried back unchanged in the consent submission. + csrf_token: + type: string + description: Per-row CSRF token bound to this authorization request (not to the session). Must be echoed back on POST. + client_display_name: + type: string + description: Human-readable name of the OAuth client requesting authorization. + resource_display_name: + type: string + description: Human-readable name of the protected resource. + scopes: + type: array + description: Scopes the client is requesting for this resource. The frontend should present these for the user to approve. + items: + type: string + workspaces: + type: array + description: Workspaces the user can select from. Membership is re-checked on POST. + items: + $ref: "#/components/schemas/OAuthConsentChallengeWorkspace" + + OAuthConsentChallengeWorkspace: + type: object + x-runtime: [cloud] + description: "[cloud-only] One workspace option presented in the OAuth consent challenge." + required: [id, name, type, role] + properties: + id: { type: string } + name: { type: string } + type: { type: string, enum: [personal, team] } + role: { type: string, enum: [owner, member] } + + OAuthAuthorizeRedirectResponse: + type: object + x-runtime: [cloud] + description: "[cloud-only] Redirect target produced after a JSON consent submission. The frontend must navigate the browser to this URL so custom-scheme client callbacks work without relying on fetch-visible 302 headers." + required: + - redirect_url + properties: + redirect_url: + type: string + format: uri + description: OAuth client redirect URI with either code+state for allow, or error+state for deny. + + OAuthTokenResponse: + type: object + x-runtime: [cloud] + description: "[cloud-only] RFC 6749 §5.1 successful token response." + required: [access_token, token_type, expires_in, refresh_token, scope] + properties: + access_token: + type: string + description: Resource-bound access token (audience matches the protected resource). + token_type: + type: string + enum: [Bearer] + expires_in: + type: integer + description: Access token lifetime in seconds. + refresh_token: + type: string + description: Opaque refresh token. Rotates on every successful refresh; presenting an already-rotated token revokes the entire family. + scope: + type: string + description: Space-delimited scopes granted with this token. + + OAuthTokenError: + type: object + x-runtime: [cloud] + description: "[cloud-only] RFC 6749 §5.2 error response." + required: [error] + properties: + error: + type: string + description: 'RFC 6749 §5.2 error code: invalid_request, invalid_client, invalid_grant, unauthorized_client, unsupported_grant_type, invalid_scope.' + error_description: + type: string + description: Human-readable, no leak of internal storage state. + + OAuthRegisterRequest: + type: object + x-runtime: [cloud] + additionalProperties: false + description: "[cloud-only] RFC 7591 §2 client metadata document. Only the fields the server honors are listed; presence of `scope` or `resource_grants` in the request is rejected (`invalid_client_metadata`) because those are server-owned for dynamic clients." + required: + - redirect_uris + - application_type + properties: + redirect_uris: + type: array + items: + type: string + minItems: 1 + maxItems: 5 + description: 1–5 redirect URIs. Validated against `application_type` policy. + client_name: + type: string + maxLength: 100 + description: Human-readable name shown in the consent UI. Reserved-name list rejects impersonation of major clients. + application_type: + type: string + enum: [native, web] + description: | + RFC 7591 §2 application_type. **REQUIRED** — clients MUST declare intent; the server does not default this field. `native` for desktop / CLI / MCP-spec-strict clients (loopback redirects); `web` for hosted clients (HTTPS only, host must be allowlisted). A missing or explicitly empty `application_type` rejects with `invalid_client_metadata`. + token_endpoint_auth_method: + type: string + enum: [none] + description: 'Public clients only this phase — must be `none` if present. The server forces `none` regardless.' + grant_types: + type: array + items: + type: string + enum: [authorization_code, refresh_token] + description: Optional. Defaults to `["authorization_code","refresh_token"]`. + response_types: + type: array + items: + type: string + enum: [code] + description: Optional. Defaults to `["code"]`. + scope: + type: string + nullable: true + description: "**REJECTED IF PRESENT.** Dynamic clients do not pick scopes — the server assigns scopes from the active resource's published list. Sending `scope` in the registration body is treated as a privilege-escalation attempt and returns `invalid_client_metadata`." + resource_grants: + type: object + nullable: true + additionalProperties: + type: array + items: + type: string + description: "**REJECTED IF PRESENT.** Same reason as `scope`. The set of resources and scopes a dynamic client may request is server-policy, not request-driven." + client_uri: + type: string + nullable: true + description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." + logo_uri: + type: string + nullable: true + description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." + tos_uri: + type: string + nullable: true + description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." + policy_uri: + type: string + nullable: true + description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." + software_id: + type: string + nullable: true + description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." + software_version: + type: string + nullable: true + description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." + contacts: + type: array + nullable: true + items: + type: string + description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." + jwks: + type: object + nullable: true + additionalProperties: true + description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." + jwks_uri: + type: string + nullable: true + description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." + + OAuthRegisterResponse: + type: object + x-runtime: [cloud] + description: "[cloud-only] RFC 7591 §3.2.1 successful registration response." + required: + - client_id + - client_id_issued_at + - redirect_uris + - grant_types + - response_types + - token_endpoint_auth_method + - application_type + properties: + client_id: + type: string + description: Server-generated client_id. + client_id_issued_at: + type: integer + format: int64 + description: Unix timestamp (seconds) when the client was registered. + client_name: + type: string + redirect_uris: + type: array + items: + type: string + grant_types: + type: array + items: + type: string + response_types: + type: array + items: + type: string + token_endpoint_auth_method: + type: string + enum: [none] + application_type: + type: string + enum: [native, web] + + OAuthRegisterError: + type: object + x-runtime: [cloud] + description: "[cloud-only] RFC 7591 §3.2.2 error response." + required: + - error + properties: + error: + type: string + enum: [invalid_redirect_uri, invalid_client_metadata] + error_description: + type: string + nullable: true + BillingBalance: type: object x-runtime: [cloud] From ea174d3f120bf43c0219eb341e9373834036083c Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Wed, 20 May 2026 21:28:16 -0700 Subject: [PATCH 12/45] fix(openapi): correct POST /api/assets/import to importPublishedAssets (#14027) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The operation at POST /api/assets/import was defined as `importAssets` with a URL-list body shape, but no runtime actually serves that operation at this path. The cloud runtime serves a different operation here — `importPublishedAssets` — which imports published-workflow assets into the caller's library by ID, not by URL. Cloud's URL-based asset ingestion lives at separate paths (POST /assets/download + GET /assets/remote-metadata) tracked elsewhere; nothing in this PR affects that work. Changes: - Replace the operation at POST /api/assets/import with `importPublishedAssets`, taking ImportPublishedAssetsRequest (published_asset_ids + optional share_id) and returning ImportPublishedAssetsResponse (list of AssetInfo). - Remove the unused AssetImportRequest component schema (no other references in the spec). - Operation and schemas tagged x-runtime: [cloud] with [cloud-only] description prefix, matching the existing convention for cloud-runtime-only operations elsewhere in the spec. Spectral lint passes (0 errors); the two hint-level findings on the spec are pre-existing and unrelated. No FE consumer references AssetImportRequest today; this is a pure spec correction to match what the cloud runtime actually serves. --- openapi.yaml | 59 ++++++++++++++++++++++++++-------------------------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index 92f7eaccc..0e7a9b4a7 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -2514,37 +2514,25 @@ paths: /api/assets/import: post: - operationId: importAssets + operationId: importPublishedAssets tags: [assets] - summary: Import assets from external URLs - description: "[cloud-only] Imports one or more assets from external URLs into the cloud asset store." + summary: "[cloud-only] Import published assets into the caller's library" + description: | + [cloud-only] Imports the specified published assets into the caller's asset library. New DB records reference the same storage objects; no file copying occurs. Assets the caller already owns (by hash) are deduplicated. The `id` field on each returned `AssetInfo` is the caller's newly-created private asset ID, not the published asset ID supplied in the request. x-runtime: [cloud] requestBody: required: true content: application/json: schema: - type: object - required: - - imports - properties: - imports: - type: array - items: - $ref: "#/components/schemas/AssetImportRequest" - description: Assets to import + $ref: "#/components/schemas/ImportPublishedAssetsRequest" responses: "200": - description: Import initiated + description: Successfully imported assets content: application/json: schema: - type: object - properties: - assets: - type: array - items: - $ref: "#/components/schemas/Asset" + $ref: "#/components/schemas/ImportPublishedAssetsResponse" "400": description: Bad request content: @@ -7379,24 +7367,35 @@ components: type: string description: Target path on the runtime filesystem - AssetImportRequest: + ImportPublishedAssetsRequest: type: object x-runtime: [cloud] - description: "[cloud-only] A single asset to import from an external URL." + description: "[cloud-only] Request body for importing published assets into the caller's library." required: - - url + - published_asset_ids properties: - url: - type: string - format: uri - description: URL of the asset to import - name: - type: string - description: Display name for the imported asset - tags: + published_asset_ids: type: array + description: IDs of published assets (inputs and models) to import. items: type: string + share_id: + type: string + nullable: true + description: | + Optional. Share ID of the published workflow these assets belong to. When provided (non-null, non-empty): all `published_asset_ids` must belong to this share's workflow version; returns 400 if the share is not found or any asset does not belong to it. When omitted, null, or empty string: no share-scoped validation is performed and the assets are validated only against global rules (preserved for clients that have not yet adopted `share_id`). + + ImportPublishedAssetsResponse: + type: object + x-runtime: [cloud] + description: "[cloud-only] Response after importing published assets. Each returned `AssetInfo.id` is the caller's newly-created private asset ID, not the published asset ID supplied in the request." + required: + - assets + properties: + assets: + type: array + items: + $ref: "#/components/schemas/AssetInfo" RemoteAssetMetadata: type: object From 1668aaf0378db1fe8ddd2c0572e7312a9ebbdd41 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Wed, 20 May 2026 21:32:08 -0700 Subject: [PATCH 13/45] openapi: remove cloud-only job_ids query param from GET /api/assets (#14016) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The job_ids query parameter on GET /api/assets is tagged x-runtime: [cloud] and only exists for cloud's variant of this endpoint. Cloud removed all consumers and the cloud-side handler/codegen/tests in Comfy-Org/cloud#3778. With cloud no longer accepting this parameter, the [cloud-only] documentation here is wrong — drop it so the daily sync to cloud/services/ingest/vendor/openapi.yaml propagates the removal. --- openapi.yaml | 6 ------ 1 file changed, 6 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index 0e7a9b4a7..885231acc 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -1556,12 +1556,6 @@ paths: type: string enum: [asc, desc] description: Sort direction - - name: job_ids - in: query - schema: - type: string - x-runtime: [cloud] - description: "[cloud-only] Comma-separated UUIDs to filter assets by associated job." - name: include_public in: query schema: From 7b7c5fed7ce978b05da27b13e26ef340d284b60e Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Thu, 21 May 2026 14:39:30 +0800 Subject: [PATCH 14/45] Update MediaPipe nodes to standardize with existing code base (CORE-242) (#14025) --- comfy_extras/nodes_mediapipe.py | 35 +++++++++++-------- folder_paths.py | 2 +- .../put_detection_models_here} | 0 3 files changed, 22 insertions(+), 15 deletions(-) rename models/{mediapipe/put_mediapipe_models_here => detection/put_detection_models_here} (100%) diff --git a/comfy_extras/nodes_mediapipe.py b/comfy_extras/nodes_mediapipe.py index 2e67ae83f..6b7916aee 100644 --- a/comfy_extras/nodes_mediapipe.py +++ b/comfy_extras/nodes_mediapipe.py @@ -28,7 +28,7 @@ from comfy_extras.mediapipe.face_landmarker import FaceLandmarker from comfy_extras.mediapipe.face_geometry import transformation_matrix_from_detection -FaceLandmarkerType = io.Custom("FACE_LANDMARKER") +FaceDetectionType = io.Custom("FACE_DETECTION_MODEL") FaceLandmarksType = io.Custom("FACE_LANDMARKS") _CANONICAL_KEYS = ("canonical_vertices", "procrustes_indices", "procrustes_weights") @@ -204,18 +204,19 @@ class LoadMediaPipeFaceLandmarker(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LoadMediaPipeFaceLandmarker", - display_name="Load MediaPipe Face Landmarker", + search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"], + display_name="Load Face Detection Model (MediaPipe)", category="loaders", inputs=[ - io.Combo.Input("model_name", options=folder_paths.get_filename_list("mediapipe"), - tooltip="Face Landmarker safetensors from models/mediapipe/."), + io.Combo.Input("model_name", options=folder_paths.get_filename_list("detection"), + tooltip="Face detection model from models/detection/."), ], - outputs=[FaceLandmarkerType.Output()], + outputs=[FaceDetectionType.Output()], ) @classmethod def execute(cls, model_name) -> io.NodeOutput: - sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("mediapipe", model_name), safe_load=True) + sd = comfy.utils.load_torch_file(folder_paths.get_full_path_or_raise("detection", model_name), safe_load=True) wrapper = FaceLandmarkerModel(sd) return io.NodeOutput(wrapper) @@ -234,10 +235,12 @@ class MediaPipeFaceLandmarker(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="MediaPipeFaceLandmarker", - display_name="MediaPipe Face Landmarker", + search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"], + display_name="Detect Face Landmarks (MediaPipe)", category="image/detection", + description="Detects facial landmarks using MediaPipe model.", inputs=[ - FaceLandmarkerType.Input("face_landmarker"), + FaceDetectionType.Input("face_detection_model"), io.Image.Input("image"), io.Combo.Input("detector_variant", options=["short", "full", "both"], default="short", tooltip="Face detector range. 'short' is tuned for close-up faces " @@ -261,9 +264,9 @@ class MediaPipeFaceLandmarker(io.ComfyNode): ) @classmethod - def execute(cls, face_landmarker, image, detector_variant, num_faces, min_confidence, + def execute(cls, face_detection_model, image, detector_variant, num_faces, min_confidence, missing_frame_fallback) -> io.NodeOutput: - canonical = face_landmarker.canonical_data + canonical = face_detection_model.canonical_data img_np = _image_to_uint8(image) B, H, W = img_np.shape[:3] chunk = 16 @@ -276,7 +279,7 @@ class MediaPipeFaceLandmarker(io.ComfyNode): with tqdm(total=B, desc=f"MediaPipe Face Landmarker ({variant})") as tq: for i in range(0, B, chunk): end = min(i + chunk, B) - res.extend(face_landmarker.detect_batch( + res.extend(face_detection_model.detect_batch( [img_np[bi] for bi in range(i, end)], num_faces=int(num_faces), score_thresh=float(min_confidence), @@ -306,7 +309,7 @@ class MediaPipeFaceLandmarker(io.ComfyNode): per_bb.append({"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1, "label": "face", "score": float(f["score"])}) bboxes.append(per_bb) return io.NodeOutput({"frames": frames, "image_size": (H, W), - "connection_sets": face_landmarker.connection_sets}, bboxes) + "connection_sets": face_detection_model.connection_sets}, bboxes) # Topology keys unioned by the 'all' connections preset (contour parts + irises + nose). @@ -332,8 +335,10 @@ class MediaPipeFaceMeshVisualize(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="MediaPipeFaceMeshVisualize", - display_name="MediaPipe Face Mesh Visualize", + search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection", "visualize"], + display_name="Visualize Face Landmarks (MediaPipe)", category="image/detection", + description="Draws face landmarks mesh on the input image.", inputs=[ FaceLandmarksType.Input("face_landmarks"), io.Image.Input("image", optional=True, tooltip="If not connected, a black canvas will be used."), @@ -443,8 +448,10 @@ class MediaPipeFaceMask(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="MediaPipeFaceMask", - display_name="MediaPipe Face Mask", + search_aliases=["face", "facial", "mediapipe", "face mask", "blazeface", "face detection", "visualize"], + display_name="Draw Face Mask (MediaPipe)", category="image/detection", + description="Draws a mask from face landmarks.", inputs=[ FaceLandmarksType.Input("face_landmarks"), io.DynamicCombo.Input( diff --git a/folder_paths.py b/folder_paths.py index ce152eb37..36d61fcd0 100644 --- a/folder_paths.py +++ b/folder_paths.py @@ -60,7 +60,7 @@ folder_names_and_paths["geometry_estimation"] = ([os.path.join(models_dir, "geom folder_names_and_paths["optical_flow"] = ([os.path.join(models_dir, "optical_flow")], supported_pt_extensions) -folder_names_and_paths["mediapipe"] = ([os.path.join(models_dir, "mediapipe")], supported_pt_extensions) +folder_names_and_paths["detection"] = ([os.path.join(models_dir, "detection")], supported_pt_extensions) output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") diff --git a/models/mediapipe/put_mediapipe_models_here b/models/detection/put_detection_models_here similarity index 100% rename from models/mediapipe/put_mediapipe_models_here rename to models/detection/put_detection_models_here From af3d9b60afddbe6f7c82e31ee688f7f5c9af39d0 Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Thu, 21 May 2026 15:14:16 +0800 Subject: [PATCH 15/45] chore: Dataset nodes clean-up (CORE-237) (#14002) --- comfy_extras/nodes_audio.py | 7 +- comfy_extras/nodes_dataset.py | 188 ++++++++++++++++++++++---------- comfy_extras/nodes_hunyuan3d.py | 9 +- comfy_extras/nodes_images.py | 3 +- comfy_extras/nodes_lt_audio.py | 8 +- 5 files changed, 145 insertions(+), 70 deletions(-) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 2d6b3c7ea..d5084497e 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -543,7 +543,7 @@ class AudioConcat(IO.ComfyNode): return IO.Schema( node_id="AudioConcat", search_aliases=["join audio", "combine audio", "append audio"], - display_name="Audio Concat", + display_name="Concatenate Audio", description="Concatenates the audio1 to audio2 in the specified direction.", category="audio", inputs=[ @@ -597,7 +597,7 @@ class AudioMerge(IO.ComfyNode): return IO.Schema( node_id="AudioMerge", search_aliases=["mix audio", "overlay audio", "layer audio"], - display_name="Audio Merge", + display_name="Merge Audio", description="Combine two audio tracks by overlaying their waveforms.", category="audio", inputs=[ @@ -667,8 +667,9 @@ class AudioAdjustVolume(IO.ComfyNode): return IO.Schema( node_id="AudioAdjustVolume", search_aliases=["audio gain", "loudness", "audio level"], - display_name="Audio Adjust Volume", + display_name="Adjust Audio Volume", category="audio", + description="Adjust the volume of the audio by a specified amount in decibels (dB).", inputs=[ IO.Audio.Input("audio"), IO.Int.Input( diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 98ed25d7e..22f5ff203 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -47,8 +47,10 @@ class LoadImageDataSetFromFolderNode(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LoadImageDataSetFromFolder", - display_name="Load Image Dataset from Folder", - category="dataset", + search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"], + display_name="Load Image (from Folder)", + category="image", + description="Load a dataset of images from a specified folder and return a list of images. Supported formats: PNG, JPG, JPEG, WEBP.", is_experimental=True, inputs=[ io.Combo.Input( @@ -84,14 +86,16 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LoadImageTextDataSetFromFolder", - display_name="Load Image and Text Dataset from Folder", - category="dataset", + search_aliases=["load folder", "load from folder", "load dataset", "load images", "import dataset"], + display_name="Load Image-Text (from Folder)", + category="image", + description="Load a dataset of pairs of images and text captions from a specified folder and return them as a list. Supported formats: PNG, JPG, JPEG, WEBP.", is_experimental=True, inputs=[ io.Combo.Input( "folder", options=folder_paths.get_input_subfolders(), - tooltip="The folder to load images from.", + tooltip="The folder to load images and text captions from.", ) ], outputs=[ @@ -206,8 +210,10 @@ class SaveImageDataSetToFolderNode(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SaveImageDataSetToFolder", - display_name="Save Image Dataset to Folder", - category="dataset", + search_aliases=["save folder", "save to folder", "save dataset", "save images", "export dataset"], + display_name="Save Image (to Folder) (DEPRECATED)", + category="image", + description="Save a dataset of images to a specified folder. Supported formats: PNG.", is_experimental=True, is_output_node=True, is_input_list=True, # Receive images as list @@ -226,6 +232,7 @@ class SaveImageDataSetToFolderNode(io.ComfyNode): ), ], outputs=[], + is_deprecated=True, # This node is redundant and superseded by existing Save Image nodes where the target folder can be specified in the filename_prefix ) @classmethod @@ -246,14 +253,20 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SaveImageTextDataSetToFolder", - display_name="Save Image and Text Dataset to Folder", - category="dataset", + search_aliases=["save folder", "save to folder", "save dataset", "save images", "save text", "export dataset"], + display_name="Save Image-Text (to Folder)", + category="image", + description="Save a dataset of pairs of images and text captions to a specified folder. Images are saved as PNG files and captions are saved as TXT files with the same filename_prefix.", is_experimental=True, is_output_node=True, is_input_list=True, # Receive both images and texts as lists inputs=[ io.Image.Input("images", tooltip="List of images to save."), - io.String.Input("texts", tooltip="List of text captions to save."), + io.String.Input("texts", + optional=True, + force_input=True, + tooltip="List of text captions to save." + ), io.String.Input( "folder_name", default="dataset", @@ -270,7 +283,7 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode): ) @classmethod - def execute(cls, images, texts, folder_name, filename_prefix): + def execute(cls, images, folder_name, filename_prefix, texts=None): # Extract scalar values folder_name = folder_name[0] filename_prefix = filename_prefix[0] @@ -279,11 +292,12 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode): saved_files = save_images_to_folder(images, output_dir, filename_prefix) # Save captions - for idx, (filename, caption) in enumerate(zip(saved_files, texts)): - caption_filename = filename.replace(".png", ".txt") - caption_path = os.path.join(output_dir, caption_filename) - with open(caption_path, "w", encoding="utf-8") as f: - f.write(caption) + if texts: + for idx, (filename, caption) in enumerate(zip(saved_files, texts)): + caption_filename = filename.replace(".png", ".txt") + caption_path = os.path.join(output_dir, caption_filename) + with open(caption_path, "w", encoding="utf-8") as f: + f.write(caption) logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.") return io.NodeOutput() @@ -314,11 +328,13 @@ class ImageProcessingNode(io.ComfyNode): Child classes should set: node_id: Unique node identifier (required) + search_aliases: List of search aliases (optional) display_name: Display name (optional, defaults to node_id) description: Node description (optional) extra_inputs: List of additional io.Input objects beyond "images" (optional) is_group_process: None (auto-detect), True (group), or False (individual) (optional) is_output_list: True (list output) or False (single output) (optional, default True) + is_deprecated: True if the node is deprecated (optional, default False) Child classes must implement ONE of: _process(cls, image, **kwargs) -> tensor (for single-item processing) @@ -326,12 +342,13 @@ class ImageProcessingNode(io.ComfyNode): """ node_id = None + search_aliases = [] display_name = None description = None extra_inputs = [] is_group_process = None # None = auto-detect, True/False = explicit is_output_list = None # None = auto-detect based on processing mode - + is_deprecated = False @classmethod def _detect_processing_mode(cls): """Detect whether this node uses group or individual processing. @@ -402,8 +419,10 @@ class ImageProcessingNode(io.ComfyNode): return io.Schema( node_id=cls.node_id, + search_aliases=cls.search_aliases, display_name=cls.display_name or cls.node_id, - category="dataset/image", + category=cls.category, + description=cls.description, is_experimental=True, is_input_list=is_group, # True for group, False for individual inputs=inputs, @@ -472,11 +491,13 @@ class TextProcessingNode(io.ComfyNode): Child classes should set: node_id: Unique node identifier (required) + search_aliases: List of search aliases (optional) display_name: Display name (optional, defaults to node_id) description: Node description (optional) extra_inputs: List of additional io.Input objects beyond "texts" (optional) is_group_process: None (auto-detect), True (group), or False (individual) (optional) is_output_list: True (list output) or False (single output) (optional, default True) + is_deprecated: True if the node is deprecated (optional, default False) Child classes must implement ONE of: _process(cls, text, **kwargs) -> str (for single-item processing) @@ -484,12 +505,13 @@ class TextProcessingNode(io.ComfyNode): """ node_id = None + search_aliases = [] display_name = None description = None extra_inputs = [] is_group_process = None # None = auto-detect, True/False = explicit is_output_list = None # None = auto-detect based on processing mode - + is_deprecated = False @classmethod def _detect_processing_mode(cls): """Detect whether this node uses group or individual processing. @@ -627,15 +649,17 @@ class TextProcessingNode(io.ComfyNode): class ResizeImagesByShorterEdgeNode(ImageProcessingNode): node_id = "ResizeImagesByShorterEdge" - display_name = "Resize Images by Shorter Edge" - description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio." + display_name = "Resize Images by Shorter Edge (DEPRECATED)" + category = "image/transform" + description = "Resize images so that the shorter edge matches the specified dimension while preserving aspect ratio." + is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale shorter dimension extra_inputs = [ io.Int.Input( "shorter_edge", default=512, min=1, max=8192, - tooltip="Target length for the shorter edge.", + tooltip="Target dimension for the shorter edge.", ), ] @@ -655,15 +679,17 @@ class ResizeImagesByShorterEdgeNode(ImageProcessingNode): class ResizeImagesByLongerEdgeNode(ImageProcessingNode): node_id = "ResizeImagesByLongerEdge" - display_name = "Resize Images by Longer Edge" - description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio." + display_name = "Resize Images by Longer Edge (DEPRECATED)" + category = "image/transform" + description = "Resize images so that the longer edge matches the specified dimension while preserving aspect ratio." + is_deprecated = True # This node is superseded by Resize Image/Mask with resize_type = scale longer dimension extra_inputs = [ io.Int.Input( "longer_edge", default=1024, min=1, max=8192, - tooltip="Target length for the longer edge.", + tooltip="Target dimension for the longer edge.", ), ] @@ -686,8 +712,10 @@ class ResizeImagesByLongerEdgeNode(ImageProcessingNode): class CenterCropImagesNode(ImageProcessingNode): node_id = "CenterCropImages" - display_name = "Center Crop Images" - description = "Center crop all images to the specified dimensions." + search_aliases=["crop", "cut", "trim"] + display_name="Crop Image (Center)" + category="image/transform" + description = "Center crop an image to the specified dimensions." extra_inputs = [ io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), @@ -706,10 +734,11 @@ class CenterCropImagesNode(ImageProcessingNode): class RandomCropImagesNode(ImageProcessingNode): node_id = "RandomCropImages" - display_name = "Random Crop Images" - description = ( - "Randomly crop all images to the specified dimensions (for data augmentation)." - ) + search_aliases=["crop", "cut", "trim"] + display_name = "Crop Image (Random)" + category="image/transform" + description = "Randomly crop an image to the specified dimensions." + extra_inputs = [ io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), @@ -734,7 +763,9 @@ class RandomCropImagesNode(ImageProcessingNode): class NormalizeImagesNode(ImageProcessingNode): node_id = "NormalizeImages" - display_name = "Normalize Images" + search_aliases=["normalize", "normalize colors"] + display_name = "Normalize Image Colors" + category = "image/color" description = "Normalize images using mean and standard deviation." extra_inputs = [ io.Float.Input( @@ -762,8 +793,10 @@ class NormalizeImagesNode(ImageProcessingNode): class AdjustBrightnessNode(ImageProcessingNode): node_id = "AdjustBrightness" + search_aliases=["brightness"] display_name = "Adjust Brightness" - description = "Adjust brightness of all images." + category="image/adjustments" + description = "Adjust the brightness of an image." extra_inputs = [ io.Float.Input( "factor", @@ -781,8 +814,10 @@ class AdjustBrightnessNode(ImageProcessingNode): class AdjustContrastNode(ImageProcessingNode): node_id = "AdjustContrast" + search_aliases=["contrast"] display_name = "Adjust Contrast" - description = "Adjust contrast of all images." + category="image/adjustments" + description = "Adjust the contrast of an image." extra_inputs = [ io.Float.Input( "factor", @@ -800,8 +835,10 @@ class AdjustContrastNode(ImageProcessingNode): class ShuffleDatasetNode(ImageProcessingNode): node_id = "ShuffleDataset" - display_name = "Shuffle Image Dataset" - description = "Randomly shuffle the order of images in the dataset." + search_aliases=["shuffle", "randomize", "mix"] + display_name = "Shuffle Images List" + category = "image/batch" + description = "Randomly shuffle the order of images in a list." is_group_process = True # Requires full list to shuffle extra_inputs = [ io.Int.Input( @@ -823,13 +860,15 @@ class ShuffleImageTextDatasetNode(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ShuffleImageTextDataset", - display_name="Shuffle Image-Text Dataset", - category="dataset/image", + search_aliases=["shuffle", "randomize", "mix"], + display_name = "Shuffle Pairs of Image-Text", + category = "image/batch", + description = "Randomly shuffle the order of pairs of image-text in a list.", is_experimental=True, is_input_list=True, inputs=[ io.Image.Input("images", tooltip="List of images to shuffle."), - io.String.Input("texts", tooltip="List of texts to shuffle."), + io.String.Input("texts", tooltip="List of texts to shuffle.", force_input=True), io.Int.Input( "seed", default=0, @@ -865,8 +904,11 @@ class ShuffleImageTextDatasetNode(io.ComfyNode): class TextToLowercaseNode(TextProcessingNode): node_id = "TextToLowercase" - display_name = "Text to Lowercase" - description = "Convert all texts to lowercase." + search_aliases=["lowercase"] + display_name = "Convert Text to Lowercase (DEPRECATED)" + category = "text" + description = "Convert text to lowercase." + is_deprecated = True # This node is superseded by the Convert Text Case node @classmethod def _process(cls, text): @@ -875,8 +917,11 @@ class TextToLowercaseNode(TextProcessingNode): class TextToUppercaseNode(TextProcessingNode): node_id = "TextToUppercase" - display_name = "Text to Uppercase" - description = "Convert all texts to uppercase." + search_aliases=["uppercase"] + display_name = "Convert Text to Uppercase (DEPRECATED)" + category = "text" + description = "Convert text to uppercase." + is_deprecated = True # This node is superseded by the Convert Text Case node @classmethod def _process(cls, text): @@ -885,8 +930,10 @@ class TextToUppercaseNode(TextProcessingNode): class TruncateTextNode(TextProcessingNode): node_id = "TruncateText" + search_aliases=["truncate", "cut", "shorten"] display_name = "Truncate Text" - description = "Truncate all texts to a maximum length." + category = "text" + description = "Truncate text to a maximum length." extra_inputs = [ io.Int.Input( "max_length", default=77, min=1, max=10000, tooltip="Maximum text length." @@ -900,8 +947,10 @@ class TruncateTextNode(TextProcessingNode): class AddTextPrefixNode(TextProcessingNode): node_id = "AddTextPrefix" - display_name = "Add Text Prefix" + display_name = "Add Text Prefix (DEPRECATED)" + category = "text" description = "Add a prefix to all texts." + is_deprecated = True # This node is superseded by the Concatenate Text node extra_inputs = [ io.String.Input("prefix", default="", tooltip="Prefix to add."), ] @@ -913,8 +962,10 @@ class AddTextPrefixNode(TextProcessingNode): class AddTextSuffixNode(TextProcessingNode): node_id = "AddTextSuffix" - display_name = "Add Text Suffix" + display_name = "Add Text Suffix (DEPRECATED)" + category = "text" description = "Add a suffix to all texts." + is_deprecated = True # This node is superseded by the Concatenate Text node extra_inputs = [ io.String.Input("suffix", default="", tooltip="Suffix to add."), ] @@ -926,8 +977,10 @@ class AddTextSuffixNode(TextProcessingNode): class ReplaceTextNode(TextProcessingNode): node_id = "ReplaceText" - display_name = "Replace Text" + display_name = "Replace Text (DEPRECATED)" + category = "text" description = "Replace text in all texts." + is_deprecated = True # This node is superseded by the other Replace Text node extra_inputs = [ io.String.Input("find", default="", tooltip="Text to find."), io.String.Input("replace", default="", tooltip="Text to replace with."), @@ -940,8 +993,10 @@ class ReplaceTextNode(TextProcessingNode): class StripWhitespaceNode(TextProcessingNode): node_id = "StripWhitespace" - display_name = "Strip Whitespace" + display_name = "Strip Whitespace (DEPRECATED)" + category = "text" description = "Strip leading and trailing whitespace from all texts." + is_deprecated = True # This node is superseded by the Trim Text node @classmethod def _process(cls, text): @@ -952,11 +1007,13 @@ class StripWhitespaceNode(TextProcessingNode): class ImageDeduplicationNode(ImageProcessingNode): - """Remove duplicate or very similar images from the dataset using perceptual hashing.""" + """Remove duplicate or very similar images from a list using perceptual hashing.""" node_id = "ImageDeduplication" - display_name = "Image Deduplication" - description = "Remove duplicate or very similar images from the dataset." + search_aliases=["deduplicate", "remove duplicates", "similarity filter"] + display_name = "Deduplicate Images" + category = "image/batch" + description = "Remove duplicate or very similar images from a list." is_group_process = True # Requires full list to compare images extra_inputs = [ io.Float.Input( @@ -1026,7 +1083,9 @@ class ImageGridNode(ImageProcessingNode): """Combine multiple images into a single grid/collage.""" node_id = "ImageGrid" - display_name = "Image Grid" + search_aliases=["grid", "collage", "combine"] + display_name = "Make Image Grid" + category="image/batch" description = "Arrange multiple images into a grid layout." is_group_process = True # Requires full list to create grid is_output_list = False # Outputs single grid image @@ -1102,9 +1161,12 @@ class MergeImageListsNode(ImageProcessingNode): """Merge multiple image lists into a single list.""" node_id = "MergeImageLists" - display_name = "Merge Image Lists" + search_aliases=["list", "merge list", "make list"] + display_name = "Merge Image Lists (DEPRECATED)" + category = "image/batch" description = "Concatenate multiple image lists into one." is_group_process = True # Receives images as list + is_deprecated = True # This node is superseded by the Create List node @classmethod def _group_process(cls, images): @@ -1119,9 +1181,11 @@ class MergeTextListsNode(TextProcessingNode): """Merge multiple text lists into a single list.""" node_id = "MergeTextLists" - display_name = "Merge Text Lists" + display_name = "Merge Text Lists (DEPRECATED)" + category = "text" description = "Concatenate multiple text lists into one." is_group_process = True # Receives texts as list + is_deprecated = True # This node is superseded by the Create List node @classmethod def _group_process(cls, texts): @@ -1142,8 +1206,10 @@ class ResolutionBucket(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ResolutionBucket", + search_aliases=["bucket by resolution", "group by resolution", "batch by resolution"], display_name="Resolution Bucket", - category="dataset", + category="training", + description="Group latents and conditionings into buckets", is_experimental=True, is_input_list=True, inputs=[ @@ -1236,7 +1302,8 @@ class MakeTrainingDataset(io.ComfyNode): node_id="MakeTrainingDataset", search_aliases=["encode dataset"], display_name="Make Training Dataset", - category="dataset", + category="training", + description="Encode images with VAE and texts with CLIP to create a training dataset of latents and conditionings.", is_experimental=True, is_input_list=True, # images and texts as lists inputs=[ @@ -1251,6 +1318,7 @@ class MakeTrainingDataset(io.ComfyNode): "texts", optional=True, tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).", + force_input=True ), ], outputs=[ @@ -1320,9 +1388,10 @@ class SaveTrainingDataset(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SaveTrainingDataset", - search_aliases=["export training data"], + search_aliases=["export dataset", "save dataset"], display_name="Save Training Dataset", - category="dataset", + category="training", + description="Save encoded training dataset (latents + conditioning) to disk for efficient loading during training.", is_experimental=True, is_output_node=True, is_input_list=True, # Receive lists @@ -1424,7 +1493,8 @@ class LoadTrainingDataset(io.ComfyNode): node_id="LoadTrainingDataset", search_aliases=["import dataset", "training data"], display_name="Load Training Dataset", - category="dataset", + category="training", + description="Load encoded training dataset (latents + conditioning) from disk for use in training.", is_experimental=True, inputs=[ io.String.Input( diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index 403eb855b..bcd3f9198 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -419,15 +419,17 @@ class VoxelToMeshBasic(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="VoxelToMeshBasic", - display_name="Voxel to Mesh (Basic)", + display_name="Voxel to Mesh (Basic) (DEPRECATED)", category="3d", + description="Converts a voxel grid to a mesh.", + is_deprecated=True, # This node is superseded by the Voxel To Mesh node inputs=[ IO.Voxel.Input("voxel"), IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01), ], outputs=[ IO.Mesh.Output(), - ] + ], ) @classmethod @@ -453,9 +455,10 @@ class VoxelToMesh(IO.ComfyNode): node_id="VoxelToMesh", display_name="Voxel to Mesh", category="3d", + description="Converts a voxel grid to a mesh.", inputs=[ IO.Voxel.Input("voxel"), - IO.Combo.Input("algorithm", options=["surface net", "basic"], advanced=True), + IO.Combo.Input("algorithm", options=["surface net", "basic"]), IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01), ], outputs=[ diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 6326c5be8..4856346d7 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -55,9 +55,10 @@ class ImageCropV2(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ImageCropV2", - search_aliases=["trim"], + search_aliases=["crop", "cut", "trim"], display_name="Crop Image", category="image/transform", + description = "Crop an image to the specified dimensions.", essentials_category="Image Tools", has_intermediate_output=True, inputs=[ diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index 2c1f63afb..51ddf584a 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -11,8 +11,8 @@ class LTXVAudioVAELoader(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="LTXVAudioVAELoader", - display_name="LTXV Audio VAE Loader", - category="audio", + display_name="Load LTXV Audio VAE", + category="loaders", inputs=[ io.Combo.Input( "ckpt_name", @@ -40,7 +40,7 @@ class LTXVAudioVAEEncode(VAEEncodeAudio): return io.Schema( node_id="LTXVAudioVAEEncode", display_name="LTXV Audio VAE Encode", - category="audio", + category="latent/audio", inputs=[ io.Audio.Input("audio", tooltip="The audio to be encoded."), io.Vae.Input( @@ -63,7 +63,7 @@ class LTXVAudioVAEDecode(io.ComfyNode): return io.Schema( node_id="LTXVAudioVAEDecode", display_name="LTXV Audio VAE Decode", - category="audio", + category="latent/audio", inputs=[ io.Latent.Input("samples", tooltip="The latent to be decoded."), io.Vae.Input( From 4259a0c7c3b805e3dd1f178e603e6d725780583a Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Thu, 21 May 2026 16:50:09 +0800 Subject: [PATCH 16/45] Update MoGe nodes display names, search aliases and descriptions (#14030) --- comfy_extras/nodes_moge.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/comfy_extras/nodes_moge.py b/comfy_extras/nodes_moge.py index d9a08ebc7..3508781a0 100644 --- a/comfy_extras/nodes_moge.py +++ b/comfy_extras/nodes_moge.py @@ -103,8 +103,10 @@ class MoGePanoramaInference(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="MoGePanoramaInference", - display_name="MoGe Panorama Inference", + search_aliases=["moge", "panorama", "depth", "geometry", "depth estimation", "geometry estimation"], + display_name="Run MoGe Panorama Inference", category="image/geometry_estimation", + description="Run MoGe on an equirectangular panorama by splitting it into 12 perspective views, running inference on each, and merging the results into a single depth map.", inputs=[ MoGeModelType.Input("moge_model"), io.Image.Input("image", tooltip="Equirectangular panorama (any aspect)."), @@ -222,7 +224,9 @@ class MoGeInference(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="MoGeInference", - display_name="MoGe Inference", + search_aliases=["moge", "depth", "geometry", "depth estimation", "geometry estimation"], + display_name="Run MoGe Inference", + description="Run MoGe on a single image to estimate depth and geometry.", category="image/geometry_estimation", inputs=[ MoGeModelType.Input("moge_model"), @@ -277,7 +281,9 @@ class MoGeRender(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="MoGeRender", - display_name="MoGe Render", + search_aliases=["moge", "render", "geometry", "depth", "normal"], + display_name="Render MoGe Geometry", + description="Render a depth map or normal map from geometry data", category="image/geometry_estimation", inputs=[ MoGeGeometry.Input("moge_geometry"), @@ -342,7 +348,9 @@ class MoGePointMapToMesh(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="MoGePointMapToMesh", - display_name="MoGe Point Map to Mesh", + search_aliases=["moge", "mesh", "geometry", "point map"], + display_name="Convert MoGe Point Map to Mesh", + description="Convert a MoGe point map into a 3D mesh.", category="image/geometry_estimation", inputs=[ MoGeGeometry.Input("moge_geometry"), From aab41a9ddb3cb586024a75141fcc2f5e838da12c Mon Sep 17 00:00:00 2001 From: Edoardo Carmignani Date: Thu, 21 May 2026 17:47:20 +0200 Subject: [PATCH 17/45] fix(lanczos): correct dimension transposition for single-channel tensors (#12679) --- comfy/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/comfy/utils.py b/comfy/utils.py index 00e382fac..31052714a 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1019,10 +1019,11 @@ def bislerp(samples, width, height): def lanczos(samples, width, height): #the below API is strict and expects grayscale to be squeezed - samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1) + if samples.ndim == 4: + samples = samples.squeeze(1) if samples.shape[1] == 1 else samples.movedim(1, -1) images = [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in samples] images = [image.resize((width, height), resample=Image.Resampling.LANCZOS) for image in images] - images = [torch.from_numpy(np.array(image).astype(np.float32) / 255.0).movedim(-1, 0) for image in images] + images = [torch.from_numpy(t).movedim(-1, 0) if (t := np.array(image).astype(np.float32) / 255.0).ndim == 3 else torch.from_numpy(t) for image in images] result = torch.stack(images) return result.to(samples.device, samples.dtype) From 03e511862ee783fec84ef14fe306ee30d4240e2c Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 22 May 2026 02:47:16 +1000 Subject: [PATCH 18/45] Fix reshaping lora application (#14031) * ModelPatcherDyanmic: purge stale vbar allocs on force cast * ModelPatcherDynamic: restore backups before load If doing a clean reload, mutative changes (lora application) could be applied on-top of the already loaded weight. Restore from backup unconditionally so that the new load is clean. --- comfy/model_patcher.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c8ed02e70..b44b99e4a 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1613,6 +1613,16 @@ class ModelPatcherDynamic(ModelPatcher): #use all ModelPatcherDynamic this is ignored and its all done dynamically. return super().memory_required(input_shape=input_shape) * 1.3 + (1024 ** 3) + def restore_loaded_backups(self): + restored = self.model.model_loaded_weight_memory + for key in list(self.backup.keys()): + bk = self.backup.pop(key) + comfy.utils.set_attr_param(self.model, key, bk.weight) + for key in list(self.backup_buffers.keys()): + comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key)) + self.model.model_loaded_weight_memory = 0 + return restored + def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False, dirty=False): @@ -1629,7 +1639,7 @@ class ModelPatcherDynamic(ModelPatcher): num_patches = 0 allocated_size = 0 - self.model.model_loaded_weight_memory = 0 + self.restore_loaded_backups() with self.use_ejected(): self.unpatch_hooks() @@ -1716,6 +1726,9 @@ class ModelPatcherDynamic(ModelPatcher): force_load=True if force_load: + if hasattr(m, "_v"): + comfy_aimdo.model_vbar.vbar_unpin(m._v) + delattr(m, "_v") force_load_param(self, "weight", device_to) force_load_param(self, "bias", device_to) else: @@ -1773,13 +1786,7 @@ class ModelPatcherDynamic(ModelPatcher): freed = 0 if vbar is None else vbar.free_memory(memory_to_free) if freed < memory_to_free: - for key in list(self.backup.keys()): - bk = self.backup.pop(key) - comfy.utils.set_attr_param(self.model, key, bk.weight) - for key in list(self.backup_buffers.keys()): - comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key)) - freed += self.model.model_loaded_weight_memory - self.model.model_loaded_weight_memory = 0 + freed += self.restore_loaded_backups() return freed From 6ecf5eca7ac6e5a78af96650c2da33ab8c44bb40 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 21 May 2026 21:36:11 +0300 Subject: [PATCH 19/45] [Partner Nodes] add OpenRouter LLM node (#14007) * [Partner Nodes] add reasoning widget to Anthropic node Signed-off-by: bigcat88 * [Partner Nodes] add new OpenRouterLLM node Signed-off-by: bigcat88 * [Partner Nodes] fix passing images to Grok LLM Signed-off-by: bigcat88 --------- Signed-off-by: bigcat88 --- comfy_api_nodes/apis/anthropic.py | 25 +- comfy_api_nodes/apis/openrouter.py | 93 +++++++ comfy_api_nodes/nodes_anthropic.py | 83 +++++- comfy_api_nodes/nodes_openrouter.py | 374 ++++++++++++++++++++++++++++ 4 files changed, 563 insertions(+), 12 deletions(-) create mode 100644 comfy_api_nodes/apis/openrouter.py create mode 100644 comfy_api_nodes/nodes_openrouter.py diff --git a/comfy_api_nodes/apis/anthropic.py b/comfy_api_nodes/apis/anthropic.py index 6cac537ea..46a5bb428 100644 --- a/comfy_api_nodes/apis/anthropic.py +++ b/comfy_api_nodes/apis/anthropic.py @@ -35,6 +35,19 @@ class AnthropicMessage(BaseModel): content: list[AnthropicTextContent | AnthropicImageContent] = Field(...) +class AnthropicThinkingConfig(BaseModel): + type: Literal["enabled", "disabled", "adaptive"] = Field(...) + budget_tokens: int | None = Field( + None, ge=1024, + description="Reasoning budget in tokens. Used when type is 'enabled'. Must be less than max_tokens.", + ) + + +class AnthropicOutputConfig(BaseModel): + """Used with `thinking.type='adaptive'` on models like Opus 4.7.""" + effort: Literal["low", "medium", "high"] | None = Field(None) + + class AnthropicMessagesRequest(BaseModel): model: str = Field(...) messages: list[AnthropicMessage] = Field(...) @@ -44,6 +57,8 @@ class AnthropicMessagesRequest(BaseModel): top_p: float | None = Field(None, ge=0.0, le=1.0) top_k: int | None = Field(None, ge=0) stop_sequences: list[str] | None = Field(None) + thinking: AnthropicThinkingConfig | None = Field(None) + output_config: AnthropicOutputConfig | None = Field(None) class AnthropicResponseTextBlock(BaseModel): @@ -51,6 +66,14 @@ class AnthropicResponseTextBlock(BaseModel): text: str = Field(...) +class AnthropicResponseThinkingBlock(BaseModel): + type: Literal["thinking"] = "thinking" + thinking: str = Field(...) + + +AnthropicResponseBlock = AnthropicResponseTextBlock | AnthropicResponseThinkingBlock + + class AnthropicCacheCreationUsage(BaseModel): ephemeral_5m_input_tokens: int | None = Field(None) ephemeral_1h_input_tokens: int | None = Field(None) @@ -69,7 +92,7 @@ class AnthropicMessagesResponse(BaseModel): type: str | None = Field(None) role: str | None = Field(None) model: str | None = Field(None) - content: list[AnthropicResponseTextBlock] | None = Field(None) + content: list[AnthropicResponseBlock] | None = Field(None) stop_reason: str | None = Field(None) stop_sequence: str | None = Field(None) usage: AnthropicMessagesUsage | None = Field(None) diff --git a/comfy_api_nodes/apis/openrouter.py b/comfy_api_nodes/apis/openrouter.py new file mode 100644 index 000000000..e30d9bcfb --- /dev/null +++ b/comfy_api_nodes/apis/openrouter.py @@ -0,0 +1,93 @@ +"""Pydantic models for the OpenRouter chat completions API. + +See: https://openrouter.ai/docs/api/api-reference/chat/send-chat-completion-request +""" + +from typing import Literal + +from pydantic import BaseModel, Field + + +class OpenRouterTextContent(BaseModel): + type: Literal["text"] = "text" + text: str = Field(...) + + +class OpenRouterImageUrl(BaseModel): + url: str = Field(...) + + +class OpenRouterImageContent(BaseModel): + type: Literal["image_url"] = "image_url" + image_url: OpenRouterImageUrl = Field(...) + + +class OpenRouterVideoUrl(BaseModel): + url: str = Field(...) + + +class OpenRouterVideoContent(BaseModel): + type: Literal["video_url"] = "video_url" + video_url: OpenRouterVideoUrl = Field(...) + + +OpenRouterContentBlock = OpenRouterTextContent | OpenRouterImageContent | OpenRouterVideoContent + + +class OpenRouterMessage(BaseModel): + role: Literal["system", "user", "assistant"] = Field(...) + content: str | list[OpenRouterContentBlock] = Field(...) + + +class OpenRouterReasoningConfig(BaseModel): + effort: str | None = Field(None) + exclude: bool | None = Field(None, description="If true, model reasons but reasoning is excluded from response.") + + +class OpenRouterWebSearchOptions(BaseModel): + search_context_size: str | None = Field(None) + + +class OpenRouterChatRequest(BaseModel): + model: str = Field(...) + messages: list[OpenRouterMessage] = Field(...) + seed: int | None = Field(None) + reasoning: OpenRouterReasoningConfig | None = Field(None) + web_search_options: OpenRouterWebSearchOptions | None = Field(None) + stream: bool = Field(False) + + +class OpenRouterUsage(BaseModel): + prompt_tokens: int | None = Field(None) + completion_tokens: int | None = Field(None) + total_tokens: int | None = Field(None) + cost: float | None = Field(None, description="Server-side authoritative USD cost of the call.") + + +class OpenRouterResponseMessage(BaseModel): + role: str | None = Field(None) + content: str | None = Field(None) + reasoning: str | None = Field(None) + refusal: str | None = Field(None) + + +class OpenRouterChoice(BaseModel): + index: int | None = Field(None) + message: OpenRouterResponseMessage | None = Field(None) + finish_reason: str | None = Field(None) + + +class OpenRouterError(BaseModel): + code: int | str | None = Field(None) + message: str | None = Field(None) + metadata: dict | None = Field(None) + + +class OpenRouterChatResponse(BaseModel): + id: str | None = Field(None) + model: str | None = Field(None) + object: str | None = Field(None) + provider: str | None = Field(None) + choices: list[OpenRouterChoice] | None = Field(None) + usage: OpenRouterUsage | None = Field(None) + error: OpenRouterError | None = Field(None) diff --git a/comfy_api_nodes/nodes_anthropic.py b/comfy_api_nodes/nodes_anthropic.py index 28dd70d4e..42ec5708f 100644 --- a/comfy_api_nodes/nodes_anthropic.py +++ b/comfy_api_nodes/nodes_anthropic.py @@ -9,8 +9,11 @@ from comfy_api_nodes.apis.anthropic import ( AnthropicMessage, AnthropicMessagesRequest, AnthropicMessagesResponse, + AnthropicOutputConfig, + AnthropicResponseTextBlock, AnthropicRole, AnthropicTextContent, + AnthropicThinkingConfig, ) from comfy_api_nodes.util import ( ApiEndpoint, @@ -32,15 +35,29 @@ CLAUDE_MODELS: dict[str, str] = { "Haiku 4.5": "claude-haiku-4-5-20251001", } +_THINKING_UNSUPPORTED = {"Haiku 4.5"} +# Models that use the newer "adaptive" thinking mode (Opus 4.7 requires it; older models keep the explicit budget API). +# Anthropic decides the actual budget when adaptive is used, based on the `output_config.effort` hint. +_ADAPTIVE_THINKING_MODELS = {"Opus 4.7", "Opus 4.6", "Sonnet 4.6"} -def _claude_model_inputs(): - return [ +# Budget mode (Sonnet 4.5): effort -> reasoning budget in tokens. Must be < max_tokens. +# Sized so even the "high" budget fits comfortably under the default max_tokens=32768. +_REASONING_BUDGET: dict[str, int] = { + "low": 2048, + "medium": 8192, + "high": 16384, +} +_REASONING_EFFORTS = ["off", "low", "medium", "high"] + + +def _claude_model_inputs(model_label: str): + inputs: list = [ IO.Int.Input( "max_tokens", - default=16000, - min=32, - max=32000, - tooltip="Maximum number of tokens to generate before stopping.", + default=32768, + min=4096, + max=64000, + tooltip="Maximum number of tokens to generate (includes reasoning tokens when enabled).", advanced=True, ), IO.Float.Input( @@ -49,10 +66,24 @@ def _claude_model_inputs(): min=0.0, max=1.0, step=0.01, - tooltip="Controls randomness. 0.0 is deterministic, 1.0 is most random. Ignored for Opus 4.7.", + tooltip=( + "Controls randomness. 0.0 is deterministic, 1.0 is most random. " + "Ignored for Opus 4.7 and any model when reasoning_effort is set." + ), advanced=True, ), ] + if model_label not in _THINKING_UNSUPPORTED: + inputs.append( + IO.Combo.Input( + "reasoning_effort", + options=_REASONING_EFFORTS, + default="off", + tooltip="Extended thinking effort. 'off' disables reasoning.", + advanced=True, + ) + ) + return inputs def _model_price_per_million(model: str) -> tuple[float, float] | None: @@ -95,7 +126,11 @@ def calculate_tokens_price(response: AnthropicMessagesResponse) -> float | None: def _get_text_from_response(response: AnthropicMessagesResponse) -> str: if not response.content: return "" - return "\n".join(block.text for block in response.content if block.text) + # Thinking blocks are silently dropped — we never want reasoning in the output. + return "\n".join( + block.text for block in response.content + if isinstance(block, AnthropicResponseTextBlock) and block.text + ) async def _build_image_content_blocks( @@ -133,7 +168,10 @@ class ClaudeNode(IO.ComfyNode): ), IO.DynamicCombo.Input( "model", - options=[IO.DynamicCombo.Option(label, _claude_model_inputs()) for label in CLAUDE_MODELS], + options=[ + IO.DynamicCombo.Option(label, _claude_model_inputs(label)) + for label in CLAUDE_MODELS + ], tooltip="The Claude model used to generate the response.", ), IO.Int.Input( @@ -207,8 +245,29 @@ class ClaudeNode(IO.ComfyNode): ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) model_label = model["model"] - max_tokens = model["max_tokens"] - temperature = None if model_label == "Opus 4.7" else model["temperature"] + max_tokens = model.get("max_tokens", 32768) + reasoning_effort = model.get("reasoning_effort", "off") + thinking_enabled = reasoning_effort not in ("off", None) and model_label not in _THINKING_UNSUPPORTED + + # Anthropic requires temperature to be unset (defaults to 1.0) when thinking is enabled. + # Opus 4.7 also rejects user-supplied temperature. + if thinking_enabled or model_label == "Opus 4.7": + temperature = None + else: + temperature = model.get("temperature", 1.0) + + thinking_cfg: AnthropicThinkingConfig | None = None + output_cfg: AnthropicOutputConfig | None = None + if thinking_enabled: + if model_label in _ADAPTIVE_THINKING_MODELS: + # Adaptive mode - Anthropic chooses the budget based on effort hint + thinking_cfg = AnthropicThinkingConfig(type="adaptive") + output_cfg = AnthropicOutputConfig(effort=reasoning_effort) + else: + # Budget mode (Sonnet 4.5). Leave at least 1024 tokens for the actual response + budget = _REASONING_BUDGET[reasoning_effort] + budget = min(budget, max(1024, max_tokens - 1024)) + thinking_cfg = AnthropicThinkingConfig(type="enabled", budget_tokens=budget) image_tensors: list[Input.Image] = [t for t in (images or {}).values() if t is not None] if sum(get_number_of_images(t) for t in image_tensors) > CLAUDE_MAX_IMAGES: @@ -229,6 +288,8 @@ class ClaudeNode(IO.ComfyNode): messages=[AnthropicMessage(role=AnthropicRole.user, content=content)], system=system_prompt or None, temperature=temperature, + thinking=thinking_cfg, + output_config=output_cfg, ), price_extractor=calculate_tokens_price, ) diff --git a/comfy_api_nodes/nodes_openrouter.py b/comfy_api_nodes/nodes_openrouter.py new file mode 100644 index 000000000..031301870 --- /dev/null +++ b/comfy_api_nodes/nodes_openrouter.py @@ -0,0 +1,374 @@ +"""API Nodes for OpenRouter LLM chat completions.""" + +from dataclasses import dataclass +from typing import Literal + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.openrouter import ( + OpenRouterChatRequest, + OpenRouterChatResponse, + OpenRouterContentBlock, + OpenRouterImageContent, + OpenRouterImageUrl, + OpenRouterMessage, + OpenRouterReasoningConfig, + OpenRouterTextContent, + OpenRouterVideoContent, + OpenRouterVideoUrl, + OpenRouterWebSearchOptions, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + get_number_of_images, + sync_op, + upload_images_to_comfyapi, + upload_video_to_comfyapi, + validate_string, +) + +OPENROUTER_CHAT_ENDPOINT = "/proxy/openrouter/api/v1/chat/completions" + + +Profile = Literal["standard", "reasoning", "frontier_reasoning", "perplexity", "perplexity_reasoning"] + + +@dataclass(frozen=True) +class _ModelSpec: + slug: str # exact OpenRouter model id + profile: Profile + price_in: float # USD per token (prompt) + price_out: float # USD per token (completion) + max_images: int = 0 # 0 = no image input; otherwise max URL-passed images supported + max_videos: int = 0 # 0 = no video input; otherwise max URL-passed videos supported + + +MODELS: list[_ModelSpec] = [ + _ModelSpec("anthropic/claude-opus-4.7", "frontier_reasoning", 0.000005, 0.000025, max_images=20), + _ModelSpec("openai/gpt-5.5-pro", "frontier_reasoning", 0.00003, 0.00018, max_images=20), + _ModelSpec("openai/gpt-5.5", "frontier_reasoning", 0.000005, 0.00003, max_images=20), + _ModelSpec("google/gemini-3.5-flash", "reasoning", 0.0000015, 0.000009, max_images=20, max_videos=4), + _ModelSpec("x-ai/grok-4.20", "reasoning", 0.00000125, 0.0000025, max_images=20), + _ModelSpec("x-ai/grok-4.3", "reasoning", 0.00000125, 0.0000025, max_images=20), + _ModelSpec("deepseek/deepseek-v4-pro", "reasoning", 0.000000435, 0.00000087), + _ModelSpec("deepseek/deepseek-v4-flash", "reasoning", 0.000000112, 0.000000224), + _ModelSpec("deepseek/deepseek-v3.2", "reasoning", 0.000000252, 0.000000378), + _ModelSpec("qwen/qwen3.6-max-preview", "reasoning", 0.00000104, 0.00000624), + _ModelSpec("qwen/qwen3.6-plus", "reasoning", 0.000000325, 0.00000195, max_images=10, max_videos=4), + _ModelSpec("qwen/qwen3.6-flash", "reasoning", 0.0000001875, 0.000001125, max_images=10, max_videos=4), + _ModelSpec("mistralai/mistral-large-2512", "standard", 0.0000005, 0.0000015, max_images=8), + _ModelSpec("mistralai/mistral-medium-3-5", "reasoning", 0.0000015, 0.0000075, max_images=8), + _ModelSpec("z-ai/glm-4.6", "reasoning", 0.00000043, 0.00000174), + _ModelSpec("z-ai/glm-5", "reasoning", 0.0000006, 0.00000192), + _ModelSpec("moonshotai/kimi-k2.6", "reasoning", 0.00000073, 0.00000349, max_images=10), + _ModelSpec("moonshotai/kimi-k2-thinking", "reasoning", 0.0000006, 0.0000025), + _ModelSpec("perplexity/sonar-pro", "perplexity", 0.000003, 0.000015), + _ModelSpec("perplexity/sonar-reasoning-pro", "perplexity_reasoning", 0.000002, 0.000008), + _ModelSpec("perplexity/sonar-deep-research", "perplexity_reasoning", 0.000002, 0.000008), +] + +_MODELS_BY_SLUG: dict[str, _ModelSpec] = {m.slug: m for m in MODELS} +_REASONING_EFFORTS = ["off", "low", "medium", "high"] +_SEARCH_CONTEXT_SIZES = ["low", "medium", "high"] + + +def _reasoning_extra_inputs() -> list: + return [ + IO.Combo.Input( + "reasoning_effort", + options=_REASONING_EFFORTS, + default="off", + tooltip="Reasoning effort. 'off' disables reasoning entirely.", + advanced=True, + ), + ] + + +def _perplexity_extra_inputs() -> list: + return [ + IO.Combo.Input( + "search_context_size", + options=_SEARCH_CONTEXT_SIZES, + default="medium", + tooltip="How much web search context to retrieve. Larger = more grounded but slower/pricier.", + advanced=True, + ), + ] + + +def _profile_inputs(profile: Profile) -> list: + if profile == "standard": + return [] + if profile in ("reasoning", "frontier_reasoning"): + return _reasoning_extra_inputs() + if profile == "perplexity": + return _perplexity_extra_inputs() + if profile == "perplexity_reasoning": + return _perplexity_extra_inputs() + _reasoning_extra_inputs() + raise ValueError(f"Unknown profile: {profile}") + + +def _media_inputs(spec: _ModelSpec) -> list: + extras: list = [] + if spec.max_images > 0: + extras.append( + IO.Autogrow.Input( + "images", + template=IO.Autogrow.TemplateNames( + IO.Image.Input("image"), + names=[f"image_{i}" for i in range(1, spec.max_images + 1)], + min=0, + ), + tooltip=f"Optional reference image(s) — up to {spec.max_images}. Sent as URLs.", + ) + ) + if spec.max_videos > 0: + extras.append( + IO.Autogrow.Input( + "videos", + template=IO.Autogrow.TemplateNames( + IO.Video.Input("video"), + names=[f"video_{i}" for i in range(1, spec.max_videos + 1)], + min=0, + ), + tooltip=f"Optional reference video(s) — up to {spec.max_videos}. Sent as URLs.", + ) + ) + return extras + + +def _inputs_for_model(spec: _ModelSpec) -> list: + return _profile_inputs(spec.profile) + _media_inputs(spec) + + +def _build_model_options() -> list[IO.DynamicCombo.Option]: + return [IO.DynamicCombo.Option(spec.slug, _inputs_for_model(spec)) for spec in MODELS] + + +def _calculate_price(response: OpenRouterChatResponse) -> float | None: + if response.usage and response.usage.cost is not None: + return float(response.usage.cost) + return None + + +def _price_badge_jsonata() -> str: + rates_pairs = [] + for spec in MODELS: + prompt_per_1k = spec.price_in * 1000 + completion_per_1k = spec.price_out * 1000 + rates_pairs.append(f' "{spec.slug}": [{prompt_per_1k:.8g}, {completion_per_1k:.8g}]') + rates_block = ",\n".join(rates_pairs) + return ( + "(\n" + " $rates := {\n" + f"{rates_block}\n" + " };\n" + " $r := $lookup($rates, widgets.model);\n" + " $r ? {\n" + ' "type": "list_usd",\n' + ' "usd": $r,\n' + ' "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }\n' + ' } : {"type": "text", "text": "Token-based"}\n' + ")" + ) + + +async def _build_image_blocks( + cls: type[IO.ComfyNode], spec: _ModelSpec, images: list[Input.Image] +) -> list[OpenRouterImageContent]: + urls = await upload_images_to_comfyapi( + cls, + images, + max_images=spec.max_images, + total_pixels=2048 * 2048, + mime_type="image/png", + wait_label="Uploading reference images", + ) + return [OpenRouterImageContent(image_url=OpenRouterImageUrl(url=url)) for url in urls] + + +async def _build_video_blocks(cls: type[IO.ComfyNode], videos: list[Input.Video]) -> list[OpenRouterVideoContent]: + blocks: list[OpenRouterVideoContent] = [] + total = len(videos) + for idx, video in enumerate(videos): + label = "Uploading reference video" + if total > 1: + label = f"{label} ({idx + 1}/{total})" + url = await upload_video_to_comfyapi(cls, video, wait_label=label) + blocks.append(OpenRouterVideoContent(video_url=OpenRouterVideoUrl(url=url))) + return blocks + + +def _user_message(prompt: str, media_blocks: list[OpenRouterContentBlock]) -> OpenRouterMessage: + if not media_blocks: + return OpenRouterMessage(role="user", content=prompt) + blocks: list[OpenRouterContentBlock] = list(media_blocks) + blocks.append(OpenRouterTextContent(text=prompt)) + return OpenRouterMessage(role="user", content=blocks) + + +def _build_messages( + system_prompt: str, prompt: str, media_blocks: list[OpenRouterContentBlock] +) -> list[OpenRouterMessage]: + messages: list[OpenRouterMessage] = [] + if system_prompt: + messages.append(OpenRouterMessage(role="system", content=system_prompt)) + messages.append(_user_message(prompt, media_blocks)) + return messages + + +def _build_request( + slug: str, + system_prompt: str, + prompt: str, + media_blocks: list[OpenRouterContentBlock], + *, + seed: int, + reasoning_effort: str | None, + search_context_size: str | None, +) -> OpenRouterChatRequest: + reasoning_cfg: OpenRouterReasoningConfig | None = None + if reasoning_effort and reasoning_effort != "off": + # exclude=True asks providers to reason internally but not return the trace + reasoning_cfg = OpenRouterReasoningConfig(effort=reasoning_effort, exclude=True) + web_search_cfg: OpenRouterWebSearchOptions | None = None + if search_context_size: + web_search_cfg = OpenRouterWebSearchOptions(search_context_size=search_context_size) + return OpenRouterChatRequest( + model=slug, + messages=_build_messages(system_prompt, prompt, media_blocks), + seed=seed if seed > 0 else None, + reasoning=reasoning_cfg, + web_search_options=web_search_cfg, + ) + + +def _extract_text(response: OpenRouterChatResponse) -> str: + if response.error: + code = response.error.code if response.error.code is not None else "unknown" + raise ValueError(f"OpenRouter error ({code}): {response.error.message or 'no message'}") + if not response.choices: + raise ValueError("Empty response from OpenRouter (no choices).") + message = response.choices[0].message + if not message: + raise ValueError("Empty response from OpenRouter (no message).") + if message.refusal: + raise ValueError(f"Model refused to respond: {message.refusal}") + return message.content or "" + + +class OpenRouterLLMNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="OpenRouterLLMNode", + display_name="OpenRouter LLM", + category="api node/text/OpenRouter", + essentials_category="Text Generation", + description=( + "Generate text responses through OpenRouter. Routes to a curated set of popular " + "models from xAI, DeepSeek, Qwen, Mistral, Z.AI (GLM), Moonshot (Kimi), and " + "Perplexity Sonar." + ), + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text input to the model.", + ), + IO.DynamicCombo.Input( + "model", + options=_build_model_options(), + tooltip="The OpenRouter model used to generate the response.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed for sampling. Set to 0 to omit. Most models treat this as a hint only.", + ), + IO.String.Input( + "system_prompt", + multiline=True, + default="", + optional=True, + advanced=True, + tooltip="Foundational instructions that dictate the model's behavior.", + ), + ], + outputs=[IO.String.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=_price_badge_jsonata(), + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + seed: int, + system_prompt: str = "", + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + slug: str = model["model"] + spec = _MODELS_BY_SLUG.get(slug) + if spec is None: + raise ValueError(f"Unknown OpenRouter model: {slug}") + + reasoning_effort: str | None = model.get("reasoning_effort") + search_context_size: str | None = model.get("search_context_size") + + image_tensors: list[Input.Image] = [t for t in (model.get("images") or {}).values() if t is not None] + if image_tensors and sum(get_number_of_images(t) for t in image_tensors) > spec.max_images: + raise ValueError(f"Up to {spec.max_images} images are supported for {slug}.") + video_inputs: list[Input.Video] = [v for v in (model.get("videos") or {}).values() if v is not None] + if video_inputs and len(video_inputs) > spec.max_videos: + raise ValueError(f"Up to {spec.max_videos} videos are supported for {slug}.") + + media_blocks: list[OpenRouterContentBlock] = [] + if image_tensors: + media_blocks.extend(await _build_image_blocks(cls, spec, image_tensors)) + if video_inputs: + media_blocks.extend(await _build_video_blocks(cls, video_inputs)) + + request = _build_request( + slug, + system_prompt, + prompt, + media_blocks, + seed=seed, + reasoning_effort=reasoning_effort, + search_context_size=search_context_size, + ) + + response = await sync_op( + cls, + ApiEndpoint(path=OPENROUTER_CHAT_ENDPOINT, method="POST"), + response_model=OpenRouterChatResponse, + data=request, + price_extractor=_calculate_price, + ) + return IO.NodeOutput(_extract_text(response)) + + +class OpenRouterExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [OpenRouterLLMNode] + + +async def comfy_entrypoint() -> OpenRouterExtension: + return OpenRouterExtension() From 2ca1480f9198b04aba5fb03d7584e2fb1a30065f Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Fri, 22 May 2026 02:48:20 +0800 Subject: [PATCH 20/45] chore: update workflow templates to v0.9.82 (#14034) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index d2986eda8..e20b6e044 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.43.18 -comfyui-workflow-templates==0.9.79 +comfyui-workflow-templates==0.9.82 comfyui-embedded-docs==0.5.0 torch torchsde From b293f8cefd18b2f8be061e33cb985149ec2ee872 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Thu, 21 May 2026 21:58:03 +0300 Subject: [PATCH 21/45] [Partner Nodes] add widget for automatic upscaling for the ByteDance2Reference node (#14032) Signed-off-by: bigcat88 --- comfy_api_nodes/nodes_bytedance.py | 33 ++++++++++++++++++------ comfy_api_nodes/util/__init__.py | 6 +++-- comfy_api_nodes/util/conversions.py | 40 ++++++++++++++++++++++++++--- 3 files changed, 66 insertions(+), 13 deletions(-) diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index d6b479336..e08fc0b01 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -43,15 +43,16 @@ from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_image_tensor, download_url_to_video_output, + downscale_video_to_max_pixels, get_number_of_images, image_tensor_pair_to_batch, poll_op, - resize_video_to_pixel_budget, sync_op, upload_audio_to_comfyapi, upload_image_to_comfyapi, upload_images_to_comfyapi, upload_video_to_comfyapi, + upscale_video_to_min_pixels, validate_image_aspect_ratio, validate_image_dimensions, validate_string, @@ -110,12 +111,13 @@ def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: st max_px = limits.get("max") if min_px and pixels < min_px: raise ValueError( - f"Reference video {index} is too small: {w}x{h} = {pixels:,}px. " f"Minimum is {min_px:,}px for this model." + f"Reference video {index} is too small: {w}x{h} = {pixels:,} total pixels. " + f"Minimum for this model is {min_px:,} total pixels." ) if max_px and pixels > max_px: raise ValueError( - f"Reference video {index} is too large: {w}x{h} = {pixels:,}px. " - f"Maximum is {max_px:,}px for this model. Try downscaling the video." + f"Reference video {index} is too large: {w}x{h} = {pixels:,} total pixels. " + f"Maximum for this model is {max_px:,} total pixels. Try downscaling the video." ) @@ -1676,14 +1678,14 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): "first_frame_asset_id", default="", tooltip="Seedance asset_id to use as the first frame. " - "Mutually exclusive with the first_frame image input.", + "Mutually exclusive with the first_frame image input.", optional=True, ), IO.String.Input( "last_frame_asset_id", default="", tooltip="Seedance asset_id to use as the last frame. " - "Mutually exclusive with the last_frame image input.", + "Mutually exclusive with the last_frame image input.", optional=True, ), IO.Int.Input( @@ -1865,11 +1867,20 @@ def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16 IO.Boolean.Input( "auto_downscale", default=False, - advanced=True, optional=True, tooltip="Automatically downscale reference videos that exceed the model's pixel budget " "for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.", ), + IO.Boolean.Input( + "auto_upscale", + default=False, + advanced=True, + optional=True, + tooltip="Automatically upscale reference videos that are below the model's minimum pixel count " + "for the selected resolution. Aspect ratio is preserved; videos already meeting the minimum are " + "untouched. Note: upscaling a low-resolution source does not add real detail and may produce " + "lower-quality generations.", + ), IO.Autogrow.Input( "reference_assets", template=IO.Autogrow.TemplateNames( @@ -2030,7 +2041,13 @@ class ByteDance2ReferenceNode(IO.ComfyNode): max_px = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}).get(model["resolution"], {}).get("max") if max_px: for key in reference_videos: - reference_videos[key] = resize_video_to_pixel_budget(reference_videos[key], max_px) + reference_videos[key] = downscale_video_to_max_pixels(reference_videos[key], max_px) + + if model.get("auto_upscale") and reference_videos: + min_px = SEEDANCE2_REF_VIDEO_PIXEL_LIMITS.get(model_id, {}).get(model["resolution"], {}).get("min") + if min_px: + for key in reference_videos: + reference_videos[key] = upscale_video_to_min_pixels(reference_videos[key], min_px) total_video_duration = 0.0 for i, key in enumerate(reference_videos, 1): diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index f3584aba9..25cb88869 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -16,16 +16,17 @@ from .conversions import ( convert_mask_to_image, downscale_image_tensor, downscale_image_tensor_by_max_side, + downscale_video_to_max_pixels, image_tensor_pair_to_batch, pil_to_bytesio, resize_mask_to_image, - resize_video_to_pixel_budget, tensor_to_base64_string, tensor_to_bytesio, tensor_to_pil, text_filepath_to_base64_string, text_filepath_to_data_uri, trim_video, + upscale_video_to_min_pixels, video_to_base64_string, ) from .download_helpers import ( @@ -88,16 +89,17 @@ __all__ = [ "convert_mask_to_image", "downscale_image_tensor", "downscale_image_tensor_by_max_side", + "downscale_video_to_max_pixels", "image_tensor_pair_to_batch", "pil_to_bytesio", "resize_mask_to_image", - "resize_video_to_pixel_budget", "tensor_to_base64_string", "tensor_to_bytesio", "tensor_to_pil", "text_filepath_to_base64_string", "text_filepath_to_data_uri", "trim_video", + "upscale_video_to_min_pixels", "video_to_base64_string", # Validation utilities "get_image_dimensions", diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index be5d5719b..5738df57f 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -415,14 +415,48 @@ def trim_video(video: Input.Video, duration_sec: float) -> Input.Video: raise RuntimeError(f"Failed to trim video: {str(e)}") from e -def resize_video_to_pixel_budget(video: Input.Video, total_pixels: int) -> Input.Video: - """Downscale a video to fit within ``total_pixels`` (w * h), preserving aspect ratio. +def downscale_video_to_max_pixels(video: Input.Video, max_pixels: int) -> Input.Video: + """Downscale a video to fit within ``max_pixels`` (w * h), preserving aspect ratio. Returns the original video object untouched when it already fits. Preserves frame rate, duration, and audio. Aspect ratio is preserved up to a fraction of a percent (even-dim rounding). """ src_w, src_h = video.get_dimensions() - scale_dims = _compute_downscale_dims(src_w, src_h, total_pixels) + scale_dims = _compute_downscale_dims(src_w, src_h, max_pixels) + if scale_dims is None: + return video + return _apply_video_scale(video, scale_dims) + + +def _compute_upscale_dims(src_w: int, src_h: int, total_pixels: int) -> tuple[int, int] | None: + """Return upscaled (w, h) with even dims meeting at least ``total_pixels``, or None if already large enough. + + Source aspect ratio is preserved; output may drift by a fraction of a percent because both dimensions + are rounded up to even values (many codecs require divisible-by-2). The result is guaranteed to be at + least ``total_pixels``. + """ + pixels = src_w * src_h + if pixels >= total_pixels: + return None + scale = math.sqrt(total_pixels / pixels) + new_w = math.ceil(src_w * scale) + new_h = math.ceil(src_h * scale) + if new_w % 2: + new_w += 1 + if new_h % 2: + new_h += 1 + return new_w, new_h + + +def upscale_video_to_min_pixels(video: Input.Video, min_pixels: int) -> Input.Video: + """Upscale a video to meet at least ``min_pixels`` (w * h), preserving aspect ratio. + + Returns the original video object untouched when it already meets the minimum. Preserves frame rate, + duration, and audio. Aspect ratio is preserved up to a fraction of a percent (even-dim rounding). + Note: upscaling a low-resolution source does not add real detail; downstream model quality may suffer. + """ + src_w, src_h = video.get_dimensions() + scale_dims = _compute_upscale_dims(src_w, src_h, min_pixels) if scale_dims is None: return video return _apply_video_scale(video, scale_dims) From 32e58393b8c329c1a3fb1ddf74902c182c5064d5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 21 May 2026 14:49:55 -0700 Subject: [PATCH 22/45] Add backport release workflow. (#14038) --- .github/workflows/backport_release.yaml | 401 ++++++++++++++++++++++++ 1 file changed, 401 insertions(+) create mode 100644 .github/workflows/backport_release.yaml diff --git a/.github/workflows/backport_release.yaml b/.github/workflows/backport_release.yaml new file mode 100644 index 000000000..ba1f70e58 --- /dev/null +++ b/.github/workflows/backport_release.yaml @@ -0,0 +1,401 @@ +name: Backport Release + +on: + workflow_dispatch: + inputs: + branch: + description: 'Source branch containing the backported commits (PR source branch into master)' + required: true + type: string + +permissions: + contents: read + pull-requests: read + checks: read + +jobs: + backport-release: + name: Create backport release + runs-on: ubuntu-latest + environment: backport release + + steps: + - name: Generate GitHub App token + id: app-token + uses: actions/create-github-app-token@bcd2ba49218906704ab6c1aa796996da409d3eb1 + with: + app-id: ${{ secrets.FEN_RELEASE_APP_ID }} + private-key: ${{ secrets.FEN_RELEASE_PRIVATE_KEY }} + + - name: Checkout repository + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd + with: + token: ${{ steps.app-token.outputs.token }} + fetch-depth: 0 + fetch-tags: true + + - name: Configure git + run: | + git config user.name "fen-release[bot]" + git config user.email "fen-release[bot]@users.noreply.github.com" + + - name: Validate source branch exists + env: + SOURCE_BRANCH: ${{ inputs.branch }} + run: | + set -euo pipefail + git fetch origin "refs/heads/${SOURCE_BRANCH}:refs/remotes/origin/${SOURCE_BRANCH}" + if ! git show-ref --verify --quiet "refs/remotes/origin/${SOURCE_BRANCH}"; then + echo "::error::Source branch '${SOURCE_BRANCH}' not found on origin." + exit 1 + fi + + - name: Determine latest stable release + id: latest + env: + GH_TOKEN: ${{ steps.app-token.outputs.token }} + run: | + set -euo pipefail + + # List all tags matching vMAJOR.MINOR.PATCH and pick the highest by numeric + # comparison of each component. We DO NOT use `sort -V` because it treats + # v0.19.99 as higher than v0.20.1. + latest_tag="$( + git tag --list 'v[0-9]*.[0-9]*.[0-9]*' \ + | grep -E '^v[0-9]+\.[0-9]+\.[0-9]+$' \ + | awk -F'[v.]' '{ printf "%010d %010d %010d %s\n", $2, $3, $4, $0 }' \ + | sort -k1,1n -k2,2n -k3,3n \ + | tail -n1 \ + | awk '{print $4}' + )" + + if [[ -z "${latest_tag}" ]]; then + echo "::error::No stable release tags (vMAJOR.MINOR.PATCH) were found." + exit 1 + fi + + # Parse components + ver="${latest_tag#v}" + major="${ver%%.*}" + rest="${ver#*.}" + minor="${rest%%.*}" + patch="${rest#*.}" + + new_patch=$((patch + 1)) + new_version="v${major}.${minor}.${new_patch}" + release_branch="release/v${major}.${minor}" + + latest_sha="$(git rev-list -n 1 "refs/tags/${latest_tag}")" + + echo "latest_tag=${latest_tag}" >> "$GITHUB_OUTPUT" + echo "latest_sha=${latest_sha}" >> "$GITHUB_OUTPUT" + echo "major=${major}" >> "$GITHUB_OUTPUT" + echo "minor=${minor}" >> "$GITHUB_OUTPUT" + echo "patch=${patch}" >> "$GITHUB_OUTPUT" + echo "new_version=${new_version}" >> "$GITHUB_OUTPUT" + echo "new_version_no_v=${major}.${minor}.${new_patch}" >> "$GITHUB_OUTPUT" + echo "release_branch=${release_branch}" >> "$GITHUB_OUTPUT" + + echo "Latest stable release: ${latest_tag} (${latest_sha})" + echo "New version will be: ${new_version}" + echo "Release branch: ${release_branch}" + + - name: Validate source branch is cut directly from the latest stable release + env: + SOURCE_BRANCH: ${{ inputs.branch }} + LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }} + LATEST_TAG: ${{ steps.latest.outputs.latest_tag }} + run: | + set -euo pipefail + + source_sha="$(git rev-parse "refs/remotes/origin/${SOURCE_BRANCH}")" + + # The source branch must be cut directly off the latest stable tag. + # "Cut directly off" means: walking first-parent from the source tip + # eventually reaches LATEST_TAG_SHA. This rejects branches that were + # cut from master after the tag (which would carry unrelated commits), + # while accepting a branch rooted at the tag with N backport commits + # on top (each of which may itself be a merge — first-parent walks + # through the mainline of the branch). + if ! git rev-list --first-parent "${source_sha}" \ + | grep -qx "${LATEST_TAG_SHA}"; then + echo "::error::Source branch '${SOURCE_BRANCH}' is not cut from '${LATEST_TAG}'." + echo "::error::Its first-parent history does not include ${LATEST_TAG_SHA}." + exit 1 + fi + + # Additionally, every commit added on top of the tag (the set we are + # about to publish) must itself be a descendant of the tag along + # first-parent — i.e. no sibling commits from master sneak in via a + # non-first-parent path. Enforce by requiring that the symmetric + # difference is empty in one direction: commits in source that are + # NOT first-parent-reachable from source starting at the tag. + # We do this by intersecting: + # A = commits reachable from source but not from tag (full DAG) + # B = commits on the first-parent chain from source down to tag + # and requiring A == B. + all_added="$(git rev-list "${LATEST_TAG_SHA}..${source_sha}" | sort)" + first_parent_added="$( + git rev-list --first-parent "${LATEST_TAG_SHA}..${source_sha}" | sort + )" + + if [[ "${all_added}" != "${first_parent_added}" ]]; then + echo "::error::Source branch '${SOURCE_BRANCH}' contains commits not on its first-parent chain from '${LATEST_TAG}'." + echo "::error::This usually means the branch was cut from master (not from the tag) or contains a merge from master." + echo "Commits reachable but not on first-parent chain:" + comm -23 <(printf '%s\n' "${all_added}") <(printf '%s\n' "${first_parent_added}") \ + | while read -r sha; do + echo " $(git log -1 --format='%h %s' "${sha}")" + done + exit 1 + fi + + added_count="$(printf '%s\n' "${all_added}" | grep -c . || true)" + echo "Source branch is cut directly from ${LATEST_TAG} with ${added_count} commit(s) on top." + + - name: Validate PR exists, is named correctly, and checks pass + env: + GH_TOKEN: ${{ steps.app-token.outputs.token }} + SOURCE_BRANCH: ${{ inputs.branch }} + NEW_VERSION: ${{ steps.latest.outputs.new_version }} + REPO: ${{ github.repository }} + run: | + set -euo pipefail + + expected_title="ComfyUI backport release ${NEW_VERSION}" + + # Find open PRs from this branch into master + pr_json="$( + gh pr list \ + --repo "${REPO}" \ + --state open \ + --head "${SOURCE_BRANCH}" \ + --base master \ + --json number,title,headRefOid \ + --limit 10 + )" + + pr_count="$(echo "${pr_json}" | jq 'length')" + if [[ "${pr_count}" -eq 0 ]]; then + echo "::error::No open PR found from '${SOURCE_BRANCH}' into 'master'." + exit 1 + fi + + # Pick the PR matching the expected title + pr_number="$(echo "${pr_json}" | jq -r --arg t "${expected_title}" ' + map(select(.title == $t)) | .[0].number // empty + ')" + pr_head_sha="$(echo "${pr_json}" | jq -r --arg t "${expected_title}" ' + map(select(.title == $t)) | .[0].headRefOid // empty + ')" + + if [[ -z "${pr_number}" ]]; then + echo "::error::No open PR from '${SOURCE_BRANCH}' into 'master' is titled '${expected_title}'." + echo "Found PRs:" + echo "${pr_json}" | jq -r '.[] | " #\(.number): \(.title)"' + exit 1 + fi + + echo "Found PR #${pr_number} titled '${expected_title}' (head ${pr_head_sha})." + + # Verify all check runs on the head commit have completed successfully. + # A check is considered passing if conclusion is success, neutral, or skipped. + checks_json="$( + gh api \ + --paginate \ + "repos/${REPO}/commits/${pr_head_sha}/check-runs" \ + --jq '.check_runs[] | {name: .name, status: .status, conclusion: .conclusion}' + )" + + if [[ -z "${checks_json}" ]]; then + echo "::error::No check runs found on PR head commit ${pr_head_sha}." + exit 1 + fi + + echo "Check runs on ${pr_head_sha}:" + echo "${checks_json}" | jq -s '.' + + failing="$(echo "${checks_json}" | jq -s ' + map(select( + .status != "completed" + or (.conclusion as $c + | ["success","neutral","skipped"] + | index($c) | not) + )) + ')" + + failing_count="$(echo "${failing}" | jq 'length')" + if [[ "${failing_count}" -gt 0 ]]; then + echo "::error::One or more checks have not passed on PR head commit ${pr_head_sha}:" + echo "${failing}" | jq -r '.[] | " - \(.name): status=\(.status) conclusion=\(.conclusion)"' + exit 1 + fi + + echo "All checks have passed on ${pr_head_sha}." + + - name: Prepare release branch + id: prepare + env: + GH_TOKEN: ${{ steps.app-token.outputs.token }} + REPO: ${{ github.repository }} + SOURCE_BRANCH: ${{ inputs.branch }} + RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }} + LATEST_TAG: ${{ steps.latest.outputs.latest_tag }} + LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }} + PATCH: ${{ steps.latest.outputs.patch }} + run: | + set -euo pipefail + + # Try to fetch the release branch. If patch == 0, it shouldn't exist yet + # and we'll create it from the latest stable tag. If patch > 0, it must + # already exist and its tip must equal the latest stable tag commit (i.e. + # the previous patch release). + if git ls-remote --exit-code --heads origin "${RELEASE_BRANCH}" >/dev/null 2>&1; then + echo "Release branch '${RELEASE_BRANCH}' already exists on origin." + git fetch origin "refs/heads/${RELEASE_BRANCH}:refs/remotes/origin/${RELEASE_BRANCH}" + git checkout -B "${RELEASE_BRANCH}" "refs/remotes/origin/${RELEASE_BRANCH}" + + current_tip="$(git rev-parse HEAD)" + if [[ "${current_tip}" != "${LATEST_TAG_SHA}" ]]; then + echo "::error::Release branch '${RELEASE_BRANCH}' tip (${current_tip}) is not at the latest stable release '${LATEST_TAG}' (${LATEST_TAG_SHA})." + echo "::error::Refusing to release on top of a divergent branch." + exit 1 + fi + echo "branch_existed=true" >> "$GITHUB_OUTPUT" + else + if [[ "${PATCH}" != "0" ]]; then + echo "::error::Release branch '${RELEASE_BRANCH}' does not exist on origin, but the latest stable release '${LATEST_TAG}' has patch=${PATCH} (>0). This is inconsistent." + exit 1 + fi + echo "Release branch '${RELEASE_BRANCH}' does not exist. Creating from ${LATEST_TAG}." + git checkout -B "${RELEASE_BRANCH}" "refs/tags/${LATEST_TAG}" + echo "branch_existed=false" >> "$GITHUB_OUTPUT" + fi + + - name: Fast-forward merge source branch into release branch + env: + SOURCE_BRANCH: ${{ inputs.branch }} + RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }} + run: | + set -euo pipefail + + # --ff-only guarantees no merge commit is created. If a fast-forward is + # not possible (i.e. the release branch has commits the source branch + # doesn't), the merge will fail and we abort. Because we already validated + # that the source branch is rooted on the latest stable tag, and the + # release branch tip equals that same tag, this fast-forward should + # always succeed for a well-formed backport branch. + if ! git merge --ff-only "refs/remotes/origin/${SOURCE_BRANCH}"; then + echo "::error::Cannot fast-forward '${RELEASE_BRANCH}' to '${SOURCE_BRANCH}'. A merge commit would be required. Aborting." + exit 1 + fi + + echo "Fast-forwarded '${RELEASE_BRANCH}' to tip of '${SOURCE_BRANCH}'." + + - name: Bump version files + env: + NEW_VERSION_NO_V: ${{ steps.latest.outputs.new_version_no_v }} + run: | + set -euo pipefail + + if [[ ! -f comfyui_version.py ]]; then + echo "::error::comfyui_version.py not found in repo root." + exit 1 + fi + if [[ ! -f pyproject.toml ]]; then + echo "::error::pyproject.toml not found in repo root." + exit 1 + fi + + # Replace the version string in comfyui_version.py. + # Expected format: __version__ = "X.Y.Z" + python3 - "$NEW_VERSION_NO_V" <<'PY' + import re, sys, pathlib + new = sys.argv[1] + + p = pathlib.Path("comfyui_version.py") + src = p.read_text() + new_src, n = re.subn( + r'(__version__\s*=\s*[\'"])[^\'"]+([\'"])', + lambda m: f'{m.group(1)}{new}{m.group(2)}', + src, + count=1, + ) + if n != 1: + sys.exit("Could not find __version__ assignment in comfyui_version.py") + p.write_text(new_src) + + p = pathlib.Path("pyproject.toml") + src = p.read_text() + # Replace the first `version = "..."` inside [project] or [tool.poetry]. + new_src, n = re.subn( + r'(?m)^(version\s*=\s*")[^"]+(")', + lambda m: f'{m.group(1)}{new}{m.group(2)}', + src, + count=1, + ) + if n != 1: + sys.exit("Could not find version assignment in pyproject.toml") + p.write_text(new_src) + PY + + echo "Updated version to ${NEW_VERSION_NO_V} in comfyui_version.py and pyproject.toml." + git --no-pager diff -- comfyui_version.py pyproject.toml + + - name: Commit version bump and tag release + env: + NEW_VERSION: ${{ steps.latest.outputs.new_version }} + run: | + set -euo pipefail + + git add comfyui_version.py pyproject.toml + git commit -m "ComfyUI ${NEW_VERSION}" + + if git rev-parse -q --verify "refs/tags/${NEW_VERSION}" >/dev/null; then + echo "::error::Tag ${NEW_VERSION} already exists locally." + exit 1 + fi + git tag "${NEW_VERSION}" + + - name: Verify tag does not already exist on origin + env: + NEW_VERSION: ${{ steps.latest.outputs.new_version }} + run: | + set -euo pipefail + if git ls-remote --exit-code --tags origin "refs/tags/${NEW_VERSION}" >/dev/null 2>&1; then + echo "::error::Tag ${NEW_VERSION} already exists on origin. Aborting." + exit 1 + fi + + - name: Push release branch and tag + env: + RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }} + NEW_VERSION: ${{ steps.latest.outputs.new_version }} + run: | + set -euo pipefail + + # Push the branch first, then the tag. Atomic-ish: if the branch push + # fails we never publish the tag. + git push origin "refs/heads/${RELEASE_BRANCH}:refs/heads/${RELEASE_BRANCH}" + git push origin "refs/tags/${NEW_VERSION}" + + echo "Released ${NEW_VERSION} on ${RELEASE_BRANCH}." + + - name: Summary + if: always() + env: + NEW_VERSION: ${{ steps.latest.outputs.new_version }} + RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }} + LATEST_TAG: ${{ steps.latest.outputs.latest_tag }} + SOURCE_BRANCH: ${{ inputs.branch }} + run: | + { + echo "## Backport release" + echo "" + echo "| Field | Value |" + echo "|---|---|" + echo "| Source branch | \`${SOURCE_BRANCH}\` |" + echo "| Previous stable | \`${LATEST_TAG}\` |" + echo "| New version | \`${NEW_VERSION}\` |" + echo "| Release branch | \`${RELEASE_BRANCH}\` |" + } >> "$GITHUB_STEP_SUMMARY" From 5d681a5420fea4772a4a9c9426c2fce7a88a3d24 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Thu, 21 May 2026 16:29:08 -0700 Subject: [PATCH 23/45] Fix SIGPIPE false negative in backport release validation (#14041) --- .github/workflows/backport_release.yaml | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/.github/workflows/backport_release.yaml b/.github/workflows/backport_release.yaml index ba1f70e58..b28d62656 100644 --- a/.github/workflows/backport_release.yaml +++ b/.github/workflows/backport_release.yaml @@ -110,15 +110,13 @@ jobs: source_sha="$(git rev-parse "refs/remotes/origin/${SOURCE_BRANCH}")" - # The source branch must be cut directly off the latest stable tag. - # "Cut directly off" means: walking first-parent from the source tip - # eventually reaches LATEST_TAG_SHA. This rejects branches that were - # cut from master after the tag (which would carry unrelated commits), - # while accepting a branch rooted at the tag with N backport commits - # on top (each of which may itself be a merge — first-parent walks - # through the mainline of the branch). - if ! git rev-list --first-parent "${source_sha}" \ - | grep -qx "${LATEST_TAG_SHA}"; then + # Walking first-parent from the source tip must reach LATEST_TAG_SHA. + # We capture rev-list into a variable and grep against a here-string + # rather than piping `rev-list | grep -q`: under `set -o pipefail`, + # `grep -q` would exit on first match and SIGPIPE the still-streaming + # `rev-list`, propagating exit 141 as a spurious "not found". + first_parent_chain="$(git rev-list --first-parent "${source_sha}")" + if ! grep -Fxq "${LATEST_TAG_SHA}" <<< "${first_parent_chain}"; then echo "::error::Source branch '${SOURCE_BRANCH}' is not cut from '${LATEST_TAG}'." echo "::error::Its first-parent history does not include ${LATEST_TAG_SHA}." exit 1 From 8fecef0686275c9ce334ac7b6780d35eb93b836f Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 21 May 2026 16:39:19 -0700 Subject: [PATCH 24/45] Add validation for source branch in backport workflow (#14042) --- .github/workflows/backport_release.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/backport_release.yaml b/.github/workflows/backport_release.yaml index b28d62656..03788dd48 100644 --- a/.github/workflows/backport_release.yaml +++ b/.github/workflows/backport_release.yaml @@ -42,8 +42,13 @@ jobs: - name: Validate source branch exists env: SOURCE_BRANCH: ${{ inputs.branch }} + DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} run: | set -euo pipefail + if [[ "${SOURCE_BRANCH}" == "${DEFAULT_BRANCH}" ]]; then + echo "::error::Source branch must not be the default branch ('${DEFAULT_BRANCH}')." + exit 1 + fi git fetch origin "refs/heads/${SOURCE_BRANCH}:refs/remotes/origin/${SOURCE_BRANCH}" if ! git show-ref --verify --quiet "refs/remotes/origin/${SOURCE_BRANCH}"; then echo "::error::Source branch '${SOURCE_BRANCH}' not found on origin." From 8edff549e3195442b4ee2da4f79076ee26d4c653 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 21 May 2026 18:22:47 -0700 Subject: [PATCH 25/45] Update backport workflow to use commit SHA input (#14043) --- .github/workflows/backport_release.yaml | 130 +++++++++++++++++++----- 1 file changed, 105 insertions(+), 25 deletions(-) diff --git a/.github/workflows/backport_release.yaml b/.github/workflows/backport_release.yaml index 03788dd48..474e7045b 100644 --- a/.github/workflows/backport_release.yaml +++ b/.github/workflows/backport_release.yaml @@ -3,8 +3,8 @@ name: Backport Release on: workflow_dispatch: inputs: - branch: - description: 'Source branch containing the backported commits (PR source branch into master)' + commit: + description: 'Full 40-char SHA of the tip commit of the backport source branch (the PR head commit that passed tests). The branch is resolved from this SHA and must be unique.' required: true type: string @@ -39,21 +39,71 @@ jobs: git config user.name "fen-release[bot]" git config user.email "fen-release[bot]@users.noreply.github.com" - - name: Validate source branch exists + - name: Resolve source branch from commit SHA + id: resolve env: - SOURCE_BRANCH: ${{ inputs.branch }} + SOURCE_COMMIT: ${{ inputs.commit }} DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} run: | set -euo pipefail - if [[ "${SOURCE_BRANCH}" == "${DEFAULT_BRANCH}" ]]; then + + # Require a full 40-char lowercase-hex SHA. Short SHAs are ambiguous + # and we will be comparing this value against API responses (PR head + # SHA, ref tips) that always return the full form. + if [[ ! "${SOURCE_COMMIT}" =~ ^[0-9a-f]{40}$ ]]; then + echo "::error::Input commit '${SOURCE_COMMIT}' is not a full 40-char lowercase hex SHA." + exit 1 + fi + + # Fetch all remote branches so we can search for which one(s) point + # at this SHA. `actions/checkout` with fetch-depth: 0 fetches full + # history of the checked-out ref but does not necessarily populate + # every refs/remotes/origin/*, so do it explicitly. + git fetch --prune origin '+refs/heads/*:refs/remotes/origin/*' + + # Verify the commit actually exists in this repo's object DB. + if ! git cat-file -e "${SOURCE_COMMIT}^{commit}" 2>/dev/null; then + echo "::error::Commit ${SOURCE_COMMIT} was not found in the repository." + exit 1 + fi + + # Find every remote branch whose tip == SOURCE_COMMIT. Exactly one + # branch must point at it. If zero, the commit isn't anyone's tip + # (likely stale, force-pushed past, or never the PR head). If more + # than one, the (branch -> SHA) mapping is ambiguous and we refuse + # to guess — the operator must give us a unique branch to release. + mapfile -t matching_branches < <( + git for-each-ref \ + --format='%(refname:strip=3)' \ + --points-at="${SOURCE_COMMIT}" \ + refs/remotes/origin/ \ + | grep -vx 'HEAD' || true + ) + + if [[ "${#matching_branches[@]}" -eq 0 ]]; then + echo "::error::No branch on origin has ${SOURCE_COMMIT} as its tip." + echo "::error::Either the branch was updated after you copied this SHA, or this commit was never the head of a branch." + exit 1 + fi + + if [[ "${#matching_branches[@]}" -gt 1 ]]; then + echo "::error::More than one branch on origin has ${SOURCE_COMMIT} as its tip; cannot pick one:" + for b in "${matching_branches[@]}"; do + echo "::error:: - ${b}" + done + echo "::error::Refusing to proceed with an ambiguous source branch." + exit 1 + fi + + source_branch="${matching_branches[0]}" + + if [[ "${source_branch}" == "${DEFAULT_BRANCH}" ]]; then echo "::error::Source branch must not be the default branch ('${DEFAULT_BRANCH}')." exit 1 fi - git fetch origin "refs/heads/${SOURCE_BRANCH}:refs/remotes/origin/${SOURCE_BRANCH}" - if ! git show-ref --verify --quiet "refs/remotes/origin/${SOURCE_BRANCH}"; then - echo "::error::Source branch '${SOURCE_BRANCH}' not found on origin." - exit 1 - fi + + echo "Resolved commit ${SOURCE_COMMIT} to branch '${source_branch}'." + echo "source_branch=${source_branch}" >> "$GITHUB_OUTPUT" - name: Determine latest stable release id: latest @@ -107,13 +157,18 @@ jobs: - name: Validate source branch is cut directly from the latest stable release env: - SOURCE_BRANCH: ${{ inputs.branch }} + SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }} + SOURCE_COMMIT: ${{ inputs.commit }} LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }} LATEST_TAG: ${{ steps.latest.outputs.latest_tag }} run: | set -euo pipefail - source_sha="$(git rev-parse "refs/remotes/origin/${SOURCE_BRANCH}")" + # Use the user-provided SHA directly rather than re-resolving the branch + # tip — the resolve step already proved the branch tip equals SOURCE_COMMIT, + # and pinning to the SHA here makes the rest of the job TOCTOU-safe against + # someone pushing to the branch mid-run. + source_sha="${SOURCE_COMMIT}" # Walking first-parent from the source tip must reach LATEST_TAG_SHA. # We capture rev-list into a variable and grep against a here-string @@ -156,10 +211,11 @@ jobs: added_count="$(printf '%s\n' "${all_added}" | grep -c . || true)" echo "Source branch is cut directly from ${LATEST_TAG} with ${added_count} commit(s) on top." - - name: Validate PR exists, is named correctly, and checks pass + - name: Validate PR exists, is open, named correctly, has latest commit, and checks pass env: GH_TOKEN: ${{ steps.app-token.outputs.token }} - SOURCE_BRANCH: ${{ inputs.branch }} + SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }} + SOURCE_COMMIT: ${{ inputs.commit }} NEW_VERSION: ${{ steps.latest.outputs.new_version }} REPO: ${{ github.repository }} run: | @@ -167,20 +223,22 @@ jobs: expected_title="ComfyUI backport release ${NEW_VERSION}" - # Find open PRs from this branch into master + # Find open PRs from this branch into master. The --state open filter + # is load-bearing: a closed/merged PR with passing checks must not be + # accepted as authorization for a new release. pr_json="$( gh pr list \ --repo "${REPO}" \ --state open \ --head "${SOURCE_BRANCH}" \ --base master \ - --json number,title,headRefOid \ + --json number,title,headRefOid,state \ --limit 10 )" pr_count="$(echo "${pr_json}" | jq 'length')" if [[ "${pr_count}" -eq 0 ]]; then - echo "::error::No open PR found from '${SOURCE_BRANCH}' into 'master'." + echo "::error::No open PR found from '${SOURCE_BRANCH}' into 'master'. The PR must exist and be open." exit 1 fi @@ -199,7 +257,19 @@ jobs: exit 1 fi - echo "Found PR #${pr_number} titled '${expected_title}' (head ${pr_head_sha})." + # The PR's current head commit must equal the SHA the operator gave us. + # This is what closes the door on releasing stale code: if anyone has + # pushed to the branch since the operator validated tests passed, the + # PR head will have advanced past SOURCE_COMMIT and we abort. (The + # resolve step already proved the branch tip == SOURCE_COMMIT; this + # ties that same SHA to the PR that authorizes the release.) + if [[ "${pr_head_sha}" != "${SOURCE_COMMIT}" ]]; then + echo "::error::PR #${pr_number} head commit is ${pr_head_sha}, but the operator-provided commit is ${SOURCE_COMMIT}." + echo "::error::The PR has new commits since this release was authorized. Re-run with the new head SHA after verifying its checks." + exit 1 + fi + + echo "Found open PR #${pr_number} titled '${expected_title}' at head ${pr_head_sha} (matches operator-provided commit)." # Verify all check runs on the head commit have completed successfully. # A check is considered passing if conclusion is success, neutral, or skipped. @@ -241,7 +311,6 @@ jobs: env: GH_TOKEN: ${{ steps.app-token.outputs.token }} REPO: ${{ github.repository }} - SOURCE_BRANCH: ${{ inputs.branch }} RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }} LATEST_TAG: ${{ steps.latest.outputs.latest_tag }} LATEST_TAG_SHA: ${{ steps.latest.outputs.latest_sha }} @@ -277,7 +346,8 @@ jobs: - name: Fast-forward merge source branch into release branch env: - SOURCE_BRANCH: ${{ inputs.branch }} + SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }} + SOURCE_COMMIT: ${{ inputs.commit }} RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }} run: | set -euo pipefail @@ -288,12 +358,16 @@ jobs: # that the source branch is rooted on the latest stable tag, and the # release branch tip equals that same tag, this fast-forward should # always succeed for a well-formed backport branch. - if ! git merge --ff-only "refs/remotes/origin/${SOURCE_BRANCH}"; then - echo "::error::Cannot fast-forward '${RELEASE_BRANCH}' to '${SOURCE_BRANCH}'. A merge commit would be required. Aborting." + # + # We merge the operator-provided SHA, not the branch ref, so a push to + # the branch in the window between resolve and now cannot smuggle new + # commits into the release. + if ! git merge --ff-only "${SOURCE_COMMIT}"; then + echo "::error::Cannot fast-forward '${RELEASE_BRANCH}' to ${SOURCE_COMMIT} (tip of '${SOURCE_BRANCH}'). A merge commit would be required. Aborting." exit 1 fi - echo "Fast-forwarded '${RELEASE_BRANCH}' to tip of '${SOURCE_BRANCH}'." + echo "Fast-forwarded '${RELEASE_BRANCH}' to ${SOURCE_COMMIT} (tip of '${SOURCE_BRANCH}')." - name: Bump version files env: @@ -390,14 +464,20 @@ jobs: NEW_VERSION: ${{ steps.latest.outputs.new_version }} RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }} LATEST_TAG: ${{ steps.latest.outputs.latest_tag }} - SOURCE_BRANCH: ${{ inputs.branch }} + SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }} + SOURCE_COMMIT: ${{ inputs.commit }} run: | + # SOURCE_BRANCH is empty if the resolve step never produced an output + # (e.g. the workflow failed in or before that step). Show a placeholder + # in that case so the summary table still renders cleanly. + source_branch_display="${SOURCE_BRANCH:-(unresolved)}" { echo "## Backport release" echo "" echo "| Field | Value |" echo "|---|---|" - echo "| Source branch | \`${SOURCE_BRANCH}\` |" + echo "| Source commit | \`${SOURCE_COMMIT}\` |" + echo "| Source branch | \`${source_branch_display}\` |" echo "| Previous stable | \`${LATEST_TAG}\` |" echo "| New version | \`${NEW_VERSION}\` |" echo "| Release branch | \`${RELEASE_BRANCH}\` |" From f48c32871b1d07a25715675b6b943c14da2ad501 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 22 May 2026 12:18:13 +1000 Subject: [PATCH 26/45] fe: Consolidate warnings (#13970) --- app/frontend_management.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/app/frontend_management.py b/app/frontend_management.py index d0596b276..483da2d29 100644 --- a/app/frontend_management.py +++ b/app/frontend_management.py @@ -62,6 +62,8 @@ def get_comfy_package_versions(): def check_comfy_packages_versions(): """Warn for every comfy* package whose installed version is below requirements.txt.""" from packaging.version import InvalidVersion, parse as parse_pep440 + outdated_packages = [] + for pkg in get_comfy_package_versions(): installed_str = pkg["installed"] required_str = pkg["required"] @@ -73,19 +75,26 @@ def check_comfy_packages_versions(): logging.error(f"Failed to check {pkg['name']} version: {e}") continue if outdated: - app.logger.log_startup_warning( - f""" + outdated_packages.append((pkg["name"], installed_str, required_str)) + else: + logging.info("{} version: {}".format(pkg["name"], installed_str)) + + if outdated_packages: + package_warnings = "\n".join( + f"Installed {name} version {installed} is lower than the recommended version {required}." + for name, installed, required in outdated_packages + ) + app.logger.log_startup_warning( + f""" ________________________________________________________________________ WARNING WARNING WARNING WARNING WARNING -Installed {pkg["name"]} version {installed_str} is lower than the recommended version {required_str}. +{package_warnings} {get_missing_requirements_message()} ________________________________________________________________________ """.strip() - ) - else: - logging.info("{} version: {}".format(pkg["name"], installed_str)) + ) REQUEST_TIMEOUT = 10 # seconds From 965057037818bc3ee24e034308c8a716f6654a65 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 21 May 2026 19:52:38 -0700 Subject: [PATCH 27/45] Update Discord invite link in README.md (#14045) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 0eecd8a4b..5125bad14 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ [website-url]: https://www.comfy.org/ [discord-shield]: https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fdiscord.com%2Fapi%2Finvites%2Fcomfyorg%3Fwith_counts%3Dtrue&query=%24.approximate_member_count&logo=discord&logoColor=white&label=Discord&color=green&suffix=%20total -[discord-url]: https://www.comfy.org/discord +[discord-url]: https://discord.com/invite/comfyorg [twitter-shield]: https://img.shields.io/twitter/follow/ComfyUI [twitter-url]: https://x.com/ComfyUI From 38ebc19037cb4f341a5f21c676486dd42299d8ed Mon Sep 17 00:00:00 2001 From: Pauan Date: Thu, 21 May 2026 20:01:12 -0700 Subject: [PATCH 28/45] Adding in And, Or, and Not nodes. (#14004) --- comfy_extras/nodes_logic.py | 79 +++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py index c066064ac..65c7eebca 100644 --- a/comfy_extras/nodes_logic.py +++ b/comfy_extras/nodes_logic.py @@ -8,6 +8,82 @@ from comfy_api.latest import _io MISSING = object() +class NotNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ComfyNotNode", + display_name="Not", + category="utils/logic", + description="Logical NOT operation. Returns true if the value is falsy. Uses Python's rules for truthiness.", + search_aliases=["invert", "toggle", "negate", "flip boolean"], + inputs=[ + io.AnyType.Input("value"), + ], + outputs=[ + io.Boolean.Output(), + ], + ) + + @classmethod + def execute(cls, value) -> io.NodeOutput: + return io.NodeOutput(not value) + + +class AndNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = io.Autogrow.TemplatePrefix( + input=io.AnyType.Input("value"), + prefix="value", + min=1, + ) + return io.Schema( + node_id="ComfyAndNode", + display_name="And", + category="utils/logic", + description="Logical AND operation. Returns true if all of the values are truthy. Uses Python's rules for truthiness.", + search_aliases=["all", "every"], + inputs=[ + io.Autogrow.Input("values", template=template), + ], + outputs=[ + io.Boolean.Output(), + ], + ) + + @classmethod + def execute(cls, values: io.Autogrow.Type) -> io.NodeOutput: + return io.NodeOutput(all(values.values())) + + +class OrNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = io.Autogrow.TemplatePrefix( + input=io.AnyType.Input("value"), + prefix="value", + min=1, + ) + return io.Schema( + node_id="ComfyOrNode", + display_name="Or", + category="utils/logic", + description="Logical OR operation. Returns true if any of the values are truthy. Uses Python's rules for truthiness.", + search_aliases=["any", "some"], + inputs=[ + io.Autogrow.Input("values", template=template), + ], + outputs=[ + io.Boolean.Output(), + ], + ) + + @classmethod + def execute(cls, values: io.Autogrow.Type) -> io.NodeOutput: + return io.NodeOutput(any(values.values())) + + class SwitchNode(io.ComfyNode): @classmethod def define_schema(cls): @@ -261,6 +337,9 @@ class LogicExtension(ComfyExtension): return [ SwitchNode, CustomComboNode, + NotNode, + AndNode, + OrNode, # SoftSwitchNode, # ConvertStringToComboNode, # DCTestNode, From 93888ae8e3618a4f1aad0004ca0bfe5f80c9127a Mon Sep 17 00:00:00 2001 From: Alexis Rolland Date: Fri, 22 May 2026 13:32:08 +0800 Subject: [PATCH 29/45] Move logic nodes into utils category (#14033) --- comfy_extras/nodes_logic.py | 16 ++++++++-------- comfy_extras/nodes_math.py | 2 +- comfy_extras/nodes_toolkit.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py index 65c7eebca..342cadb69 100644 --- a/comfy_extras/nodes_logic.py +++ b/comfy_extras/nodes_logic.py @@ -91,7 +91,7 @@ class SwitchNode(io.ComfyNode): return io.Schema( node_id="ComfySwitchNode", display_name="Switch", - category="logic", + category="utils/logic", is_experimental=True, inputs=[ io.Boolean.Input("switch"), @@ -122,7 +122,7 @@ class SoftSwitchNode(io.ComfyNode): return io.Schema( node_id="ComfySoftSwitchNode", display_name="Soft Switch", - category="logic", + category="utils/logic", is_experimental=True, inputs=[ io.Boolean.Input("switch"), @@ -212,7 +212,7 @@ class DCTestNode(io.ComfyNode): return io.Schema( node_id="DCTestNode", display_name="DCTest", - category="logic", + category="utils/logic", is_output_node=True, inputs=[io.DynamicCombo.Input("combo", options=[ io.DynamicCombo.Option("option1", [io.String.Input("string")]), @@ -250,7 +250,7 @@ class AutogrowNamesTestNode(io.ComfyNode): return io.Schema( node_id="AutogrowNamesTestNode", display_name="AutogrowNamesTest", - category="logic", + category="utils/logic", inputs=[ _io.Autogrow.Input("autogrow", template=template) ], @@ -270,7 +270,7 @@ class AutogrowPrefixTestNode(io.ComfyNode): return io.Schema( node_id="AutogrowPrefixTestNode", display_name="AutogrowPrefixTest", - category="logic", + category="utils/logic", inputs=[ _io.Autogrow.Input("autogrow", template=template) ], @@ -289,7 +289,7 @@ class ComboOutputTestNode(io.ComfyNode): return io.Schema( node_id="ComboOptionTestNode", display_name="ComboOptionTest", - category="logic", + category="utils/logic", inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]), io.Combo.Input("combo2", options=["option4", "option5", "option6"])], outputs=[io.Combo.Output(), io.Combo.Output()], @@ -306,7 +306,7 @@ class ConvertStringToComboNode(io.ComfyNode): node_id="ConvertStringToComboNode", search_aliases=["string to dropdown", "text to combo"], display_name="Convert String to Combo", - category="logic", + category="utils/logic", inputs=[io.String.Input("string")], outputs=[io.Combo.Output()], ) @@ -322,7 +322,7 @@ class InvertBooleanNode(io.ComfyNode): node_id="InvertBooleanNode", search_aliases=["not", "toggle", "negate", "flip boolean"], display_name="Invert Boolean", - category="logic", + category="utils/logic", inputs=[io.Boolean.Input("boolean")], outputs=[io.Boolean.Output()], ) diff --git a/comfy_extras/nodes_math.py b/comfy_extras/nodes_math.py index 6030ee9d8..06aefa475 100644 --- a/comfy_extras/nodes_math.py +++ b/comfy_extras/nodes_math.py @@ -70,7 +70,7 @@ class MathExpressionNode(io.ComfyNode): return io.Schema( node_id="ComfyMathExpression", display_name="Math Expression", - category="logic", + category="utils", search_aliases=[ "expression", "formula", "calculate", "calculator", "eval", "math", diff --git a/comfy_extras/nodes_toolkit.py b/comfy_extras/nodes_toolkit.py index 71faf7226..ae802896b 100644 --- a/comfy_extras/nodes_toolkit.py +++ b/comfy_extras/nodes_toolkit.py @@ -14,7 +14,7 @@ class CreateList(io.ComfyNode): return io.Schema( node_id="CreateList", display_name="Create List", - category="logic", + category="utils", is_input_list=True, search_aliases=["Image Iterator", "Text Iterator", "Iterator"], inputs=[io.Autogrow.Input("inputs", template=template_autogrow)], From 1579bbb52de5b439bef0717dce723c39849c6b37 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Fri, 22 May 2026 19:07:21 +0300 Subject: [PATCH 30/45] [Partner Nodes] add new Rodin2.5 nodes (#14051) * [Partner Nodes] add new Rodin2.5 nodes Signed-off-by: bigcat88 * [Partner Nodes] fixed Quality Mesh Options Signed-off-by: bigcat88 * [Partner Nodes] fix: remove non-supported "usdz" Signed-off-by: bigcat88 * [Partner Nodes] fix: always pass seed to server Signed-off-by: bigcat88 * [Partner Nodes] fix: set the default "material" value to "Shaded" Signed-off-by: bigcat88 --------- Signed-off-by: bigcat88 --- comfy_api_nodes/apis/rodin.py | 56 ++- comfy_api_nodes/nodes_rodin.py | 671 ++++++++++++++++++++++++++++++--- 2 files changed, 661 insertions(+), 66 deletions(-) diff --git a/comfy_api_nodes/apis/rodin.py b/comfy_api_nodes/apis/rodin.py index fc26a6e73..24524d642 100644 --- a/comfy_api_nodes/apis/rodin.py +++ b/comfy_api_nodes/apis/rodin.py @@ -1,7 +1,5 @@ -from __future__ import annotations - from enum import Enum -from typing import Optional, List + from pydantic import BaseModel, Field @@ -11,44 +9,76 @@ class Rodin3DGenerateRequest(BaseModel): material: str = Field(..., description="The material type.") quality_override: int = Field(..., description="The poly count of the mesh.") mesh_mode: str = Field(..., description="It controls the type of faces of generated models.") - TAPose: Optional[bool] = Field(None, description="") + TAPose: bool | None = Field(None, description="") + + +class Rodin3DGen25Request(BaseModel): + + tier: str = Field(..., description="Gen-2.5 tier (e.g. Gen-2.5-High).") + prompt: str | None = Field(None, description="Required for Text-to-3D; ignored otherwise.") + seed: int | None = Field(None, description="0-65535.") + material: str | None = Field(None, description="PBR | Shaded | All | None.") + geometry_file_format: str | None = Field(None, description="glb | usdz | fbx | obj | stl.") + texture_mode: str | None = Field(None, description="legacy | extreme-low | low | medium | high.") + mesh_mode: str | None = Field(None, description="Raw (triangular) | Quad.") + quality_override: int | None = Field(None, description="Mesh face count override.") + geometry_instruct_mode: str | None = Field(None, description="faithful | creative.") + bbox_condition: list[int] | None = Field(None, description="Bounding box [Width(Y), Height(Z), Length(X)] in cm.") + height: int | None = Field(None, description="Approximate model height in cm.") + TAPose: bool | None = Field(None, description="T/A pose for human-like models.") + hd_texture: bool | None = Field(None, description="Enhanced texture quality.") + texture_delight: bool | None = Field(None, description="Remove baked lighting from textures.") + is_micro: bool | None = Field(None, description="Micro detail (Extreme-High only).") + use_original_alpha: bool | None = Field(None, description="Preserve image transparency.") + preview_render: bool | None = Field(None, description="Generate high-quality preview render.") + addons: list[str] | None = Field(None, description='Optional addons, e.g. ["HighPack"].') + class GenerateJobsData(BaseModel): - uuids: List[str] = Field(..., description="str LIST") + uuids: list[str] = Field(..., description="str LIST") subscription_key: str = Field(..., description="subscription key") + class Rodin3DGenerateResponse(BaseModel): - message: Optional[str] = Field(None, description="Return message.") - prompt: Optional[str] = Field(None, description="Generated Prompt from image.") - submit_time: Optional[str] = Field(None, description="Submit Time") - uuid: Optional[str] = Field(None, description="Task str") - jobs: Optional[GenerateJobsData] = Field(None, description="Details of jobs") + message: str | None = Field(None, description="Return message.") + prompt: str | None = Field(None, description="Generated Prompt from image.") + submit_time: str | None = Field(None, description="Submit Time") + uuid: str | None = Field(None, description="Task str") + jobs: GenerateJobsData | None = Field(None, description="Details of jobs") + class JobStatus(str, Enum): """ Status for jobs """ + Done = "Done" Failed = "Failed" Generating = "Generating" Waiting = "Waiting" + class Rodin3DCheckStatusRequest(BaseModel): subscription_key: str = Field(..., description="subscription from generate endpoint") + class JobItem(BaseModel): uuid: str = Field(..., description="uuid") - status: JobStatus = Field(...,description="Status Currently") + status: JobStatus = Field(..., description="Status Currently") + class Rodin3DCheckStatusResponse(BaseModel): - jobs: List[JobItem] = Field(..., description="Job status List") + jobs: list[JobItem] = Field(..., description="Job status List") + class Rodin3DDownloadRequest(BaseModel): task_uuid: str = Field(..., description="Task str") + class RodinResourceItem(BaseModel): url: str = Field(..., description="Download Url") name: str = Field(..., description="File name with ext") + class Rodin3DDownloadResponse(BaseModel): - list: List[RodinResourceItem] = Field(..., description="Source List") + items: list[RodinResourceItem] = Field(..., alias="list", description="Source List") diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index 2b829b8db..2df5a3e13 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -5,32 +5,37 @@ Rodin API docs: https://developer.hyper3d.ai/ """ -from inspect import cleandoc -import folder_paths as comfy_paths -import os import logging import math +import os +from inspect import cleandoc from io import BytesIO -from typing_extensions import override +from typing import Any + +import aiohttp from PIL import Image +from typing_extensions import override + +import folder_paths as comfy_paths +from comfy_api.latest import IO, ComfyExtension, Types from comfy_api_nodes.apis.rodin import ( - Rodin3DGenerateRequest, - Rodin3DGenerateResponse, + JobStatus, Rodin3DCheckStatusRequest, Rodin3DCheckStatusResponse, Rodin3DDownloadRequest, Rodin3DDownloadResponse, - JobStatus, + Rodin3DGen25Request, + Rodin3DGenerateRequest, + Rodin3DGenerateResponse, ) from comfy_api_nodes.util import ( - sync_op, - poll_op, ApiEndpoint, download_url_to_bytesio, download_url_to_file_3d, + poll_op, + sync_op, + validate_string, ) -from comfy_api.latest import ComfyExtension, IO, Types - COMMON_PARAMETERS = [ IO.Int.Input( @@ -51,40 +56,30 @@ COMMON_PARAMETERS = [ ] -def get_quality_mode(poly_count): - polycount = poly_count.split("-") - poly = polycount[1] - count = polycount[0] - if poly == "Triangle": - mesh_mode = "Raw" - elif poly == "Quad": - mesh_mode = "Quad" - else: - mesh_mode = "Quad" - - if count == "4K": - quality_override = 4000 - elif count == "8K": - quality_override = 8000 - elif count == "18K": - quality_override = 18000 - elif count == "50K": - quality_override = 50000 - elif count == "2K": - quality_override = 2000 - elif count == "20K": - quality_override = 20000 - elif count == "150K": - quality_override = 150000 - elif count == "500K": - quality_override = 500000 - else: - quality_override = 18000 - - return mesh_mode, quality_override +_QUALITY_MESH_OPTIONS: dict[str, tuple[str, int]] = { + "4K-Quad": ("Quad", 4000), + "8K-Quad": ("Quad", 8000), + "18K-Quad": ("Quad", 18000), + "50K-Quad": ("Quad", 50000), + "200K-Quad": ("Quad", 200000), + "2K-Triangle": ("Raw", 2000), + "20K-Triangle": ("Raw", 20000), + "150K-Triangle": ("Raw", 150000), + "200K-Triangle": ("Raw", 200000), + "500K-Triangle": ("Raw", 500000), + "1M-Triangle": ("Raw", 1000000), +} -def tensor_to_filelike(tensor, max_pixels: int = 2048*2048): +def get_quality_mode(poly_count: str) -> tuple[str, int]: + """Map a polygon-count preset like '18K-Quad' to (mesh_mode, quality_override). + + Falls back to ('Quad', 18000) for unknown labels; legacy parity. + """ + return _QUALITY_MESH_OPTIONS.get(poly_count, ("Quad", 18000)) + + +def tensor_to_filelike(tensor, max_pixels: int = 2048 * 2048): """ Converts a PyTorch tensor to a file-like object. @@ -96,8 +91,8 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048*2048): - io.BytesIO: A file-like object containing the image data. """ array = tensor.cpu().numpy() - array = (array * 255).astype('uint8') - image = Image.fromarray(array, 'RGB') + array = (array * 255).astype("uint8") + image = Image.fromarray(array, "RGB") original_width, original_height = image.size original_pixels = original_width * original_height @@ -112,7 +107,7 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048*2048): image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) img_byte_arr = BytesIO() - image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression + image.save(img_byte_arr, format="PNG") # PNG is used for lossless compression img_byte_arr.seek(0) return img_byte_arr @@ -145,11 +140,9 @@ async def create_generate_task( TAPose=ta_pose, ), files=[ - ( - "images", - open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image) - ) - for image in images if image is not None + ("images", open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image)) + for image in images + if image is not None ], content_type="multipart/form-data", ) @@ -177,6 +170,7 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str: return "DONE" return "Generating" + def extract_progress(response: Rodin3DCheckStatusResponse) -> int | None: if not response.jobs: return None @@ -214,7 +208,7 @@ async def download_files(url_list, task_uuid: str) -> tuple[str | None, Types.Fi model_file_path = None file_3d = None - for i in url_list.list: + for i in url_list.items: file_path = os.path.join(save_path, i.name) if i.name.lower().endswith(".glb"): model_file_path = os.path.join(result_folder_name, i.name) @@ -489,7 +483,16 @@ class Rodin3D_Gen2(IO.ComfyNode): IO.Combo.Input("Material_Type", options=["PBR", "Shaded"], default="PBR", optional=True), IO.Combo.Input( "Polygon_count", - options=["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"], + options=[ + "4K-Quad", + "8K-Quad", + "18K-Quad", + "50K-Quad", + "2K-Triangle", + "20K-Triangle", + "150K-Triangle", + "500K-Triangle", + ], default="500K-Triangle", optional=True, ), @@ -542,6 +545,566 @@ class Rodin3D_Gen2(IO.ComfyNode): return IO.NodeOutput(model_path, file_3d) +def _rodin_multipart_parser(data: dict[str, Any]) -> aiohttp.FormData: + """Convert a Rodin request dict to an aiohttp form, fixing bool/list serialization. + + Booleans --> "true"/"false". Lists --> one field per element. + """ + form = aiohttp.FormData(default_to_multipart=True) + for key, value in data.items(): + if value is None: + continue + if isinstance(value, bool): + form.add_field(key, "true" if value else "false") + elif isinstance(value, list): + for item in value: + form.add_field(key, str(item)) + elif isinstance(value, (bytes, bytearray)): + form.add_field(key, value) + else: + form.add_field(key, str(value)) + return form + + +async def _create_gen25_task( + cls: type[IO.ComfyNode], + request: Rodin3DGen25Request, + images: list | None, +) -> tuple[str, str]: + """Submit a Gen-2.5 generate job; returns (task_uuid, subscription_key).""" + + if images is not None and len(images) > 5: + raise ValueError("Rodin Gen-2.5 supports at most 5 input images.") + + files = None + if images: + files = [ + ( + "images", + open(image, "rb") if isinstance(image, str) else tensor_to_filelike(image), + ) + for image in images + if image is not None + ] + + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"), + response_model=Rodin3DGenerateResponse, + data=request, + files=files, + content_type="multipart/form-data", + multipart_parser=_rodin_multipart_parser, + ) + + if not response.uuid or not response.jobs or not response.jobs.subscription_key: + raise RuntimeError(f"Rodin Gen-2.5 submit failed: message={response.message!r}") + return response.uuid, response.jobs.subscription_key + + +_PREVIEWABLE_3D_EXTS = {".glb", ".obj", ".fbx", ".stl", ".gltf"} + + +async def _download_gen25_files( + download_list: Rodin3DDownloadResponse, + task_uuid: str, + geometry_file_format: str, +) -> Types.File3D | None: + """Download every file in the list; return the File3D matching the chosen format.""" + + folder_name = f"Rodin3D_Gen25_{task_uuid}" + save_dir = os.path.join(comfy_paths.get_output_directory(), folder_name) + os.makedirs(save_dir, exist_ok=True) + + target_ext = f".{geometry_file_format.lower().lstrip('.')}" + file_3d: Types.File3D | None = None + + for item in download_list.items: + file_path = os.path.join(save_dir, item.name) + ext = os.path.splitext(item.name.lower())[1] + # Prefer the file matching the user's chosen format; fall back below. + if file_3d is None and ext == target_ext and ext in _PREVIEWABLE_3D_EXTS: + file_3d = await download_url_to_file_3d(item.url, target_ext.lstrip(".")) + with open(file_path, "wb") as f: + f.write(file_3d.get_bytes()) + continue + await download_url_to_bytesio(item.url, file_path) + + # If the chosen format wasn't found, surface any model file we did get. + if file_3d is None: + for item in download_list.items: + ext = os.path.splitext(item.name.lower())[1] + if ext in _PREVIEWABLE_3D_EXTS: + file_3d = await download_url_to_file_3d(item.url, ext.lstrip(".")) + break + return file_3d + + +_MODE_REGULAR = "Regular" +_MODE_FAST = "Fast" +_MODE_EXTREME_HIGH = "Extreme-High" + +_REGULAR_POLY_OPTIONS = [ + "Default", + "4K-Quad", + "8K-Quad", + "18K-Quad", + "50K-Quad", + "2K-Triangle", + "20K-Triangle", + "150K-Triangle", + "500K-Triangle", + "1M-Triangle", +] + +_TEXTURE_MODE_OPTIONS = ["Default", "legacy", "extreme-low", "low", "medium", "high"] +_GEOMETRY_FORMAT_OPTIONS = ["glb", "fbx", "obj", "stl"] +_MATERIAL_OPTIONS = ["PBR", "Shaded", "All", "None"] + + +def _build_mode_input(name: str = "mode") -> IO.DynamicCombo.Input: + return IO.DynamicCombo.Input( + name, + options=[ + IO.DynamicCombo.Option( + _MODE_REGULAR, + [ + IO.Combo.Input( + "tier", + options=["Gen-2.5-Low", "Gen-2.5-Medium", "Gen-2.5-High"], + default="Gen-2.5-High", + tooltip="Quality tier. Higher tiers produce higher-fidelity geometry.", + ), + IO.Combo.Input( + "polygon_count", + options=_REGULAR_POLY_OPTIONS, + default="Default", + tooltip="Preset face count. 'Default' uses the server's default for the selected tier.", + ), + IO.Boolean.Input( + "creative", + default=False, + tooltip="Creative mode (Medium/High only). Enhances generative robustness.", + ), + ], + ), + IO.DynamicCombo.Option( + _MODE_FAST, + [ + IO.Combo.Input( + "tier", + options=[ + "Gen-2.5-Extreme-Low", + "Gen-2.5-Low", + "Gen-2.5-Medium", + "Gen-2.5-High", + ], + default="Gen-2.5-Low", + ), + IO.Int.Input( + "mesh_faces", + default=20000, + min=1000, + max=20000, + display_mode=IO.NumberDisplay.number, + tooltip="Mesh face count (1K-20K in Fast mode).", + ), + ], + ), + IO.DynamicCombo.Option( + _MODE_EXTREME_HIGH, + [ + IO.Combo.Input("mesh_mode", options=["Raw", "Quad"], default="Raw"), + IO.Int.Input( + "mesh_faces", + default=1000000, + min=20000, + max=2000000, + display_mode=IO.NumberDisplay.number, + tooltip=( + "Mesh face count. Raw mode: 20K-2M. " + "Quad mode: keep under 200K (upstream may reject higher values)." + ), + ), + IO.Boolean.Input( + "is_micro", + default=False, + tooltip="Enable micro detail (Extreme-High only).", + ), + IO.Boolean.Input( + "creative", + default=False, + tooltip="Creative mode. Enhances generative robustness.", + ), + ], + ), + ], + tooltip=( + "Generation mode. Regular = balanced. Fast = 1K-20K faces for rapid prototyping. " + "Extreme-High = 20K-2M faces with optional micro details." + ), + ) + + +def _build_common_inputs(*, include_image_only: bool) -> list: + inputs: list = [ + IO.Combo.Input("material", options=_MATERIAL_OPTIONS, default="Shaded"), + IO.Combo.Input("geometry_file_format", options=_GEOMETRY_FORMAT_OPTIONS, default="glb"), + IO.Combo.Input( + "texture_mode", + options=_TEXTURE_MODE_OPTIONS, + default="Default", + optional=True, + tooltip="Texture quality preset. 'Default' uses the server's default for the selected tier.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=65535, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + optional=True, + ), + IO.Boolean.Input( + "TAPose", default=False, optional=True, advanced=True, tooltip="T/A pose for human-like models." + ), + IO.Boolean.Input( + "hd_texture", default=False, optional=True, advanced=True, tooltip="High-quality texture enhancement." + ), + IO.Boolean.Input( + "texture_delight", + default=False, + optional=True, + advanced=True, + tooltip="Remove baked lighting from textures.", + ), + ] + if include_image_only: + inputs.append( + IO.Boolean.Input( + "use_original_alpha", + default=False, + optional=True, + advanced=True, + tooltip="Preserve image transparency.", + ) + ) + inputs.extend( + [ + IO.Boolean.Input( + "addon_highpack", + default=False, + optional=True, + advanced=True, + tooltip="HighPack addon: 4K textures and ~16x faces in Quad mode.", + ), + IO.Int.Input( + "bbox_width", + default=0, + min=0, + max=300, + display_mode=IO.NumberDisplay.number, + optional=True, + advanced=True, + tooltip="Bounding-box width (Y axis). Set to 0 with the others to skip bbox.", + ), + IO.Int.Input( + "bbox_height", + default=0, + min=0, + max=300, + display_mode=IO.NumberDisplay.number, + optional=True, + advanced=True, + tooltip="Bounding-box height (Z axis).", + ), + IO.Int.Input( + "bbox_length", + default=0, + min=0, + max=300, + display_mode=IO.NumberDisplay.number, + optional=True, + advanced=True, + tooltip="Bounding-box length (X axis).", + ), + IO.Int.Input( + "height_cm", + default=0, + min=0, + max=10000, + display_mode=IO.NumberDisplay.number, + optional=True, + advanced=True, + tooltip="Approximate model height in centimeters (0 to skip).", + ), + ] + ) + return inputs + + +_PRICE_EXPR = """ +( + $baseCredits := widgets.mode = "extreme-high" ? 1.0 : 0.5; + $addonCredits := widgets.addon_highpack ? 1.0 : 0.0; + $total := ($baseCredits * 1.5) + ($addonCredits * 0.8); + {"type":"usd","usd": $total} +) +""" + + +def _resolve_mode_params(mode_input: dict) -> dict: + """Translate the DynamicCombo `mode` payload into Gen-2.5 request fields. + + Returns a dict with: tier, quality_override, mesh_mode, geometry_instruct_mode, is_micro. + Missing keys mean "do not send" (so we don't override server defaults). + """ + selected = mode_input["mode"] + out: dict = {} + + if selected == _MODE_REGULAR: + out["tier"] = mode_input["tier"] + polygon = mode_input.get("polygon_count", "Default") + if polygon != "Default": + mesh_mode, faces = get_quality_mode(polygon) + out["mesh_mode"] = mesh_mode + out["quality_override"] = faces + if mode_input.get("creative"): + out["geometry_instruct_mode"] = "creative" + + elif selected == _MODE_FAST: + out["tier"] = mode_input["tier"] + out["mesh_mode"] = "Raw" + out["quality_override"] = int(mode_input["mesh_faces"]) + + elif selected == _MODE_EXTREME_HIGH: + out["tier"] = "Gen-2.5-Extreme-High" + out["mesh_mode"] = mode_input["mesh_mode"] + out["quality_override"] = int(mode_input["mesh_faces"]) + if mode_input.get("is_micro"): + out["is_micro"] = True + if mode_input.get("creative"): + out["geometry_instruct_mode"] = "creative" + return out + + +def _build_request( + *, + mode_input: dict, + material: str, + geometry_file_format: str, + texture_mode: str, + seed: int, + TAPose: bool, + hd_texture: bool, + texture_delight: bool, + addon_highpack: bool, + bbox_width: int, + bbox_height: int, + bbox_length: int, + height_cm: int, + prompt: str | None = None, + use_original_alpha: bool = False, +) -> Rodin3DGen25Request: + mode_params = _resolve_mode_params(mode_input) + + bbox = None + if bbox_width and bbox_height and bbox_length: + bbox = [bbox_width, bbox_height, bbox_length] + + return Rodin3DGen25Request( + tier=mode_params["tier"], + prompt=prompt or None, + seed=seed, + material=material, + geometry_file_format=geometry_file_format, + texture_mode=None if texture_mode == "Default" else texture_mode, + mesh_mode=mode_params.get("mesh_mode"), + quality_override=mode_params.get("quality_override"), + geometry_instruct_mode=mode_params.get("geometry_instruct_mode"), + bbox_condition=bbox, + height=height_cm or None, + TAPose=TAPose or None, + hd_texture=hd_texture or None, + texture_delight=texture_delight or None, + is_micro=mode_params.get("is_micro"), + use_original_alpha=use_original_alpha or None, + addons=["HighPack"] if addon_highpack else None, + ) + + +class Rodin3D_Gen25_Image(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="Rodin3D_Gen25_Image", + display_name="Rodin 3D Gen-2.5 - Image to 3D", + category="api node/3d/Rodin", + description=( + "Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. " + "Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost." + ), + inputs=[ + IO.Autogrow.Input( + "images", + template=IO.Autogrow.TemplatePrefix(IO.Image.Input("image"), prefix="image", min=1, max=5), + tooltip="1-5 images. The first image is used for materials when multi-view.", + ), + _build_mode_input(), + *_build_common_inputs(include_image_only=True), + ], + outputs=[IO.File3DAny.Output(display_name="model_file")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode", "addon_highpack"]), + expr=_PRICE_EXPR, + ), + ) + + @classmethod + async def execute( + cls, + images: IO.Autogrow.Type, + mode: dict, + material: str, + geometry_file_format: str, + texture_mode: str, + seed: int, + TAPose: bool, + hd_texture: bool, + texture_delight: bool, + use_original_alpha: bool, + addon_highpack: bool, + bbox_width: int, + bbox_height: int, + bbox_length: int, + height_cm: int, + ) -> IO.NodeOutput: + image_tensors = [img for img in images.values() if img is not None] + if not image_tensors: + raise ValueError("Rodin Gen-2.5 Image-to-3D requires at least one image.") + + # Flatten multi-image tensors into individual frames; the API accepts each as a separate part. + flat_images: list = [] + for tensor in image_tensors: + if hasattr(tensor, "shape") and len(tensor.shape) == 4: + for i in range(tensor.shape[0]): + flat_images.append(tensor[i]) + else: + flat_images.append(tensor) + + if len(flat_images) > 5: + raise ValueError(f"Rodin Gen-2.5 accepts at most 5 images; received {len(flat_images)}.") + + request = _build_request( + mode_input=mode, + material=material, + geometry_file_format=geometry_file_format, + texture_mode=texture_mode, + seed=seed, + TAPose=TAPose, + hd_texture=hd_texture, + texture_delight=texture_delight, + addon_highpack=addon_highpack, + bbox_width=bbox_width, + bbox_height=bbox_height, + bbox_length=bbox_length, + height_cm=height_cm, + prompt=None, + use_original_alpha=use_original_alpha, + ) + + task_uuid, subscription_key = await _create_gen25_task(cls, request, flat_images) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) + file_3d = await _download_gen25_files(download_list, task_uuid, geometry_file_format) + return IO.NodeOutput(file_3d) + + +class Rodin3D_Gen25_Text(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="Rodin3D_Gen25_Text", + display_name="Rodin 3D Gen-2.5 - Text to 3D", + category="api node/3d/Rodin", + description=( + "Generate a 3D model from a text prompt via Rodin Gen-2.5. " + "Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost." + ), + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for the 3D model.", + ), + _build_mode_input(), + *_build_common_inputs(include_image_only=False), + ], + outputs=[IO.File3DAny.Output(display_name="model_file")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["mode", "addon_highpack"]), + expr=_PRICE_EXPR, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + mode: dict, + material: str, + geometry_file_format: str, + texture_mode: str, + seed: int, + TAPose: bool, + hd_texture: bool, + texture_delight: bool, + addon_highpack: bool, + bbox_width: int, + bbox_height: int, + bbox_length: int, + height_cm: int, + ) -> IO.NodeOutput: + validate_string(prompt, field_name="prompt", min_length=1, max_length=2500) + request = _build_request( + mode_input=mode, + material=material, + geometry_file_format=geometry_file_format, + texture_mode=texture_mode, + seed=seed, + TAPose=TAPose, + hd_texture=hd_texture, + texture_delight=texture_delight, + addon_highpack=addon_highpack, + bbox_width=bbox_width, + bbox_height=bbox_height, + bbox_length=bbox_length, + height_cm=height_cm, + prompt=prompt, + ) + task_uuid, subscription_key = await _create_gen25_task(cls, request, images=None) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) + file_3d = await _download_gen25_files(download_list, task_uuid, geometry_file_format) + return IO.NodeOutput(file_3d) + + class Rodin3DExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -551,6 +1114,8 @@ class Rodin3DExtension(ComfyExtension): Rodin3D_Smooth, Rodin3D_Sketch, Rodin3D_Gen2, + Rodin3D_Gen25_Image, + Rodin3D_Gen25_Text, ] From 112fcd5f3b86771d25b74a97e092856375c96daa Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Fri, 22 May 2026 14:31:43 -0700 Subject: [PATCH 31/45] openapi: align response declarations with implementation (5 endpoints) (#14058) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * openapi: align response declarations with implementation (5 endpoints) - POST /api/assets/download: replace 200 with 202 + tracking-task body (endpoint runs asynchronously and returns task_id/status/message). - POST /api/assets/export: same 200 → 202 + tracking-task body. - POST /api/assets/from-workflow: change 201 → 200 (handler responds 200, not 201; no Location header emitted). - POST /api/feedback: change 200 → 201 (creates a feedback record). - /api/jobs and /api/jobs/{job_id}: change timestamp fields from type: number to type: integer + format: int64. Values are Unix milliseconds — number causes oapi-codegen to emit float64, losing precision and producing the wrong Go type. Affected fields: create_time, update_time, execution_start_time, execution_end_time. Verification: each change reflects what the endpoint observably returns; no handler changes required. Backwards-compatible for existing clients (integer is a subset of number; status code shifts within 2xx). * openapi: align asset download/export 202 status enum with runtime + sibling schemas CodeRabbit caught a vocabulary mismatch: the two new 202 response schemas declared `[pending, running, completed, failed]` while the rest of the same spec uses `[created, running, completed, failed]` for the identical task lifecycle (download/export progress WebSocket events, /api/tasks, TaskEntry, TaskResponse — 4 sites total). Cloud's runtime emits `created` on initial creation (AssetDownloadResponseStatusCreated; task.Status sourced from the DB enum whose initial value is Created). `pending` would have introduced a fifth, contradictory vocabulary for the same lifecycle and pushed the spec further from the implementation it is meant to align with. Followup tracked separately: extract a shared TaskStatus enum so all five sites move in lockstep instead of needing per-site edits. --- openapi.yaml | 70 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 51 insertions(+), 19 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index 885231acc..8fb769bc8 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -2342,16 +2342,27 @@ paths: $ref: "#/components/schemas/AssetDownloadRequest" description: Assets to download responses: - "200": - description: Download initiated + "202": + description: Download task accepted content: application/json: schema: type: object + required: + - task_id + - status properties: task_id: type: string - description: Task ID for tracking progress via WebSocket + format: uuid + description: ID of the download task; use to poll status. + status: + type: string + enum: [created, running, completed, failed] + description: Current task status (typically `created` on initial creation). + message: + type: string + description: Human-readable task message. "400": description: Bad request content: @@ -2391,17 +2402,27 @@ paths: type: string description: Name for the export archive responses: - "200": - description: Export initiated + "202": + description: Export task accepted content: application/json: schema: type: object + required: + - task_id + - status properties: task_id: type: string - export_name: + format: uuid + description: ID of the export task; use to poll status. + status: type: string + enum: [created, running, completed, failed] + description: Current task status (typically `created` on initial creation). + message: + type: string + description: Human-readable task message. "400": description: Bad request content: @@ -2476,8 +2497,8 @@ paths: type: string description: Tags to apply to the created assets responses: - "201": - description: Assets created + "200": + description: Assets created or referenced content: application/json: schema: @@ -5056,7 +5077,7 @@ paths: additionalProperties: true description: Additional context metadata responses: - "200": + "201": description: Feedback submitted content: application/json: @@ -6102,14 +6123,17 @@ components: type: string description: Current job status create_time: - type: number - description: Job creation timestamp + type: integer + format: int64 + description: Job creation timestamp (Unix milliseconds). execution_start_time: - type: number - description: Workflow execution start timestamp + type: integer + format: int64 + description: Workflow execution start timestamp (Unix milliseconds, terminal states only). execution_end_time: - type: number - description: Workflow execution end timestamp + type: integer + format: int64 + description: Workflow execution end timestamp (Unix milliseconds, terminal states only). preview_output: type: object additionalProperties: true @@ -6141,13 +6165,21 @@ components: execution_error: $ref: "#/components/schemas/ExecutionError" create_time: - type: number + type: integer + format: int64 + description: Job creation timestamp (Unix milliseconds). update_time: - type: number + type: integer + format: int64 + description: Last state-change timestamp (Unix milliseconds). execution_start_time: - type: number + type: integer + format: int64 + description: Workflow execution start timestamp (Unix milliseconds, terminal states only). execution_end_time: - type: number + type: integer + format: int64 + description: Workflow execution end timestamp (Unix milliseconds, terminal states only). preview_output: type: object additionalProperties: true From e75b739c1d416923e5c391775838f2f9ce9e327c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 22 May 2026 15:47:03 -0700 Subject: [PATCH 32/45] Delete the source branch after doing the backport. (#14062) --- .github/workflows/backport_release.yaml | 35 +++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/.github/workflows/backport_release.yaml b/.github/workflows/backport_release.yaml index 474e7045b..ede6bde33 100644 --- a/.github/workflows/backport_release.yaml +++ b/.github/workflows/backport_release.yaml @@ -458,6 +458,41 @@ jobs: echo "Released ${NEW_VERSION} on ${RELEASE_BRANCH}." + - name: Delete remote source branch + env: + GH_TOKEN: ${{ steps.app-token.outputs.token }} + REPO: ${{ github.repository }} + SOURCE_BRANCH: ${{ steps.resolve.outputs.source_branch }} + SOURCE_COMMIT: ${{ inputs.commit }} + RELEASE_BRANCH: ${{ steps.latest.outputs.release_branch }} + DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} + run: | + set -euo pipefail + + # Belt-and-braces: the resolve step already refuses the default branch, + # but never delete the default or the release branch under any + # circumstances. + if [[ "${SOURCE_BRANCH}" == "${DEFAULT_BRANCH}" || "${SOURCE_BRANCH}" == "${RELEASE_BRANCH}" ]]; then + echo "::error::Refusing to delete '${SOURCE_BRANCH}' (matches default or release branch)." + exit 1 + fi + + # Delete the source branch on origin, but only if its tip is still the + # SHA we released from. If someone pushed new commits to it after we + # resolved it, leave it alone — those commits would be silently lost. + current_tip="$(git ls-remote origin "refs/heads/${SOURCE_BRANCH}" | awk '{print $1}')" + if [[ -z "${current_tip}" ]]; then + echo "Source branch '${SOURCE_BRANCH}' no longer exists on origin; nothing to delete." + exit 0 + fi + if [[ "${current_tip}" != "${SOURCE_COMMIT}" ]]; then + echo "::warning::Source branch '${SOURCE_BRANCH}' tip (${current_tip}) no longer matches released commit (${SOURCE_COMMIT}). Leaving it in place." + exit 0 + fi + + git push origin --delete "refs/heads/${SOURCE_BRANCH}" + echo "Deleted remote branch '${SOURCE_BRANCH}'." + - name: Summary if: always() env: From 7984a6a38eba7418dcbe6d2c977d461a84ac80f6 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Fri, 22 May 2026 16:15:18 -0700 Subject: [PATCH 33/45] openapi: rename 55 cloud-side operationIds to match runtime (PR A of 3) (#14060) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * openapi: rename 55 cloud-side operationIds to match runtime handlers For the 55 operations below, vendor's operationId did not match the name cloud's runtime handlers expect. Generated types from vendor therefore had different names (e.g. CreateSubscription200JSONResponse) than what cloud handlers reference (Subscribe200JSONResponse), which blocks the post-cutover combined-spec codegen. All 55 renames target the cloud-runtime-authoritative name. Several of these endpoints are shared concepts (queue, settings, userdata, object_info) that OSS local also serves — the rename aligns vendor with the longstanding cloud handler-side convention to unblock the shared codegen. No request/response *shape* changes in this PR; only operationId labels. Notable categories: - Billing/subscriptions: 7 renames (subscribe, getBillingPlans, ...) - Workspace + workflows: 13 renames (createWorkflow, ...) - Hub: 3 renames - Auth/users: 5 renames - Shared OSS surface (settings, queue, view, userdata): 12 renames - Misc cloud-only: 15 renames Identified via Comfy-Org/cloud's TestCutoverSafe build-safety gate (BE-1106), which compares handler type references against codegen output from the combined spec. * fix(openapi): resolve getHistory operationId collision Spectral flagged: both /api/history (OSS local) and /api/history_v2 (cloud) had operationId 'getHistory' after the rename. Rename vendor's /api/history to 'getPromptHistory' to disambiguate. Cloud's runtime denies /api/history at the overlay level so combined codegen is unaffected by this change. * openapi: add 41 cloud-runtime schemas to components.schemas (PR B of 3) (#14061) * openapi: add 41 cloud-runtime schemas to components.schemas (cutover prep) Adds schemas that exist in Comfy-Org/cloud's hand-written ingest spec but not yet in this vendored OSS spec. All tagged x-runtime: [cloud] per the field-drift convention and prefixed with [cloud-only] in the description. These schemas are referenced by cloud's Go handlers via the generated ingest. Go type names. Codegen from the vendored spec didn't produce those types because the schemas weren't declared here. Adding them unblocks the post-cutover combined-spec codegen. Schemas added (alphabetical): AssetDownloadResponse, AssetMetadataResponse, BillingBalanceResponse, BillingPlansResponse, BillingStatusResponse, GetUserDataResponseFull, HistoryDetailEntry, HistoryDetailResponse, HistoryResponse, HubLabelInfo, HubProfileSummary, HubWorkflowListResponse, HubWorkflowStatus, HubWorkflowSummary, HubWorkflowTemplateEntry, JobStatusResponse, JobsListResponse, LabelRef, LogsResponse, Member, OAuthRegisterBadRequestResponse, PendingInvite, Plan, PlanAvailability, PlanAvailabilityReason, PlanSeatSummary, PreviewPlanInfo, PreviewSubscribeResponse, PublishedWorkflowDetail, SecretResponse, SubscriptionDuration, SubscriptionTier, UserDataResponseFull, ValidationError, ValidationResult, WorkflowForkedFrom, WorkflowResponse, WorkflowVersionContentResponse, WorkspaceAPIKeyInfo, WorkspaceSummary, WorkspaceWithRole Identified via Comfy-Org/cloud's TestCutoverSafe build-safety gate (BE-1106). Companion to PR #14060 (operationId renames). * fix(openapi): add BindingErrorResponse schema OAuthRegisterBadRequestResponse references BindingErrorResponse but that schema wasn't in the original add. Adding it now as a cloud-only schema matching the cloud runtime's binding-error shape (single 'message' string field). * openapi: add missing 4xx/5xx response bodies for cloud-emitting endpoints (#14063) Vendor declares shared endpoints (e.g. /api/queue, /api/settings, /api/assets/*, /api/billing/*) with success responses but is missing many of the 4xx/5xx error response bodies that Comfy-Org/cloud's runtime actually emits. Cloud's Go handlers reference the generated ingest.OpJSONResponse types for these missing statuses, which currently fail to resolve when codegen runs against the vendored spec. This PR adds 237 response entries across 117 operations, restoring the documented error responses that cloud emits. Bodies are copied verbatim from Comfy-Org/cloud's hand-written ingest spec (services/ingest/openapi.yaml) and reference a new ErrorResponse schema also added in this PR (matches cloud's {code, message} runtime shape, tagged x-runtime: [cloud]). ErrorResponse is intentionally separate from the existing CloudError schema. CloudError's shape ({error}) describes one runtime; cloud emits a different shape ({code, message}). Existing CloudError refs in vendor are untouched; new cloud-emitting error references use ErrorResponse. Identified via Comfy-Org/cloud's TestCutoverSafe build-safety gate (BE-1106). Companion to PR #14060 (operationId renames) and PR #14061 (cloud-only schema additions). --- openapi.yaml | 2737 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 2680 insertions(+), 57 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index 8fb769bc8..59b6817e5 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -104,6 +104,8 @@ paths: responses: "101": description: WebSocket upgrade successful + '401': + description: Unauthorized x-websocket-messages: - type: status schema: @@ -170,6 +172,18 @@ paths: application/json: schema: $ref: "#/components/schemas/PromptInfo" + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' post: operationId: executePrompt tags: [prompt] @@ -195,12 +209,36 @@ paths: schema: $ref: "#/components/schemas/PromptErrorResponse" + '402': + description: Payment required - Insufficient credits + content: + application/json: + schema: + $ref: '#/components/schemas/PromptErrorResponse' + '429': + description: Payment required - User has not paid + content: + application/json: + schema: + $ref: '#/components/schemas/PromptErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/PromptErrorResponse' + '503': + description: Service unavailable + content: + application/json: + schema: + $ref: '#/components/schemas/PromptErrorResponse' # --------------------------------------------------------------------------- # Queue # --------------------------------------------------------------------------- /api/queue: get: - operationId: getQueue + operationId: getQueueInfo tags: [queue] summary: Get running and pending queue items description: Returns the server's current execution queue, split into the currently-running prompt and the list of pending prompts. @@ -211,6 +249,18 @@ paths: application/json: schema: $ref: "#/components/schemas/QueueInfo" + '400': + description: Invalid request parameters + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Invalid request parameters + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' post: operationId: manageQueue tags: [queue] @@ -226,9 +276,27 @@ paths: "200": description: Queue updated + '400': + description: Invalid request parameters + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/interrupt: post: - operationId: interruptExecution + operationId: interruptJob tags: [queue] summary: Interrupt current execution description: Interrupts the prompt that is currently executing. The next queued prompt (if any) will start immediately after. @@ -247,6 +315,18 @@ paths: "200": description: Interrupt signal sent + '401': + description: Unauthorized - Authentication required + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/free: post: operationId: freeMemory @@ -327,9 +407,21 @@ paths: pagination: $ref: "#/components/schemas/PaginationInfo" + '401': + description: Unauthorized - Authentication required + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/jobs/{job_id}: get: - operationId: getJob + operationId: getJobDetail tags: [queue] summary: Get a single job by ID description: Returns the full record for a single completed prompt execution, including its outputs, status, and metadata. @@ -351,12 +443,30 @@ paths: "404": description: Job not found + '401': + description: Unauthorized - Authentication required + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '403': + description: Forbidden - Job does not belong to user + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' # --------------------------------------------------------------------------- # History # --------------------------------------------------------------------------- /api/history: get: - operationId: getHistory + operationId: getPromptHistory tags: [history] summary: Get execution history deprecated: true @@ -388,6 +498,8 @@ paths: type: object additionalProperties: $ref: "#/components/schemas/HistoryEntry" + '404': + description: "Not Found \u2014 use /api/history_v2 instead" post: operationId: manageHistory tags: [history] @@ -409,6 +521,24 @@ paths: "200": description: History updated + '400': + description: Invalid request parameters + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized - Authentication required + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/history/{prompt_id}: get: operationId: getHistoryByPromptId @@ -438,6 +568,8 @@ paths: additionalProperties: $ref: "#/components/schemas/HistoryEntry" + '404': + description: "Not Found \u2014 use /api/jobs/{prompt_id} instead" # --------------------------------------------------------------------------- # Upload # --------------------------------------------------------------------------- @@ -481,6 +613,18 @@ paths: "400": description: No file provided or invalid request + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/upload/mask: post: operationId: uploadMask @@ -539,6 +683,18 @@ paths: "400": description: No file provided or invalid request + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' # --------------------------------------------------------------------------- # View # --------------------------------------------------------------------------- @@ -601,6 +757,33 @@ paths: "404": description: File not found + '302': + description: Redirect to GCS signed URL + headers: + Location: + description: Signed URL to access the file in GCS + schema: + type: string + Cache-Control: + description: Cache directive for the redirect response + schema: + type: string + Vary: + description: Headers that affect response caching + schema: + type: string + '400': + description: Invalid request parameters + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/view_metadata/{folder_name}: get: operationId: viewMetadata @@ -648,6 +831,12 @@ paths: schema: $ref: "#/components/schemas/SystemStatsResponse" + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/features: get: operationId: getFeatures @@ -706,7 +895,7 @@ paths: # --------------------------------------------------------------------------- /api/object_info: get: - operationId: getObjectInfo + operationId: getNodeInfo tags: [node] summary: Get all node definitions description: | @@ -782,6 +971,8 @@ paths: items: type: string + '404': + description: "Not Found \u2014 use /api/experiment/models instead" /api/models/{folder}: get: operationId: getModelsByFolder @@ -809,7 +1000,7 @@ paths: /api/experiment/models: get: - operationId: getExperimentModels + operationId: getModelFolders tags: [model] summary: List model folders with paths description: Returns an array of model folder objects with name and folder paths. @@ -823,9 +1014,15 @@ paths: items: $ref: "#/components/schemas/ModelFolder" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/experiment/models/{folder}: get: - operationId: getExperimentModelsByFolder + operationId: getModelsInFolder tags: [model] summary: List model files with metadata description: Returns the model files in the given folder with richer metadata (path index, mtime, size) than the legacy `/api/models/{folder}` endpoint. @@ -848,6 +1045,12 @@ paths: "404": description: Unknown folder type + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/experiment/models/preview/{folder}/{path_index}/{filename}: get: operationId: getModelPreview @@ -884,12 +1087,18 @@ paths: "404": description: Preview not found + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' # --------------------------------------------------------------------------- # Users # --------------------------------------------------------------------------- /api/users: get: - operationId: getUsers + operationId: getUsersInfo tags: [user] summary: Get user storage info description: | @@ -917,6 +1126,12 @@ paths: additionalProperties: type: string description: Map of user_id to directory name (multi-user) + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' post: operationId: createUser tags: [user] @@ -952,7 +1167,7 @@ paths: # --------------------------------------------------------------------------- /api/userdata: get: - operationId: listUserdata + operationId: getUserdata tags: [userdata] summary: List files in a userdata directory description: Lists files in the authenticated user's data directory. Returns either filename strings or full objects depending on the `full_info` query parameter. @@ -989,6 +1204,24 @@ paths: "404": description: Directory not found + '400': + description: Bad request (e.g., invalid filename). + content: + text/plain: + schema: + type: string + '401': + description: Unauthorized. + content: + text/plain: + schema: + type: string + '500': + description: General error + content: + text/plain: + schema: + type: string /api/v2/userdata: get: operationId: listUserdataV2 @@ -1025,6 +1258,8 @@ paths: type: number description: Unix timestamp + '404': + description: "Not Found \u2014 use /api/userdata instead" /api/userdata/{file}: get: operationId: getUserdataFile @@ -1049,8 +1284,26 @@ paths: format: binary "404": description: File not found + '400': + description: Bad request (e.g., invalid filename). + content: + text/plain: + schema: + type: string + '401': + description: Unauthorized. + content: + text/plain: + schema: + type: string + '500': + description: General error + content: + text/plain: + schema: + type: string post: - operationId: writeUserdataFile + operationId: postUserdataFile tags: [userdata] summary: Write or create a userdata file description: Writes (creates or replaces) a file in the authenticated user's data directory. @@ -1090,6 +1343,30 @@ paths: $ref: "#/components/schemas/UserDataResponse" "409": description: File exists and overwrite not set + '400': + description: Missing or invalid 'file' parameter. + content: + text/plain: + schema: + type: string + '401': + description: Unauthorized. + content: + text/plain: + schema: + type: string + '403': + description: The requested path is not allowed. + content: + text/plain: + schema: + type: string + '500': + description: General error + content: + text/plain: + schema: + type: string delete: operationId: deleteUserdataFile tags: [userdata] @@ -1109,6 +1386,18 @@ paths: "404": description: File not found + '401': + description: Unauthorized. + content: + text/plain: + schema: + type: string + '500': + description: Internal server error. + content: + text/plain: + schema: + type: string /api/userdata/{file}/move/{dest}: post: operationId: moveUserdataFile @@ -1151,12 +1440,30 @@ paths: "409": description: Destination exists and overwrite not set + '400': + description: Missing or invalid parameters. + content: + text/plain: + schema: + type: string + '401': + description: Unauthorized. + content: + text/plain: + schema: + type: string + '500': + description: General error + content: + text/plain: + schema: + type: string # --------------------------------------------------------------------------- # Settings # --------------------------------------------------------------------------- /api/settings: get: - operationId: getSettings + operationId: getAllSettings tags: [settings] summary: Get all user settings description: Returns all settings for the authenticated user. @@ -1170,8 +1477,14 @@ paths: schema: type: object additionalProperties: true + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' post: - operationId: updateSettings + operationId: updateMultipleSettings tags: [settings] summary: Update user settings (partial merge) description: Replaces the authenticated user's settings with the provided object. @@ -1189,9 +1502,21 @@ paths: "200": description: Settings updated + '400': + description: Invalid request + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/settings/{id}: get: - operationId: getSetting + operationId: getSettingById tags: [settings] summary: Get a single setting by key description: Returns the value of a single setting, identified by key. @@ -1211,8 +1536,20 @@ paths: schema: nullable: true description: The setting value (any JSON type), or null if not set + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Setting not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' post: - operationId: updateSetting + operationId: updateSettingById tags: [settings] summary: Set a single setting value description: Sets the value of a single setting, identified by key. @@ -1234,6 +1571,18 @@ paths: "200": description: Setting updated + '400': + description: Invalid request + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' # --------------------------------------------------------------------------- # Extensions / Templates / i18n # --------------------------------------------------------------------------- @@ -1308,6 +1657,12 @@ paths: additionalProperties: $ref: "#/components/schemas/GlobalSubgraphInfo" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/global_subgraphs/{id}: get: operationId: getGlobalSubgraph @@ -1331,6 +1686,12 @@ paths: "404": description: Subgraph not found + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' # --------------------------------------------------------------------------- # Node Replacements # --------------------------------------------------------------------------- @@ -1351,6 +1712,12 @@ paths: type: object additionalProperties: true + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' # --------------------------------------------------------------------------- # Internal (x-internal: true) # --------------------------------------------------------------------------- @@ -1454,7 +1821,7 @@ paths: /internal/files/{directory_type}: get: - operationId: getInternalFiles + operationId: getFiles tags: [internal] summary: List files in a directory type description: Lists the files present in one of ComfyUI's known directories (input, output, or temp). @@ -1476,6 +1843,12 @@ paths: items: type: string + '400': + description: Invalid directory type + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' # --------------------------------------------------------------------------- # Assets (x-feature-gate: enable-assets) # --------------------------------------------------------------------------- @@ -1499,6 +1872,24 @@ paths: "404": description: No asset with this hash + '400': + description: Invalid hash format + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/assets: get: operationId: listAssets @@ -1575,8 +1966,26 @@ paths: application/json: schema: $ref: "#/components/schemas/ListAssetsResponse" + '400': + description: Invalid request parameters + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' post: - operationId: createAsset + operationId: uploadAsset tags: [assets] summary: Upload a new asset description: Uploads a new asset (binary content plus metadata) and registers it in the asset database. @@ -1664,6 +2073,60 @@ paths: schema: $ref: "#/components/schemas/AssetCreated" + '200': + description: Asset already exists (returned existing asset) + content: + application/json: + schema: + $ref: '#/components/schemas/AssetCreated' + '400': + description: Invalid request (bad file, invalid URL, invalid content type, etc.) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '403': + description: Source URL requires authentication or access denied + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Source URL not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '413': + description: File too large + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '415': + description: Unsupported media type + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '422': + description: Download failed due to network error or timeout + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/assets/from-hash: post: operationId: createAssetFromHash @@ -1707,9 +2170,39 @@ paths: schema: $ref: "#/components/schemas/AssetCreated" + '200': + description: Asset reference already exists (returned existing) + content: + application/json: + schema: + $ref: '#/components/schemas/AssetCreated' + '400': + description: Invalid request (bad hash format, invalid tags, etc.) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Source asset with given hash not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/assets/{id}: get: - operationId: getAsset + operationId: getAssetById tags: [assets] summary: Get asset metadata description: Returns the metadata for a single asset. @@ -1731,6 +2224,18 @@ paths: $ref: "#/components/schemas/Asset" "404": description: Asset not found + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' put: operationId: updateAsset tags: [assets] @@ -1775,6 +2280,30 @@ paths: application/json: schema: $ref: "#/components/schemas/AssetUpdated" + '400': + description: Invalid request (no fields provided) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Asset not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' delete: operationId: deleteAsset tags: [assets] @@ -1798,6 +2327,30 @@ paths: "204": description: Asset deleted + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Asset not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '409': + description: Asset cannot be deleted because it is referenced by another resource (e.g., workflow version) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/assets/{id}/content: get: operationId: getAssetContent @@ -1859,6 +2412,36 @@ paths: application/json: schema: $ref: "#/components/schemas/TagsModificationResponse" + '400': + description: Invalid request + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Asset not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '422': + description: Validation error (e.g., reserved tag) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' delete: operationId: removeAssetTags tags: [assets] @@ -1894,6 +2477,36 @@ paths: schema: $ref: "#/components/schemas/TagsModificationResponse" + '400': + description: Invalid request + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Asset not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '422': + description: Validation error (e.g., reserved tag) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/tags: get: operationId: listTags @@ -1923,9 +2536,27 @@ paths: schema: $ref: "#/components/schemas/ListTagsResponse" + '400': + description: Invalid request parameters + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/assets/tags/refine: get: - operationId: refineAssetTags + operationId: getAssetTagHistogram tags: [assets] summary: Get tag counts for assets matching current filters description: Returns suggested additional tags that would refine a filtered asset query, together with the count of assets each tag would select. @@ -1986,6 +2617,24 @@ paths: schema: $ref: "#/components/schemas/AssetTagHistogramResponse" + '400': + description: Invalid request parameters + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/assets/seed: post: operationId: seedAssets @@ -2117,9 +2766,21 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '400': + description: Bad Request - job_id is not a valid UUID (emitted by request validation before the handler runs) + content: + application/json: + schema: + $ref: '#/components/schemas/BindingErrorResponse' + '500': + description: Internal server error - cancellation failed + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/job/{job_id}/status: get: - operationId: getCloudJobStatus + operationId: getJobStatus tags: [queue] summary: Get status of a cloud job deprecated: true @@ -2156,6 +2817,18 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '403': + description: Forbidden - job belongs to another user + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/prompt/{prompt_id}: get: operationId: getCloudPrompt @@ -2193,7 +2866,7 @@ paths: /api/history_v2: get: - operationId: getHistoryV2 + operationId: getHistory tags: [history] summary: Get paginated execution history (v2) deprecated: true @@ -2234,9 +2907,15 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/history_v2/{prompt_id}: get: - operationId: getHistoryV2ByPromptId + operationId: getHistoryForPrompt tags: [history] summary: Get v2 history for a specific prompt deprecated: true @@ -2273,9 +2952,15 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/logs: get: - operationId: getCloudLogs + operationId: getLogs tags: [system] summary: Get cloud execution logs deprecated: true @@ -2322,7 +3007,7 @@ paths: # --------------------------------------------------------------------------- /api/assets/download: post: - operationId: downloadAssets + operationId: createAssetDownload tags: [assets] summary: Download assets to cloud runtime description: "[cloud-only] Initiates a download of one or more assets to the cloud runtime environment. Returns a task ID for tracking download progress via WebSocket." @@ -2376,9 +3061,27 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '200': + description: File already exists in storage - asset created/returned immediately + content: + application/json: + schema: + $ref: '#/components/schemas/AssetCreated' + '422': + description: Validation errors + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/assets/export: post: - operationId: exportAssets + operationId: createAssetExport tags: [assets] summary: Export assets as a downloadable archive description: "[cloud-only] Initiates a bulk export of assets. Returns a task ID for tracking progress via WebSocket. When complete, the export can be downloaded via the exports endpoint." @@ -2436,6 +3139,12 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/assets/exports/{exportName}: get: operationId: getAssetExport @@ -2471,9 +3180,21 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '400': + description: Invalid export name + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/assets/from-workflow: post: - operationId: createAssetsFromWorkflow + operationId: postAssetsFromWorkflow tags: [assets] summary: Create asset records from a workflow execution description: "[cloud-only] Registers output files from a workflow execution as assets in the asset database." @@ -2527,6 +3248,12 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/assets/import: post: operationId: importPublishedAssets @@ -2561,9 +3288,15 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/assets/remote-metadata: get: - operationId: getAssetRemoteMetadata + operationId: getRemoteAssetMetadata tags: [assets] summary: Fetch metadata for a remote asset URL description: "[cloud-only] Fetches and returns metadata (content type, size, filename) for a remote URL without downloading the full content." @@ -2596,6 +3329,18 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '422': + description: Failed to retrieve metadata from source + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' # --------------------------------------------------------------------------- # Custom nodes / hub (cloud) # --------------------------------------------------------------------------- @@ -2751,7 +3496,7 @@ paths: /api/hub/assets/upload-url: post: - operationId: getHubAssetUploadUrl + operationId: createHubAssetUploadUrl tags: [hub] summary: Get a pre-signed upload URL for a hub asset description: "[cloud-only] Returns a pre-signed URL that can be used to upload an asset file directly to storage." @@ -2805,6 +3550,18 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '404': + description: Not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/hub/labels: get: operationId: listHubLabels @@ -2822,6 +3579,18 @@ paths: items: $ref: "#/components/schemas/HubLabel" + '400': + description: Bad request (e.g. invalid type parameter) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/hub/profiles: get: operationId: listHubProfiles @@ -2905,6 +3674,12 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/hub/profiles/{username}: get: operationId: getHubProfile @@ -2933,9 +3708,15 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/hub/profiles/check: get: - operationId: checkHubProfileUsername + operationId: checkHubUsername tags: [hub] summary: Check if a hub username is available description: "[cloud-only] Returns whether the given username is available for registration." @@ -2960,6 +3741,24 @@ paths: username: type: string + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/hub/profiles/me: get: operationId: getMyHubProfile @@ -2980,6 +3779,18 @@ paths: application/json: schema: $ref: "#/components/schemas/CloudError" + '404': + description: No hub profile exists + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' put: operationId: updateMyHubProfile tags: [hub] @@ -3079,6 +3890,24 @@ paths: application/json: schema: $ref: "#/components/schemas/HubWorkflowList" + '400': + description: Bad request (e.g. malformed pagination cursor) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '404': + description: Profile not found (when filtering by username) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' post: operationId: publishHubWorkflow tags: [hub] @@ -3117,6 +3946,12 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/hub/workflows/{share_id}: get: operationId: getHubWorkflow @@ -3144,6 +3979,18 @@ paths: application/json: schema: $ref: "#/components/schemas/CloudError" + '413': + description: Workflow JSON too large + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' delete: operationId: deleteHubWorkflow tags: [hub] @@ -3173,9 +4020,15 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/hub/workflows/index: get: - operationId: getHubWorkflowIndex + operationId: listHubWorkflowIndex tags: [hub] summary: Get the hub workflow index description: "[cloud-only] Returns the lightweight index of all hub workflows for client-side search and navigation." @@ -3190,12 +4043,18 @@ paths: items: $ref: "#/components/schemas/HubWorkflowIndexEntry" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' # --------------------------------------------------------------------------- # Workflows (cloud) # --------------------------------------------------------------------------- /api/workflows: get: - operationId: listCloudWorkflows + operationId: listWorkflows tags: [workflows] summary: List cloud workflows description: "[cloud-only] Returns a paginated list of the authenticated user's cloud workflows." @@ -3240,8 +4099,14 @@ paths: application/json: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' post: - operationId: createCloudWorkflow + operationId: createWorkflow tags: [workflows] summary: Create a new cloud workflow description: "[cloud-only] Creates a new cloud workflow with the provided name and optional initial content." @@ -3285,9 +4150,21 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '422': + description: Validation error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/workflows/{workflow_id}: get: - operationId: getCloudWorkflow + operationId: getWorkflow tags: [workflows] summary: Get a cloud workflow by ID description: "[cloud-only] Returns the metadata for a cloud workflow." @@ -3319,8 +4196,20 @@ paths: application/json: schema: $ref: "#/components/schemas/CloudError" + '403': + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' patch: - operationId: updateCloudWorkflow + operationId: updateWorkflow tags: [workflows] summary: Update a cloud workflow description: "[cloud-only] Updates the metadata (name, description) of an existing cloud workflow." @@ -3369,8 +4258,20 @@ paths: application/json: schema: $ref: "#/components/schemas/CloudError" + '422': + description: Validation error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' delete: - operationId: deleteCloudWorkflow + operationId: deleteWorkflow tags: [workflows] summary: Delete a cloud workflow description: "[cloud-only] Deletes a cloud workflow and all its versions." @@ -3399,9 +4300,15 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/workflows/{workflow_id}/content: get: - operationId: getCloudWorkflowContent + operationId: getWorkflowContent tags: [workflows] summary: Get the content of a cloud workflow description: "[cloud-only] Returns the full workflow graph JSON for the latest version of a cloud workflow." @@ -3440,6 +4347,18 @@ paths: application/json: schema: $ref: "#/components/schemas/CloudError" + '403': + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' put: operationId: updateCloudWorkflowContent tags: [workflows] @@ -3490,7 +4409,7 @@ paths: /api/workflows/{workflow_id}/fork: post: - operationId: forkCloudWorkflow + operationId: forkWorkflow tags: [workflows] summary: Fork a cloud workflow description: "[cloud-only] Creates a copy of a cloud workflow under the authenticated user's account." @@ -3533,6 +4452,24 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '403': + description: Forbidden + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '422': + description: Validation error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/workflows/{workflow_id}/versions: get: operationId: listCloudWorkflowVersions @@ -3587,7 +4524,7 @@ paths: schema: $ref: "#/components/schemas/CloudError" post: - operationId: createCloudWorkflowVersion + operationId: createWorkflowVersion tags: [workflows] summary: Create a new cloud workflow version description: "[cloud-only] Creates a new workflow version with updated workflow JSON. Uses optimistic concurrency via base_version." @@ -3638,6 +4575,18 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '422': + description: Validation error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/workflows/published/{share_id}: get: operationId: getPublishedWorkflow @@ -3666,6 +4615,24 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '413': + description: Workflow JSON too large + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' # --------------------------------------------------------------------------- # Auth / session (cloud) # --------------------------------------------------------------------------- @@ -3690,7 +4657,7 @@ paths: schema: $ref: "#/components/schemas/CloudError" post: - operationId: createAuthSession + operationId: createSession tags: [auth] summary: Create a session cookie description: "[cloud-only] Creates a session cookie from the bearer token in the Authorization header. Returns a Set-Cookie header with a secure HttpOnly session cookie. Cookie authentication is not allowed for this endpoint." @@ -3714,8 +4681,14 @@ paths: application/json: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' delete: - operationId: deleteAuthSession + operationId: deleteSession tags: [auth] summary: Delete session cookie (logout) description: "[cloud-only] Clears the session cookie and optionally revokes the session on the server." @@ -3728,9 +4701,15 @@ paths: schema: $ref: "#/components/schemas/DeleteSessionResponse" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/auth/token: post: - operationId: createAuthToken + operationId: exchangeToken tags: [auth] summary: Exchange credentials for an access token description: "[cloud-only] Exchanges authentication credentials (e.g. an authorization code) for an access token." @@ -3778,6 +4757,18 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '404': + description: Workspace not found or user not a member + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /.well-known/jwks.json: get: operationId: getJwks @@ -4106,9 +5097,15 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/billing/events: get: - operationId: listBillingEvents + operationId: getBillingEvents tags: [billing] summary: List billing events description: "[cloud-only] Returns a paginated list of billing events (charges, credits, refunds) for the authenticated user." @@ -4143,9 +5140,15 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/billing/ops/{id}: get: - operationId: getBillingOp + operationId: getBillingOpStatus tags: [billing] summary: Get a billing operation by ID description: "[cloud-only] Returns details of a specific billing operation." @@ -4177,9 +5180,15 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/billing/payment-portal: post: - operationId: createPaymentPortalSession + operationId: getPaymentPortal tags: [billing] summary: Create a payment portal session description: "[cloud-only] Creates a Stripe customer portal session for managing payment methods and invoices. Returns a URL to redirect the user to." @@ -4203,9 +5212,21 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '400': + description: Bad request (e.g., missing return_url) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/billing/plans: get: - operationId: listBillingPlans + operationId: getBillingPlans tags: [billing] summary: List available billing plans description: "[cloud-only] Returns the list of available subscription plans and their pricing." @@ -4220,9 +5241,21 @@ paths: items: $ref: "#/components/schemas/BillingPlan" + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/billing/preview-subscribe: post: - operationId: previewSubscription + operationId: previewSubscribe tags: [billing] summary: Preview a subscription change description: "[cloud-only] Returns a preview of what a subscription change would cost, including prorations." @@ -4259,6 +5292,12 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/billing/status: get: operationId: getBillingStatus @@ -4280,9 +5319,21 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '404': + description: Workspace not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/billing/subscribe: post: - operationId: createSubscription + operationId: subscribe tags: [billing] summary: Subscribe to a billing plan description: "[cloud-only] Creates a new subscription to the specified billing plan." @@ -4322,6 +5373,12 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/billing/subscription/cancel: post: operationId: cancelSubscription @@ -4343,6 +5400,18 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '400': + description: Invalid request (e.g., no active subscription) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/billing/subscription/resubscribe: post: operationId: resubscribe @@ -4364,9 +5433,21 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '400': + description: Invalid request (e.g., no active subscription, not in cancellation grace period) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/billing/topup: post: - operationId: topUpCredits + operationId: createTopup tags: [billing] summary: Purchase additional credits description: "[cloud-only] Purchases a one-time credit top-up using the user's payment method on file." @@ -4403,12 +5484,18 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' # --------------------------------------------------------------------------- # Workspace (cloud) # --------------------------------------------------------------------------- /api/workspace/api-keys: get: - operationId: listWorkspaceApiKeys + operationId: listWorkspaceAPIKeys tags: [workspace] summary: List workspace API keys description: "[cloud-only] Returns the list of API keys for the current workspace." @@ -4434,8 +5521,14 @@ paths: application/json: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' post: - operationId: createWorkspaceApiKey + operationId: createWorkspaceAPIKey tags: [workspace] summary: Create a workspace API key description: "[cloud-only] Creates a new API key for the current workspace." @@ -4482,9 +5575,33 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '404': + description: Workspace not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '422': + description: Validation error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '429': + description: Key limit reached + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/workspace/api-keys/{id}: delete: - operationId: deleteWorkspaceApiKey + operationId: revokeWorkspaceAPIKey tags: [workspace] summary: Delete a workspace API key description: "[cloud-only] Revokes and deletes a workspace API key." @@ -4518,6 +5635,12 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/workspace/invites: get: operationId: listWorkspaceInvites @@ -4546,6 +5669,12 @@ paths: application/json: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' post: operationId: createWorkspaceInvite tags: [workspace] @@ -4601,9 +5730,27 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '404': + description: Workspace not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '422': + description: Validation error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/workspace/invites/{inviteId}: delete: - operationId: deleteWorkspaceInvite + operationId: revokeWorkspaceInvite tags: [workspace] summary: Cancel a workspace invite description: "[cloud-only] Cancels a pending workspace invitation." @@ -4637,6 +5784,12 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/workspace/leave: post: operationId: leaveWorkspace @@ -4660,6 +5813,18 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '404': + description: Workspace not found or not a member + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/workspace/members: get: operationId: listWorkspaceMembers @@ -4689,6 +5854,24 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '404': + description: Workspace not found + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '422': + description: Validation error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/workspace/members/{user_id}/api-keys: get: operationId: listMemberApiKeys @@ -4731,7 +5914,7 @@ paths: schema: $ref: "#/components/schemas/CloudError" delete: - operationId: bulkRevokeMemberApiKeys + operationId: bulkRevokeWorkspaceMemberAPIKeys tags: [workspace] summary: Bulk revoke a member's API keys description: "[cloud-only] Revokes all active API keys for a specific workspace member. Only workspace owners can perform this action." @@ -4764,6 +5947,18 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '422': + description: Validation error (e.g. empty user_id) + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/workspace/members/{userId}: patch: operationId: updateWorkspaceMember @@ -4857,6 +6052,12 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/workspaces: get: operationId: listWorkspaces @@ -4879,6 +6080,18 @@ paths: application/json: schema: $ref: "#/components/schemas/CloudError" + '404': + description: Feature not enabled for user + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' post: operationId: createWorkspace tags: [workspace] @@ -4917,6 +6130,24 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '404': + description: Feature not enabled for user + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '422': + description: Validation error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/workspaces/{id}: get: operationId: getWorkspace @@ -4956,6 +6187,12 @@ paths: application/json: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' patch: operationId: updateWorkspace tags: [workspace] @@ -5010,6 +6247,18 @@ paths: application/json: schema: $ref: "#/components/schemas/CloudError" + '422': + description: Validation error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' delete: operationId: deleteWorkspace tags: [workspace] @@ -5045,6 +6294,12 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' # --------------------------------------------------------------------------- # User / settings / misc (cloud) # --------------------------------------------------------------------------- @@ -5101,6 +6356,12 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/files/mask-layers: get: operationId: getMaskLayers @@ -5199,9 +6460,15 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/invites/{token}/accept: post: - operationId: acceptInvite + operationId: acceptWorkspaceInvite tags: [workspace] summary: Accept a workspace invitation description: "[cloud-only] Accepts a workspace invitation using the invite token. The authenticated user is added to the workspace." @@ -5239,6 +6506,24 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '403': + description: Email does not match invite + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '409': + description: Already a member of this workspace + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/secrets: get: operationId: listSecrets @@ -5261,6 +6546,18 @@ paths: application/json: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '503': + description: Service unavailable - feature is disabled + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' post: operationId: createSecret tags: [settings] @@ -5303,6 +6600,30 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '409': + description: Conflict - secret with this name or provider already exists + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '422': + description: Validation error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '503': + description: Service unavailable - secrets feature disabled + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/secrets/{id}: get: operationId: getSecret @@ -5337,6 +6658,24 @@ paths: application/json: schema: $ref: "#/components/schemas/CloudError" + '403': + description: Forbidden - user does not own this secret + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '503': + description: Service unavailable - secrets feature disabled + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' patch: operationId: updateSecret tags: [settings] @@ -5388,6 +6727,24 @@ paths: application/json: schema: $ref: "#/components/schemas/CloudError" + '403': + description: Forbidden - user does not own this secret + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '503': + description: Service unavailable - secrets feature disabled + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' delete: operationId: deleteSecret tags: [settings] @@ -5417,9 +6774,27 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '403': + description: Forbidden - user does not own this secret + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '503': + description: Service unavailable - secrets feature disabled + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/user: get: - operationId: getCloudUser + operationId: getUser tags: [user] summary: Get the authenticated cloud user description: "[cloud-only] Returns the profile and account information for the currently authenticated user." @@ -5508,8 +6883,14 @@ paths: application/json: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' post: - operationId: publishUserdataFile + operationId: postUserdataFilePublish tags: [userdata] summary: Publish a userdata file to the cloud description: "[cloud-only] Makes a userdata file available via a public URL for sharing or embedding." @@ -5546,9 +6927,21 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '400': + description: Bad request + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/vhs/queryvideo: get: - operationId: queryVhsVideo + operationId: getVhsQueryVideo tags: [view] summary: Query VHS video metadata description: "[cloud-only] Returns metadata about a video file processed by the VHS (Video Helper Suite) integration." @@ -5592,6 +6985,15 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '400': + description: 'Missing required query parameter. Produced by the oapi-codegen + wrapper via echo.NewHTTPError, so the body shape matches Echo''s + default HTTPError serialization rather than ErrorResponse. + ' + content: + application/json: + schema: + $ref: '#/components/schemas/BindingErrorResponse' /api/vhs/viewaudio: get: operationId: viewVhsAudio @@ -5812,6 +7214,12 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' /api/tasks/{task_id}: get: operationId: getTask @@ -5847,6 +7255,12 @@ paths: schema: $ref: "#/components/schemas/CloudError" + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' components: parameters: ComfyUserHeader: @@ -8755,4 +10169,1213 @@ components: items: $ref: "#/components/schemas/TaskEntry" pagination: - $ref: "#/components/schemas/PaginationInfo" \ No newline at end of file + $ref: "#/components/schemas/PaginationInfo" + + # ===== Cloud-only schemas (Comfy-Org/cloud runtime, BE-1106) ===== + AssetDownloadResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Acknowledgement of an async asset download task; clients poll GET /api/tasks/{task_id} for status.' + required: + - task_id + - status + properties: + task_id: + type: string + format: uuid + description: Task ID for tracking download progress via GET /api/tasks/{task_id} + status: + type: string + enum: + - created + - running + - completed + - failed + description: Current task status + message: + type: string + description: Human-readable message + example: Download task created. Use task_id to track progress. + + AssetMetadataResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Metadata for a remotely hosted asset resolved by URL.' + required: + - content_length + properties: + content_length: + type: integer + format: int64 + description: Size of the asset in bytes (-1 if unknown) + example: 4294967296 + content_type: + type: string + description: MIME type of the asset + example: application/octet-stream + filename: + type: string + description: Suggested filename for the asset from source + example: realistic-vision-v5.safetensors + name: + type: string + description: Display name or title for the asset from source + example: Realistic Vision v5.0 + tags: + type: array + items: + type: string + description: Tags for categorization from source + example: + - models + - checkpoint + preview_image: + type: string + description: Preview image as base64-encoded data URL + example: data:image/jpeg;base64,/9j/4AAQSkZJRg... + validation: + description: Validation results for the file + allOf: + - $ref: '#/components/schemas/ValidationResult' + + BillingBalanceResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Current credit balance and usage details for a workspace.' + required: + - amount_micros + - currency + properties: + amount_micros: + type: number + format: double + description: The total remaining balance in microamount (1/1,000,000 of the currency unit) + prepaid_balance_micros: + type: number + format: double + description: The remaining balance from prepaid commits in microamount + cloud_credit_balance_micros: + type: number + format: double + description: The remaining balance from cloud credits in microamount + pending_charges_micros: + type: number + format: double + description: The total amount of pending/unbilled charges from draft invoices in microamount + effective_balance_micros: + type: number + format: double + description: The effective balance (total balance minus pending charges). Can be negative if pending charges exceed + the balance. + currency: + type: string + example: usd + description: Currency code + + BillingPlansResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] List of available billing plans for subscription.' + required: + - plans + properties: + current_plan_slug: + type: string + description: Current plan slug if subscribed + plans: + type: array + items: + $ref: '#/components/schemas/Plan' + + BillingStatusResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Current billing and subscription status for a workspace.' + required: + - is_active + - has_funds + properties: + is_active: + type: boolean + description: Whether the workspace has an active subscription + subscription_status: + type: string + enum: + - active + - ended + - canceled + description: Subscription activity status (scheduled subscriptions are not returned) + subscription_tier: + $ref: '#/components/schemas/SubscriptionTier' + subscription_duration: + $ref: '#/components/schemas/SubscriptionDuration' + plan_slug: + type: string + description: Plan identifier (e.g., standard-monthly, team-pro-annual) + billing_status: + $ref: '#/components/schemas/BillingStatus' + has_funds: + type: boolean + description: Whether the workspace has available credits + cancel_at: + type: string + format: date-time + description: When the subscription will become inactive (if canceled) + renewal_date: + type: string + format: date-time + description: When the current billing period ends and the next one begins + + GetUserDataResponseFull: + type: array + x-runtime: [cloud] + description: '[cloud-only] List of user data file entries (each with path, size, and modification time) returned when full_info=true.' + items: + $ref: '#/components/schemas/GetUserDataResponseFullFile' + + HistoryDetailEntry: + type: object + x-runtime: [cloud] + description: '[cloud-only] History entry with full prompt data' + properties: + prompt: + type: object + description: Full prompt execution data + properties: + priority: + type: number + format: double + description: Execution priority + prompt_id: + type: string + description: The prompt ID + prompt: + type: object + description: The workflow nodes + additionalProperties: true + extra_data: + type: object + description: Additional execution data + additionalProperties: true + outputs_to_execute: + type: array + items: + type: string + description: Output nodes to execute + outputs: + type: object + description: Output data from execution (generated images, files, etc.) + additionalProperties: true + status: + type: object + description: Execution status and timeline information + additionalProperties: true + meta: + type: object + description: Metadata about the execution and nodes + additionalProperties: true + + HistoryDetailResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Detailed execution history response for a specific prompt. + + Returns a dictionary with prompt_id as key and full history data as value. + + ' + additionalProperties: + $ref: '#/components/schemas/HistoryDetailEntry' + + HistoryResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Execution history response with history array. + + Returns an object with a "history" key containing an array of history entries. + + Each entry includes prompt_id as a property along with execution data. + + ' + required: + - history + properties: + history: + type: array + description: Array of history entries ordered by creation time (newest first) + items: + $ref: '#/components/schemas/HistoryEntry' + + HubLabelInfo: + type: object + x-runtime: [cloud] + description: '[cloud-only] Metadata for a single Hub label.' + required: + - name + - display_name + - type + properties: + name: + type: string + description: Slug identifier. + display_name: + type: string + description: Human-readable display name. + description: + type: string + description: Optional description of the label. + type: + type: string + enum: + - tag + - model + - custom_node + description: Label category. + + HubProfileSummary: + type: object + x-runtime: [cloud] + description: '[cloud-only] Abbreviated Hub profile used in workflow listings.' + required: + - username + properties: + username: + type: string + display_name: + type: string + avatar_url: + type: string + description: Public URL of the profile avatar image. + + HubWorkflowListResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Paginated list of Hub workflows matching search criteria.' + required: + - workflows + properties: + workflows: + type: array + items: + anyOf: + - $ref: '#/components/schemas/HubWorkflowSummary' + - $ref: '#/components/schemas/HubWorkflowDetail' + description: Array of HubWorkflowSummary (default) or HubWorkflowDetail (when detail=true). + next_cursor: + type: string + description: Cursor for the next page, empty if no more results. + + HubWorkflowStatus: + type: string + x-runtime: [cloud] + description: '[cloud-only] Public workflow status. NULL in the database is represented as pending in API responses.' + enum: + - pending + - approved + - rejected + - deprecated + + HubWorkflowSummary: + type: object + x-runtime: [cloud] + description: '[cloud-only] Abbreviated Hub workflow metadata used in search and listing results.' + required: + - share_id + - name + - profile + - status + properties: + share_id: + type: string + name: + type: string + status: + $ref: '#/components/schemas/HubWorkflowStatus' + description: + type: string + tags: + type: array + items: + $ref: '#/components/schemas/LabelRef' + models: + type: array + items: + $ref: '#/components/schemas/LabelRef' + custom_nodes: + type: array + items: + $ref: '#/components/schemas/LabelRef' + thumbnail_type: + type: string + enum: + - image + - video + - image_comparison + thumbnail_url: + type: string + thumbnail_comparison_url: + type: string + publish_time: + type: string + format: date-time + nullable: true + profile: + $ref: '#/components/schemas/HubProfileSummary' + metadata: + type: object + additionalProperties: true + tutorial_url: + type: string + sample_image_urls: + type: array + items: + type: string + + HubWorkflowTemplateEntry: + type: object + x-runtime: [cloud] + description: '[cloud-only] Entry in the curated workflow template gallery shown on the home page.' + required: + - name + - title + - status + properties: + name: + type: string + description: Slug identifier for the template + title: + type: string + status: + $ref: '#/components/schemas/HubWorkflowStatus' + description: + type: string + tags: + type: array + items: + type: string + models: + type: array + items: + type: string + requiresCustomNodes: + type: array + items: + type: string + thumbnailVariant: + type: string + mediaType: + type: string + mediaSubtype: + type: string + size: + type: integer + format: int64 + description: Workflow asset size in bytes. + vram: + type: integer + format: int64 + description: Approximate VRAM requirement in bytes. + usage: + type: integer + format: int64 + description: Usage count reported upstream. + searchRank: + type: integer + format: int64 + description: Search ranking score reported upstream. + isEssential: + type: boolean + description: Whether the template belongs to a module marked as essential. + openSource: + type: boolean + profile: + $ref: '#/components/schemas/HubProfileSummary' + tutorialUrl: + type: string + logos: + type: array + items: + type: object + additionalProperties: true + date: + type: string + description: Publication date in YYYY-MM-DD format + io: + type: object + properties: + inputs: + type: array + items: + type: object + additionalProperties: true + outputs: + type: array + items: + type: object + additionalProperties: true + includeOnDistributions: + type: array + items: + type: string + thumbnailUrl: + type: string + description: Public URL of the primary thumbnail + thumbnailComparisonUrl: + type: string + description: Public URL of the comparison thumbnail + shareId: + type: string + description: Share ID for linking to the hub workflow detail + extendedDescription: + type: string + description: AI-generated extended description of the workflow + metaDescription: + type: string + description: AI-generated SEO meta description (under 160 chars) + howToUse: + type: array + items: + type: string + description: AI-generated step-by-step usage instructions + suggestedUseCases: + type: array + items: + type: string + description: AI-generated suggested use cases + faqItems: + type: array + items: + type: object + required: + - question + - answer + properties: + question: + type: string + answer: + type: string + description: AI-generated FAQ items + contentTemplate: + type: string + description: Content template used for generation (tutorial, showcase, comparison, breakthrough) + + JobStatusResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Job status information' + properties: + id: + type: string + format: uuid + description: The job ID + status: + type: string + enum: + - waiting_to_dispatch + - pending + - in_progress + - completed + - error + - cancelled + description: Current job status + created_at: + type: string + format: date-time + description: When the job was created + updated_at: + type: string + format: date-time + description: When the job was last updated + last_state_update: + type: string + format: date-time + description: When the job status was last changed + assigned_inference: + type: string + nullable: true + description: The inference instance assigned to this job (if any) + error_message: + type: string + nullable: true + description: Error message if the job failed + required: + - id + - status + - created_at + - updated_at + + JobsListResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Paginated list of jobs for the authenticated user.' + required: + - jobs + - pagination + properties: + jobs: + type: array + description: Array of jobs ordered by specified sort field + items: + $ref: '#/components/schemas/JobEntry' + pagination: + $ref: '#/components/schemas/PaginationInfo' + + LabelRef: + type: object + x-runtime: [cloud] + description: '[cloud-only] Reference to a Hub label by ID.' + required: + - name + - display_name + properties: + name: + type: string + description: Slug identifier (e.g. "video-generation", "flux"). + display_name: + type: string + description: Human-readable display name (e.g. "Video Generation", "Flux"). + + LogsResponse: + type: array + x-runtime: [cloud] + description: '[cloud-only] System logs response' + items: + type: object + properties: + timestamp: + type: string + format: date-time + description: When the log entry was created + level: + type: string + enum: + - debug + - info + - warn + - error + description: Log level + message: + type: string + description: Log message + source: + type: string + description: Source of the log entry + metadata: + type: object + additionalProperties: true + description: Additional log metadata + + Member: + type: object + x-runtime: [cloud] + description: '[cloud-only] Workspace member with profile and role information.' + required: + - id + - name + - email + - role + - joined_at + properties: + id: + type: string + description: User ID + name: + type: string + description: User's display name + email: + type: string + format: email + description: User's email address + role: + type: string + enum: + - owner + - member + description: User's role in the workspace + joined_at: + type: string + format: date-time + description: When the user joined the workspace + + OAuthRegisterBadRequestResponse: + x-runtime: [cloud] + description: "[cloud-only] Union of the two 400 shapes /oauth/register can emit. `OAuthRegisterError` is the handler-shaped\ + \ RFC 7591 \xA73.2.2 error; `BindingErrorResponse` is the strict-server binding-layer error fired when the request body\ + \ fails OpenAPI-schema validation before the handler runs.\n" + oneOf: + - $ref: '#/components/schemas/OAuthRegisterError' + - $ref: '#/components/schemas/BindingErrorResponse' + + PendingInvite: + type: object + x-runtime: [cloud] + description: '[cloud-only] An outstanding workspace invitation that has not yet been accepted.' + required: + - id + - email + - invited_at + - expires_at + properties: + id: + type: string + description: Invite ID + email: + type: string + format: email + description: Email address of the invited user + token: + type: string + description: Invite token for constructing invite links. Empty for expired invites. + invited_at: + type: string + format: date-time + description: When the invite was created + expires_at: + type: string + format: date-time + description: When the invite expires + + Plan: + type: object + x-runtime: [cloud] + description: '[cloud-only] Billing plan details including pricing, limits, and features.' + required: + - slug + - tier + - duration + - price_cents + - credits_cents + - max_seats + - availability + - seat_summary + properties: + slug: + type: string + description: Plan identifier (e.g., "pro-monthly", "team-standard-annual") + example: pro-monthly + tier: + $ref: '#/components/schemas/SubscriptionTier' + duration: + $ref: '#/components/schemas/SubscriptionDuration' + price_cents: + type: integer + format: int64 + description: Per-member price in cents (base + one seat) + example: 10000 + credits_cents: + type: integer + format: int64 + description: Per-member credits in cents (base + one seat) + example: 10000 + max_seats: + type: integer + format: int64 + description: Maximum number of seats allowed for this plan + example: 20 + availability: + $ref: '#/components/schemas/PlanAvailability' + seat_summary: + $ref: '#/components/schemas/PlanSeatSummary' + + PlanAvailability: + type: object + x-runtime: [cloud] + description: '[cloud-only] Availability and eligibility information for a billing plan.' + required: + - available + properties: + available: + type: boolean + description: Whether the workspace can subscribe to this plan + reason: + $ref: '#/components/schemas/PlanAvailabilityReason' + + PlanAvailabilityReason: + type: string + x-runtime: [cloud] + enum: + - same_plan + - incompatible_transition + - requires_team + - requires_personal + - exceeds_max_seats + description: '[cloud-only] Reason why a plan is unavailable' + + PlanSeatSummary: + type: object + x-runtime: [cloud] + description: '[cloud-only] Summary of seat costs based on current workspace members' + required: + - seat_count + - total_cost_cents + - total_credits_cents + properties: + seat_count: + type: integer + description: Total number of seats (owner + members) that would be charged + example: 5 + total_cost_cents: + type: integer + format: int64 + description: Total cost for all seats in cents + example: 50000 + total_credits_cents: + type: integer + format: int64 + description: Total credits granted for all seats in cents + example: 50000 + + PreviewPlanInfo: + type: object + x-runtime: [cloud] + description: '[cloud-only] Plan information for preview display' + required: + - slug + - tier + - duration + - price_cents + - credits_cents + - seat_summary + properties: + slug: + type: string + description: Plan slug + example: team-pro-monthly + tier: + $ref: '#/components/schemas/SubscriptionTier' + duration: + $ref: '#/components/schemas/SubscriptionDuration' + price_cents: + type: integer + format: int64 + description: Per-seat price in cents + example: 10000 + credits_cents: + type: integer + format: int64 + description: Per-seat credits in cents + example: 10000 + seat_summary: + $ref: '#/components/schemas/PlanSeatSummary' + period_start: + type: string + format: date-time + description: Current billing period start (only for current_plan) + period_end: + type: string + format: date-time + description: Current billing period end (only for current_plan) + + PreviewSubscribeResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Itemized cost preview for a pending subscription change.' + required: + - allowed + - transition_type + - effective_at + - is_immediate + - cost_today_cents + - cost_next_period_cents + - credits_today_cents + - credits_next_period_cents + - new_plan + properties: + allowed: + type: boolean + description: Whether this subscription change is allowed + reason: + type: string + description: Reason why the change is not allowed (only present if allowed=false) + transition_type: + type: string + enum: + - new_subscription + - upgrade + - downgrade + - duration_change + description: Type of subscription transition + effective_at: + type: string + format: date-time + description: When the change takes effect + is_immediate: + type: boolean + description: Whether the change takes effect immediately (true) or at period end (false) + cost_today_cents: + type: integer + format: int64 + description: Amount to charge today in cents (0 for downgrades) + example: 5000 + cost_next_period_cents: + type: integer + format: int64 + description: Amount that will be charged at next billing period in cents + example: 10000 + credits_today_cents: + type: integer + format: int64 + description: Credits granted today in cents (prorated for mid-period upgrades) + example: 5000 + credits_next_period_cents: + type: integer + format: int64 + description: Credits that will be granted at next billing period in cents + example: 10000 + current_plan: + $ref: '#/components/schemas/PreviewPlanInfo' + new_plan: + $ref: '#/components/schemas/PreviewPlanInfo' + + PublishedWorkflowDetail: + type: object + x-runtime: [cloud] + description: '[cloud-only] Full detail of a publicly published workflow on the Hub.' + required: + - share_id + - workflow_id + - name + - listed + - workflow_json + - assets + properties: + share_id: + type: string + workflow_id: + type: string + name: + type: string + description: Human-readable workflow name. + listed: + type: boolean + publish_time: + type: string + format: date-time + nullable: true + workflow_json: + type: object + additionalProperties: true + description: The workflow JSON content at publish time. + assets: + type: array + description: Published assets with their library status for the caller. + items: + $ref: '#/components/schemas/AssetInfo' + + SecretResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] User secret metadata (the secret value itself is never returned after creation).' + required: + - id + - name + - created_at + - updated_at + properties: + id: + type: string + format: uuid + description: Unique identifier for the secret + name: + type: string + description: User-provided label for the secret + provider: + type: string + description: Provider identifier (e.g., huggingface, civitai) + last_used_at: + type: string + format: date-time + description: When the secret was last used for decryption + created_at: + type: string + format: date-time + description: When the secret was created + updated_at: + type: string + format: date-time + description: When the secret was last updated + + SubscriptionDuration: + type: string + x-runtime: [cloud] + enum: + - MONTHLY + - ANNUAL + description: '[cloud-only] Billing period (uppercase to match comfy-api)' + + SubscriptionTier: + type: string + x-runtime: [cloud] + enum: + - FREE + - STANDARD + - CREATOR + - PRO + - FOUNDERS_EDITION + description: '[cloud-only] Subscription tier (uppercase to match comfy-api)' + + UserDataResponseFull: + type: object + x-runtime: [cloud] + description: '[cloud-only] User data listing entry with file metadata (path, size, modification time).' + properties: + path: + type: string + size: + type: integer + modified: + type: integer + format: int64 + description: UNIX timestamp of the last modification in milliseconds. + + ValidationError: + type: object + x-runtime: [cloud] + description: '[cloud-only] Details of a single validation error encountered during asset operations.' + required: + - code + - message + - field + properties: + code: + type: string + description: Machine-readable error code + example: FORMAT_NOT_ALLOWED + message: + type: string + description: Human-readable error message + example: 'File format "PickleTensor" is not allowed. Allowed formats: [SafeTensor]' + field: + type: string + description: Field that failed validation + example: format + + ValidationResult: + type: object + x-runtime: [cloud] + description: '[cloud-only] Result of validating a set of asset operations.' + required: + - is_valid + properties: + is_valid: + type: boolean + description: Overall validation status (true if all checks passed) + example: true + errors: + type: array + items: + $ref: '#/components/schemas/ValidationError' + description: Blocking validation errors that prevent download + warnings: + type: array + items: + $ref: '#/components/schemas/ValidationError' + description: Non-blocking validation warnings (informational only) + + WorkflowForkedFrom: + type: object + x-runtime: [cloud] + description: '[cloud-only] Reference to the parent workflow from which this workflow was forked.' + properties: + workflow_id: + type: string + workflow_version_id: + type: string + + WorkflowResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Full workflow entity including metadata and version history.' + required: + - id + - latest_version + - created_by + - created_at + - updated_at + properties: + id: + type: string + name: + type: string + description: + type: string + default_view: + type: string + enum: + - workflow + - app + latest_version: + type: integer + forked_from: + $ref: '#/components/schemas/WorkflowForkedFrom' + created_by: + type: string + created_at: + type: string + format: date-time + updated_at: + type: string + format: date-time + + WorkflowVersionContentResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Full workflow version including the serialized workflow JSON.' + required: + - id + - version + - workflow_json + - created_by + - created_at + properties: + id: + type: string + version: + type: integer + workflow_json: + type: object + additionalProperties: true + created_by: + type: string + created_at: + type: string + format: date-time + dependency_asset_ids: + type: array + items: + type: string + + WorkspaceAPIKeyInfo: + type: object + x-runtime: [cloud] + description: '[cloud-only] Metadata for a workspace-scoped API key (secret is never returned).' + required: + - id + - workspace_id + - user_id + - name + - description + - key_prefix + - created_at + properties: + id: + type: string + format: uuid + description: API key ID + workspace_id: + type: string + description: Workspace this key belongs to + user_id: + type: string + description: User who created this key + name: + type: string + description: User-provided label + description: + type: string + description: User-provided description of the key's purpose. Limit is byte-based (UTF-8 encoding); 5000 bytes equals + 5000 ASCII characters or fewer multi-byte characters. + maxLength: 5000 + key_prefix: + type: string + description: First 8 chars after prefix for display + expires_at: + type: string + format: date-time + description: When the key expires (if set) + last_used_at: + type: string + format: date-time + description: Last time the key was used + revoked_at: + type: string + format: date-time + description: When the key was revoked (if revoked) + created_at: + type: string + format: date-time + description: When the key was created + + WorkspaceSummary: + type: object + x-runtime: [cloud] + description: '[cloud-only] Abbreviated workspace metadata used in list responses.' + required: + - id + - name + - type + properties: + id: + type: string + example: w-a1b2c3d4-5678-90ab-cdef-1234567890ab + name: + type: string + example: My Team + type: + type: string + enum: + - personal + - team + + WorkspaceWithRole: + type: object + x-runtime: [cloud] + description: '[cloud-only] Workspace entity annotated with the requesting user''s role.' + required: + - id + - name + - type + - role + - created_at + - joined_at + properties: + id: + type: string + example: w-a1b2c3d4-5678-90ab-cdef-1234567890ab + name: + type: string + example: My Team + type: + type: string + enum: + - personal + - team + role: + type: string + enum: + - owner + - member + created_at: + type: string + format: date-time + description: When the workspace was created + joined_at: + type: string + format: date-time + description: When the user joined the workspace (same as created_at for the workspace creator) + subscription_tier: + $ref: '#/components/schemas/SubscriptionTier' + + BindingErrorResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Error shape returned when request binding or validation fails before the handler runs.' + required: + - message + properties: + message: + type: string + + ErrorResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Standard error response from cloud endpoints with a machine-readable code and human-readable message.' + required: + - code + - message + properties: + code: + type: string + description: Machine-readable error code + message: + type: string + description: Human-readable error message From c3c881f37b1cad344d400e16fd3293012556c8dc Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Fri, 22 May 2026 16:34:52 -0700 Subject: [PATCH 34/45] openapi: rename cloud-side response schemas to match runtime (PR D) (#14065) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * openapi: rename cloud-side response schemas to match runtime (PR D) Follow-up to the BE-1106 stack (#14060/61/63). Cloud's Go handlers reference response schemas by name (e.g., ingest.WorkflowResponse, ingest.SubscribeResponse), but vendor's matching operations were declaring those responses against differently-named vendor-side schemas (CloudWorkflow, BillingSubscription, etc.). After the stack landed, schemas like WorkflowResponse exist in vendor but weren't referenced by any path, so codegen pruned the unreferenced types. This PR: 1. Updates 34 operation $refs in cloud-runtime paths to point to the schema names cloud's handlers expect (e.g., CloudWorkflow → WorkflowResponse on /api/workflows/{workflow_id}). 2. Adds 12 cloud-only schemas that weren't in vendor yet but are referenced by these renames (e.g., SubscribeResponse, CancelSubscriptionResponse, BillingOpStatusResponse). Each copied verbatim from Comfy-Org/cloud's hand-written ingest spec and tagged x-runtime: [cloud] with a [cloud-only] description prefix. Schema renames span the same domains as the operationId renames in PR A: billing/subscriptions (7 schemas), workflows (5), userdata (3), jobs (2), hub (2), history (2), auth/workspace (4), and misc cloud endpoints (9). Convergent safety check after this lands (against cloud's TestCutoverSafe gate, BE-1106): Pre-PR D: 205 missing handler refs Post-PR D: 105 missing handler refs (-49%) Cumulative since the original 938-ref baseline: -89% The remaining 105 are a Phase 3 follow-up (response headers, text/plain responses, codegen-derived enum sub-types, and a small set of inline-response-schema operations that vendor declares inline where cloud has named-schema $refs). * openapi: drop PR-label comment from new schemas block PR-internal labels don't belong in committed code — future readers won't know what 'PR D' means and the marker stops being useful the moment this PR merges. --- openapi.yaml | 355 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 321 insertions(+), 34 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index 59b6817e5..bbe5b3562 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -1200,7 +1200,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/ListUserdataResponse" + $ref: "#/components/schemas/GetUserDataResponseFull" "404": description: Directory not found @@ -1340,7 +1340,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/UserDataResponse" + $ref: "#/components/schemas/UserDataResponseFull" "409": description: File exists and overwrite not set '400': @@ -1434,7 +1434,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/UserDataResponse" + $ref: "#/components/schemas/UserDataResponseFull" "404": description: Source file not found "409": @@ -2752,7 +2752,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/CloudJobStatus" + $ref: "#/components/schemas/JobCancelResponse" "401": description: Unauthorized content: @@ -2803,7 +2803,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/CloudJobStatus" + $ref: "#/components/schemas/JobStatusResponse" "401": description: Unauthorized content: @@ -2899,7 +2899,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/HistoryV2Response" + $ref: "#/components/schemas/HistoryResponse" "401": description: Unauthorized content: @@ -2938,7 +2938,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/HistoryV2Entry" + $ref: "#/components/schemas/HistoryDetailResponse" "401": description: Unauthorized content: @@ -2994,7 +2994,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/CloudLogsResponse" + $ref: "#/components/schemas/LogsResponse" "401": description: Unauthorized content: @@ -3315,7 +3315,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/RemoteAssetMetadata" + $ref: "#/components/schemas/AssetMetadataResponse" "400": description: Bad request content: @@ -3889,7 +3889,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/HubWorkflowList" + $ref: "#/components/schemas/HubWorkflowListResponse" '400': description: Bad request (e.g. malformed pagination cursor) content: @@ -3972,7 +3972,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/HubWorkflow" + $ref: "#/components/schemas/HubWorkflowDetail" "404": description: Not found content: @@ -4092,7 +4092,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/CloudWorkflowList" + $ref: "#/components/schemas/WorkflowListResponse" "401": description: Unauthorized content: @@ -4136,7 +4136,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/CloudWorkflow" + $ref: "#/components/schemas/WorkflowResponse" "400": description: Bad request content: @@ -4183,7 +4183,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/CloudWorkflow" + $ref: "#/components/schemas/WorkflowResponse" "401": description: Unauthorized content: @@ -4239,7 +4239,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/CloudWorkflow" + $ref: "#/components/schemas/WorkflowResponse" "400": description: Bad request content: @@ -4438,7 +4438,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/CloudWorkflow" + $ref: "#/components/schemas/WorkflowResponse" "401": description: Unauthorized content: @@ -4607,7 +4607,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/CloudWorkflow" + $ref: "#/components/schemas/PublishedWorkflowDetail" "404": description: Not found content: @@ -4743,7 +4743,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/AuthTokenResponse" + $ref: "#/components/schemas/ExchangeTokenResponse" "400": description: Bad request content: @@ -5089,7 +5089,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/BillingBalance" + $ref: "#/components/schemas/BillingBalanceResponse" "401": description: Unauthorized content: @@ -5132,7 +5132,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/BillingEventList" + $ref: "#/components/schemas/BillingEventsResponse" "401": description: Unauthorized content: @@ -5166,7 +5166,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/BillingOp" + $ref: "#/components/schemas/BillingOpStatusResponse" "401": description: Unauthorized content: @@ -5278,7 +5278,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/SubscriptionPreview" + $ref: "#/components/schemas/PreviewSubscribeResponse" "400": description: Bad request content: @@ -5311,7 +5311,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/BillingStatus" + $ref: "#/components/schemas/BillingStatusResponse" "401": description: Unauthorized content: @@ -5359,7 +5359,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/BillingSubscription" + $ref: "#/components/schemas/SubscribeResponse" "400": description: Bad request content: @@ -5392,7 +5392,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/BillingSubscription" + $ref: "#/components/schemas/CancelSubscriptionResponse" "401": description: Unauthorized content: @@ -5425,7 +5425,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/BillingSubscription" + $ref: "#/components/schemas/ResubscribeResponse" "401": description: Unauthorized content: @@ -5470,7 +5470,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/BillingBalance" + $ref: "#/components/schemas/CreateTopupResponse" "400": description: Bad request content: @@ -5555,7 +5555,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/WorkspaceApiKeyCreated" + $ref: "#/components/schemas/CreateWorkspaceAPIKeyResponse" "400": description: Bad request content: @@ -5704,7 +5704,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/WorkspaceInvite" + $ref: "#/components/schemas/PendingInvite" "400": description: Bad request content: @@ -6486,7 +6486,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/Workspace" + $ref: "#/components/schemas/AcceptInviteResponse" "400": description: Bad request content: @@ -6586,7 +6586,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/SecretMeta" + $ref: "#/components/schemas/SecretResponse" "400": description: Bad request content: @@ -6645,7 +6645,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/SecretMeta" + $ref: "#/components/schemas/SecretResponse" "401": description: Unauthorized content: @@ -6702,7 +6702,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/SecretMeta" + $ref: "#/components/schemas/SecretResponse" "400": description: Bad request content: @@ -6805,7 +6805,7 @@ paths: content: application/json: schema: - $ref: "#/components/schemas/CloudUser" + $ref: "#/components/schemas/UserResponse" "401": description: Unauthorized content: @@ -11379,3 +11379,290 @@ components: message: type: string description: Human-readable error message + + AcceptInviteResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Response returned after successfully accepting a workspace invitation.' + required: + - workspace_id + - workspace_name + properties: + workspace_id: + type: string + description: ID of the workspace joined + workspace_name: + type: string + description: Name of the workspace joined + + BillingEventsResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Paginated list of billing events for a workspace.' + required: + - total + - events + - page + - limit + - totalPages + properties: + total: + type: integer + description: Total number of events + events: + type: array + items: + $ref: '#/components/schemas/BillingEvent' + page: + type: integer + description: Current page number (1-indexed) + limit: + type: integer + description: Items per page + totalPages: + type: integer + description: Total number of pages + + BillingOpStatusResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Status of an asynchronous billing operation.' + required: + - id + - status + - started_at + properties: + id: + type: string + description: Unique identifier for the billing operation + status: + type: string + enum: + - pending + - succeeded + - failed + description: Current status of the operation + error_message: + type: string + description: Error message if status is failed + started_at: + type: string + format: date-time + description: When the operation was initiated + completed_at: + type: string + format: date-time + description: When the operation completed (success or failure) + + CancelSubscriptionResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Response after successfully cancelling a subscription.' + required: + - cancel_at + - billing_op_id + properties: + billing_op_id: + type: string + description: Billing operation ID to poll for status via GET /api/billing/ops/{id} + cancel_at: + type: string + format: date-time + description: The date when the subscription will end (end of current billing period) + + CreateTopupResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Response after successfully purchasing a credit top-up.' + required: + - topup_id + - status + - amount_cents + - billing_op_id + properties: + billing_op_id: + type: string + description: Billing operation ID to poll for status via GET /api/billing/ops/{id} + topup_id: + type: string + description: Unique identifier for the top-up request (same as billing_op_id, deprecated) + status: + type: string + enum: + - pending + - completed + - failed + description: Current status of the top-up + amount_cents: + type: integer + format: int64 + description: Amount being charged in cents + + CreateWorkspaceAPIKeyResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Response containing the newly created workspace API key.' + required: + - id + - name + - description + - key + - key_prefix + - created_at + properties: + id: + type: string + format: uuid + description: API key ID + name: + type: string + description: User-provided label + description: + type: string + description: User-provided description of the key's purpose. Limit is byte-based (UTF-8 encoding); 5000 bytes equals + 5000 ASCII characters or fewer multi-byte characters. + maxLength: 5000 + key: + type: string + description: The full plaintext API key (only shown once) + key_prefix: + type: string + description: First 8 chars after prefix for display + expires_at: + type: string + format: date-time + description: When the key expires (if set) + created_at: + type: string + format: date-time + description: When the key was created + + ExchangeTokenResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Response containing the issued Cloud JWT and its expiry.' + required: + - token + - expires_at + - workspace + - role + - permissions + properties: + token: + type: string + description: Cloud JWT token + expires_at: + type: string + format: date-time + description: Token expiration time (RFC 3339) + workspace: + $ref: '#/components/schemas/WorkspaceSummary' + role: + type: string + enum: + - owner + - member + description: User's role in the workspace + permissions: + type: array + items: + type: string + description: Permission strings for the role + example: + - owner:* + + JobCancelResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Response for POST /api/jobs/{job_id}/cancel. Returned on both fresh cancels and idempotent no-ops.' + required: + - cancelled + properties: + cancelled: + type: boolean + description: "True when a cancel event was successfully dispatched by this call.\nFalse when the job was already in\ + \ a terminal or cancelling state,\nin which case the call is a no-op (still 200 \u2014 idempotent).\n" + + ResubscribeResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Response after successfully resubscribing to a billing plan.' + required: + - status + - billing_op_id + properties: + billing_op_id: + type: string + description: Billing operation ID to poll for status via GET /api/billing/ops/{id} + status: + type: string + enum: + - active + description: The subscription status after resubscribing + message: + type: string + description: Human-readable confirmation message + + SubscribeResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Response after successfully subscribing to a billing plan.' + required: + - status + - billing_op_id + properties: + billing_op_id: + type: string + description: Billing operation ID to poll for status via GET /api/billing/ops/{id} + status: + type: string + enum: + - subscribed + - needs_payment_method + - pending_payment + description: 'Status of the subscription operation: + + - subscribed: Subscription is active immediately + + - needs_payment_method: User must add payment method via payment_method_url + + - pending_payment: Upgrade initiated, waiting for payment to complete + + ' + effective_at: + type: string + format: date-time + description: When the subscription became/becomes active (present when status=subscribed or pending_payment) + payment_method_url: + type: string + description: URL to redirect user to add payment method (present when status=needs_payment_method) + + UserResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] User information response' + required: + - id + - status + properties: + id: + type: string + description: Firebase UID of the authenticated user + status: + type: string + description: User status (always "active" for authenticated users) + + WorkflowListResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Paginated list of saved workflows.' + required: + - data + - pagination + properties: + data: + type: array + items: + $ref: '#/components/schemas/WorkflowResponse' + pagination: + $ref: '#/components/schemas/PaginationInfo' From 187442cca4594a59c563780e7cd144e6d8dc02ab Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Fri, 22 May 2026 18:23:22 -0700 Subject: [PATCH 35/45] openapi: add enum values + FeedbackRequest schema for cloud cutover (PR E) (#14070) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * openapi: add enum values + FeedbackRequest schema for cloud cutover (PR E) Adds missing cloud-runtime enum values to vendor schemas that the cloud runtime emits but vendor declared as plain strings. Changes: - JobEntry.status: enum [pending, in_progress, completed, failed, cancelled] - JobDetailResponse.status: same enum - BillingStatus: enum [awaiting_payment_method, pending_payment, paid, payment_failed, inactive] - FeedbackRequest schema added (with type enum) - /api/feedback POST: requestBody now $refs FeedbackRequest All cloud-runtime-emitted; no impact on OSS-local semantics. Identified via Comfy-Org/cloud's TestCutoverSafe gate (BE-1106) as the remaining schema-level divergences after PRs A-D landed and got synced. * openapi: add type enum to Workspace schema (cutover follow-up) Cloud's Workspace runtime shape includes a 'type' field with enum [personal, team] that vendor's Workspace was missing. Cloud handlers reference the generated ingest.WorkspaceType Go enum. Same kind of surgical addition as JobEntry.status / BillingStatus / JobDetailResponse.status in this PR — adds cloud-runtime field to existing vendor schema. --- openapi.yaml | 62 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 46 insertions(+), 16 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index bbe5b3562..2347bd659 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -6315,22 +6315,7 @@ paths: content: application/json: schema: - type: object - required: - - message - properties: - message: - type: string - description: Feedback message - rating: - type: integer - minimum: 1 - maximum: 5 - description: Optional satisfaction rating - context: - type: object - additionalProperties: true - description: Additional context metadata + $ref: "#/components/schemas/FeedbackRequest" responses: "201": description: Feedback submitted @@ -7535,6 +7520,12 @@ components: description: Unique job identifier (same as prompt_id) status: type: string + enum: + - pending + - in_progress + - completed + - failed + - cancelled description: Current job status create_time: type: integer @@ -7568,6 +7559,12 @@ components: format: uuid status: type: string + enum: + - pending + - in_progress + - completed + - failed + - cancelled workflow: type: object additionalProperties: true @@ -9598,6 +9595,12 @@ components: $ref: "#/components/schemas/BillingBalance" has_payment_method: type: boolean + enum: + - awaiting_payment_method + - pending_payment + - paid + - payment_failed + - inactive BillingSubscription: type: object @@ -9659,6 +9662,12 @@ components: type: string name: type: string + type: + type: string + enum: + - personal + - team + description: Workspace type (personal vs. team). owner_id: type: string member_count: @@ -11666,3 +11675,24 @@ components: $ref: '#/components/schemas/WorkflowResponse' pagination: $ref: '#/components/schemas/PaginationInfo' + + FeedbackRequest: + type: object + x-runtime: [cloud] + description: "[cloud-only] User feedback submission body." + required: + - message + properties: + type: + type: string + enum: + - missing_nodes + - general + - missing_models + description: Feedback category + category: + type: string + description: Additional category metadata + message: + type: string + description: User-provided feedback message From d80fcafee78a9453e89c21da41ecc815ad69a116 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 22 May 2026 19:56:36 -0700 Subject: [PATCH 36/45] Remove dead code. (#14072) --- comfy/samplers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/comfy/samplers.py b/comfy/samplers.py index 0a4d062db..c5e36ff05 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -265,7 +265,6 @@ def _calc_cond_batch(model: BaseModel, conds: list[list[dict]], x_in: torch.Tens input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] cond_shapes = collections.defaultdict(list) for tt in batch_amount: - cond = {k: v.size() for k, v in to_run[tt][0].conditioning.items()} for k, v in to_run[tt][0].conditioning.items(): cond_shapes[k].append(v.size()) From 0af123022de374a091d7bf6ca6ad767fa6dcc69d Mon Sep 17 00:00:00 2001 From: Comfy Org PR Bot Date: Sun, 24 May 2026 09:27:52 +0900 Subject: [PATCH 37/45] Bump comfyui-frontend-package to 1.44.19 (#14074) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index e20b6e044..b70c21e1e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -comfyui-frontend-package==1.43.18 +comfyui-frontend-package==1.44.19 comfyui-workflow-templates==0.9.82 comfyui-embedded-docs==0.5.0 torch From 08d809d128df9c6b6800dbb4198cf11cabc5422e Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sat, 23 May 2026 17:44:28 -0700 Subject: [PATCH 38/45] Fix --use-flash-attention ignored when xformers installed. (#14083) --- comfy/ldm/modules/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index a68cb8439..55360535a 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -741,12 +741,12 @@ optimized_attention = attention_basic if model_management.sage_attention_enabled(): logging.info("Using sage attention") optimized_attention = attention_sage -elif model_management.xformers_enabled(): - logging.info("Using xformers attention") - optimized_attention = attention_xformers elif model_management.flash_attention_enabled(): logging.info("Using Flash Attention") optimized_attention = attention_flash +elif model_management.xformers_enabled(): + logging.info("Using xformers attention") + optimized_attention = attention_xformers elif model_management.pytorch_attention_enabled(): logging.info("Using pytorch attention") optimized_attention = attention_pytorch From 32a7092c52d2cee053fded50a6e12c7e275b195e Mon Sep 17 00:00:00 2001 From: Robin Huang Date: Sat, 23 May 2026 19:48:31 -0700 Subject: [PATCH 39/45] fix: correct description of where compiled FE files live (#14013) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5125bad14..dc2389266 100644 --- a/README.md +++ b/README.md @@ -433,7 +433,7 @@ See also: [https://www.comfy.org/](https://www.comfy.org/) ## Frontend Development -As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory. +As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). The compiled JS files (from TS/Vue) are published to [pypi](https://pypi.org/project/comfyui-frontend-package) and installed as a dependency in ComfyUI. ### Reporting Issues and Requesting Features From ea62dc11c9a10dae52186fdcc3da033eb46018a1 Mon Sep 17 00:00:00 2001 From: Matt Miller Date: Sat, 23 May 2026 19:58:35 -0700 Subject: [PATCH 40/45] openapi: fix invalid BillingStatus schema (object + enum hybrid) (#14071) --- openapi.yaml | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index 2347bd659..502e518c7 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -9585,16 +9585,9 @@ components: description: List of plan features BillingStatus: - type: object + type: string x-runtime: [cloud] - description: "[cloud-only] Overall billing and subscription status." - properties: - subscription: - $ref: "#/components/schemas/BillingSubscription" - balance: - $ref: "#/components/schemas/BillingBalance" - has_payment_method: - type: boolean + description: "[cloud-only] Overall billing/payment lifecycle status." enum: - awaiting_payment_method - pending_payment From 39f963b4b02522b0103fe7ca53fa8d1a0d17ceae Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 25 May 2026 08:25:59 +1000 Subject: [PATCH 41/45] mark loads to pins as cold immediately (#14088) This does the posix_fadvise to kick pins out of the disk cache (to avoid a double copy in RAM). --- comfy/model_management.py | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 3894dfa9c..cd8772d3a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1217,7 +1217,7 @@ def get_aimdo_cast_buffer(offload_stream, device): def get_pin_buffer(offload_stream): pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None) if pin_buffer is None: - pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3)) + pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3), mark_cold=False) STREAM_PIN_BUFFERS[offload_stream] = pin_buffer elif offload_stream is not None: event = getattr(pin_buffer, "_comfy_event", None) diff --git a/requirements.txt b/requirements.txt index b70c21e1e..a22fa50ad 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0 filelock av>=14.2.0 comfy-kitchen>=0.2.8 -comfy-aimdo==0.4.3 +comfy-aimdo==0.4.5 requests simpleeval>=1.0.0 blake3 From b30e980a206607d1a9d56b7a6f7df3999d68438a Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Mon, 25 May 2026 08:26:50 +1000 Subject: [PATCH 42/45] cache-ram: lower thresholds (#14089) Use the RAM right up to the wire as the community is bit accustomed too. This trades off headroom for the case where large chunky intermediates arrive and potenitally hits pagefile/swap, but a lot of people have "it just fits" workflows out there, so strike a compromise with 75->90%. Disable the incative cache for all but the very high RAM users. --- comfy/cli_args.py | 2 +- main.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 9d88c8517..47b8174f4 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -111,7 +111,7 @@ parser.add_argument("--preview-method", type=LatentPreviewMethod, default=Latent parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.") cache_group = parser.add_mutually_exclusive_group() -cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metavar="GB", help="Use RAM pressure caching with the specified headroom thresholds. This is the default caching mode. The first value sets the active-cache threshold; the optional second value sets the inactive-cache/pin threshold. Defaults when no values are provided: active 25%% of system RAM (min 4GB, max 32GB), inactive 75%% of system RAM (min 12GB, max 96GB).") +cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metavar="GB", help="Use RAM pressure caching with the specified headroom thresholds. This is the default caching mode. The first value sets the active-cache threshold; the optional second value sets the inactive-cache/pin threshold. Defaults when no values are provided: active 10%% of system RAM (min 2GB, max 10GB), inactive 100%% of system RAM (max 96GB).") cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") diff --git a/main.py b/main.py index 1e47cab84..f23074942 100644 --- a/main.py +++ b/main.py @@ -286,8 +286,8 @@ def prompt_worker(q, server_instance): cache_ram = 0 cache_ram_inactive = 0 if not args.cache_classic and not args.cache_none and args.cache_lru <= 0: - cache_ram = min(32.0, max(4.0, comfy.model_management.total_ram * 0.25 / 1024.0)) - cache_ram_inactive = min(96.0, max(12.0, comfy.model_management.total_ram * 0.75 / 1024.0)) + cache_ram = min(10.0, max(2.0, comfy.model_management.total_ram * 0.10 / 1024.0)) + cache_ram_inactive = min(96.0, comfy.model_management.total_ram / 1024.0) if len(args.cache_ram) > 0: cache_ram = args.cache_ram[0] if len(args.cache_ram) > 1: From 63bcaec5d14cb309679a72ddbf875c5dc8d62d46 Mon Sep 17 00:00:00 2001 From: Talmaj Date: Mon, 25 May 2026 04:00:55 +0200 Subject: [PATCH 43/45] Add colored logs (#14036) --- app/logger.py | 40 ++++++++++++++++++++++++++++++++++++++-- main.py | 4 ++-- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/app/logger.py b/app/logger.py index 3d26d98fe..bde815822 100644 --- a/app/logger.py +++ b/app/logger.py @@ -5,6 +5,40 @@ import logging import sys import threading +ANSI_NAMED_COLORS = { + 'black': '\033[30m', + 'red': '\033[31m', + 'green': '\033[32m', + 'yellow': '\033[33m', + 'blue': '\033[34m', + 'magenta': '\033[35m', + 'cyan': '\033[36m', + 'white': '\033[37m', +} + +ANSI_LEVEL_COLORS = { + 'DEBUG': ANSI_NAMED_COLORS['cyan'], + 'INFO': ANSI_NAMED_COLORS['green'], + 'WARNING': ANSI_NAMED_COLORS['yellow'], + 'ERROR': ANSI_NAMED_COLORS['red'], + 'CRITICAL': ANSI_NAMED_COLORS['magenta'], +} + +ANSI_RESET = '\033[0m' +ANSI_BOLD = '\033[1m' + + +class ColoredFormatter(logging.Formatter): + def format(self, record): + color = ANSI_LEVEL_COLORS.get(record.levelname, '') + bold = ANSI_BOLD if record.levelno >= logging.WARNING else '' + level_tag = f"{bold}{color}[{record.levelname}]{ANSI_RESET} " + message = super().format(record) + line_color = ANSI_NAMED_COLORS.get(getattr(record, 'color', ''), '') + if line_color: + return f"{level_tag}{line_color}{message}{ANSI_RESET}" + return level_tag + message + logs = None stdout_interceptor = None stderr_interceptor = None @@ -68,8 +102,10 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool logger = logging.getLogger() logger.setLevel(log_level) + formatter = ColoredFormatter("%(message)s") + stream_handler = logging.StreamHandler() - stream_handler.setFormatter(logging.Formatter("%(message)s")) + stream_handler.setFormatter(formatter) if use_stdout: # Only errors and critical to stderr @@ -77,7 +113,7 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool # Lesser to stdout stdout_handler = logging.StreamHandler(sys.stdout) - stdout_handler.setFormatter(logging.Formatter("%(message)s")) + stdout_handler.setFormatter(formatter) stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR) logger.addHandler(stdout_handler) diff --git a/main.py b/main.py index f23074942..26d523c30 100644 --- a/main.py +++ b/main.py @@ -344,9 +344,9 @@ def prompt_worker(q, server_instance): # Log Time in a more readable way after 10 minutes if execution_time > 600: execution_time = time.strftime("%H:%M:%S", time.gmtime(execution_time)) - logging.info(f"Prompt executed in {execution_time}") + logging.info(f"Prompt executed in {execution_time}", extra={'color': 'green'}) else: - logging.info("Prompt executed in {:.2f} seconds".format(execution_time)) + logging.info("Prompt executed in {:.2f} seconds".format(execution_time), extra={'color': 'green'}) if not asset_seeder.is_disabled(): paths = _collect_output_absolute_paths(e.history_result) From 0077d78cbfb44eee4d8dabc27c84f8b1ab7e2852 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Sun, 24 May 2026 20:01:34 -0700 Subject: [PATCH 44/45] Save Image advanced node (CORE-32) (#13850) --- comfy_extras/nodes_images.py | 408 +++++++++++++++++++++++++++++++++++ 1 file changed, 408 insertions(+) diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 4856346d7..33933229d 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -3,15 +3,23 @@ from __future__ import annotations import nodes import folder_paths +import av import json + import os import re import math +import numpy as np +import struct import torch + +import zlib import comfy.utils +from fractions import Fraction from server import PromptServer from comfy_api.latest import ComfyExtension, IO, UI +from comfy.cli_args import args from typing_extensions import override SVG = IO.SVG.Type # TODO: temporary solution for backward compatibility, will be removed later. @@ -835,6 +843,405 @@ class ImageMergeTileList(IO.ComfyNode): return IO.NodeOutput(merged_image) +# --------------------------------------------------------------------------- +# Format specifications +# --------------------------------------------------------------------------- + +# Maps (file_format, bit_depth, has_alpha) -> (numpy dtype scale, av pixel format, +# stream pix_fmt). Keeps the encode path declarative instead of branchy. +_FORMAT_SPECS = { + ("png", "8-bit", False): {"scale": 255.0, "dtype": np.uint8, "frame_fmt": "rgb24", "stream_fmt": "rgb24"}, + ("png", "8-bit", True): {"scale": 255.0, "dtype": np.uint8, "frame_fmt": "rgba", "stream_fmt": "rgba"}, + ("png", "16-bit", False): {"scale": 65535.0, "dtype": np.uint16, "frame_fmt": "rgb48le", "stream_fmt": "rgb48be"}, + ("png", "16-bit", True): {"scale": 65535.0, "dtype": np.uint16, "frame_fmt": "rgba64le", "stream_fmt": "rgba64be"}, + ("exr", "32-bit float", False): {"scale": 1.0, "dtype": np.float32, "frame_fmt": "gbrpf32le", "stream_fmt": "gbrpf32le"}, + ("exr", "32-bit float", True): {"scale": 1.0, "dtype": np.float32, "frame_fmt": "gbrapf32le", "stream_fmt": "gbrapf32le"}, +} + + +# --------------------------------------------------------------------------- +# Color transforms +# --------------------------------------------------------------------------- + +def srgb_to_linear(t: torch.Tensor) -> torch.Tensor: + """Inverse sRGB EOTF (IEC 61966-2-1). Operates on RGB channels only; + alpha (if present as the 4th channel) is passed through unchanged.""" + if t.shape[-1] == 4: + rgb, alpha = t[..., :3], t[..., 3:] + return torch.cat([srgb_to_linear(rgb), alpha], dim=-1) + + # Piecewise: linear toe below 0.04045, gamma curve above. + low = t / 12.92 + high = ((t.clamp(min=0.0) + 0.055) / 1.055) ** 2.4 + return torch.where(t <= 0.04045, low, high) + + +# HLG OETF constants from BT.2100 Table 5. +_HLG_A = 0.17883277 +_HLG_B = 0.28466892 +_HLG_C = 0.55991072928 # = 0.5 - a*ln(4*a) + + +def hlg_to_linear(t: torch.Tensor) -> torch.Tensor: + """Inverse HLG OETF (BT.2100). Maps a non-linear HLG signal in [0, 1] to + *scene*-linear light in [0, 1]. Per BT.2100 Note 5a, this is the correct + transform when converting HLG to a linear scene-light representation + (rather than display-light, which would also involve the HLG OOTF). + + Operates on RGB channels only; alpha is passed through unchanged.""" + if t.shape[-1] == 4: + rgb, alpha = t[..., :3], t[..., 3:] + return torch.cat([hlg_to_linear(rgb), alpha], dim=-1) + + # Piecewise: sqrt branch below 0.5, log branch above. + # Clamp inside the log branch so negative / out-of-range values don't blow up; + # values above 1.0 are allowed and extrapolate naturally. + low = (t ** 2) / 3.0 + high = (torch.exp((t.clamp(min=_HLG_C) - _HLG_C) / _HLG_A) + _HLG_B) / 12.0 + return torch.where(t <= 0.5, low, high) + + +# --------------------------------------------------------------------------- +# Metadata injection +# --------------------------------------------------------------------------- + +_PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n" + + +def _png_chunk(chunk_type: bytes, data: bytes) -> bytes: + """Build a single PNG chunk: length | type | data | CRC32(type+data).""" + crc = zlib.crc32(chunk_type + data) & 0xFFFFFFFF + return struct.pack(">I", len(data)) + chunk_type + data + struct.pack(">I", crc) + + +def _png_text_chunk(keyword: str, text: str) -> bytes: + """tEXt chunk: latin-1 keyword + NUL + latin-1 text.""" + payload = keyword.encode("latin-1") + b"\x00" + text.encode("latin-1", errors="replace") + return _png_chunk(b"tEXt", payload) + + +def inject_png_metadata(png_bytes: bytes, prompt: dict | None, extra_pnginfo: dict | None) -> bytes: + """Insert ComfyUI prompt/workflow as tEXt chunks right after IHDR.""" + if not png_bytes.startswith(_PNG_SIGNATURE): + return png_bytes + + chunks: list[bytes] = [] + if prompt is not None: + chunks.append(_png_text_chunk("prompt", json.dumps(prompt))) + if extra_pnginfo: + for key, value in extra_pnginfo.items(): + chunks.append(_png_text_chunk(key, json.dumps(value))) + if not chunks: + return png_bytes + + # IHDR is always the first chunk; insert ours immediately after it. + ihdr_length = struct.unpack(">I", png_bytes[8:12])[0] + ihdr_end = 8 + 8 + ihdr_length + 4 # signature + (len+type) + data + crc + return png_bytes[:ihdr_end] + b"".join(chunks) + png_bytes[ihdr_end:] + + +# Standard chromaticities (CIE 1931 xy) for the colorspaces this node writes. +# Each tuple is (Rx, Ry, Gx, Gy, Bx, By, Wx, Wy). All share D65 white point. +_CHROMATICITIES = { + # ITU-R BT.709 / sRGB primaries + "Rec.709": (0.6400, 0.3300, 0.3000, 0.6000, 0.1500, 0.0600, 0.3127, 0.3290), + # ITU-R BT.2020 (UHDTV / wide-gamut HDR) primaries + "Rec.2020": (0.7080, 0.2920, 0.1700, 0.7970, 0.1310, 0.0460, 0.3127, 0.3290), +} + + +def _pack_chromaticities(primaries: tuple) -> bytes: + """Serialize 8 chromaticity floats into the EXR `chromaticities` payload.""" + return struct.pack("<8f", *primaries) + + +def _exr_attribute(name: str, attr_type: str, value: bytes) -> bytes: + """Serialize one EXR header attribute: name\\0 type\\0 size:int32 value.""" + return ( + name.encode("utf-8") + b"\x00" + + attr_type.encode("utf-8") + b"\x00" + + struct.pack(" bytes: + """Insert ComfyUI metadata and color-space info into an EXR header. + + Color: EXR pixels are linear by convention. The standard way to describe + their RGB→XYZ relationship is the `chromaticities` attribute. We pick the + primaries that match what the user told us their input was: + + colorspace="sRGB" → Rec. 709 / sRGB primaries (D65) + colorspace="HDR" → Rec. 2020 / BT.2100 primaries (D65) + + Pixels are always converted to linear scene light upstream (sRGB EOTF + inverse for sRGB; HLG OETF inverse for HDR), so the file content is + scene-linear in the indicated gamut. OpenEXR has no standard transfer- + function attribute (the OpenEXR TSC has discussed adding one but it + doesn't exist), so we don't invent one — `chromaticities` plus the EXR + linear-by-convention rule fully specifies the color. + + Prompt/workflow: written as plain `string` attributes using the same keys + (`prompt`, `workflow`, ...) that Comfy uses for PNG tEXt chunks, so the + same readers can pull them out symmetrically. + + Implementation note: the chunk-offset table that follows the header stores + *absolute* byte offsets into the file. Inserting N bytes into the header + means every offset must be incremented by N or the file becomes unreadable. + """ + if len(exr_bytes) < 8 or exr_bytes[:4] != b"\x76\x2f\x31\x01": + return exr_bytes + + new_blob = b"" + if prompt is not None: + new_blob += _exr_attribute("prompt", "string", json.dumps(prompt).encode("utf-8")) + if extra_pnginfo: + for key, value in extra_pnginfo.items(): + new_blob += _exr_attribute(key, "string", json.dumps(value).encode("utf-8")) + if colorspace is not None: + # Map each colorspace option to the RGB primaries the linear pixels + # are now in. "sRGB" and "linear" both produce Rec. 709 linear; "HDR" + # (HLG-encoded Rec. 2020 input) produces Rec. 2020 linear. + primaries_name = { + "sRGB": "Rec.709", + "linear": "Rec.709", + "HDR": "Rec.2020", + }.get(colorspace, "Rec.709") + new_blob += _exr_attribute( + "chromaticities", + "chromaticities", + _pack_chromaticities(_CHROMATICITIES[primaries_name]), + ) + if not new_blob: + return exr_bytes + + # Walk header attributes to find the terminating null byte, and pick up + # dataWindow + compression so we know how many chunks the offset table has. + pos = 8 # past magic (4) + version (4) + data_window = None + compression = 0 + while pos < len(exr_bytes) and exr_bytes[pos] != 0: + name_end = exr_bytes.index(b"\x00", pos) + attr_name = exr_bytes[pos:name_end].decode("latin-1", errors="replace") + type_end = exr_bytes.index(b"\x00", name_end + 1) + attr_type = exr_bytes[name_end + 1:type_end].decode("latin-1", errors="replace") + size = struct.unpack(" bytes: + """Encode a single HxWxC tensor to PNG or EXR bytes in memory. + + For EXR the input is interpreted according to `colorspace` and converted + to scene-linear (EXR's convention) before writing: + + "sRGB" → input is sRGB-encoded Rec. 709; apply inverse sRGB EOTF. + "HDR" → input is HLG-encoded Rec. 2020 (BT.2100); apply inverse HLG + OETF to get scene-linear, per BT.2100 Note 5a. + "linear" → input is already scene-linear (Rec. 709 primaries); write + through unchanged. Use this for renderer/compositor output. + + For PNG, colorspace selection does not modify pixels — PNG is delivered + sRGB-encoded and there is no PNG path for wide-gamut HDR in this node. + """ + height, width, num_channels = img_tensor.shape + has_alpha = num_channels == 4 + + spec = _FORMAT_SPECS[(file_format, bit_depth, has_alpha)] + + if spec["dtype"] == np.float32: + # EXR path: preserve full range, no clamp. + if colorspace == "sRGB": + img_tensor = srgb_to_linear(img_tensor) + elif colorspace == "HDR": + img_tensor = hlg_to_linear(img_tensor) + img_np = img_tensor.cpu().numpy().astype(np.float32) + else: + # PNG path: quantize to integer range. + scaled = (img_tensor * spec["scale"]).clamp(0, spec["scale"]) + img_np = scaled.to(torch.int32).cpu().numpy().astype(spec["dtype"]) + + # Encode directly via CodecContext. PyAV's `image2` muxer does NOT write to + # BytesIO (it expects a real file path), so we bypass the container entirely. + # For single-frame PNG/EXR the raw codec output IS the file. + codec = av.CodecContext.create(file_format, "w") + codec.width = width + codec.height = height + codec.pix_fmt = spec["stream_fmt"] + codec.time_base = Fraction(1, 1) + + frame = av.VideoFrame.from_ndarray(img_np, format=spec["frame_fmt"]) + if spec["frame_fmt"] != spec["stream_fmt"]: + frame = frame.reformat(format=spec["stream_fmt"]) + frame.pts = 0 + frame.time_base = codec.time_base + + packets = list(codec.encode(frame)) + list(codec.encode(None)) # flush with None + return b"".join(bytes(p) for p in packets) + + +# --------------------------------------------------------------------------- +# Node +# --------------------------------------------------------------------------- + +class SaveImageAdvanced(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveImageAdvanced", + search_aliases=["save", "save image", "export image", "output image", "write image"], + display_name="Save Image (Advanced)", + description="Saves the input images to your ComfyUI output directory.", + category="image", + essentials_category="Basics", + inputs=[ + IO.Image.Input("images", tooltip="The images to save."), + IO.String.Input( + "filename_prefix", + default="ComfyUI", + tooltip=( + "The prefix for the file to save. May include formatting tokens " + "such as %date:yyyy-MM-dd% or %Empty Latent Image.width%." + ), + ), + IO.DynamicCombo.Input( + "format", + options=[ + IO.DynamicCombo.Option("png", [ + IO.Combo.Input("bit_depth", options=["8-bit", "16-bit"], + default="8-bit", advanced=True), + IO.Combo.Input("input_color_space", options=["sRGB"], + default="sRGB", advanced=True), + ]), + IO.DynamicCombo.Option("exr", [ + IO.Combo.Input("bit_depth", options=["32-bit float"], + default="32-bit float", advanced=True), + IO.Combo.Input( + "input_color_space", + options=["sRGB", "HDR", "linear"], + default="sRGB", + advanced=True, + tooltip=( + "Colorspace of the input tensor. The EXR is " + "always written as scene-linear in the matching " + "gamut.\n" + " 'sRGB' — input is sRGB-encoded Rec.709; " + "the inverse sRGB EOTF is applied.\n" + " 'HDR' — input is HLG-encoded Rec.2020 " + "(BT.2100); the inverse HLG OETF is applied " + "to get scene-linear light.\n" + " 'linear' — input is already scene-linear " + "(Rec.709 primaries); written through unchanged. " + "Use this for renderer/compositor output." + ), + ), + ]), + ], + tooltip="The file format in which to save the image.", + ), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, images, filename_prefix: str, format: dict) -> IO.NodeOutput: + file_format = format["format"] + bit_depth = format["bit_depth"] + colorspace = format.get("input_color_space", "sRGB") + + output_dir = folder_paths.get_output_directory() + full_output_folder, filename, counter, subfolder, filename_prefix = ( + folder_paths.get_save_image_path( + filename_prefix, output_dir, images[0].shape[1], images[0].shape[0] + ) + ) + + prompt = cls.hidden.prompt + extra_pnginfo = cls.hidden.extra_pnginfo + write_metadata = not args.disable_metadata + + results = [] + for batch_number, image in enumerate(images): + encoded = _encode_image(image, file_format, bit_depth, colorspace) + + if write_metadata: + if file_format == "png": + encoded = inject_png_metadata(encoded, prompt, extra_pnginfo) + elif file_format == "exr": + encoded = inject_exr_metadata(encoded, prompt, extra_pnginfo, colorspace) + + name = filename.replace("%batch_num%", str(batch_number)) + file = f"{name}_{counter:05}.{file_format}" + with open(os.path.join(full_output_folder, file), "wb") as f: + f.write(encoded) + + results.append({"filename": file, "subfolder": subfolder, "type": "output"}) + counter += 1 + + return IO.NodeOutput(ui={"images": results}) + + class ImagesExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -847,6 +1254,7 @@ class ImagesExtension(ComfyExtension): ImageAddNoise, SaveAnimatedWEBP, SaveAnimatedPNG, + SaveImageAdvanced, SaveSVGNode, ImageStitch, ResizeAndPadImage, From a4141a0f5a90b8a43f834f51d5f4796862965e53 Mon Sep 17 00:00:00 2001 From: "Daxiong (Lin)" Date: Tue, 26 May 2026 01:57:18 +0800 Subject: [PATCH 45/45] chore: update embedded docs to v0.5.1 (#14101) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index a22fa50ad..2ca6d8929 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ comfyui-frontend-package==1.44.19 comfyui-workflow-templates==0.9.82 -comfyui-embedded-docs==0.5.0 +comfyui-embedded-docs==0.5.1 torch torchsde torchvision