diff --git a/.github/workflows/detect-unreviewed-merge.yml b/.github/workflows/detect-unreviewed-merge.yml new file mode 100644 index 000000000..4fabecb94 --- /dev/null +++ b/.github/workflows/detect-unreviewed-merge.yml @@ -0,0 +1,24 @@ +name: Detect Unreviewed Merge + +# SOC 2 compliance — reusable workflow lives in Comfy-Org/github-workflows, +# tracking issues are filed in Comfy-Org/unreviewed-merges. + +on: + push: + branches: [master] + +concurrency: + group: detect-unreviewed-merge-${{ github.sha }} + cancel-in-progress: false + +permissions: + contents: read + pull-requests: read + +jobs: + detect: + uses: Comfy-Org/github-workflows/.github/workflows/detect-unreviewed-merge.yml@4d9cb6b87f953bb7cd69954280e1465fb9bd2040 # v1 + with: + approval-mode: latest-per-reviewer + secrets: + UNREVIEWED_MERGES_TOKEN: ${{ secrets.UNREVIEWED_MERGES_TOKEN }} diff --git a/comfy/float.py b/comfy/float.py index 184b3d6d0..3c82d6359 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -1,5 +1,20 @@ +import logging + import torch +_CK_STOCHASTIC_ROUNDING_AVAILABLE = False +try: + import comfy_kitchen as ck + _ck_stochastic_rounding_fp8 = ck.stochastic_rounding_fp8 + _CK_STOCHASTIC_ROUNDING_AVAILABLE = True +except (AttributeError, ImportError): + logging.warning("comfy_kitchen does not support stochastic FP8 rounding, please update comfy_kitchen.") + +if not _CK_STOCHASTIC_ROUNDING_AVAILABLE: + def _ck_stochastic_rounding_fp8(value, rng, dtype): + raise NotImplementedError("comfy_kitchen does not support stochastic FP8 rounding") + + def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None): mantissa_scaled = torch.where( normal_mask, @@ -57,6 +72,10 @@ def stochastic_rounding(value, dtype, seed=0): if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: generator = torch.Generator(device=value.device) generator.manual_seed(seed) + if _CK_STOCHASTIC_ROUNDING_AVAILABLE: + rng = torch.randint(0, 256, value.size(), dtype=torch.uint8, layout=value.layout, device=value.device, generator=generator) + return _ck_stochastic_rounding_fp8(value, rng, dtype) + output = torch.empty_like(value, dtype=dtype) num_slices = max(1, (value.numel() / (4096 * 4096))) slice_size = max(1, round(value.shape[0] / num_slices)) diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 9a03d964e..15b5497c8 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -804,13 +804,15 @@ class ZImagePixelSpace(ChromaRadiance): """ pass - class HiDreamO1Pixel(ChromaRadiance): """Pixel-space latent format for HiDream-O1. No VAE — model patches/unpatches raw RGB internally with patch_size=32. """ pass +class PixelDiTPixel(ChromaRadiance): + pass + class CogVideoX(LatentFormat): """Latent format for CogVideoX-2b (THUDM/CogVideoX-2b). diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py index a6258b755..c28be5b49 100644 --- a/comfy/ldm/audio/dit.py +++ b/comfy/ldm/audio/dit.py @@ -433,11 +433,11 @@ class Attention(nn.Module): if self.differential: q, q_diff = q.unbind(dim=1) k, k_diff = k.unbind(dim=1) - out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) - out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, transformer_options=transformer_options) + out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options) + out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options) out = out - out_diff else: - out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) + out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options) out = self.to_out(out) diff --git a/comfy/ldm/audio/vae_sa3.py b/comfy/ldm/audio/vae_sa3.py index 276846444..8be36d6ee 100644 --- a/comfy/ldm/audio/vae_sa3.py +++ b/comfy/ldm/audio/vae_sa3.py @@ -138,11 +138,11 @@ class Attention(nn.Module): k_diff = _apply_rotary_pos_emb(k_diff.float(), freqs).to(k_dtype) if self.differential: - out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True) - - optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True)) + out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False) + - optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)) del q, k, v, q_diff, k_diff else: - out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True) + out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False) del q, k, v return self.to_out(out) diff --git a/comfy/ldm/lens/model.py b/comfy/ldm/lens/model.py new file mode 100644 index 000000000..cd5015ddc --- /dev/null +++ b/comfy/ldm/lens/model.py @@ -0,0 +1,510 @@ +"""Lens denoising transformer (DiT)""" + +from __future__ import annotations + +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ldm.flux.layers +import comfy.patcher_extension +from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.flux.math import apply_rope +from comfy.ldm.modules.attention import optimized_attention + + +def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor: + return comfy.ldm.flux.layers.timestep_embedding(t, dim) + + +def _lens_position_ids( + frame: int, height: int, width: int, text_seq_len: int, + scale_rope: bool = True, device=None, +) -> torch.Tensor: + """Lens axial (frame, h, w) position ids for joint image + text sequence. + + With ``scale_rope=True`` h/w are centered around 0 (negative + positive + halves) and text starts at ``max(h//2, w//2)``. Result shape ``[seq, 3]``; + caller adds a batch dim for ``EmbedND``. + """ + if scale_rope: + h_pos = torch.cat([torch.arange(-(height - height // 2), 0, device=device), + torch.arange(0, height // 2, device=device)]) + w_pos = torch.cat([torch.arange(-(width - width // 2), 0, device=device), + torch.arange(0, width // 2, device=device)]) + text_start = max(height // 2, width // 2) + else: + h_pos = torch.arange(height, device=device) + w_pos = torch.arange(width, device=device) + text_start = max(height, width) + + f_pos = torch.arange(frame, device=device) + img_ids = torch.zeros(frame, height, width, 3, device=device) + img_ids[..., 0] = f_pos[:, None, None] + img_ids[..., 1] = h_pos[None, :, None] + img_ids[..., 2] = w_pos[None, None, :] + img_ids = img_ids.reshape(-1, 3) + + # Text positions replicate across all 3 axes (matches original packing). + txt_pos = torch.arange(text_start, text_start + text_seq_len, device=device).float() + txt_ids = txt_pos[:, None].expand(text_seq_len, 3) + + return torch.cat([img_ids, txt_ids], dim=0) + + +class _TimestepEmbedder(nn.Module): + def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None) -> None: + super().__init__() + self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device) + self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear_1(x) + x = F.silu(x) + return self.linear_2(x) + + +class LensTimestepProjEmbeddings(nn.Module): + def __init__(self, embedding_dim: int, dtype=None, device=None, operations=None) -> None: + super().__init__() + self.timestep_embedder = _TimestepEmbedder(256, embedding_dim, dtype=dtype, device=device, operations=operations) + + def forward(self, timestep: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor: + proj = _lens_time_proj(timestep, 256) + return self.timestep_embedder(proj.to(dtype=hidden_states.dtype)) + + +class GateMLP(nn.Module): + """SwiGLU MLP.""" + + def __init__(self, dim: int, hidden_dim: int, dtype=None, device=None, operations=None) -> None: + super().__init__() + self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device) + self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device) + self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device) + + def forward(self, x): + return self.w2(F.silu(self.w1(x), inplace=True).mul_(self.w3(x))) + + +class LensJointAttention(nn.Module): + """Joint image+text attention with fused QKV per stream.""" + + def __init__( + self, + query_dim: int, + added_kv_proj_dim: int, + dim_head: int = 64, + heads: int = 8, + out_dim: Optional[int] = None, + eps: float = 1e-5, + dtype=None, + device=None, + operations=None, + ) -> None: + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.heads = self.inner_dim // dim_head + self.dim_head = dim_head + self.out_dim = out_dim if out_dim is not None else query_dim + + self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device) + + self.img_qkv = operations.Linear(query_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device) + self.txt_qkv = operations.Linear(added_kv_proj_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device) + + # ModuleList([Linear, Identity]) for state-dict key compatibility. + self.to_out = nn.ModuleList([ + operations.Linear(self.inner_dim, self.out_dim, bias=True, dtype=dtype, device=device), + nn.Identity(), + ]) + self.to_add_out = operations.Linear(self.inner_dim, query_dim, bias=True, dtype=dtype, device=device) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + freqs_cis: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + transformer_options: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + bsz, seq_img, _ = hidden_states.shape + seq_txt = encoder_hidden_states.shape[1] + + # image stream + img_qkv = self.img_qkv(hidden_states).view(bsz, seq_img, 3, self.heads, self.dim_head) + img_q, img_k, img_v = img_qkv.unbind(dim=2) + img_q = self.norm_q(img_q) + img_k = self.norm_k(img_k) + del img_qkv + + # text stream + txt_qkv = self.txt_qkv(encoder_hidden_states).view(bsz, seq_txt, 3, self.heads, self.dim_head) + txt_q, txt_k, txt_v = txt_qkv.unbind(dim=2) + txt_q = self.norm_added_q(txt_q) + txt_k = self.norm_added_k(txt_k) + + # [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks + q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2) + del img_q, txt_q + k = torch.cat([img_k, txt_k], dim=1).transpose(1, 2) + del img_k, txt_k + v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2) + del img_v, txt_v + + q, k = apply_rope(q, k, freqs_cis) + + if attention_mask is not None: + expected = (bsz, 1, 1, seq_img + seq_txt) + if attention_mask.shape != expected: + raise ValueError( + f"attention_mask must be {expected}, got {tuple(attention_mask.shape)}" + ) + attention_mask = attention_mask.to(q.dtype) + + out = optimized_attention( + q, k, v, self.heads, mask=attention_mask, skip_reshape=True, + transformer_options=transformer_options, + ) + + img_out = self.to_out[1](self.to_out[0](out[:, :seq_img, :])) + txt_out = self.to_add_out(out[:, seq_img:, :]) + return img_out, txt_out + + +class LensTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + eps: float = 1e-6, + rms_norm: bool = True, + dtype=None, + device=None, + operations=None, + ) -> None: + super().__init__() + + self.attn = LensJointAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + eps=1e-5, + dtype=dtype, + device=device, + operations=operations, + ) + + if rms_norm: + NormCls = operations.RMSNorm + norm_kwargs = {} + else: + NormCls = operations.LayerNorm + norm_kwargs = {"elementwise_affine": False} + + mlp_hidden = int(dim / 3 * 8) + + # Sequential(SiLU, Linear) so state-dict lands at img_mod.1.{weight,bias}. + self.img_mod = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device), + ) + self.img_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs) + self.img_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs) + self.img_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations) + + self.txt_mod = nn.Sequential( + nn.SiLU(), + operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device), + ) + self.txt_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs) + self.txt_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs) + self.txt_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations) + + @staticmethod + def _modulate(x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + shift, scale, gate = mod_params.chunk(3, dim=-1) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + freqs_cis: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + transformer_options: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + img_mod1, img_mod2 = self.img_mod(temb).chunk(2, dim=-1) + txt_mod1, txt_mod2 = self.txt_mod(temb).chunk(2, dim=-1) + + img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1) + txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1) + + img_attn, txt_attn = self.attn( + hidden_states=img_modulated, + encoder_hidden_states=txt_modulated, + freqs_cis=freqs_cis, + attention_mask=attention_mask, + transformer_options=transformer_options, + ) + + hidden_states = hidden_states + img_gate1 * img_attn + encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn + + img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2) + hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2) + + txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2) + encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2) + + return encoder_hidden_states, hidden_states + + +class _AdaLayerNormContinuousNoAffine(nn.Module): + """AdaLayerNormContinuous(elementwise_affine=False). + + The reference uses ``scale, shift = chunk(2)`` (scale first) — opposite + to Flux's ``LastLayer``. + """ + + def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, eps: float = 1e-6, + dtype=None, device=None, operations=None) -> None: + super().__init__() + self.linear = operations.Linear( + conditioning_embedding_dim, embedding_dim * 2, bias=True, dtype=dtype, device=device + ) + self.eps = eps + self.embedding_dim = embedding_dim + + def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor: + emb = self.linear(F.silu(conditioning)) + scale, shift = torch.chunk(emb, 2, dim=-1) + x = F.layer_norm(x, (self.embedding_dim,), None, None, self.eps) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class LensTransformer2DModel(nn.Module): + """Lens dual-stream MMDiT (48 blocks, inner_dim=1536, multi-layer text).""" + + def __init__( + self, + patch_size: int = 2, + in_channels: int = 128, + out_channels: Optional[int] = 32, + num_layers: int = 48, + attention_head_dim: int = 64, + num_attention_heads: int = 24, + enc_hidden_dim: int = 2880, + axes_dims_rope: Tuple[int, int, int] = (8, 28, 28), + rms_norm: bool = True, + multi_layer_encoder_feature: bool = True, + selected_layer_index: Tuple[int, ...] = (5, 11, 17, 23), + image_model=None, # unused; accepted for detection-side configs. + dtype=None, + device=None, + operations=None, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.in_channels = in_channels + self.out_channels = out_channels if out_channels is not None else in_channels + self.inner_dim = num_attention_heads * attention_head_dim + self.multi_layer_encoder_feature = multi_layer_encoder_feature + self.selected_layer_index = list(selected_layer_index) + self.dtype = dtype + + self.pos_embed = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope)) + self.time_text_embed = LensTimestepProjEmbeddings( + embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations + ) + + if self.multi_layer_encoder_feature: + self.txt_norm = nn.ModuleList( + [operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device) + for _ in self.selected_layer_index] + ) + self.txt_in = operations.Linear( + enc_hidden_dim * len(self.selected_layer_index), + self.inner_dim, bias=True, dtype=dtype, device=device, + ) + else: + self.txt_norm = operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device) + self.txt_in = operations.Linear(enc_hidden_dim, self.inner_dim, bias=True, dtype=dtype, device=device) + + self.img_in = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device) + + self.transformer_blocks = nn.ModuleList([ + LensTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + eps=1e-6, + rms_norm=rms_norm, + dtype=dtype, device=device, operations=operations, + ) + for _ in range(num_layers) + ]) + + self.norm_out = _AdaLayerNormContinuousNoAffine( + self.inner_dim, self.inner_dim, eps=1e-6, + dtype=dtype, device=device, operations=operations, + ) + self.proj_out = operations.Linear( + self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, + dtype=dtype, device=device, + ) + + def forward(self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, + transformer_options: Optional[Dict[str, Any]] = None, **kwargs) -> torch.Tensor: + if transformer_options is None: + transformer_options = {} + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options), + ).execute(x, timestep, context, attention_mask, transformer_options, **kwargs) + + def _forward( + self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + transformer_options: Optional[Dict[str, Any]] = None, + control: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> torch.Tensor: + """ComfyUI bridge: ``(x[B,128,h,w], t[B], context[B,S,L*H], mask[B,S])``.""" + if transformer_options is None: + transformer_options = {} + transformer_options = transformer_options.copy() + patches = transformer_options.get("patches", {}) + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + + B, C, h, w = x.shape + hidden_states = x.permute(0, 2, 3, 1).reshape(B, h * w, C) + + if self.multi_layer_encoder_feature: + L = len(self.selected_layer_index) + enc_dim = context.shape[-1] // L + encoder_hidden_states = list( + context.reshape(B, -1, L, enc_dim).unbind(dim=2) + ) + text_seq_len = encoder_hidden_states[0].shape[1] + else: + encoder_hidden_states = context + text_seq_len = context.shape[1] + + if attention_mask is None: + attention_mask = torch.ones( + (B, text_seq_len), dtype=torch.bool, device=x.device + ) + + img_len = h * w + joint_mask = self._build_joint_attention_mask(attention_mask, img_len) + + hidden_states = self.img_in(hidden_states) + timestep = timestep.to(hidden_states.dtype) + + if self.multi_layer_encoder_feature: + normed = [self.txt_norm[i](encoder_hidden_states[i]) for i in range(L)] + encoder_hidden_states = torch.cat(normed, dim=-1) + else: + encoder_hidden_states = self.txt_norm(encoder_hidden_states) + encoder_hidden_states = self.txt_in(encoder_hidden_states) + + if "post_input" in patches: + for p in patches["post_input"]: + out = p({ + "img": hidden_states, + "txt": encoder_hidden_states, + "transformer_options": transformer_options, + }) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + + temb = self.time_text_embed(timestep, hidden_states) + ids = _lens_position_ids(1, h, w, text_seq_len, device=hidden_states.device).unsqueeze(0) + freqs_cis = self.pos_embed(ids) + + transformer_options["total_blocks"] = len(self.transformer_blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.transformer_blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["txt"], out["img"] = block( + hidden_states=args["img"], + encoder_hidden_states=args["txt"], + temb=args["vec"], + freqs_cis=args["pe"], + attention_mask=args.get("attn_mask"), + transformer_options=args.get("transformer_options"), + ) + return out + out = blocks_replace[("double_block", i)]( + { + "img": hidden_states, + "txt": encoder_hidden_states, + "vec": temb, + "pe": freqs_cis, + "attn_mask": joint_mask, + "transformer_options": transformer_options, + }, + {"original_block": block_wrap}, + ) + encoder_hidden_states = out["txt"] + hidden_states = out["img"] + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + freqs_cis=freqs_cis, + attention_mask=joint_mask, + transformer_options=transformer_options, + ) + + if "double_block" in patches: + for p in patches["double_block"]: + out = p({ + "img": hidden_states, + "txt": encoder_hidden_states, + "x": x, + "block_index": i, + "transformer_options": transformer_options, + }) + hidden_states = out["img"] + encoder_hidden_states = out["txt"] + + if control is not None: + control_i = control.get("input") + if control_i is not None and i < len(control_i): + add = control_i[i] + if add is not None: + hidden_states[:, :add.shape[1]] += add + + hidden_states = self.norm_out(hidden_states, temb) + out = self.proj_out(hidden_states) + return out.reshape(B, h, w, C).permute(0, 3, 1, 2).contiguous() + + @staticmethod + def _build_joint_attention_mask(text_mask: torch.Tensor, img_len: int) -> torch.Tensor: + if text_mask.dtype != torch.bool: + text_mask = text_mask.bool() + bsz = text_mask.shape[0] + img_ones = torch.ones((bsz, img_len), dtype=torch.bool, device=text_mask.device) + joint = torch.cat([img_ones, text_mask], dim=1) + additive = torch.zeros_like(joint, dtype=torch.float32) + additive.masked_fill_(~joint, torch.finfo(torch.float32).min) + return additive[:, None, None, :] diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 0dc8fe789..9ab3c463c 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -211,7 +211,7 @@ class TimestepEmbedder(nn.Module): Embeds scalar timesteps into vector representations. """ - def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None): + def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None, max_period=10000): super().__init__() if output_size is None: output_size = hidden_size @@ -221,9 +221,10 @@ class TimestepEmbedder(nn.Module): operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device), ) self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period def forward(self, t, dtype, **kwargs): - t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype) + t_freq = timestep_embedding(t, self.frequency_embedding_size, max_period=self.max_period).to(dtype) t_emb = self.mlp(t_freq) return t_emb diff --git a/comfy/ldm/pixeldit/model.py b/comfy/ldm/pixeldit/model.py new file mode 100644 index 000000000..b044b9b29 --- /dev/null +++ b/comfy/ldm/pixeldit/model.py @@ -0,0 +1,239 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ldm.common_dit +import comfy.patcher_extension +from comfy.ldm.flux.math import apply_rope, rope +from comfy.ldm.hidream.model import FeedForwardSwiGLU +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder + +from .modules import ( + FinalLayer, + PatchTokenEmbedder, + PiTBlock, + PixelTokenEmbedder, + apply_adaln_, + precompute_freqs_cis_2d, +) + + +class MMDiTJointAttention(nn.Module): + """Joint MMDiT attention with separate Q/K/V/proj for image and text streams. + + RoPE is applied to each stream before concatenation so each stream uses its own + 2D/1D positional encoding. Concat order is [text, image] (text first). + """ + def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None): + super().__init__() + assert dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.qkv_x = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + self.qkv_y = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + + self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + + self.proj_x = operations.Linear(dim, dim, dtype=dtype, device=device) + self.proj_y = operations.Linear(dim, dim, dtype=dtype, device=device) + + def forward(self, x, y, pos_img, pos_txt=None, attn_mask=None, transformer_options={}): + B, Nx, _ = x.shape + _, Ny, _ = y.shape + H = self.num_heads + D = self.head_dim + + qkv_x = self.qkv_x(x).reshape(B, Nx, 3, H, D).permute(2, 0, 3, 1, 4) + qx, kx, vx = qkv_x.unbind(0) + qx = self.q_norm_x(qx) + kx = self.k_norm_x(kx) + + qkv_y = self.qkv_y(y).reshape(B, Ny, 3, H, D).permute(2, 0, 3, 1, 4) + qy, ky, vy = qkv_y.unbind(0) + qy = self.q_norm_y(qy) + ky = self.k_norm_y(ky) + + qx, kx = apply_rope(qx, kx, pos_img[None, None]) + if pos_txt is not None: + qy, ky = apply_rope(qy, ky, pos_txt[None, None]) + + q_joint = torch.cat([qy, qx], dim=2) + k_joint = torch.cat([ky, kx], dim=2) + v_joint = torch.cat([vy, vx], dim=2) + + out_joint = optimized_attention( + q_joint, k_joint, v_joint, H, + mask=attn_mask, skip_reshape=True, skip_output_reshape=True, + transformer_options=transformer_options, + ) + + out_y = out_joint[:, :, :Ny, :].transpose(1, 2).reshape(B, Ny, H * D) + out_x = out_joint[:, :, Ny:, :].transpose(1, 2).reshape(B, Nx, H * D) + + return self.proj_x(out_x), self.proj_y(out_y) + + +class MMDiTBlockT2I(nn.Module): + def __init__(self, hidden_size, groups, mlp_ratio=4.0, dtype=None, device=None, operations=None): + super().__init__() + self.norm_x1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.norm_y1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.attn = MMDiTJointAttention(hidden_size, num_heads=groups, qkv_bias=False, dtype=dtype, device=device, operations=operations) + self.norm_x2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.norm_y2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.mlp_x = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations) + self.mlp_y = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations) + self.adaLN_modulation_img = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)) + self.adaLN_modulation_txt = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)) + + def forward(self, x, y, c, pos_img, pos_txt=None, attn_mask=None, transformer_options={}): + shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1) + shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1) + + x_norm = apply_adaln_(self.norm_x1(x), shift_msa_x, scale_msa_x) + y_norm = apply_adaln_(self.norm_y1(y), shift_msa_y, scale_msa_y) + attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options) + x = torch.addcmul(x, gate_msa_x, attn_x) + y = torch.addcmul(y, gate_msa_y, attn_y) + + x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln_(self.norm_x2(x), shift_mlp_x, scale_mlp_x))) + y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln_(self.norm_y2(y), shift_mlp_y, scale_mlp_y))) + return x, y + + +class PixDiT_T2I(nn.Module): + """PixelDiT T2I model. Hardcoded for the released 1024px Stage-3 checkpoint + (also runs at 512px when fed the appropriate latent size and flow_shift). + + Forward: + x: [B, 3, H, W] pixel-space input (no VAE) + timesteps:[B] in [0, 1000] (ComfyUI flow sampling convention) + context: [B, Ltxt, 2304] Gemma-2-2b-it hidden states (chi_prompt prepended) + Returns flow-matching velocity [B, 3, H, W]. + """ + def __init__( + self, + in_channels=3, + num_groups=24, + hidden_size=1536, + pixel_hidden_size=16, + pixel_attn_hidden_size=1152, + pixel_num_groups=16, + patch_depth=14, + pixel_depth=2, + patch_size=16, + txt_embed_dim=2304, + txt_max_length=300, + use_text_rope=True, + text_rope_theta=10000.0, + image_model=None, + dtype=None, + device=None, + operations=None, + pixel_mlp_chunks=2, + ): + super().__init__() + self.dtype = dtype + self.in_channels = in_channels + self.out_channels = in_channels + self.hidden_size = hidden_size + self.num_groups = num_groups + self.patch_depth = patch_depth + self.pixel_depth = pixel_depth + self.patch_size = patch_size + self.pixel_hidden_size = pixel_hidden_size + self.pixel_attn_hidden_size = pixel_attn_hidden_size + self.pixel_num_groups = pixel_num_groups + self.txt_embed_dim = txt_embed_dim + self.txt_max_length = txt_max_length + self.use_text_rope = use_text_rope + self.text_rope_theta = text_rope_theta + + self.pixel_embedder = PixelTokenEmbedder(self.in_channels, self.pixel_hidden_size, dtype=dtype, device=device, operations=operations) + self.s_embedder = PatchTokenEmbedder(self.in_channels * self.patch_size ** 2, self.hidden_size, bias=True, dtype=dtype, device=device, operations=operations) + self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations, max_period=10) + self.y_embedder = PatchTokenEmbedder(self.txt_embed_dim, self.hidden_size, bias=True, use_norm=True, dtype=dtype, device=device, operations=operations) + self.y_pos_embedding = nn.Parameter(torch.empty(1, self.txt_max_length, self.hidden_size, dtype=dtype, device=device)) + + self.patch_blocks = nn.ModuleList([ + MMDiTBlockT2I(self.hidden_size, self.num_groups, + dtype=dtype, device=device, operations=operations) + for _ in range(self.patch_depth) + ]) + self.pixel_blocks = nn.ModuleList([ + PiTBlock( + self.pixel_hidden_size, + self.hidden_size, + patch_size=self.patch_size, + num_heads=self.num_groups, + attn_hidden_size=self.pixel_attn_hidden_size, + attn_num_heads=self.pixel_num_groups, + dtype=dtype, device=device, operations=operations, + mlp_chunks=pixel_mlp_chunks, + ) + for _ in range(self.pixel_depth) + ]) + + self.final_layer = FinalLayer(self.pixel_hidden_size, self.out_channels, dtype=dtype, device=device, operations=operations) + + def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts): + return precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width, device=device, dtype=dtype, **rope_opts) + + def _fetch_text_pos(self, length, device, dtype): + return rope(torch.arange(length, dtype=torch.float32, device=device).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0).to(dtype=dtype) + + def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options), + ).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs) + + def _pre_patch_block(self, s, i, **kwargs): + """Hook for subclasses to inject per-block state into the patch stream (e.g. PiD's LQ gate).""" + return s + + def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs): + H_orig, W_orig = x.shape[2], x.shape[3] + x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) + B, _, H, W = x.shape + Hs = H // self.patch_size + Ws = W // self.patch_size + L = Hs * Ws + + pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {})) + x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2) + + t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size) + + if context is None or context.dim() != 3: + raise ValueError("PixDiT_T2I requires context (text embeddings) of shape [B, L, D]") + Ltxt = min(context.shape[1], self.txt_max_length) + y = context[:, :Ltxt, :] + y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size) + y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb) # y_pos_embedding is a raw nn.Parameter + + condition = F.silu(t_emb) + pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None + + s = self.s_embedder(x_patches) + for i, blk in enumerate(self.patch_blocks): + s = self._pre_patch_block(s, i, **kwargs) + s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options) + s = F.silu(t_emb + s) + + s_cond = s.view(B * L, self.hidden_size) + x_pixels = self.pixel_embedder(x, patch_size=self.patch_size) + for blk in self.pixel_blocks: + x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None, transformer_options=transformer_options) + + x_pixels = self.final_layer(x_pixels) + C_out = self.out_channels + P2 = self.patch_size * self.patch_size + x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).reshape(B, C_out * P2, L) + out = F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size) + return out[:, :, :H_orig, :W_orig] diff --git a/comfy/ldm/pixeldit/modules.py b/comfy/ldm/pixeldit/modules.py new file mode 100644 index 000000000..4b1e538c7 --- /dev/null +++ b/comfy/ldm/pixeldit/modules.py @@ -0,0 +1,187 @@ +import torch +import torch.nn as nn + +from comfy.ldm.flux.math import apply_rope, rope +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, get_1d_sincos_pos_embed_from_grid_torch + + +def apply_adaln_(x, shift, scale): + return x.addcmul_(x, scale).add_(shift) + + +def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0, + ref_grid_h=None, ref_grid_w=None, + scale_x=1.0, scale_y=1.0, shift_x=0.0, shift_y=0.0, + device=None, dtype=torch.float32, **kwargs): + """2D RoPE with x/y axis frequencies interleaved at stride 2 across head dim. + + rope_options: + scale_x / scale_y multiply the position range (RoPE extrapolation). + shift_x / shift_y offset the position origin (tiled / regional inference). + With ref_grid_h/w set, also applies NTK-aware per-axis theta scaling + (rope_mode='ntk_aware'): theta_axis = theta * (current/ref)^(dim_axis/(dim_axis-2)). + Returns Flux-format rotation matrices of shape [H*W, dim/2, 2, 2]. + Layout of head-dim pairs: [x_0, y_0, x_1, y_1, ..., x_{dim/4-1}, y_{dim/4-1}]. + """ + dim_axis = dim // 2 + if ref_grid_h is not None and dim_axis > 2: + h_ntk = (height / ref_grid_h) ** (dim_axis / (dim_axis - 2)) + w_ntk = (width / ref_grid_w) ** (dim_axis / (dim_axis - 2)) + else: + h_ntk = w_ntk = 1.0 + + x_lin = torch.linspace(shift_x, scale * scale_x + shift_x, width, device=device) + y_lin = torch.linspace(shift_y, scale * scale_y + shift_y, height, device=device) + y_grid, x_grid = torch.meshgrid(y_lin, x_lin, indexing="ij") + x_rope = rope(x_grid.reshape(1, -1), dim_axis, theta * w_ntk).squeeze(0) + y_rope = rope(y_grid.reshape(1, -1), dim_axis, theta * h_ntk).squeeze(0) + out = torch.stack([x_rope, y_rope], dim=2).reshape(height * width, dim // 2, 2, 2) + return out.to(dtype=dtype) + + +def get_2d_sincos_pos_embed(embed_dim, height, width, device=None, dtype=torch.float32): + """Standard 2D sin/cos absolute positional embedding (ViT-style). + + first half encodes W-coordinates, second half H. + """ + assert embed_dim % 4 == 0 + grid_h = torch.arange(height, dtype=torch.float32, device=device) + grid_w = torch.arange(width, dtype=torch.float32, device=device) + grid_y, grid_x = torch.meshgrid(grid_h, grid_w, indexing="ij") + emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_x.reshape(-1), device=device) + emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_y.reshape(-1), device=device) + return torch.cat([emb_w, emb_h], dim=1).to(dtype=dtype) + + +class RotaryAttention(nn.Module): + """Single-stream self-attention with rotary positional encoding (used inside PiTBlock).""" + def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None): + super().__init__() + assert dim % num_heads == 0 + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device) + self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) + + def forward(self, x, pos, mask=None, transformer_options={}): + B, N, C = x.shape + H = self.num_heads + D = self.head_dim + qkv = self.qkv(x).reshape(B, N, 3, H, D).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = apply_rope(self.q_norm(q), self.k_norm(k), pos[None, None]) + x = optimized_attention(q, k, v, H, mask=mask, skip_reshape=True, transformer_options=transformer_options) + return self.proj(x) + + +class FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.norm = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device) + + def forward(self, x): + return self.linear(self.norm(x)) + + +class PatchTokenEmbedder(nn.Module): + """Linear projection used both for patchified-image tokens and text-feature tokens.""" + def __init__(self, in_chans, embed_dim, use_norm=False, bias=True, dtype=None, device=None, operations=None): + super().__init__() + self.proj = operations.Linear(in_chans, embed_dim, bias=bias, dtype=dtype, device=device) + self.norm = operations.RMSNorm(embed_dim, eps=1e-6, dtype=dtype, device=device) if use_norm else nn.Identity() + + def forward(self, x): + return self.norm(self.proj(x)) + + +class PixelTokenEmbedder(nn.Module): + """Pixel-level embedder: lifts each RGB pixel to hidden_size and packs into per-patch sequences.""" + def __init__(self, in_channels, hidden_size_output, dtype=None, device=None, operations=None): + super().__init__() + self.in_channels = in_channels + self.hidden_size_output = hidden_size_output + self.proj = operations.Linear(self.in_channels, self.hidden_size_output, bias=True, dtype=dtype, device=device) + + def forward(self, inputs, patch_size): + B, _, H, W = inputs.shape + Hs, Ws = H // patch_size, W // patch_size + P2 = patch_size * patch_size + x = inputs.permute(0, 2, 3, 1).contiguous() + x = self.proj(x) + pos_full = get_2d_sincos_pos_embed(self.hidden_size_output, H, W, device=x.device, dtype=x.dtype).view(H, W, self.hidden_size_output) + x = x + pos_full.unsqueeze(0) + x = x.view(B, Hs, patch_size, Ws, patch_size, self.hidden_size_output) + return x.permute(0, 1, 3, 2, 4, 5).reshape(B * Hs * Ws, P2, self.hidden_size_output) + + +class PiTBlock(nn.Module): + """Pixel-level transformer block. + + Compresses each patch's P^2 pixel tokens → 1 attention token via a linear, + runs global self-attention across patches with 2D RoPE, then expands back to P^2 tokens. + Conditioning is per-pixel adaLN from the patch-level features. + """ + def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0, + attn_hidden_size=None, attn_num_heads=None, dtype=None, device=None, operations=None, mlp_chunks=1): + super().__init__() + self.pixel_dim = pixel_hidden_size + self.context_dim = patch_hidden_size + self.attn_dim = attn_hidden_size if attn_hidden_size is not None else patch_hidden_size + self.num_heads = attn_num_heads if attn_num_heads is not None else num_heads + assert self.attn_dim % self.num_heads == 0 + + p2 = patch_size * patch_size + self.compress_to_attn = operations.Linear(p2 * self.pixel_dim, self.attn_dim, bias=True, dtype=dtype, device=device) + self.expand_from_attn = operations.Linear(self.attn_dim, p2 * self.pixel_dim, bias=True, dtype=dtype, device=device) + + self.norm1 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device) + self.attn = RotaryAttention(self.attn_dim, num_heads=self.num_heads, qkv_bias=False, dtype=dtype, device=device, operations=operations) + self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device) + self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio), dtype=dtype, device=device, operations=operations) + + self.adaLN_modulation_msa = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device) + self.adaLN_modulation_mlp = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device) + + self._rope_fn = precompute_freqs_cis_2d + self.mlp_chunks = max(1, int(mlp_chunks)) + + def _fetch_pos(self, height, width, device, dtype, **rope_opts): + return self._rope_fn(self.attn_dim // self.num_heads, height, width, device=device, dtype=dtype, **rope_opts) + + def forward(self, x, s_cond, image_height, image_width, patch_size, mask=None, transformer_options={}): + BL, P2, _ = x.shape + Hs, Ws = image_height // patch_size, image_width // patch_size + L = Hs * Ws + B = BL // L + + # Attention path uses only msa params; compute, use, free before mlp params allocate. + msa_params = self.adaLN_modulation_msa(s_cond).view(BL, P2, 3 * self.pixel_dim) + shift_msa, scale_msa, gate_msa = msa_params.chunk(3, dim=-1) + + x_norm = apply_adaln_(self.norm1(x), shift_msa, scale_msa) + x_flat = x_norm.view(BL, P2 * self.pixel_dim) + + x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim) + pos_comp = self._fetch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {})) + attn_out = self.attn(x_comp, pos_comp, mask=mask, transformer_options=transformer_options) + attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim)) + attn_exp = attn_flat.view(BL, P2, self.pixel_dim) + x = torch.addcmul(x, gate_msa, attn_exp) + del msa_params, shift_msa, scale_msa, gate_msa + + mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim) + shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1) + gate_mlp = gate_mlp.contiguous() # detach from mlp_params so the del below frees shift+scale storage before the MLP + mlp_input = apply_adaln_(self.norm2(x), shift_mlp, scale_mlp) + del mlp_params, shift_mlp, scale_mlp + + # MLP in chunks since the peak memory usage is huge here + chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks + for s in range(0, BL, chunk_size): + e = min(s + chunk_size, BL) + x[s:e].addcmul_(gate_mlp[s:e], self.mlp(mlp_input[s:e])) + return x diff --git a/comfy/ldm/pixeldit/pid.py b/comfy/ldm/pixeldit/pid.py new file mode 100644 index 000000000..21b73907a --- /dev/null +++ b/comfy/ldm/pixeldit/pid.py @@ -0,0 +1,227 @@ +"""PiD — Pixel Diffusion Decoder. Decodes a Flux/SD3/Flux2/Z-Image latent +directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I +body + LQ projection branch injected before each MMDiT patch block. +""" + +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .model import PixDiT_T2I +from .modules import precompute_freqs_cis_2d + + +class SigmaAwareGatePerTokenPerDim(nn.Module): + """gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq. + + Trained init gives ~0.88 gate at sigma=0, ~0.05 at sigma=1. + """ + + def __init__(self, dim: int, dtype=None, device=None, operations=None): + super().__init__() + self.content_proj = operations.Linear(dim * 2, dim, dtype=dtype, device=device) + self.log_alpha = nn.Parameter(torch.empty((), dtype=dtype, device=device)) + + def forward(self, x: torch.Tensor, lq: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + content_logit = self.content_proj(torch.cat([x, lq], dim=-1)) + # log_alpha is a raw nn.Parameter -> doesn't auto-cast under dynamic VRAM. + log_alpha = self.log_alpha.to(device=x.device, dtype=torch.float32) + sigma_offset = -log_alpha.exp() * sigma.float().view(-1, 1, 1) + gate = torch.sigmoid(content_logit + sigma_offset) + return x + (gate * lq).to(x.dtype) + + +class ResBlock(nn.Module): + """Pre-activation ResNet block: GN -> SiLU -> Conv -> GN -> SiLU -> Conv + skip.""" + + def __init__(self, channels: int, num_groups: int = 4, dtype=None, device=None, operations=None): + super().__init__() + self.block = nn.Sequential( + operations.GroupNorm(num_groups, channels, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device), + operations.GroupNorm(num_groups, channels, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.block(x) + + +class LQProjection2D(nn.Module): + """LQ latent -> per-block patch-aligned features for controlnet-style injection.""" + + def __init__( + self, + latent_channels: int, + hidden_dim: int = 512, + out_dim: int = 1536, + patch_size: int = 16, + sr_scale: int = 4, + latent_spatial_down_factor: int = 8, + num_res_blocks: int = 4, + num_outputs: int = 7, + interval: int = 2, + dtype=None, device=None, operations=None, + ): + super().__init__() + self.latent_channels = latent_channels + self.hidden_dim = hidden_dim + self.out_dim = out_dim + self.patch_size = patch_size + self.sr_scale = sr_scale + self.latent_spatial_down_factor = latent_spatial_down_factor + self.num_outputs = num_outputs + self.interval = interval + + z_to_patch_ratio = (sr_scale * latent_spatial_down_factor) / patch_size + self.z_to_patch_ratio = z_to_patch_ratio + if z_to_patch_ratio >= 1: + self.latent_fold_factor = 0 + latent_proj_in_ch = latent_channels + else: + fold_factor = int(1 / z_to_patch_ratio) + assert fold_factor * z_to_patch_ratio == 1.0 + self.latent_fold_factor = fold_factor + latent_proj_in_ch = latent_channels * fold_factor * fold_factor + + layers = [ + operations.Conv2d(latent_proj_in_ch, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device), + nn.SiLU(), + operations.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device), + ] + for _ in range(num_res_blocks): + layers.append(ResBlock(hidden_dim, dtype=dtype, device=device, operations=operations)) + self.latent_proj = nn.Sequential(*layers) + + self.output_heads = nn.ModuleList( + [operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) for _ in range(num_outputs)] + ) + self.gate_modules = nn.ModuleList( + [SigmaAwareGatePerTokenPerDim(out_dim, dtype=dtype, device=device, operations=operations) + for _ in range(num_outputs)] + ) + + def is_gate_active(self, block_idx: int) -> bool: + return block_idx % self.interval == 0 + + def output_index(self, block_idx: int) -> int: + return block_idx // self.interval + + def gate(self, x: torch.Tensor, lq_feature: torch.Tensor, sigma: torch.Tensor, out_idx: int) -> torch.Tensor: + return self.gate_modules[out_idx](x, lq_feature, sigma) + + def _align_latent_to_patch_grid(self, lq_latent: torch.Tensor, pH: int, pW: int) -> torch.Tensor: + B, z_dim = lq_latent.shape[:2] + if self.z_to_patch_ratio >= 1: + if lq_latent.shape[2] != pH or lq_latent.shape[3] != pW: + z_aligned = F.interpolate(lq_latent, size=(pH, pW), mode="nearest") + else: + z_aligned = lq_latent + else: + f = self.latent_fold_factor + zH_expected, zW_expected = pH * f, pW * f + if lq_latent.shape[2] != zH_expected or lq_latent.shape[3] != zW_expected: + lq_latent = F.interpolate(lq_latent, size=(zH_expected, zW_expected), mode="nearest") + z_aligned = lq_latent.reshape(B, z_dim, pH, f, pW, f).permute(0, 1, 3, 5, 2, 4) + z_aligned = z_aligned.reshape(B, z_dim * f * f, pH, pW) + return self.latent_proj(z_aligned) + + def forward(self, lq_latent: torch.Tensor, target_pH: int, target_pW: int) -> List[torch.Tensor]: + feat = self._align_latent_to_patch_grid(lq_latent, target_pH, target_pW) + B, C, H, W = feat.shape + tokens = feat.permute(0, 2, 3, 1).contiguous().view(B, H * W, C) + return [head(tokens) for head in self.output_heads] + + +class PidNet(PixDiT_T2I): + """PixDiT_T2I + LQ injection (one sigma-gated feature inserted before each patch block).""" + + def __init__( + self, + lq_latent_channels: int = 16, + lq_hidden_dim: int = 512, + lq_num_res_blocks: int = 4, + lq_interval: int = 2, + sr_scale: int = 4, + latent_spatial_down_factor: int = 8, + rope_ref_h: int = 1024, # NTK ref resolution in PIXEL units: 1024px / patch=16 -> grid_ref=64. + rope_ref_w: int = 1024, + image_model=None, + dtype=None, device=None, operations=None, + **pixdit_kwargs, + ): + super().__init__(dtype=dtype, device=device, operations=operations, **pixdit_kwargs) + + self.rope_ref_grid_h = rope_ref_h // self.patch_size + self.rope_ref_grid_w = rope_ref_w // self.patch_size + + # Parent's PiTBlocks were built with plain RoPE — swap in NTK-aware. + def _pit_rope_fn(head_dim, h, w, device=None, dtype=torch.float32, **rope_opts): + return precompute_freqs_cis_2d(head_dim, h, w, ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w, device=device, dtype=dtype, **rope_opts) + for blk in self.pixel_blocks: + blk._rope_fn = _pit_rope_fn + + num_lq_outputs = (self.patch_depth + lq_interval - 1) // lq_interval + self.lq_proj = LQProjection2D( + latent_channels=lq_latent_channels, + hidden_dim=lq_hidden_dim, + out_dim=self.hidden_size, + patch_size=self.patch_size, + sr_scale=sr_scale, + latent_spatial_down_factor=latent_spatial_down_factor, + num_res_blocks=lq_num_res_blocks, + num_outputs=num_lq_outputs, + interval=lq_interval, + dtype=dtype, + device=device, + operations=operations, + ) + + def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts): + return precompute_freqs_cis_2d( + self.hidden_size // self.num_groups, + height, width, + ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w, + device=device, dtype=dtype, **rope_opts, + ) + + def _pre_patch_block(self, s, i, pid_lq_features, pid_degrade_sigma, **kwargs): + if not self.lq_proj.is_gate_active(i): + return s + out_idx = self.lq_proj.output_index(i) + if out_idx >= len(pid_lq_features): + return s + return self.lq_proj.gate(s, pid_lq_features[out_idx], pid_degrade_sigma, out_idx) + + def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, lq_latent=None, degrade_sigma=None, **kwargs): + if lq_latent is None: + raise ValueError("PidNet requires lq_latent — attach via PiDConditioning") + expected_c = self.lq_proj.latent_channels + if lq_latent.shape[1] != expected_c: + raise ValueError( + f"Input latent has {lq_latent.shape[1]} channels, this model variant expects {expected_c}. " + f"Flux1/SD3 = 16 channels, Flux2 = 128 channels." + ) + B = x.shape[0] + # Match the backbone's pad_to_patch_size (round up) so the LQ grid lines up with the patch stream. + Hs = -(-x.shape[2] // self.patch_size) + Ws = -(-x.shape[3] // self.patch_size) + + degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1) + if degrade_sigma.numel() == 1 and B > 1: + degrade_sigma = degrade_sigma.expand(B).contiguous() + + lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws) + + return super()._forward( + x, timesteps, + context=context, attention_mask=attention_mask, + transformer_options=transformer_options, + pid_lq_features=lq_features, + pid_degrade_sigma=degrade_sigma, + **kwargs, + ) diff --git a/comfy/model_base.py b/comfy/model_base.py index 526ea9b48..032eec96a 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -35,6 +35,7 @@ import comfy.ldm.hydit.models import comfy.ldm.audio.dit import comfy.ldm.audio.embedders import comfy.ldm.flux.model +import comfy.ldm.lens.model import comfy.ldm.lightricks.model import comfy.ldm.hunyuan_video.model import comfy.ldm.cosmos.model @@ -48,6 +49,8 @@ import comfy.ldm.hunyuan3d.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model import comfy.ldm.chroma_radiance.model +import comfy.ldm.pixeldit.model +import comfy.ldm.pixeldit.pid import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.seedvr.model @@ -1070,6 +1073,27 @@ class Flux2(Flux): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out + +class Lens(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__( + model_config, model_type, device=device, + unet_model=comfy.ldm.lens.model.LensTransformer2DModel, + ) + + def encode_adm(self, **kwargs): + return None # Lens has no pooled/ADM conditioning. + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + return out + class GenmoMochi(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint) @@ -1387,6 +1411,53 @@ class ZImagePixelSpace(Lumina2): BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace) self.memory_usage_factor_conds = ("ref_latents",) + +class PixelDiTT2I(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, + unet_model=comfy.ldm.pixeldit.model.PixDiT_T2I) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out["attention_mask"] = comfy.conds.CONDRegular(attention_mask) + return out + + +class PiD(PixelDiTT2I): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + BaseModel.__init__(self, model_config, model_type, device=device, + unet_model=comfy.ldm.pixeldit.pid.PidNet) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + lq_latent = kwargs.get("lq_latent", None) + if lq_latent is not None: + out["lq_latent"] = comfy.conds.CONDRegular(lq_latent) + degrade_sigma = kwargs.get("degrade_sigma", None) + if degrade_sigma is not None: + out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma) + return out + + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "lq_latent" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + lq = cond_value.cond + dim = window.dim + if dim >= lq.ndim: + return None + lq_proj = self.diffusion_model.lq_proj + ratio = lq_proj.sr_scale * lq_proj.latent_spatial_down_factor + # Map x window indices -> lq indices (deduplicated, sorted, in-bounds). + lq_size = lq.size(dim) + lq_indices = sorted({i // ratio for i in window.index_list if 0 <= i // ratio < lq_size}) + if not lq_indices: + return None + idx = tuple([slice(None)] * dim + [lq_indices]) + return cond_value._copy_with(lq[idx].to(device)) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + + class WAN21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index cb94102b0..0d221b2eb 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -463,6 +463,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["extra_per_block_abs_pos_emb_type"] = "learnable" return dit_config + # PiD (Pixel Diffusion Decoder). Must check BEFORE plain PixelDiT_T2I. + _lq_w_key = '{}lq_proj.latent_proj.0.weight'.format(key_prefix) + if _lq_w_key in state_dict_keys: + in_ch = int(state_dict[_lq_w_key].shape[1]) + _gate_prefix = '{}lq_proj.gate_modules.'.format(key_prefix) + num_gates = len({k[len(_gate_prefix):].split('.')[0] + for k in state_dict_keys if k.startswith(_gate_prefix)}) + dit_config = {"image_model": "pid", + "lq_latent_channels": in_ch, + "latent_spatial_down_factor": 16 if in_ch >= 64 else 8} + if num_gates > 0: + dit_config["lq_interval"] = (14 + num_gates - 1) // num_gates + return dit_config + + if '{}core.pixel_embedder.proj.weight'.format(key_prefix) in state_dict_keys: # PixelDiT T2I + return {"image_model": "pixeldit_t2i"} + if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2 dit_config = {} dit_config["image_model"] = "lumina2" @@ -805,6 +822,30 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["timestep_scale"] = 1000.0 return dit_config + if '{}transformer_blocks.0.attn.norm_added_q.weight'.format(key_prefix) in state_dict_keys \ + and '{}transformer_blocks.0.img_mlp.w1.weight'.format(key_prefix) in state_dict_keys: # Lens + img_in_w = state_dict['{}img_in.weight'.format(key_prefix)] + proj_out_w = state_dict['{}proj_out.weight'.format(key_prefix)] + multi_layer = '{}txt_norm.0.weight'.format(key_prefix) in state_dict_keys + if multi_layer: + enc_hidden_dim = state_dict['{}txt_norm.0.weight'.format(key_prefix)].shape[0] + # Indices are TE-side; the DiT just consumes L layers in order. + selected_layer_index = tuple(range(count_blocks(state_dict_keys, '{}txt_norm.'.format(key_prefix) + '{}.'))) + else: + enc_hidden_dim = state_dict['{}txt_norm.weight'.format(key_prefix)].shape[0] + selected_layer_index = (0,) + + return { + "image_model": "lens", + "in_channels": img_in_w.shape[1], + "out_channels": proj_out_w.shape[0] // 4, # patch_size ** 2 (=2² default) + "num_layers": count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.'), + "num_attention_heads": img_in_w.shape[0] // 64, # // attention_head_dim default + "enc_hidden_dim": enc_hidden_dim, + "multi_layer_encoder_feature": multi_layer, + "selected_layer_index": selected_layer_index, + } + if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image dit_config = {} dit_config["image_model"] = "qwen_image" diff --git a/comfy/ops.py b/comfy/ops.py index 9bcd6c900..56445be8d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -18,6 +18,7 @@ import torch import logging +import contextlib import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float @@ -1047,6 +1048,144 @@ class QuantLinearFunc(torch.autograd.Function): return grad_input, grad_weight, grad_bias, None, None, None +# Quantized-weight module helpers + +def _quantized_apply(module, fn, recurse=True): + """Re-wrap Parameters after fn so .to()/.cuda() propagate through QuantizedTensor weights.""" + if recurse: + for child in module.children(): + child._apply(fn) + for key, param in module._parameters.items(): + if param is None: + continue + p = fn(param) + if (not torch.is_inference_mode_enabled()) and p.is_inference(): + p = p.clone() + module.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) + for key, buf in module._buffers.items(): + if buf is not None: + module._buffers[key] = fn(buf) + return module + + +def _load_quantized_module(module, super_load, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs, load_extra_params=False): + """Shared _load_from_state_dict body for quantized-weight modules. + + Pops weight (+ scales, +/- extras), populates module.weight as a Parameter + or Parameter-wrapped QuantizedTensor, then calls super_load and strips + consumed keys from missing_keys. Reads compute_dtype from factory_kwargs + and disabled formats from module._disabled_formats. + """ + device = module.factory_kwargs["device"] + compute_dtype = module.factory_kwargs["dtype"] + disabled_formats = module._disabled_formats + layer_name = prefix.rstrip('.') + + weight = state_dict.pop(f"{prefix}weight", None) + if weight is None: + logging.warning(f"Missing weight for layer {layer_name}") + module.weight = None + return + manually_loaded_keys = [f"{prefix}weight"] + + def pop_scale(name, dtype=None): + key = f"{prefix}{name}" + v = state_dict.pop(key, None) + if v is not None: + v = v.to(device=device) + if dtype is not None: + v = v.view(dtype=dtype) + manually_loaded_keys.append(key) + return v + + layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) + if layer_conf is not None: + layer_conf = json.loads(layer_conf.numpy().tobytes()) + + if layer_conf is None: + module.weight = torch.nn.Parameter(weight.to(device=device, dtype=compute_dtype), requires_grad=False) + else: + module.quant_format = layer_conf.get("format", None) + module._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) + if not module._full_precision_mm: + module._full_precision_mm = module._full_precision_mm_config + if module.quant_format in disabled_formats: + module._full_precision_mm = True + if module.quant_format is None: + raise ValueError(f"Unknown quantization format for layer {layer_name}") + + qconfig = QUANT_ALGOS[module.quant_format] + module.layout_type = qconfig["comfy_tensor_layout"] + layout_cls = get_layout_class(module.layout_type) + + # Per-format scales; fp8 dtype views handle both legacy uint8-on-disk and native fp8. + if module.quant_format in ("float8_e4m3fn", "float8_e5m2"): + scales = {"scale": pop_scale("weight_scale")} + elif module.quant_format == "mxfp8": + bs = pop_scale("weight_scale", torch.float8_e8m0fnu) + if bs is None: + raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}") + scales = {"scale": bs} + elif module.quant_format == "nvfp4": + ts = pop_scale("weight_scale_2") + bs = pop_scale("weight_scale", torch.float8_e4m3fn) + if ts is None or bs is None: + raise ValueError(f"Missing NVFP4 scales for layer {layer_name}") + scales = {"scale": ts, "block_scale": bs} + else: + raise ValueError(f"Unsupported quantization format: {module.quant_format}") + + params = layout_cls.Params(**scales, orig_dtype=compute_dtype, orig_shape=module._orig_shape) + module.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), module.layout_type, params), + requires_grad=False, + ) + + if load_extra_params: + for param_name in qconfig["parameters"]: + if param_name in {"weight_scale", "weight_scale_2"}: + continue + param_key = f"{prefix}{param_name}" + _v = state_dict.pop(param_key, None) + if _v is None: + continue + module.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + manually_loaded_keys.append(param_key) + + super_load(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + for key in manually_loaded_keys: + if key in missing_keys: + missing_keys.remove(key) + + +def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extra_quant_params=()): + """Shared state_dict body. extra_quant_conf merges into the comfy_quant JSON; + extra_quant_params names attributes written as additional top-level keys.""" + if not hasattr(module, 'weight'): + logging.warning(f"Warning: state dict on uninitialized op {prefix}") + return sd + bias = getattr(module, 'bias', None) + if bias is not None: + sd[f"{prefix}bias"] = bias + if module.weight is None: + return sd + if isinstance(module.weight, QuantizedTensor): + sd.update(module.weight.state_dict(f"{prefix}weight")) + quant_conf = {"format": module.quant_format} + if getattr(module, '_full_precision_mm_config', False): + quant_conf["full_precision_matrix_mult"] = True + if extra_quant_conf: + quant_conf.update(extra_quant_conf) + sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8) + for name in extra_quant_params: + value = getattr(module, name, None) + if value is not None: + sd[f"{prefix}{name}"] = value + else: + sd[f"{prefix}weight"] = module.weight + return sd + def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): class MixedPrecisionOps(manual_cast): @@ -1056,21 +1195,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec _disabled = disabled class Linear(torch.nn.Module, CastWeightBiasOp): - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = True, - device=None, - dtype=None, - ) -> None: + _disabled_formats = disabled + + def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): super().__init__() self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} - # self.factory_kwargs = {"device": device, "dtype": dtype} self.in_features = in_features self.out_features = out_features + self._orig_shape = (out_features, in_features) if bias: self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) else: @@ -1083,151 +1217,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def reset_parameters(self): return None - def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None): - key = f"{prefix}{param_name}" - value = state_dict.pop(key, None) - if value is not None: - value = value.to(device=device) - if dtype is not None: - value = value.view(dtype=dtype) - manually_loaded_keys.append(key) - return value - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs): - - device = self.factory_kwargs["device"] - layer_name = prefix.rstrip('.') - weight_key = f"{prefix}weight" - weight = state_dict.pop(weight_key, None) - if weight is None: - logging.warning(f"Missing weight for layer {layer_name}") - self.weight = None - return - - manually_loaded_keys = [weight_key] - - layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) - if layer_conf is not None: - layer_conf = json.loads(layer_conf.numpy().tobytes()) - - if layer_conf is None: - self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) - else: - self.quant_format = layer_conf.get("format", None) - self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False) - if not self._full_precision_mm: - self._full_precision_mm = self._full_precision_mm_config - - if self.quant_format in MixedPrecisionOps._disabled: - self._full_precision_mm = True - - if self.quant_format is None: - raise ValueError(f"Unknown quantization format for layer {layer_name}") - - qconfig = QUANT_ALGOS[self.quant_format] - self.layout_type = qconfig["comfy_tensor_layout"] - layout_cls = get_layout_class(self.layout_type) - - # Load format-specific parameters - if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]: - # FP8: single tensor scale - scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys) - - params = layout_cls.Params( - scale=scale, - orig_dtype=MixedPrecisionOps._compute_dtype, - orig_shape=(self.out_features, self.in_features), - ) - - elif self.quant_format == "mxfp8": - # MXFP8: E8M0 block scales stored as uint8 in safetensors - block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, - dtype=torch.uint8) - - if block_scale is None: - raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}") - - block_scale = block_scale.view(torch.float8_e8m0fnu) - - params = layout_cls.Params( - scale=block_scale, - orig_dtype=MixedPrecisionOps._compute_dtype, - orig_shape=(self.out_features, self.in_features), - ) - - elif self.quant_format == "nvfp4": - # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale) - tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) - block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, - dtype=torch.float8_e4m3fn) - - if tensor_scale is None or block_scale is None: - raise ValueError(f"Missing NVFP4 scales for layer {layer_name}") - - params = layout_cls.Params( - scale=tensor_scale, - block_scale=block_scale, - orig_dtype=MixedPrecisionOps._compute_dtype, - orig_shape=(self.out_features, self.in_features), - ) - else: - raise ValueError(f"Unsupported quantization format: {self.quant_format}") - - self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params), - requires_grad=False - ) - - for param_name in qconfig["parameters"]: - if param_name in {"weight_scale", "weight_scale_2"}: - continue # Already handled above - - param_key = f"{prefix}{param_name}" - _v = state_dict.pop(param_key, None) - if _v is None: - continue - self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) - manually_loaded_keys.append(param_key) - - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - - for key in manually_loaded_keys: - if key in missing_keys: - missing_keys.remove(key) + def _load_from_state_dict(self, *args): + _load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=True) def state_dict(self, *args, destination=None, prefix="", **kwargs): - if destination is not None: - sd = destination - else: - sd = {} - - if not hasattr(self, 'weight'): - logging.warning("Warning: state dict on uninitialized op {}".format(prefix)) - return sd - - if self.bias is not None: - sd["{}bias".format(prefix)] = self.bias - - if self.weight is None: - return sd - - if isinstance(self.weight, QuantizedTensor): - sd_out = self.weight.state_dict("{}weight".format(prefix)) - for k in sd_out: - sd[k] = sd_out[k] - - quant_conf = {"format": self.quant_format} - if self._full_precision_mm_config: - quant_conf["full_precision_matrix_mult"] = True - sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) - - input_scale = getattr(self, 'input_scale', None) - if input_scale is not None: - sd["{}input_scale".format(prefix)] = input_scale - else: - sd["{}weight".format(prefix)] = self.weight - return sd + sd = destination if destination is not None else {} + return _quantized_weight_state_dict(self, sd, prefix, extra_quant_params=("input_scale",)) def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) @@ -1317,25 +1312,126 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.weight = torch.nn.Parameter(weight, requires_grad=False) def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working - if recurse: - for module in self.children(): - module._apply(fn) + return _quantized_apply(self, fn, recurse) - for key, param in self._parameters.items(): - if param is None: - continue - p = fn(param) - if (not torch.is_inference_mode_enabled()) and p.is_inference(): - p = p.clone() - self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) - for key, buf in self._buffers.items(): - if buf is not None: - self._buffers[key] = fn(buf) - return self + class MoEExperts(torch.nn.Module, CastWeightBiasOp): + """Container for E quantized expert weights, indexed via expert_weight(i). + + The bank lives on self.weight as a single 3D tensor — either a + compute_dtype Parameter or a Parameter wrapping a QuantizedTensor + with leading expert dim. + + State-dict layout matches mixed_precision_ops.Linear with a leading + expert dim: + {prefix}.weight quant data (storage_t), leading dim = E + {prefix}.weight_scale block / per-tensor scale + {prefix}.weight_scale_2 [E] or scalar NVFP4 only + {prefix}.bias [E, out_features] optional, compute_dtype + {prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}} + + Without comfy_quant the weight loads as a plain compute_dtype 3D Parameter [E, out, in]. + """ + + _disabled_formats = disabled + + def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None): + super().__init__() + self.num_experts = num_experts + self.in_features = in_features + self.out_features = out_features + self._orig_shape = (num_experts, out_features, in_features) + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + if bias: + self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) + + # Populated by _load_from_state_dict: + self.weight = None + self.quant_format = None + self.layout_type = None + self._full_precision_mm = MixedPrecisionOps._full_precision_mm + self._full_precision_mm_config = False + self._resident_bank = None + + def reset_parameters(self): + return None + + def _apply(self, fn, recurse=True): + return _quantized_apply(self, fn, recurse) + + def _load_from_state_dict(self, *args): + _load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=False) + + def expert_weight(self, i: int): + """Expert i's weight (Tensor or per-expert QuantizedTensor view).""" + if isinstance(self.weight, QuantizedTensor): + return self._expert_qt_from(self.weight, i) + return self.weight[i] + + @contextlib.contextmanager + def bank_resident(self, input): + """Cast the whole bank once; expert_linear inside reuses the cast. + Not re-entrant — do not nest calls on the same instance. + """ + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + self._resident_bank = (weight, bias) + try: + yield self + finally: + self._resident_bank = None + uncast_bias_weight(self, weight, bias, offload_stream) + + def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor: + """Linear against expert i's weight (with optional bias).""" + resident = getattr(self, "_resident_bank", None) + if resident is not None: + weight, bias = resident + return self._expert_linear_impl(input, weight, bias, i) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + try: + return self._expert_linear_impl(input, weight, bias, i) + finally: + uncast_bias_weight(self, weight, bias, offload_stream) + + def _expert_linear_impl(self, input, weight, bias, i): + if isinstance(weight, QuantizedTensor): + qw = self._expert_qt_from(weight, i) + else: + qw = weight[i] + b = cast_to_input(bias[i], input, copy=False) if bias is not None else None + + if isinstance(qw, QuantizedTensor): + use_fast = ( + not self._full_precision_mm + and qw.layout_cls.supports_fast_matmul() + and input.dim() == 2 + ) + if use_fast: + qin = QuantizedTensor.from_float(input, self.layout_type) + return torch.nn.functional.linear(qin, qw, b) + out = input @ qw.dequantize().t() + return out + b if b is not None else out + return torch.nn.functional.linear(input, qw, b) + + def _expert_qt_from(self, weight: QuantizedTensor, i: int) -> QuantizedTensor: + """Build a per-expert QuantizedTensor by indexing into a resident bank.""" + params = weight._params + kwargs = { + "scale": params.scale[i] if params.scale.dim() else params.scale, + "orig_dtype": params.orig_dtype, + "orig_shape": (self.out_features, self.in_features), + } + if hasattr(params, "block_scale"): # NVFP4 + kwargs["block_scale"] = params.block_scale[i] + return QuantizedTensor(weight._qdata[i], weight._layout_cls, type(params)(**kwargs)) + + def state_dict(self, *args, destination=None, prefix="", **kwargs): + sd = destination if destination is not None else {} + return _quantized_weight_state_dict(self, sd, prefix, extra_quant_conf={"num_experts": self.num_experts}) class Embedding(manual_cast.Embedding): - def _load_from_state_dict(self, state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, error_msgs): + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): weight_key = f"{prefix}weight" layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) if layer_conf is not None: @@ -1343,14 +1439,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec # Only fp8 makes sense for embeddings (per-row dequant via index select). # Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently. - quant_format = layer_conf.get("format", None) if layer_conf is not None else None - if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict: + quant_format = layer_conf.get("format") if layer_conf is not None else None + manually_loaded_keys = [] + + if quant_format in ("float8_e4m3fn", "float8_e5m2") and weight_key in state_dict: self.quant_format = quant_format qconfig = QUANT_ALGOS[quant_format] self.layout_type = qconfig["comfy_tensor_layout"] layout_cls = get_layout_class(self.layout_type) weight = state_dict.pop(weight_key) - manually_loaded_keys = [weight_key] + manually_loaded_keys.append(weight_key) scale_key = f"{prefix}weight_scale" scale = state_dict.pop(scale_key, None) @@ -1366,35 +1464,19 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self.weight = torch.nn.Parameter( QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params), requires_grad=False) + elif layer_conf is not None: + # Unsupported format — restore the marker so it round-trips; fall through to default load. + state_dict[f"{prefix}comfy_quant"] = torch.tensor( + list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8) - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - for k in manually_loaded_keys: - if k in missing_keys: - missing_keys.remove(k) - else: - if layer_conf is not None: - state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8) - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + for k in manually_loaded_keys: + if k in missing_keys: + missing_keys.remove(k) def state_dict(self, *args, destination=None, prefix="", **kwargs): - if destination is not None: - sd = destination - else: - sd = {} - - if not hasattr(self, 'weight') or self.weight is None: - return sd - - if isinstance(self.weight, QuantizedTensor): - sd_out = self.weight.state_dict("{}weight".format(prefix)) - for k in sd_out: - sd[k] = sd_out[k] - - quant_conf = {"format": self.quant_format} - sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) - else: - sd["{}weight".format(prefix)] = self.weight - return sd + sd = destination if destination is not None else {} + return _quantized_weight_state_dict(self, sd, prefix) def forward_comfy_cast_weights(self, input, out_dtype=None): weight = self.weight diff --git a/comfy/sd.py b/comfy/sd.py index b4e487308..c4be183ac 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -51,6 +51,7 @@ import comfy.text_encoders.lt import comfy.text_encoders.hunyuan_video import comfy.text_encoders.cosmos import comfy.text_encoders.lumina2 +import comfy.text_encoders.pixeldit import comfy.text_encoders.wan import comfy.text_encoders.hidream import comfy.text_encoders.ace @@ -70,6 +71,7 @@ import comfy.text_encoders.ernie import comfy.text_encoders.gemma4 import comfy.text_encoders.cogvideo import comfy.text_encoders.sa3 +import comfy.text_encoders.gpt_oss import comfy.model_patcher import comfy.lora @@ -1463,6 +1465,8 @@ class CLIPType(Enum): FLUX2 = 25 LONGCAT_IMAGE = 26 COGVIDEOX = 27 + LENS = 28 + PIXELDIT = 29 @@ -1515,6 +1519,7 @@ class TEModel(Enum): GEMMA_4_E2B = 30 GEMMA_4_31B = 31 T5_GEMMA = 32 + GPT_OSS_20B = 33 def detect_te_model(sd): @@ -1556,6 +1561,9 @@ def detect_te_model(sd): else: return TEModel.GEMMA_3_4B return TEModel.GEMMA_2_2B + # Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512). + if "layers.0.self_attn.sinks" in sd and "layers.0.mlp.experts.gate_up_proj.weight" in sd: + return TEModel.GPT_OSS_20B if 'model.layers.0.self_attn.k_proj.bias' in sd: weight = sd['model.layers.0.self_attn.k_proj.bias'] if weight.shape[0] == 256: @@ -1702,8 +1710,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.tokenizer = variant.tokenizer tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) elif te_model == TEModel.GEMMA_2_2B: - clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) - clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer + if clip_type == CLIPType.PIXELDIT: + clip_target.clip = comfy.text_encoders.pixeldit.pixeldit_te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer + else: + clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model == TEModel.GEMMA_3_4B: clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b") @@ -1738,6 +1750,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2) clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) + elif te_model == TEModel.GPT_OSS_20B: + clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer + tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None) elif te_model == TEModel.QWEN3_4B: if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2: clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b") diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 3fc993665..30134e6d1 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -30,6 +30,7 @@ import comfy.text_encoders.longcat_image import comfy.text_encoders.ernie import comfy.text_encoders.cogvideo import comfy.text_encoders.hidream_o1 +import comfy.text_encoders.pixeldit from . import supported_models_base from . import latent_formats @@ -829,6 +830,50 @@ class Flux2(Flux): return None + +class Lens(supported_models_base.BASE): + """Microsoft Lens (3.8B dual-stream MMDiT, GPT-OSS-20B text features, Flux2 VAE).""" + + unet_config = { + "image_model": "lens", + } + + sampling_settings = { + "shift": 1.829, # Default mu for 1440x1440 (and any seq_len > 4300 + } + + unet_extra_config = {} + latent_format = latent_formats.Flux2 + + memory_usage_factor = 4.0 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] # fp16 causes NaNs + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def __init__(self, unet_config): + super().__init__(unet_config) + + def get_model(self, state_dict, prefix="", device=None): + return model_base.Lens(self, model_type=model_base.ModelType.FLUX, device=device) + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + for hint in ("gpt_oss.transformer.", ""): + full_prefix = "{}{}".format(pref, hint) + if "{}layers.0.self_attn.sinks".format(full_prefix) in state_dict: + detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix) + return supported_models_base.ClipTarget( + comfy.text_encoders.gpt_oss.LensTokenizer, + comfy.text_encoders.gpt_oss.lens_te(**detect), + ) + return supported_models_base.ClipTarget( + comfy.text_encoders.gpt_oss.LensTokenizer, + comfy.text_encoders.gpt_oss.lens_te(), + ) + + class GenmoMochi(supported_models_base.BASE): unet_config = { "image_model": "mochi_preview", @@ -1159,6 +1204,72 @@ class ZImagePixelSpace(ZImage): def get_model(self, state_dict, prefix="", device=None): return model_base.ZImagePixelSpace(self, device=device) +class PixelDiTT2I(supported_models_base.BASE): + unet_config = { + "image_model": "pixeldit_t2i", + } + + unet_extra_config = {} + + sampling_settings = { + "shift": 4.0, # 1024px stage 3 default; 2.0 for 512px + } + + latent_format = latent_formats.PixelDiTPixel + memory_usage_factor = 0.04 + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + return model_base.PixelDiTT2I(self, device=device) + + def process_unet_state_dict(self, state_dict): + # pixel_dim from pixel_embedder.proj.weight = (pixel_dim, in_channels); p2 derived per-weight from total // (6 * pixel_dim). + pixel_dim = next(v for k, v in state_dict.items() if k.endswith("pixel_embedder.proj.weight")).shape[0] + + out = {} + marker = ".adaLN_modulation.0." + for k, v in state_dict.items(): + if k.startswith("_repa_projector") or k.startswith("net_ema."): + continue + if k.startswith("core."): + k = k[len("core."):] + elif k.startswith("net."): + k = k[len("net."):] + if "pixel_blocks." in k and marker in k: + # Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM + p2 = v.shape[0] // (6 * pixel_dim) + trail = v.shape[1:] # () for bias, (in_dim,) for weight + vv = v.view(p2, 6, pixel_dim, *trail) + base, suffix = k.split(marker) + out[f"{base}.adaLN_modulation_msa.{suffix}"] = vv[:, 0:3].reshape(3 * p2 * pixel_dim, *trail).contiguous() + out[f"{base}.adaLN_modulation_mlp.{suffix}"] = vv[:, 3:6].reshape(3 * p2 * pixel_dim, *trail).contiguous() + else: + out[k] = v + return out + + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget( + comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer, + comfy.text_encoders.pixeldit.PixelDiTGemma2TE, + ) + +class PiD(PixelDiTT2I): + unet_config = { + "image_model": "pid", + } + + sampling_settings = { + "shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0] + } + + memory_usage_factor = 0.04 + + def get_model(self, state_dict, prefix="", device=None): + return model_base.PiD(self, device=device) + class WAN21_T2V(supported_models_base.BASE): unet_config = { "image_model": "wan2.1", @@ -2097,6 +2208,8 @@ models = [ CosmosI2VPredict2, ZImagePixelSpace, ZImage, + PiD, + PixelDiTT2I, Lumina2, WAN22_T2V, WAN21_CausalAR_T2V, @@ -2125,6 +2238,7 @@ models = [ Omnigen2, QwenImage, Flux2, + Lens, Kandinsky5Image, Kandinsky5, Anima, diff --git a/comfy/text_encoders/gpt_oss.py b/comfy/text_encoders/gpt_oss.py new file mode 100644 index 000000000..d596ef9a0 --- /dev/null +++ b/comfy/text_encoders/gpt_oss.py @@ -0,0 +1,600 @@ +"""GPT-OSS text encoder for Lens.""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any, List, Optional, Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops +from comfy import sd1_clip +from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device +from comfy.text_encoders.llama import RMSNorm, apply_rope + + +@dataclass +class GptOss20BConfig: + vocab_size: int = 201088 + hidden_size: int = 2880 + intermediate_size: int = 2880 + num_hidden_layers: int = 24 + num_attention_heads: int = 64 + num_key_value_heads: int = 8 + head_dim: int = 64 + num_local_experts: int = 32 + num_experts_per_tok: int = 4 + sliding_window: int = 128 + original_max_position_embeddings: int = 4096 + rope_theta: float = 150000.0 + rope_factor: float = 32.0 + rope_beta_fast: float = 32.0 + rope_beta_slow: float = 1.0 + rope_truncate: bool = False + rms_norm_eps: float = 1e-5 + attention_bias: bool = True + layer_types: Optional[List[str]] = None + moe_alpha: float = 1.702 + moe_limit: float = 7.0 + + def __post_init__(self): + if self.layer_types is None: + self.layer_types = [ + "sliding_attention" if (i + 1) % 2 else "full_attention" + for i in range(self.num_hidden_layers) + ] + + +def _yarn_inv_freq(head_dim: int, base: float, factor: float, beta_fast: float, beta_slow: float, + original_max_position_embeddings: int, truncate: bool, device=None) -> tuple[torch.Tensor, float]: + """YARN inv_freq + attention scaling (matches transformers).""" + dim = head_dim + + def find_correction_dim(num_rotations: float) -> float: + return (dim * math.log(original_max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + def find_correction_range() -> tuple[float, float]: + low = find_correction_dim(beta_fast) + high = find_correction_dim(beta_slow) + if truncate: + low = math.floor(low) + high = math.ceil(high) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min_: float, max_: float, n: int) -> torch.Tensor: + if min_ == max_: + max_ += 0.001 + linear = (torch.arange(n, dtype=torch.float32, device=device) - min_) / (max_ - min_) + return torch.clamp(linear, 0, 1) + + def get_mscale(scale: float) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + attention_scaling = get_mscale(factor) + + pos_freqs = base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (factor * pos_freqs) + + low, high = find_correction_range() + extrap_factor = 1 - linear_ramp_factor(low, high, dim // 2) + inv_freq = inv_freq_interpolation * (1 - extrap_factor) + inv_freq_extrapolation * extrap_factor + return inv_freq, attention_scaling + + +def _build_freqs_cis(inv_freq: torch.Tensor, attention_scaling: float, position_ids: torch.Tensor, dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + inv_freq_e = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + pos_e = position_ids[:, None, :].float() + freqs = (inv_freq_e @ pos_e).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = (emb.cos() * attention_scaling).to(dtype).unsqueeze(1) + sin = (emb.sin() * attention_scaling).to(dtype).unsqueeze(1) + sin_split = sin.shape[-1] // 2 + return cos, sin[..., :sin_split], -sin[..., sin_split:] + + +def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sinks: torch.Tensor, + attention_mask: Optional[torch.Tensor], num_heads: int, num_kv_groups: int) -> torch.Tensor: + """Attention with per-head sinks. + + Sinks add a learned term to each row's softmax denominator but contribute + nothing to the output. We fake this by appending one zero k/v position and + putting the sink logit in the mask at that column. + """ + + if num_kv_groups > 1 and not TORCH_HAS_GQA: + k = k.repeat_interleave(num_kv_groups, dim=1) + v = v.repeat_interleave(num_kv_groups, dim=1) + + B, _, S_q, D = q.shape + H_kv = k.shape[1] + S_kv = k.shape[-2] + + k = torch.cat([k, k.new_zeros(B, H_kv, 1, D)], dim=-2) + v = torch.cat([v, v.new_zeros(B, H_kv, 1, D)], dim=-2) + + sinks_col = sinks.to(q.dtype).view(1, num_heads, 1, 1).expand(B, num_heads, S_q, 1) + if attention_mask is not None: + mask_left = attention_mask[..., :S_kv].expand(B, num_heads, S_q, S_kv) + else: + mask_left = q.new_zeros(B, num_heads, S_q, S_kv) + mask = torch.cat([mask_left, sinks_col], dim=-1) + + op = optimized_attention_for_device(q.device, mask=True, small_input=True) + return op(q, k, v, num_heads, mask=mask, skip_reshape=True, enable_gqa=True) + + +class GptOssAttention(nn.Module): + def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None): + super().__init__() + self.layer_idx = layer_idx + self.layer_type = config.layer_types[layer_idx] + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.num_kv_groups = self.num_heads // self.num_kv_heads + self.head_dim = config.head_dim + self.hidden_size = config.hidden_size + self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None + + bias = config.attention_bias + self.q_proj = ops.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=bias, device=device, dtype=dtype) + self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype) + self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype) + self.o_proj = ops.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=bias, device=device, dtype=dtype) + self.sinks = nn.Parameter(torch.empty(self.num_heads, device=device, dtype=dtype)) + + def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], freqs_cis) -> torch.Tensor: + B, S, _ = hidden_states.shape + + q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) + k = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) + v = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2) + + q, k = apply_rope(q, k, freqs_cis) + + out = _attention_with_sinks(q, k, v, self.sinks, attention_mask, self.num_heads, self.num_kv_groups) + return self.o_proj(out) + + +# Mixture of Experts + +class GptOssTopKRouter(nn.Module): + def __init__(self, config: GptOss20BConfig, device=None, dtype=None): + super().__init__() + self.top_k = config.num_experts_per_tok + self.num_experts = config.num_local_experts + self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype)) + self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype)) + + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + weight = comfy.ops.cast_to_input(self.weight, hidden_states, copy=False) + bias = comfy.ops.cast_to_input(self.bias, hidden_states, copy=False) + logits = F.linear(hidden_states, weight, bias) + top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1) + # Softmax over top-k slice only + scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype) + return scores, top_idx + + +class GptOssExperts(nn.Module): + def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None): + super().__init__() + self.num_experts = config.num_local_experts + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.alpha = config.moe_alpha + self.limit = config.moe_limit + + E = self.num_experts + H = self.hidden_size + I = self.intermediate_size + + self.gate_up_proj = ops.MoEExperts(num_experts=E, in_features=H, out_features=2 * I, bias=True, device=device, dtype=dtype) + self.down_proj = ops.MoEExperts(num_experts=E, in_features=I, out_features=H, bias=True, device=device, dtype=dtype) + + def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor: + gate = gate_up[..., ::2] + up = gate_up[..., 1::2] + gate = gate.clamp(max=self.limit) + up = up.clamp(min=-self.limit, max=self.limit) + glu = gate * torch.sigmoid(gate * self.alpha) + return torch.addcmul(glu, up, glu) + + def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor: + N = hidden_states.shape[0] + top_k = router_indices.shape[-1] + H = hidden_states.shape[-1] + + per_pair = torch.zeros((N * top_k, H), dtype=hidden_states.dtype, device=hidden_states.device) + + expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + with self.gate_up_proj.bank_resident(hidden_states) as gate_up_bank, \ + self.down_proj.bank_resident(hidden_states) as down_bank: + for ei in expert_hit: + expert_idx = int(ei.item()) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current = hidden_states[token_idx] + + gate_up = gate_up_bank.expert_linear(current, expert_idx) + gated = self._apply_gate(gate_up) + expert_out = down_bank.expert_linear(gated, expert_idx) + + weighted = expert_out * routing_weights[token_idx, top_k_pos, None] + + flat_idx = token_idx * top_k + top_k_pos + per_pair[flat_idx] = weighted.to(per_pair.dtype) + + return per_pair.view(N, top_k, H).sum(dim=1) + + +class GptOssMLP(nn.Module): + def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None): + super().__init__() + self.router = GptOssTopKRouter(config, device=device, dtype=dtype) + self.experts = GptOssExperts(config, device=device, dtype=dtype, ops=ops) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + B, S, H = hidden_states.shape + flat = hidden_states.reshape(-1, H) + scores, idx = self.router(flat) + out = self.experts(flat, idx, scores) + return out.reshape(B, S, H) + + +# Decoder layer + model + +class GptOssDecoderLayer(nn.Module): + def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None): + super().__init__() + self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops) + self.mlp = GptOssMLP(config, device=device, dtype=dtype, ops=ops) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + self.layer_type = config.layer_types[layer_idx] + + def forward(self, x: torch.Tensor, attention_masks: dict[str, Optional[torch.Tensor]], freqs_cis) -> torch.Tensor: + residual = x + x = self.input_layernorm(x) + x = self.self_attn(x, attention_masks[self.layer_type], freqs_cis) + x = residual + x + + residual = x + x = self.post_attention_layernorm(x) + x = self.mlp(x) + x = residual + x + return x + + +def _make_full_causal_mask(B: int, S: int, key_padding_mask: Optional[torch.Tensor], dtype, device): + neg = torch.finfo(dtype).min + mask = torch.full((S, S), neg, dtype=dtype, device=device).triu_(1) + mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous() + if key_padding_mask is not None: + kp = key_padding_mask.to(dtype=dtype) + kp = (1.0 - kp).reshape(B, 1, 1, S) * neg + mask = mask + kp + return mask + + +def _make_sliding_causal_mask(B: int, S: int, window: int, key_padding_mask: Optional[torch.Tensor], dtype, device): + neg = torch.finfo(dtype).min + i = torch.arange(S, device=device).view(-1, 1) + j = torch.arange(S, device=device).view(1, -1) + keep = (j <= i) & (j > i - window) + mask = torch.where(keep, torch.zeros((), dtype=dtype, device=device), torch.full((), neg, dtype=dtype, device=device)) + mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous() + if key_padding_mask is not None: + kp = key_padding_mask.to(dtype=dtype) + kp = (1.0 - kp).reshape(B, 1, 1, S) * neg + mask = mask + kp + return mask + + +class GptOssModel(nn.Module): + """GPT-OSS decoder with multi-layer hidden-state capture + early exit.""" + + def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None): + super().__init__() + self.config = config + self.dtype = dtype + self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) + self.layers = nn.ModuleList( + [ + GptOssDecoderLayer(config, i, device=device, dtype=dtype, ops=ops) + for i in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype) + + # Always build on CPU so the buffer survives meta-device construction. + inv_freq, attn_scaling = _yarn_inv_freq( + head_dim=config.head_dim, + base=config.rope_theta, + factor=config.rope_factor, + beta_fast=config.rope_beta_fast, + beta_slow=config.rope_beta_slow, + original_max_position_embeddings=config.original_max_position_embeddings, + truncate=config.rope_truncate, + device=torch.device("cpu"), + ) + self.register_buffer("rope_inv_freq", inv_freq, persistent=False) + self.rope_attention_scaling = float(attn_scaling) + + @property + def num_layers(self) -> int: + return self.config.num_hidden_layers + + def get_input_embeddings(self): + return self.embed_tokens + + def _build_attention_masks(self, B: int, S: int, attention_mask: Optional[torch.Tensor], dtype: torch.dtype, device, + ) -> dict[str, torch.Tensor]: + full = _make_full_causal_mask(B, S, attention_mask, dtype, device) + masks = {"full_attention": full} + if any(t == "sliding_attention" for t in self.config.layer_types): + masks["sliding_attention"] = _make_sliding_causal_mask( + B, S, self.config.sliding_window, attention_mask, dtype, device + ) + return masks + + def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, + capture_layers: Optional[Sequence[int]] = None) -> dict[str, Any]: + B, S = input_ids.shape + device = input_ids.device + dtype = self.dtype + + hidden_states = self.embed_tokens(input_ids, out_dtype=dtype) + + position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1) + freqs_cis = _build_freqs_cis(self.rope_inv_freq.to(device=device), self.rope_attention_scaling, position_ids, dtype) + + attn_masks = self._build_attention_masks(B, S, attention_mask, dtype, device) + + capture_layers = list(capture_layers) if capture_layers else None + if capture_layers: + max_layer = max(capture_layers) + wanted = {idx: pos for pos, idx in enumerate(capture_layers)} + captured: List[Optional[torch.Tensor]] = [None] * len(capture_layers) + else: + max_layer = self.config.num_hidden_layers - 1 + wanted = None + captured = None + + for i, layer in enumerate(self.layers): + hidden_states = layer(hidden_states, attn_masks, freqs_cis) + if wanted is not None and i in wanted: + captured[wanted[i]] = hidden_states + if i >= max_layer: + break + + if captured is not None: + return {"hidden_states": captured} + return {"last_hidden_state": self.norm(hidden_states)} + + +# Lens chat-template constants (verbatim from the reference pipeline). +_LENS_CHAT_SYSTEM = ( + "Describe the image by detailing the color, shape, size, texture, " + "quantity, text, spatial relationships of the objects and background." +) +_LENS_CHAT_ASSISTANT_THINKING = "Need to generate one image according to the description." +LENS_TXT_OFFSET = 97 +LENS_SELECTED_LAYERS = (5, 11, 17, 23) +LENS_MAX_TOKENS = 512 + + +# The reference GPT-OSS Harmony template injects today's date here +_LENS_CHAT_DATE = "2026-05-23" + + +def _lens_render_chat(prompt: str) -> str: + """Render the Lens prompt in GPT-OSS Harmony format.""" + return ( + f"<|start|>system<|message|>" + f"You are ChatGPT, a large language model trained by OpenAI.\n" + f"Knowledge cutoff: 2024-06\n" + f"Current date: {_LENS_CHAT_DATE}\n\n" + f"Reasoning: medium\n\n" + f"# Valid channels: analysis, commentary, final. " + f"Channel must be included for every message.<|end|>" + f"<|start|>developer<|message|># Instructions\n\n" + f"{_LENS_CHAT_SYSTEM}\n\n<|end|>" + f"<|start|>user<|message|>{prompt}<|end|>" + f"<|start|>assistant<|channel|>analysis<|message|>" + f"{_LENS_CHAT_ASSISTANT_THINKING}<|end|>" + f"<|start|>assistant<|channel|>final<|message|>" + ) + + +# GPT-OSS-20B fixed token IDs (from the tokenizer's added-tokens table). +_LENS_PAD_TOKEN_ID = 199999 # <|endoftext|> + + +class _GptOssRawTokenizer: + """Raw ``tokenizers.Tokenizer`` wrapper. + + The tokenizer JSON ships as a byte tensor inside the encoder checkpoint + (``tokenizer_json`` key) rather than as a committed file. Extracted + it in ``sd.py`` and passes it here via ``tokenizer_data``. + """ + + def __init__(self, tokenizer_json_bytes=None, **kwargs): + from tokenizers import Tokenizer + if isinstance(tokenizer_json_bytes, torch.Tensor): + tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist()) + if tokenizer_json_bytes is None: + raise ValueError( + "Lens tokenizer requires the ``tokenizer_json`` byte tensor in the " + "encoder state dict. Re-bundle the encoder via bundle_te.py so it " + "embeds the tokenizer." + ) + self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8")) + + @classmethod + def from_pretrained(cls, tokenizer_data, **kwargs): + return cls(tokenizer_json_bytes=tokenizer_data, **kwargs) + + def __call__(self, text): + return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids} + + def get_vocab(self): + return self.tokenizer.get_vocab() + + def convert_tokens_to_ids(self, tokens): + return [self.tokenizer.token_to_id(t) for t in tokens] + + def decode(self, ids, **kwargs): + return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False)) + + +class LensGptOssTokenizer(sd1_clip.SDTokenizer): + tokenizer_json_data = None + + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_json = tokenizer_data.get("tokenizer_json", None) + self.tokenizer_json_data = tokenizer_json + super().__init__( + tokenizer_json, + embedding_directory=embedding_directory, + pad_with_end=False, + embedding_size=2880, + embedding_key="gpt_oss", + tokenizer_class=_GptOssRawTokenizer, + has_start_token=False, + has_end_token=False, + pad_to_max_length=False, + max_length=99999999, + min_length=1, + pad_left=False, + disable_weights=True, + tokenizer_data=tokenizer_data, + ) + self.pad_token_id = _LENS_PAD_TOKEN_ID + + def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs): + # Empty prompt -> empty list; encode_token_weights returns zeros (uncond). + if not text or not text.strip(): + return [[]] + rendered = _lens_render_chat(text) + ids = self.tokenizer(rendered)["input_ids"] + if len(ids) > LENS_MAX_TOKENS: + ids = ids[:LENS_MAX_TOKENS] + return [[(int(t), 1.0) for t in ids]] + + def state_dict(self): + if self.tokenizer_json_data is not None: + return {"tokenizer_json": self.tokenizer_json_data} + return {} + + +class LensTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__( + embedding_directory=embedding_directory, + tokenizer_data=tokenizer_data, + name="gpt_oss", + tokenizer=LensGptOssTokenizer, + ) + + +class LensGptOssClipModel(nn.Module): + """SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor).""" + + def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs): + super().__init__() + model_options = dict(model_options or {}) + + operations = model_options.get("custom_operations") + if operations is None: + quant_config = model_options.get("quantization_metadata") or {} + operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True) + self.operations = operations + + cfg_overrides = model_options.get("gpt_oss_config", {}) + self.config = GptOss20BConfig(**cfg_overrides) + self.selected_layers = tuple(model_options.get("selected_layers", LENS_SELECTED_LAYERS)) + self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET)) + + self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations) + self.num_layers = self.config.num_hidden_layers + self.dtype = dtype + self.execution_device = None + self._pad_token_id = _LENS_PAD_TOKEN_ID + + def set_clip_options(self, options): + self.execution_device = options.get("execution_device", self.execution_device) + + def reset_clip_options(self): + self.execution_device = None + + def _gather_tokens(self, token_weight_pairs): + ids_list = [[int(t[0]) for t in batch] for batch in token_weight_pairs] + pad_id = self._pad_token_id + max_len = max(len(x) for x in ids_list) + device = self.execution_device + ids = torch.full((len(ids_list), max_len), pad_id, dtype=torch.long, device=device) + mask = torch.zeros((len(ids_list), max_len), dtype=torch.long, device=device) + for i, x in enumerate(ids_list): + ids[i, : len(x)] = torch.tensor(x, dtype=torch.long, device=device) + mask[i, : len(x)] = 1 + return ids, mask + + def encode_token_weights(self, token_weight_pairs): + # Empty negative: emit zero-length features + zero mask + if all(len(batch) == 0 for batch in token_weight_pairs): + device = self.execution_device + B = len(token_weight_pairs) + L = len(self.selected_layers) + H = self.config.hidden_size + flat = torch.zeros(B, 0, L * H, dtype=self.dtype, device=device) + mask = torch.zeros(B, 0, dtype=torch.long, device=device) + return flat, None, {"attention_mask": mask, "num_layers_stacked": L} + + input_ids, attn_mask = self._gather_tokens(token_weight_pairs) + out = self.transformer(input_ids, attention_mask=attn_mask, capture_layers=self.selected_layers) + layers = out["hidden_states"] # list of L × [B, S, H] + stacked = torch.stack(layers, dim=2) # [B, S, L, H] + + offset = self.txt_offset + if stacked.shape[1] > offset: + stacked = stacked[:, offset:].contiguous() + mask_trim = attn_mask[:, offset:] + else: + stacked = stacked[:, :0] + mask_trim = attn_mask[:, :0] + + B, S, L, H = stacked.shape + flat = stacked.reshape(B, S, L * H) + extra = {"attention_mask": mask_trim, "num_layers_stacked": L} + return flat, None, extra + + def load_sd(self, sd): + return self.transformer.load_state_dict(sd, strict=False, assign=True) + + +class LensTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options=None): + super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {}) + + +def lens_te(dtype_llama=None, llama_quantization_metadata=None): + class LensTEModel_(LensTEModel): + def __init__(self, device="cpu", dtype=None, model_options=None): + mo = dict(model_options or {}) + if llama_quantization_metadata is not None: + mo["quantization_metadata"] = llama_quantization_metadata + if dtype is None and dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=mo) + + return LensTEModel_ diff --git a/comfy/text_encoders/pixeldit.py b/comfy/text_encoders/pixeldit.py new file mode 100644 index 000000000..3539711e4 --- /dev/null +++ b/comfy/text_encoders/pixeldit.py @@ -0,0 +1,104 @@ +import torch + +from comfy import sd1_clip +from .lumina2 import Gemma2BTokenizer, LuminaModel +import comfy.text_encoders.llama + + +class PixelDiTGemma2_2BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + + super().__init__( + device=device, layer=layer, layer_idx=layer_idx, + textmodel_json_config={}, dtype=dtype, + special_tokens={"start": 2, "pad": 0}, + layer_norm_hidden_state=False, + model_class=comfy.text_encoders.llama.Gemma2_2B, + enable_attention_masks=attention_mask, + return_attention_masks=attention_mask, + model_options=model_options, + ) + + +_PIXELDIT_CHI_PROMPT = ( + 'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions ' + "suitable for image generation. Evaluate the level of detail in the user prompt:\n" + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, " + "and spatial relationships to create vivid and concrete scenes.\n" + "- If the prompt is already detailed, refine and enhance the existing details slightly without " + "overcomplicating.\n" + "Here are examples of how to transform or refine prompts:\n" + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, " + "sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n" + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring " + "glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus " + "passing by towering glass skyscrapers.\n" + "Please generate only the enhanced description for the prompt below and avoid including any " + "additional commentary or evaluations:\n" + "User Prompt: " +) + +_PIXELDIT_MAX_LENGTH = 300 +_PIXELDIT_CHI_PROMPT_DETECT_PREFIX = 'Given a user prompt, generate an "Enhanced prompt"' + + +class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data=None): + if tokenizer_data is None: + tokenizer_data = {} + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, + name="gemma2_2b", tokenizer=Gemma2BTokenizer) + + def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): + if not text.strip(): + return super().tokenize_with_weights("", return_word_ids=return_word_ids, disable_weights=True, min_length=_PIXELDIT_MAX_LENGTH) + + chi_token_count = len(self.gemma2_2b.tokenizer(_PIXELDIT_CHI_PROMPT)["input_ids"]) + combined = text if text.startswith(_PIXELDIT_CHI_PROMPT_DETECT_PREFIX) else _PIXELDIT_CHI_PROMPT + text + max_length_all = chi_token_count + _PIXELDIT_MAX_LENGTH - 2 + out = super().tokenize_with_weights(combined, return_word_ids=return_word_ids, + disable_weights=True, min_length=max_length_all) + out["gemma2_2b"] = [out["gemma2_2b"][0][:max_length_all]] + return out + + def untokenize(self, token_weight_pair): + return self.gemma2_2b.untokenize(token_weight_pair) + + def state_dict(self): + return self.gemma2_2b.state_dict() + + +class PixelDiTGemma2TE(LuminaModel): + # PixelDiT's select_index: keep BOS + last 299 embeddings of the padded sequence. + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="gemma2_2b", + clip_model=PixelDiTGemma2_2BModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + result = super().encode_token_weights(token_weight_pairs) + cond, pooled = result[0], result[1] + extra = result[2] if len(result) > 2 else None + if cond.shape[1] > _PIXELDIT_MAX_LENGTH: + cond = torch.cat([cond[:, :1], cond[:, -(_PIXELDIT_MAX_LENGTH - 1):]], dim=1) + if extra is not None and "attention_mask" in extra: + am = extra["attention_mask"] + extra["attention_mask"] = torch.cat([am[..., :1], am[..., -(_PIXELDIT_MAX_LENGTH - 1):]], dim=-1) + if extra is not None: + return cond, pooled, extra + return cond, pooled + + +def pixeldit_te(dtype_llama=None, llama_quantization_metadata=None): + class PixelDiTTE_(PixelDiTGemma2TE): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return PixelDiTTE_ diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 5ed968960..fed8dc7f0 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -762,10 +762,17 @@ class Accumulation(ComfyTypeIO): @comfytype(io_type="LOAD3D_CAMERA") class Load3DCamera(ComfyTypeIO): class CameraInfo(TypedDict): - position: dict[str, float | int] - target: dict[str, float | int] - zoom: int - cameraType: str + # Coordinate system: right-handed, Y-up, camera looks down -Z + position: dict[str, float | int] # scene units + target: dict[str, float | int] # scene units; OrbitControls focus point + zoom: float | int # dimensionless, 1 = 100% + cameraType: str # 'perspective' | 'orthographic' + quaternion: NotRequired[dict[str, float | int]] # normalized, dimensionless; camera world rotation + fov: NotRequired[float | int] # degrees, vertical FOV (perspective only) + aspect: NotRequired[float | int] # width / height (perspective only) + near: NotRequired[float | int] # scene units + far: NotRequired[float | int] # scene units + frustum: NotRequired[dict[str, float | int]] # orthographic only: {left, right, top, bottom} in scene units Type = CameraInfo diff --git a/comfy_api_nodes/apis/beeble.py b/comfy_api_nodes/apis/beeble.py new file mode 100644 index 000000000..90175b214 --- /dev/null +++ b/comfy_api_nodes/apis/beeble.py @@ -0,0 +1,32 @@ +from pydantic import BaseModel, Field + + +class CreateSwitchXRequest(BaseModel): + generation_type: str = Field(...) + source_uri: str = Field(...) + alpha_mode: str = Field(...) + prompt: str | None = Field(None, max_length=2000) + reference_image_uri: str | None = Field(None) + alpha_uri: str | None = Field(None) + max_resolution: int = Field(1080) + callback_url: str | None = Field(None) + idempotency_key: str | None = Field(None, max_length=256, min_length=1) + + +class SwitchXOutputUrls(BaseModel): + render: str | None = Field(None) + source: str | None = Field(None) + alpha: str | None = Field(None) + + +class SwitchXStatusResponse(BaseModel): + id: str = Field(...) + status: str = Field(...) + progress: int | None = Field(None) + generation_type: str | None = Field(None) + alpha_mode: str | None = Field(None) + output: SwitchXOutputUrls | None = Field(None) + error: str | None = Field(None) + created_at: str | None = Field(None) + modified_at: str | None = Field(None) + completed_at: str | None = Field(None) diff --git a/comfy_api_nodes/apis/bytedance.py b/comfy_api_nodes/apis/bytedance.py index 03f4c445b..47f24586c 100644 --- a/comfy_api_nodes/apis/bytedance.py +++ b/comfy_api_nodes/apis/bytedance.py @@ -158,8 +158,9 @@ class SeedanceCreateAssetResponse(BaseModel): class SeedanceVirtualLibraryCreateAssetRequest(BaseModel): - url: str = Field(..., description="Publicly accessible URL of the image asset to upload.") + url: str = Field(..., description="Publicly accessible URL of the asset to upload.") hash: str = Field(..., description="Dedup key. Re-submitting the same hash returns the existing asset id.") + asset_type: str | None = Field(None, description="BytePlus asset type. Defaults to Image server-side when omitted.") # Dollars per 1K tokens, keyed by (model_id, has_video_input). diff --git a/comfy_api_nodes/apis/krea.py b/comfy_api_nodes/apis/krea.py new file mode 100644 index 000000000..6e294a3b7 --- /dev/null +++ b/comfy_api_nodes/apis/krea.py @@ -0,0 +1,46 @@ +"""Pydantic models for the Krea image-generation API.""" + +from pydantic import BaseModel, Field + + +class KreaMoodboard(BaseModel): + id: str = Field(...) + strength: float = Field(default=0.35, ge=-0.5, le=1.5) + + +class KreaImageStyleReference(BaseModel): + strength: float = Field(..., ge=-2.0, le=2.0) + url: str | None = Field(default=None) + + +class KreaGenerateImageRequest(BaseModel): + prompt: str = Field(...) + aspect_ratio: str = Field(...) + resolution: str = Field(...) + seed: int | None = Field(default=None) + creativity: str = Field(default="medium") + moodboards: list[KreaMoodboard] | None = Field(default=None) + image_style_references: list[KreaImageStyleReference] | None = Field(default=None) + + +class KreaJobResult(BaseModel): + urls: list[str] | None = Field(default=None) + style_id: str | None = Field(default=None) + + +class KreaJob(BaseModel): + job_id: str = Field(...) + status: str = Field(...) + created_at: str = Field(...) + completed_at: str | None = Field(default=None) + result: KreaJobResult | None = Field(default=None) + + +class KreaAssetResponse(BaseModel): + id: str = Field(...) + image_url: str = Field(...) + uploaded_at: str = Field(...) + width: float | None = Field(default=None) + height: float | None = Field(default=None) + size_bytes: float | None = Field(default=None) + mime_type: str | None = Field(default=None) diff --git a/comfy_api_nodes/nodes_anthropic.py b/comfy_api_nodes/nodes_anthropic.py index 42ec5708f..7805c96ce 100644 --- a/comfy_api_nodes/nodes_anthropic.py +++ b/comfy_api_nodes/nodes_anthropic.py @@ -155,7 +155,7 @@ class ClaudeNode(IO.ComfyNode): return IO.Schema( node_id="ClaudeNode", display_name="Anthropic Claude", - category="api node/text/Anthropic", + category="text/partner/Anthropic", essentials_category="Text Generation", description="Generate text responses with Anthropic's Claude models. " "Provide a text prompt and optionally one or more images for multimodal context.", diff --git a/comfy_api_nodes/nodes_beeble.py b/comfy_api_nodes/nodes_beeble.py new file mode 100644 index 000000000..f1082884c --- /dev/null +++ b/comfy_api_nodes/nodes_beeble.py @@ -0,0 +1,404 @@ +from fractions import Fraction + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl, Types +from comfy_api_nodes.apis.beeble import ( + CreateSwitchXRequest, + SwitchXStatusResponse, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + bytesio_to_image_tensor, + convert_mask_to_image, + download_url_as_bytesio, + download_url_to_image_tensor, + download_url_to_video_output, + downscale_image_tensor, + downscale_video_to_max_pixels, + poll_op, + sync_op, + upload_image_to_comfyapi, + upload_video_to_comfyapi, + validate_string, + validate_video_frame_count, +) + +_MAX_PIXELS = 2_770_000 +_MAX_FRAMES = 240 +_MAX_PROMPT_LEN = 2000 + + +def _validate_inputs(prompt: str | None, reference_image: Input.Image | None) -> str | None: + """Beeble requires at least one of prompt or reference_image. Returns the cleaned prompt.""" + cleaned = prompt.strip() if prompt else "" + if not cleaned and reference_image is None: + raise ValueError("At least one of 'prompt' or 'reference_image' must be provided.") + if cleaned: + validate_string(cleaned, strip_whitespace=False, max_length=_MAX_PROMPT_LEN) + return cleaned or None + + +async def _upload_mask_as_image( + cls: type[IO.ComfyNode], + mask: Input.Image, + *, + wait_label: str, +) -> str: + """Encode a single-frame MASK (H, W) or (1, H, W) as a PNG and upload.""" + if mask.dim() == 2: + mask = mask.unsqueeze(0) + image = convert_mask_to_image(mask[:1]) + return await upload_image_to_comfyapi( + cls, + image, + mime_type="image/png", + wait_label=wait_label, + total_pixels=_MAX_PIXELS, + ) + + +async def _upload_mask_batch_as_video( + cls: type[IO.ComfyNode], + mask: Input.Image, + *, + frame_rate: Fraction, + source_frame_count: int, + wait_label: str, +) -> str: + """Encode a MASK batch (N, H, W) as a grayscale H.264 MP4 at frame_rate and upload. + + The matte is always downscaled to the pixel budget so it stays within Beeble's limit and + keeps the same dimensions as the (similarly downscaled) source — both use the same algorithm + from the same starting dimensions, and downscaling is a no-op when already within budget. + """ + if mask.dim() == 2: + mask = mask.unsqueeze(0) + if mask.shape[0] != source_frame_count: + raise ValueError( + f"Custom alpha video frame count ({mask.shape[0]}) does not match the " + f"source video frame count ({source_frame_count}). The Beeble API requires " + "one mask per source frame." + ) + images = downscale_image_tensor(convert_mask_to_image(mask), _MAX_PIXELS) + alpha_video = InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=None, frame_rate=frame_rate)) + return await upload_video_to_comfyapi(cls, alpha_video, wait_label=wait_label) + + +def _alpha_mode_input(*, video: bool) -> IO.DynamicCombo.Input: + """Build the alpha_mode DynamicCombo with mode-specific extra inputs.""" + select_keyframe_tooltip = ( + "First-frame keyframe mask. Beeble propagates this across the video." if video else "Grayscale keyframe mask." + ) + custom_tooltip = ( + "Per-frame grayscale mask covering the entire video. " + "Must have the same frame count as the source. " + "Connect a MASK output from SAM3_TrackToMask or similar." + if video + else "Grayscale mask to apply." + ) + return IO.DynamicCombo.Input( + "alpha_mode", + tooltip=( + "Controls how SwitchX decides what to keep vs. regenerate. " + "'auto' isolates the main subject automatically. " + "'fill' regenerates the entire frame while preserving geometry. " + "'select' propagates a first-frame keyframe across the clip. " + "'custom' uses a per-frame alpha matte you provide." + ), + options=[ + IO.DynamicCombo.Option("auto", []), + IO.DynamicCombo.Option("fill", []), + IO.DynamicCombo.Option( + "select", + [IO.Mask.Input("alpha_keyframe", tooltip=select_keyframe_tooltip)], + ), + IO.DynamicCombo.Option( + "custom", + [IO.Mask.Input("alpha_mask", tooltip=custom_tooltip)], + ), + ], + ) + + +def _common_inputs(*, source: IO.Input, video: bool) -> list[IO.Input]: + return [ + source, + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip=( + "Text description of the desired output (max 2000 chars). " + "At least one of 'prompt' or 'reference_image' is required." + ), + ), + IO.Image.Input( + "reference_image", + optional=True, + tooltip=( + "Reference image whose look (background, lighting, costume) the result " + "should adopt. At least one of 'reference_image' or 'prompt' is required." + ), + ), + _alpha_mode_input(video=video), + IO.Combo.Input( + "max_resolution", + options=["1080p", "720p"], + default="1080p", + tooltip="Maximum output resolution.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip=( + "Seed controls whether the node should re-run; " "results are non-deterministic regardless of seed." + ), + ), + ] + + +async def _submit_and_poll( + cls: type[IO.ComfyNode], + request: CreateSwitchXRequest, +) -> SwitchXStatusResponse: + initial = await sync_op( + cls, + ApiEndpoint(path="/proxy/beeble/v1/switchx/generations", method="POST"), + response_model=SwitchXStatusResponse, + data=request, + ) + return await poll_op( + cls, + ApiEndpoint(path=f"/proxy/beeble/v1/switchx/generations/{initial.id}"), + response_model=SwitchXStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + + +def _require_output_url(response: SwitchXStatusResponse, name: str) -> str: + if response.output is None or getattr(response.output, name) is None: + raise RuntimeError(f"Beeble job {response.id} completed without a {name!r} output URL.") + return getattr(response.output, name) + + +def _alpha_url(response: SwitchXStatusResponse, mode: str) -> str | None: + """URL of the alpha matte, or None when the mode produces no separate matte. + + 'fill' selects the whole frame, so Beeble writes no alpha asset even though the status + response still returns a (dangling) signed URL for it — fetching it 403s with S3 + AccessDenied. The other three modes ('auto', 'custom', 'select') all produce a real, + downloadable matte. + """ + if mode == "fill" or response.output is None: + return None + return response.output.alpha + + +class BeebleSwitchXVideoEdit(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="BeebleSwitchXVideoEdit", + display_name="Beeble SwitchX Video Edit", + category="video/partner/Beeble", + description=( + "Edit a video with Beeble SwitchX. Switches anything in the scene (background, " + "lighting, costume) while preserving the original subject's pixels and motion. " + "Provide a reference image and/or text prompt to describe the new look. " + "Max 240 frames, max ~2.77MP per frame." + ), + inputs=_common_inputs(source=IO.Video.Input("video"), video=True), + outputs=[ + IO.Video.Output(display_name="video"), + IO.Video.Output( + display_name="alpha", + tooltip="The alpha matte Beeble used. Empty for 'fill' mode, which has no separate matte.", + ), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["max_resolution"]), + expr=""" + ( + $rate := widgets.max_resolution = "1080p" ? 0.429 : 0.143; + {"type":"usd","usd": $rate, "format":{"suffix":"/30 frames"}} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + prompt: str, + alpha_mode: dict, + max_resolution: str, + seed: int, + reference_image: Input.Image | None = None, + ) -> IO.NodeOutput: + cleaned_prompt = _validate_inputs(prompt, reference_image) + + validate_video_frame_count(video, max_frame_count=_MAX_FRAMES) + video = downscale_video_to_max_pixels(video, _MAX_PIXELS) + + mode = alpha_mode["alpha_mode"] + alpha_uri: str | None = None + if mode == "select": + alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_keyframe"], wait_label="Uploading keyframe") + elif mode == "custom": + alpha_uri = await _upload_mask_batch_as_video( + cls, + alpha_mode["alpha_mask"], + frame_rate=video.get_frame_rate(), + source_frame_count=video.get_frame_count(), + wait_label="Uploading alpha video", + ) + + source_uri = await upload_video_to_comfyapi(cls, video, wait_label="Uploading source") + reference_uri: str | None = None + if reference_image is not None: + reference_uri = await upload_image_to_comfyapi( + cls, + reference_image, + mime_type="image/png", + wait_label="Uploading reference", + total_pixels=_MAX_PIXELS, + ) + + request = CreateSwitchXRequest( + generation_type="video", + source_uri=source_uri, + alpha_mode=mode, + prompt=cleaned_prompt, + reference_image_uri=reference_uri, + alpha_uri=alpha_uri, + max_resolution=1080 if max_resolution == "1080p" else 720, + ) + response = await _submit_and_poll(cls, request) + + render = await download_url_to_video_output(_require_output_url(response, "render")) + alpha = None + if (alpha_url := _alpha_url(response, mode)) is not None: + alpha = await download_url_to_video_output(alpha_url) + return IO.NodeOutput(render, alpha) + + +class BeebleSwitchXImageEdit(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="BeebleSwitchXImageEdit", + display_name="Beeble SwitchX Image Edit", + category="image/partner/Beeble", + description=( + "Edit a single image with Beeble SwitchX. Switches anything in the scene " + "(background, lighting, costume) while preserving the original subject's pixels. " + "Provide a reference image and/or text prompt to describe the new look. " + "Max ~2.77MP." + ), + inputs=_common_inputs(source=IO.Image.Input("image"), video=False), + outputs=[ + IO.Image.Output(display_name="image"), + IO.Mask.Output( + display_name="alpha", + tooltip="The alpha matte Beeble used. Empty for 'fill' mode, which has no separate matte.", + ), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["max_resolution"]), + expr=""" + ( + $rate := widgets.max_resolution = "1080p" ? 0.429 : 0.143; + {"type":"usd","usd": $rate} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + prompt: str, + alpha_mode: dict, + max_resolution: str, + seed: int, + reference_image: Input.Image | None = None, + ) -> IO.NodeOutput: + cleaned_prompt = _validate_inputs(prompt, reference_image) + + image = downscale_image_tensor(image, _MAX_PIXELS) + + mode = alpha_mode["alpha_mode"] + alpha_uri: str | None = None + if mode == "select": + alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_keyframe"], wait_label="Uploading keyframe") + elif mode == "custom": + alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_mask"], wait_label="Uploading alpha") + + source_uri = await upload_image_to_comfyapi( + cls, + image, + mime_type="image/png", + wait_label="Uploading source", + total_pixels=None, + ) + reference_uri: str | None = None + if reference_image is not None: + reference_uri = await upload_image_to_comfyapi( + cls, + reference_image, + mime_type="image/png", + wait_label="Uploading reference", + total_pixels=_MAX_PIXELS, + ) + + request = CreateSwitchXRequest( + generation_type="image", + source_uri=source_uri, + alpha_mode=mode, + prompt=cleaned_prompt, + reference_image_uri=reference_uri, + alpha_uri=alpha_uri, + max_resolution=1080 if max_resolution == "1080p" else 720, + ) + response = await _submit_and_poll(cls, request) + + render = await download_url_to_image_tensor(_require_output_url(response, "render")) + alpha_mask = None + if (alpha_url := _alpha_url(response, mode)) is not None: + alpha_image = bytesio_to_image_tensor(await download_url_as_bytesio(alpha_url), mode="L") + alpha_mask = alpha_image.squeeze(-1) if alpha_image.dim() == 4 else alpha_image + return IO.NodeOutput(render, alpha_mask) + + +class BeebleExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + BeebleSwitchXVideoEdit, + BeebleSwitchXImageEdit, + ] + + +async def comfy_entrypoint() -> BeebleExtension: + return BeebleExtension() diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 3f0ce29d8..f1a5dc5f0 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -42,7 +42,7 @@ class FluxProUltraImageNode(IO.ComfyNode): return IO.Schema( node_id="FluxProUltraImageNode", display_name="Flux 1.1 [pro] Ultra Image", - category="api node/image/BFL", + category="image/partner/BFL", description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.", inputs=[ IO.String.Input( @@ -160,7 +160,7 @@ class FluxKontextProImageNode(IO.ComfyNode): return IO.Schema( node_id=cls.NODE_ID, display_name=cls.DISPLAY_NAME, - category="api node/image/BFL", + category="image/partner/BFL", description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.", inputs=[ IO.String.Input( @@ -282,7 +282,7 @@ class FluxProExpandNode(IO.ComfyNode): return IO.Schema( node_id="FluxProExpandNode", display_name="Flux.1 Expand Image", - category="api node/image/BFL", + category="image/partner/BFL", description="Outpaints image based on prompt.", inputs=[ IO.Image.Input("image"), @@ -419,7 +419,7 @@ class FluxProFillNode(IO.ComfyNode): return IO.Schema( node_id="FluxProFillNode", display_name="Flux.1 Fill Image", - category="api node/image/BFL", + category="image/partner/BFL", description="Inpaints image based on mask and prompt.", inputs=[ IO.Image.Input("image"), @@ -545,7 +545,7 @@ class Flux2ProImageNode(IO.ComfyNode): return IO.Schema( node_id=cls.NODE_ID, display_name=cls.DISPLAY_NAME, - category="api node/image/BFL", + category="image/partner/BFL", description="Generates images synchronously based on prompt and resolution.", inputs=[ IO.String.Input( @@ -716,7 +716,7 @@ class Flux2ImageNode(IO.ComfyNode): return IO.Schema( node_id="Flux2ImageNode", display_name="Flux.2 Image", - category="api node/image/BFL", + category="image/partner/BFL", description="Generate images via Flux.2 [pro] or Flux.2 [max] from a prompt and optional reference images.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py index 4044ee3ea..53e763210 100644 --- a/comfy_api_nodes/nodes_bria.py +++ b/comfy_api_nodes/nodes_bria.py @@ -31,7 +31,7 @@ class BriaImageEditNode(IO.ComfyNode): return IO.Schema( node_id="BriaImageEditNode", display_name="Bria FIBO Image Edit", - category="api node/image/Bria", + category="image/partner/Bria", description="Edit images using Bria latest model", inputs=[ IO.Combo.Input("model", options=["FIBO"]), @@ -169,7 +169,7 @@ class BriaRemoveImageBackground(IO.ComfyNode): return IO.Schema( node_id="BriaRemoveImageBackground", display_name="Bria Remove Image Background", - category="api node/image/Bria", + category="image/partner/Bria", description="Remove the background from an image using Bria RMBG 2.0.", inputs=[ IO.Image.Input("image"), @@ -245,7 +245,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode): return IO.Schema( node_id="BriaRemoveVideoBackground", display_name="Bria Remove Video Background", - category="api node/video/Bria", + category="video/partner/Bria", description="Remove the background from a video using Bria. ", inputs=[ IO.Video.Input("video"), diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index e08fc0b01..3711bac1d 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -2,11 +2,12 @@ import hashlib import logging import math import re +from io import BytesIO import torch from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api.latest import IO, ComfyExtension, Input, Types from comfy_api_nodes.apis.bytedance import ( RECOMMENDED_PRESETS, RECOMMENDED_PRESETS_SEEDREAM_4, @@ -43,6 +44,7 @@ from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_image_tensor, download_url_to_video_output, + downscale_image_tensor_by_max_side, downscale_video_to_max_pixels, get_number_of_images, image_tensor_pair_to_batch, @@ -121,6 +123,14 @@ def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: st ) +def _prepare_seedance_image(image: Input.Image) -> Input.Image: + """Auto-downscale a Seedance image input to the per-side limits, then validate it.""" + validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + image = downscale_image_tensor_by_max_side(image, max_side=6000) + validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) + return image + + async def _resolve_reference_assets( cls: type[IO.ComfyNode], asset_ids: list[str], @@ -308,6 +318,26 @@ async def _seedance_virtual_library_upload_image_asset( return f"asset://{create_resp.asset_id}" +async def _seedance_virtual_library_upload_video_asset( + cls: type[IO.ComfyNode], + video: Input.Video, + *, + wait_label: str = "Uploading video", +) -> str: + buf = BytesIO() + video.save_to(buf, format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264) + video_hash = hashlib.sha256(buf.getbuffer()).hexdigest() + public_url = await upload_video_to_comfyapi(cls, video, wait_label=wait_label) + create_resp = await sync_op( + cls, + ApiEndpoint(path="/proxy/seedance/virtual-library/assets", method="POST"), + response_model=SeedanceCreateAssetResponse, + data=SeedanceVirtualLibraryCreateAssetRequest(url=public_url, hash=video_hash, asset_type="Video"), + ) + await _wait_for_asset_active(cls, create_resp.asset_id, group_id="virtual-library") + return f"asset://{create_resp.asset_id}" + + def _seedance2_price_extractor(model_id: str, has_video_input: bool): """Returns a price_extractor closure for Seedance 2.0 poll_op.""" rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input)) @@ -338,7 +368,7 @@ class ByteDanceImageNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceImageNode", display_name="ByteDance Image", - category="api node/image/ByteDance", + category="image/partner/ByteDance", description="Generate images using ByteDance models via api based on prompt", inputs=[ IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]), @@ -462,7 +492,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceSeedreamNode", display_name="ByteDance Seedream 4.5 & 5.0", - category="api node/image/ByteDance", + category="image/partner/ByteDance", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", inputs=[ IO.Combo.Input( @@ -724,7 +754,7 @@ class ByteDanceSeedreamNodeV2(IO.ComfyNode): return IO.Schema( node_id="ByteDanceSeedreamNodeV2", display_name="ByteDance Seedream 4.5 & 5.0", - category="api node/image/ByteDance", + category="image/partner/ByteDance", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", inputs=[ IO.String.Input( @@ -890,7 +920,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceTextToVideoNode", display_name="ByteDance Text to Video", - category="api node/video/ByteDance", + category="video/partner/ByteDance", description="Generate video using ByteDance models via api based on prompt", inputs=[ IO.Combo.Input( @@ -1018,7 +1048,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceImageToVideoNode", display_name="ByteDance Image to Video", - category="api node/video/ByteDance", + category="video/partner/ByteDance", description="Generate video using ByteDance models via api based on image and prompt", inputs=[ IO.Combo.Input( @@ -1155,7 +1185,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceFirstLastFrameNode", display_name="ByteDance First-Last-Frame to Video", - category="api node/video/ByteDance", + category="video/partner/ByteDance", description="Generate video using prompt and first and last frames.", inputs=[ IO.Combo.Input( @@ -1303,7 +1333,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceImageReferenceNode", display_name="ByteDance Reference Images to Video", - category="api node/video/ByteDance", + category="video/partner/ByteDance", description="Generate video using prompt and reference images.", inputs=[ IO.Combo.Input( @@ -1546,7 +1576,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ByteDance2TextToVideoNode", display_name="ByteDance Seedance 2.0 Text to Video", - category="api node/video/ByteDance", + category="video/partner/ByteDance", description="Generate video using Seedance 2.0 models based on a text prompt.", inputs=[ IO.DynamicCombo.Input( @@ -1647,7 +1677,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="ByteDance2FirstLastFrameNode", display_name="ByteDance Seedance 2.0 First-Last-Frame to Video", - category="api node/video/ByteDance", + category="video/partner/ByteDance", description="Generate video using Seedance 2.0 from a first frame image and optional last frame image.", inputs=[ IO.DynamicCombo.Input( @@ -1760,6 +1790,11 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): if last_frame is not None and last_frame_asset_id: raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.") + if first_frame is not None: + first_frame = _prepare_seedance_image(first_frame) + if last_frame is not None: + last_frame = _prepare_seedance_image(last_frame) + asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a] image_assets: dict[str, str] = {} if asset_ids_to_resolve: @@ -1866,7 +1901,7 @@ def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16 ), IO.Boolean.Input( "auto_downscale", - default=False, + default=True, optional=True, tooltip="Automatically downscale reference videos that exceed the model's pixel budget " "for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.", @@ -1909,7 +1944,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode): return IO.Schema( node_id="ByteDance2ReferenceNode", display_name="ByteDance Seedance 2.0 Reference to Video", - category="api node/video/ByteDance", + category="video/partner/ByteDance", description="Generate, edit, or extend video using Seedance 2.0 with reference images, " "videos, and audio. Supports multimodal reference, video editing, and video extension.", inputs=[ @@ -2034,6 +2069,9 @@ class ByteDance2ReferenceNode(IO.ComfyNode): f"(audios={len(reference_audios)}, audio assets={len(reference_audio_assets)}). Maximum is 3." ) + for key in reference_images: + reference_images[key] = _prepare_seedance_image(reference_images[key]) + model_id = SEEDANCE_MODELS[model["model"]] has_video_input = total_videos > 0 @@ -2106,7 +2144,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode): content.append( TaskVideoContent( video_url=TaskVideoContentUrl( - url=await upload_video_to_comfyapi( + url=await _seedance_virtual_library_upload_video_asset( cls, reference_videos[key], wait_label=f"Uploading video {i}", @@ -2203,7 +2241,7 @@ class ByteDanceCreateImageAsset(IO.ComfyNode): return IO.Schema( node_id="ByteDanceCreateImageAsset", display_name="ByteDance Create Image Asset", - category="api node/image/ByteDance", + category="image/partner/ByteDance", description=( "Create a Seedance 2.0 personal image asset. Uploads the input image and " "registers it in the given asset group. If group_id is empty, runs a real-person " @@ -2270,7 +2308,7 @@ class ByteDanceCreateVideoAsset(IO.ComfyNode): return IO.Schema( node_id="ByteDanceCreateVideoAsset", display_name="ByteDance Create Video Asset", - category="api node/video/ByteDance", + category="video/partner/ByteDance", description=( "Create a Seedance 2.0 personal video asset. Uploads the input video and " "registers it in the given asset group. If group_id is empty, runs a real-person " diff --git a/comfy_api_nodes/nodes_bytedance_llm.py b/comfy_api_nodes/nodes_bytedance_llm.py index fa7fe370a..007cac45f 100644 --- a/comfy_api_nodes/nodes_bytedance_llm.py +++ b/comfy_api_nodes/nodes_bytedance_llm.py @@ -144,7 +144,7 @@ class ByteDanceSeedNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceSeedNode", display_name="ByteDance Seed", - category="api node/text/ByteDance", + category="text/partner/ByteDance", essentials_category="Text Generation", description="Generate text responses with ByteDance's Seed 2.0 models. " "Provide a text prompt and optionally one or more images or videos for multimodal context.", diff --git a/comfy_api_nodes/nodes_elevenlabs.py b/comfy_api_nodes/nodes_elevenlabs.py index e452daf77..37eeb2601 100644 --- a/comfy_api_nodes/nodes_elevenlabs.py +++ b/comfy_api_nodes/nodes_elevenlabs.py @@ -69,7 +69,7 @@ class ElevenLabsSpeechToText(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsSpeechToText", display_name="ElevenLabs Speech to Text", - category="api node/audio/ElevenLabs", + category="audio/partner/ElevenLabs", description="Transcribe audio to text. " "Supports automatic language detection, speaker diarization, and audio event tagging.", inputs=[ @@ -210,7 +210,7 @@ class ElevenLabsVoiceSelector(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsVoiceSelector", display_name="ElevenLabs Voice Selector", - category="api node/audio/ElevenLabs", + category="audio/partner/ElevenLabs", description="Select a predefined ElevenLabs voice for text-to-speech generation.", inputs=[ IO.Combo.Input( @@ -239,7 +239,7 @@ class ElevenLabsTextToSpeech(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsTextToSpeech", display_name="ElevenLabs Text to Speech", - category="api node/audio/ElevenLabs", + category="audio/partner/ElevenLabs", description="Convert text to speech.", inputs=[ IO.Custom(ELEVENLABS_VOICE).Input( @@ -414,7 +414,7 @@ class ElevenLabsAudioIsolation(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsAudioIsolation", display_name="ElevenLabs Voice Isolation", - category="api node/audio/ElevenLabs", + category="audio/partner/ElevenLabs", description="Remove background noise from audio, isolating vocals or speech.", inputs=[ IO.Audio.Input( @@ -459,7 +459,7 @@ class ElevenLabsTextToSoundEffects(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsTextToSoundEffects", display_name="ElevenLabs Text to Sound Effects", - category="api node/audio/ElevenLabs", + category="audio/partner/ElevenLabs", description="Generate sound effects from text descriptions.", inputs=[ IO.String.Input( @@ -555,7 +555,7 @@ class ElevenLabsInstantVoiceClone(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsInstantVoiceClone", display_name="ElevenLabs Instant Voice Clone", - category="api node/audio/ElevenLabs", + category="audio/partner/ElevenLabs", description="Create a cloned voice from audio samples. " "Provide 1-8 audio recordings of the voice to clone.", inputs=[ @@ -658,7 +658,7 @@ class ElevenLabsSpeechToSpeech(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsSpeechToSpeech", display_name="ElevenLabs Speech to Speech", - category="api node/audio/ElevenLabs", + category="audio/partner/ElevenLabs", description="Transform speech from one voice to another while preserving the original content and emotion.", inputs=[ IO.Custom(ELEVENLABS_VOICE).Input( @@ -793,7 +793,7 @@ class ElevenLabsTextToDialogue(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsTextToDialogue", display_name="ElevenLabs Text to Dialogue", - category="api node/audio/ElevenLabs", + category="audio/partner/ElevenLabs", description="Generate multi-speaker dialogue from text. Each dialogue entry has its own text and voice.", inputs=[ IO.Float.Input( diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index d18c958a8..3cfd541b2 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -300,7 +300,7 @@ class GeminiNode(IO.ComfyNode): return IO.Schema( node_id="GeminiNode", display_name="Google Gemini", - category="api node/text/Gemini", + category="text/partner/Gemini", description="Generate text responses with Google's Gemini AI model. " "You can provide multiple types of inputs (text, images, audio, video) " "as context for generating more relevant and meaningful responses.", @@ -541,7 +541,7 @@ class GeminiInputFiles(IO.ComfyNode): return IO.Schema( node_id="GeminiInputFiles", display_name="Gemini Input Files", - category="api node/text/Gemini", + category="text/partner/Gemini", description="Loads and prepares input files to include as inputs for Gemini LLM nodes. " "The files will be read by the Gemini model when generating a response. " "The contents of the text file count toward the token limit. " @@ -598,7 +598,7 @@ class GeminiImage(IO.ComfyNode): return IO.Schema( node_id="GeminiImageNode", display_name="Nano Banana (Google Gemini Image)", - category="api node/image/Gemini", + category="image/partner/Gemini", description="Edit images synchronously via Google API.", inputs=[ IO.String.Input( @@ -731,7 +731,7 @@ class GeminiImage2(IO.ComfyNode): return IO.Schema( node_id="GeminiImage2Node", display_name="Nano Banana Pro (Google Gemini Image)", - category="api node/image/Gemini", + category="image/partner/Gemini", description="Generate or edit images synchronously via Google Vertex API.", inputs=[ IO.String.Input( @@ -869,7 +869,7 @@ class GeminiNanoBanana2(IO.ComfyNode): return IO.Schema( node_id="GeminiNanoBanana2", display_name="Nano Banana 2", - category="api node/image/Gemini", + category="image/partner/Gemini", description="Generate or edit images synchronously via Google Vertex API.", inputs=[ IO.String.Input( @@ -1085,7 +1085,7 @@ class GeminiNanoBanana2V2(IO.ComfyNode): return IO.Schema( node_id="GeminiNanoBanana2V2", display_name="Nano Banana 2", - category="api node/image/Gemini", + category="image/partner/Gemini", description="Generate or edit images synchronously via Google Vertex API.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py index a103f24ee..43e3cdc26 100644 --- a/comfy_api_nodes/nodes_grok.py +++ b/comfy_api_nodes/nodes_grok.py @@ -49,7 +49,7 @@ class GrokImageNode(IO.ComfyNode): return IO.Schema( node_id="GrokImageNode", display_name="Grok Image", - category="api node/image/Grok", + category="image/partner/Grok", description="Generate images using Grok based on a text prompt", inputs=[ IO.Combo.Input( @@ -224,7 +224,7 @@ class GrokImageEditNode(IO.ComfyNode): return IO.Schema( node_id="GrokImageEditNode", display_name="Grok Image Edit", - category="api node/image/Grok", + category="image/partner/Grok", description="Modify an existing image based on a text prompt", inputs=[ IO.Combo.Input( @@ -366,7 +366,7 @@ class GrokImageEditNodeV2(IO.ComfyNode): return IO.Schema( node_id="GrokImageEditNodeV2", display_name="Grok Image Edit", - category="api node/image/Grok", + category="image/partner/Grok", description="Modify an existing image based on a text prompt", inputs=[ IO.String.Input( @@ -503,7 +503,7 @@ class GrokVideoNode(IO.ComfyNode): return IO.Schema( node_id="GrokVideoNode", display_name="Grok Video", - category="api node/video/Grok", + category="video/partner/Grok", description="Generate video from a prompt or an image", inputs=[ IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), @@ -615,7 +615,7 @@ class GrokVideoEditNode(IO.ComfyNode): return IO.Schema( node_id="GrokVideoEditNode", display_name="Grok Video Edit", - category="api node/video/Grok", + category="video/partner/Grok", description="Edit an existing video based on a text prompt.", inputs=[ IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), @@ -693,7 +693,7 @@ class GrokVideoReferenceNode(IO.ComfyNode): return IO.Schema( node_id="GrokVideoReferenceNode", display_name="Grok Reference-to-Video", - category="api node/video/Grok", + category="video/partner/Grok", description="Generate video guided by reference images as style and content references.", inputs=[ IO.String.Input( @@ -826,7 +826,7 @@ class GrokVideoExtendNode(IO.ComfyNode): return IO.Schema( node_id="GrokVideoExtendNode", display_name="Grok Video Extend", - category="api node/video/Grok", + category="video/partner/Grok", description="Extend an existing video with a seamless continuation based on a text prompt.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_hitpaw.py b/comfy_api_nodes/nodes_hitpaw.py index bca5170e4..22e679c29 100644 --- a/comfy_api_nodes/nodes_hitpaw.py +++ b/comfy_api_nodes/nodes_hitpaw.py @@ -71,7 +71,7 @@ class HitPawGeneralImageEnhance(IO.ComfyNode): return IO.Schema( node_id="HitPawGeneralImageEnhance", display_name="HitPaw General Image Enhance", - category="api node/image/HitPaw", + category="image/partner/HitPaw", description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. " f"Maximum output: {MAX_MP_GENERATIVE} megapixels.", inputs=[ @@ -201,7 +201,7 @@ class HitPawVideoEnhance(IO.ComfyNode): return IO.Schema( node_id="HitPawVideoEnhance", display_name="HitPaw Video Enhance", - category="api node/video/HitPaw", + category="video/partner/HitPaw", description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. " "Prices shown are per second of video.", inputs=[ diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py index 5fc31bccd..826a3bd2d 100644 --- a/comfy_api_nodes/nodes_hunyuan3d.py +++ b/comfy_api_nodes/nodes_hunyuan3d.py @@ -123,7 +123,7 @@ class TencentTextToModelNode(IO.ComfyNode): return IO.Schema( node_id="TencentTextToModelNode", display_name="Hunyuan3D: Text to Model", - category="api node/3d/Tencent", + category="3d/partner/Tencent", essentials_category="3D", inputs=[ IO.Combo.Input( @@ -242,7 +242,7 @@ class TencentImageToModelNode(IO.ComfyNode): return IO.Schema( node_id="TencentImageToModelNode", display_name="Hunyuan3D: Image(s) to Model", - category="api node/3d/Tencent", + category="3d/partner/Tencent", essentials_category="3D", inputs=[ IO.Combo.Input( @@ -415,7 +415,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode): return IO.Schema( node_id="TencentModelTo3DUVNode", display_name="Hunyuan3D: Model to UV", - category="api node/3d/Tencent", + category="3d/partner/Tencent", description="Perform UV unfolding on a 3D model to generate UV texture. " "Input model must have less than 30000 faces.", inputs=[ @@ -505,7 +505,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode): return IO.Schema( node_id="Tencent3DTextureEditNode", display_name="Hunyuan3D: 3D Texture Edit", - category="api node/3d/Tencent", + category="3d/partner/Tencent", description="After inputting the 3D model, perform 3D model texture redrawing.", inputs=[ IO.MultiType.Input( @@ -594,7 +594,7 @@ class Tencent3DPartNode(IO.ComfyNode): return IO.Schema( node_id="Tencent3DPartNode", display_name="Hunyuan3D: 3D Part", - category="api node/3d/Tencent", + category="3d/partner/Tencent", description="Automatically perform component identification and generation based on the model structure.", inputs=[ IO.MultiType.Input( @@ -666,7 +666,7 @@ class TencentSmartTopologyNode(IO.ComfyNode): return IO.Schema( node_id="TencentSmartTopologyNode", display_name="Hunyuan3D: Smart Topology", - category="api node/3d/Tencent", + category="3d/partner/Tencent", description="Perform smart retopology on a 3D model. " "Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.", inputs=[ diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index 97c3609bd..edd9b9435 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -234,7 +234,7 @@ class IdeogramV1(IO.ComfyNode): return IO.Schema( node_id="IdeogramV1", display_name="Ideogram V1", - category="api node/image/Ideogram", + category="image/partner/Ideogram", description="Generates images using the Ideogram V1 model.", inputs=[ IO.String.Input( @@ -360,7 +360,7 @@ class IdeogramV2(IO.ComfyNode): return IO.Schema( node_id="IdeogramV2", display_name="Ideogram V2", - category="api node/image/Ideogram", + category="image/partner/Ideogram", description="Generates images using the Ideogram V2 model.", inputs=[ IO.String.Input( @@ -526,7 +526,7 @@ class IdeogramV3(IO.ComfyNode): return IO.Schema( node_id="IdeogramV3", display_name="Ideogram V3", - category="api node/image/Ideogram", + category="image/partner/Ideogram", description="Generates images using the Ideogram V3 model. " "Supports both regular image generation from text prompts and image editing with mask.", inputs=[ diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 7586f1816..9925ec548 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -642,7 +642,7 @@ class KlingCameraControls(IO.ComfyNode): return IO.Schema( node_id="KlingCameraControls", display_name="Kling Camera Controls", - category="api node/video/Kling", + category="video/partner/Kling", description="Allows specifying configuration options for Kling Camera Controls and motion control effects.", inputs=[ IO.Combo.Input("camera_control_type", options=KlingCameraControlType), @@ -762,7 +762,7 @@ class KlingTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingTextToVideoNode", display_name="Kling Text to Video", - category="api node/video/Kling", + category="video/partner/Kling", description="Kling Text to Video Node", inputs=[ IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), @@ -849,7 +849,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProTextToVideoNode", display_name="Kling 3.0 Omni Text to Video", - category="api node/video/Kling", + category="video/partner/Kling", description="Use text prompts to generate videos with the latest Kling model.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), @@ -998,7 +998,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProFirstLastFrameNode", display_name="Kling 3.0 Omni First-Last-Frame to Video", - category="api node/video/Kling", + category="video/partner/Kling", description="Use a start frame, an optional end frame, or reference images with the latest Kling model.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), @@ -1205,7 +1205,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProImageToVideoNode", display_name="Kling 3.0 Omni Image to Video", - category="api node/video/Kling", + category="video/partner/Kling", description="Use up to 7 reference images to generate a video with the latest Kling model.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), @@ -1374,7 +1374,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProVideoToVideoNode", display_name="Kling 3.0 Omni Video to Video", - category="api node/video/Kling", + category="video/partner/Kling", description="Use a video and up to 4 reference images to generate a video with the latest Kling model.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), @@ -1485,7 +1485,7 @@ class OmniProEditVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProEditVideoNode", display_name="Kling 3.0 Omni Edit Video", - category="api node/video/Kling", + category="video/partner/Kling", essentials_category="Video Generation", description="Edit an existing video with the latest model from Kling.", inputs=[ @@ -1593,7 +1593,7 @@ class OmniProImageNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProImageNode", display_name="Kling 3.0 Omni Image", - category="api node/image/Kling", + category="image/partner/Kling", description="Create or edit images with the latest model from Kling.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-image-o1"]), @@ -1721,7 +1721,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode): return IO.Schema( node_id="KlingCameraControlT2VNode", display_name="Kling Text to Video (Camera Control)", - category="api node/video/Kling", + category="video/partner/Kling", description="Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text.", inputs=[ IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), @@ -1783,7 +1783,7 @@ class KlingImage2VideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingImage2VideoNode", display_name="Kling Image(First Frame) to Video", - category="api node/video/Kling", + category="video/partner/Kling", inputs=[ IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."), IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), @@ -1882,7 +1882,7 @@ class KlingCameraControlI2VNode(IO.ComfyNode): return IO.Schema( node_id="KlingCameraControlI2VNode", display_name="Kling Image to Video (Camera Control)", - category="api node/video/Kling", + category="video/partner/Kling", description="Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image.", inputs=[ IO.Image.Input( @@ -1953,7 +1953,7 @@ class KlingStartEndFrameNode(IO.ComfyNode): return IO.Schema( node_id="KlingStartEndFrameNode", display_name="Kling Start-End Frame to Video", - category="api node/video/Kling", + category="video/partner/Kling", description="Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last.", inputs=[ IO.Image.Input( @@ -2047,7 +2047,7 @@ class KlingVideoExtendNode(IO.ComfyNode): return IO.Schema( node_id="KlingVideoExtendNode", display_name="Kling Video Extend", - category="api node/video/Kling", + category="video/partner/Kling", description="Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes.", inputs=[ IO.String.Input( @@ -2128,7 +2128,7 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode): return IO.Schema( node_id="KlingDualCharacterVideoEffectNode", display_name="Kling Dual Character Video Effects", - category="api node/video/Kling", + category="video/partner/Kling", description="Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite.", inputs=[ IO.Image.Input("image_left", tooltip="Left side image"), @@ -2218,7 +2218,7 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode): return IO.Schema( node_id="KlingSingleImageVideoEffectNode", display_name="Kling Video Effects", - category="api node/video/Kling", + category="video/partner/Kling", description="Achieve different special effects when generating a video based on the effect_scene.", inputs=[ IO.Image.Input( @@ -2291,7 +2291,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingLipSyncAudioToVideoNode", display_name="Kling Lip Sync Video with Audio", - category="api node/video/Kling", + category="video/partner/Kling", essentials_category="Video Generation", description="Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.", inputs=[ @@ -2343,7 +2343,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingLipSyncTextToVideoNode", display_name="Kling Lip Sync Video with Text", - category="api node/video/Kling", + category="video/partner/Kling", description="Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.", inputs=[ IO.Video.Input("video"), @@ -2411,7 +2411,7 @@ class KlingVirtualTryOnNode(IO.ComfyNode): return IO.Schema( node_id="KlingVirtualTryOnNode", display_name="Kling Virtual Try On", - category="api node/image/Kling", + category="image/partner/Kling", description="Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background.", inputs=[ IO.Image.Input("human_image"), @@ -2478,7 +2478,7 @@ class KlingImageGenerationNode(IO.ComfyNode): return IO.Schema( node_id="KlingImageGenerationNode", display_name="Kling 3.0 Image", - category="api node/image/Kling", + category="image/partner/Kling", description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.", inputs=[ IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), @@ -2615,7 +2615,7 @@ class TextToVideoWithAudio(IO.ComfyNode): return IO.Schema( node_id="KlingTextToVideoWithAudio", display_name="Kling 2.6 Text to Video with Audio", - category="api node/video/Kling", + category="video/partner/Kling", inputs=[ IO.Combo.Input("model_name", options=["kling-v2-6"]), IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."), @@ -2683,7 +2683,7 @@ class ImageToVideoWithAudio(IO.ComfyNode): return IO.Schema( node_id="KlingImageToVideoWithAudio", display_name="Kling 2.6 Image(First Frame) to Video with Audio", - category="api node/video/Kling", + category="video/partner/Kling", inputs=[ IO.Combo.Input("model_name", options=["kling-v2-6"]), IO.Image.Input("start_frame"), @@ -2753,7 +2753,7 @@ class MotionControl(IO.ComfyNode): return IO.Schema( node_id="KlingMotionControl", display_name="Kling Motion Control", - category="api node/video/Kling", + category="video/partner/Kling", inputs=[ IO.String.Input("prompt", multiline=True), IO.Image.Input("reference_image"), @@ -2854,7 +2854,7 @@ class KlingVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingVideoNode", display_name="Kling 3.0 Video", - category="api node/video/Kling", + category="video/partner/Kling", description="Generate videos with Kling V3. " "Supports text-to-video and image-to-video with optional storyboard multi-prompt and audio generation.", inputs=[ @@ -3077,7 +3077,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="KlingFirstLastFrameNode", display_name="Kling 3.0 First-Last-Frame to Video", - category="api node/video/Kling", + category="video/partner/Kling", description="Generate videos with Kling V3 using first and last frames.", inputs=[ IO.String.Input("prompt", multiline=True, default=""), @@ -3202,7 +3202,7 @@ class KlingAvatarNode(IO.ComfyNode): return IO.Schema( node_id="KlingAvatarNode", display_name="Kling Avatar 2.0", - category="api node/video/Kling", + category="video/partner/Kling", description="Generate broadcast-style digital human videos from a single photo and an audio file.", inputs=[ IO.Image.Input( diff --git a/comfy_api_nodes/nodes_krea.py b/comfy_api_nodes/nodes_krea.py new file mode 100644 index 000000000..be04a272b --- /dev/null +++ b/comfy_api_nodes/nodes_krea.py @@ -0,0 +1,290 @@ +"""Krea image-generation nodes.""" + +import re + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.krea import ( + KreaAssetResponse, + KreaGenerateImageRequest, + KreaImageStyleReference, + KreaJob, + KreaMoodboard, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_image_tensor, + poll_op, + sync_op, + tensor_to_bytesio, + validate_string, +) + + +class KreaIO: + STYLE_REF = "KREA_STYLE_REF" + + +async def _upload_image_to_krea_assets(cls: type[IO.ComfyNode], image: Input.Image) -> str: + """Upload an image to Krea's /assets endpoint and return the Krea-hosted image URL.""" + img_io = tensor_to_bytesio(image, total_pixels=2048 * 2048, mime_type="image/png") + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/krea/assets", method="POST"), + response_model=KreaAssetResponse, + files=[("file", (img_io.name, img_io, "image/png"))], + content_type="multipart/form-data", + max_retries=1, + wait_label="Uploading reference", + ) + return response.image_url + + +_MODEL_MEDIUM = "Krea 2 Medium" +_MODEL_LARGE = "Krea 2 Large" +_MODEL_ENDPOINTS: dict[str, str] = { + _MODEL_MEDIUM: "/proxy/krea/generate/image/krea/krea-2/medium", + _MODEL_LARGE: "/proxy/krea/generate/image/krea/krea-2/large", +} + +_ASPECT_RATIOS = ["1:1", "4:3", "3:2", "16:9", "2.35:1", "4:5", "2:3", "9:16"] +_RESOLUTIONS = ["1K"] +_CREATIVITY_LEVELS = ["raw", "low", "medium", "high"] +_KREA_QUEUED_STATUSES = ["backlogged", "queued", "scheduled"] + +_UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$") + + +def _krea_model_inputs() -> list: + """Nested inputs shared by both Krea 2 Medium and Large under the DynamicCombo.""" + return [ + IO.Combo.Input( + "aspect_ratio", + options=_ASPECT_RATIOS, + tooltip="Output aspect ratio.", + ), + IO.Combo.Input( + "resolution", + options=_RESOLUTIONS, + tooltip="Resolution scale.", + ), + IO.Combo.Input( + "creativity", + options=_CREATIVITY_LEVELS, + default="medium", + tooltip="Prompt interpretation strength: raw stays closest to the prompt; high is most creative.", + ), + IO.String.Input( + "moodboard_id", + default="", + tooltip="Optional Krea moodboard UUID (e.g. from the Krea website). " + "Leave empty to disable. Only one moodboard is supported per request.", + optional=True, + ), + IO.Float.Input( + "moodboard_strength", + default=0.35, + min=-0.5, + max=1.5, + step=0.05, + tooltip="Moodboard influence; ignored when moodboard_id is empty.", + optional=True, + ), + IO.Custom(KreaIO.STYLE_REF).Input( + "style_reference", + optional=True, + tooltip="Optional chain of style references (max 10) from Krea 2 Style Reference nodes.", + ), + ] + + +class Krea2ImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="Krea2ImageNode", + display_name="Krea 2 Image", + category="image/partner/Krea", + description=( + "Generate images via Krea 2 — pick Medium (expressive illustrations) or " + "Large (expressive photorealism). Supports an optional moodboard and up " + "to 10 chained image style references." + ), + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for the image.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option(_MODEL_MEDIUM, _krea_model_inputs()), + IO.DynamicCombo.Option(_MODEL_LARGE, _krea_model_inputs()), + ], + tooltip="Krea 2 Medium is best for expressive illustrations; " + "Krea 2 Large is best for expressive photorealism.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Random seed for reproducibility.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model", "model.moodboard_id"], + inputs=["model.style_reference"], + ), + expr=""" + ( + $isLarge := widgets.model = "krea 2 large"; + $hasMoodboard := $length($lookup(widgets, "model.moodboard_id")) > 0; + $hasStyle := $lookup(inputs, "model.style_reference").connected; + $usd := $hasMoodboard + ? ($isLarge ? 0.07 : 0.04) + : ($hasStyle + ? ($isLarge ? 0.065 : 0.035) + : ($isLarge ? 0.06 : 0.03)); + {"type":"usd","usd": $usd} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=False, min_length=1) + + model_choice = model["model"] + endpoint_path = _MODEL_ENDPOINTS.get(model_choice) + if endpoint_path is None: + raise ValueError(f"Unknown Krea 2 model: {model_choice!r}") + + moodboards: list[KreaMoodboard] | None = None + mb_id = (model.get("moodboard_id") or "").strip() + if mb_id: + if not _UUID_RE.match(mb_id): + raise ValueError(f"moodboard_id must be a UUID (received {mb_id!r}); copy it from the Krea website.") + mb_strength = model.get("moodboard_strength") + moodboards = [KreaMoodboard(id=mb_id, strength=0.35 if mb_strength is None else float(mb_strength))] + + style_reference = model.get("style_reference") + image_style_references: list[KreaImageStyleReference] | None = None + if style_reference: + if len(style_reference) > 10: + raise ValueError(f"Krea 2 accepts at most 10 image_style_references; received {len(style_reference)}.") + image_style_references = [ + KreaImageStyleReference(url=ref["url"], strength=float(ref["strength"])) for ref in style_reference + ] + initial = await sync_op( + cls, + ApiEndpoint(path=endpoint_path, method="POST"), + response_model=KreaJob, + data=KreaGenerateImageRequest( + prompt=prompt, + aspect_ratio=model["aspect_ratio"], + resolution=model["resolution"], + seed=seed, + creativity=model["creativity"], + moodboards=moodboards, + image_style_references=image_style_references, + ), + ) + job = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/krea/jobs/{initial.job_id}", method="GET"), + response_model=KreaJob, + status_extractor=lambda r: r.status, + queued_statuses=_KREA_QUEUED_STATUSES, + ) + if not job.result or not job.result.urls: + raise RuntimeError(f"Krea 2 job {job.job_id} completed without any image URLs.") + image = await download_url_to_image_tensor(job.result.urls[0]) + return IO.NodeOutput(image) + + +class Krea2StyleReferenceNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="Krea2StyleReferenceNode", + display_name="Krea 2 Style Reference", + category="image/partner/Krea", + description=( + "Add an image style reference to a Krea 2 generation. Chain multiple Krea 2 " + "Style Reference nodes (max 10) and feed the final `style_reference` output " + "into Krea 2 Image. Each image is uploaded to ComfyAPI storage and passed as URL." + ), + inputs=[ + IO.Image.Input( + "image", + tooltip="Reference image whose style influences the generation.", + ), + IO.Float.Input( + "strength", + default=1.0, + min=-2.0, + max=2.0, + step=0.05, + tooltip="Reference strength; negative values invert the style influence.", + ), + IO.Custom(KreaIO.STYLE_REF).Input( + "style_reference", + optional=True, + tooltip="Optional incoming chain of style references; this node appends one more.", + ), + ], + outputs=[IO.Custom(KreaIO.STYLE_REF).Output(display_name="style_reference")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + strength: float, + style_reference: list[dict] | None = None, + ) -> IO.NodeOutput: + chain: list[dict] = list(style_reference) if style_reference else [] + if len(chain) >= 10: + raise ValueError("Krea 2 accepts at most 10 image_style_references in one generation.") + url = await _upload_image_to_krea_assets(cls, image) + chain.append({"url": url, "strength": float(strength)}) + return IO.NodeOutput(chain) + + +class KreaExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + Krea2ImageNode, + Krea2StyleReferenceNode, + ] + + +async def comfy_entrypoint() -> KreaExtension: + return KreaExtension() diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py index 0a219af96..01791d354 100644 --- a/comfy_api_nodes/nodes_ltxv.py +++ b/comfy_api_nodes/nodes_ltxv.py @@ -50,7 +50,7 @@ class TextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="LtxvApiTextToVideo", display_name="LTXV Text To Video", - category="api node/video/LTXV", + category="video/partner/LTXV", description="Professional-quality videos with customizable duration and resolution.", inputs=[ IO.Combo.Input("model", options=list(MODELS_MAP.keys())), @@ -127,7 +127,7 @@ class ImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="LtxvApiImageToVideo", display_name="LTXV Image To Video", - category="api node/video/LTXV", + category="video/partner/LTXV", description="Professional-quality videos with customizable duration and resolution based on start image.", inputs=[ IO.Image.Input("image", tooltip="First frame to be used for the video."), diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index d92a7c382..08ae9904c 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -46,7 +46,7 @@ class LumaReferenceNode(IO.ComfyNode): return IO.Schema( node_id="LumaReferenceNode", display_name="Luma Reference", - category="api node/image/Luma", + category="image/partner/Luma", description="Holds an image and weight for use with Luma Generate Image node.", inputs=[ IO.Image.Input( @@ -85,7 +85,7 @@ class LumaConceptsNode(IO.ComfyNode): return IO.Schema( node_id="LumaConceptsNode", display_name="Luma Concepts", - category="api node/video/Luma", + category="video/partner/Luma", description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.", inputs=[ IO.Combo.Input( @@ -134,7 +134,7 @@ class LumaImageGenerationNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageNode", display_name="Luma Text to Image", - category="api node/image/Luma", + category="image/partner/Luma", description="Generates images synchronously based on prompt and aspect ratio.", inputs=[ IO.String.Input( @@ -278,7 +278,7 @@ class LumaImageModifyNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageModifyNode", display_name="Luma Image to Image", - category="api node/image/Luma", + category="image/partner/Luma", description="Modifies images synchronously based on prompt and aspect ratio.", inputs=[ IO.Image.Input( @@ -371,7 +371,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): return IO.Schema( node_id="LumaVideoNode", display_name="Luma Text to Video", - category="api node/video/Luma", + category="video/partner/Luma", description="Generates videos synchronously based on prompt and output_size.", inputs=[ IO.String.Input( @@ -472,7 +472,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageToVideoNode", display_name="Luma Image to Video", - category="api node/video/Luma", + category="video/partner/Luma", description="Generates videos synchronously based on prompt, input images, and output_size.", inputs=[ IO.String.Input( @@ -724,7 +724,7 @@ class LumaImageNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageNode2", display_name="Luma UNI-1 Image", - category="api node/image/Luma", + category="image/partner/Luma", description="Generate images from text using the Luma UNI-1 model.", inputs=[ IO.String.Input( @@ -853,7 +853,7 @@ class LumaImageEditNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageEditNode2", display_name="Luma UNI-1 Image Edit", - category="api node/image/Luma", + category="image/partner/Luma", description="Edit an existing image with a text prompt using the Luma UNI-1 model.", inputs=[ IO.Image.Input( diff --git a/comfy_api_nodes/nodes_magnific.py b/comfy_api_nodes/nodes_magnific.py index 38b881fea..a6aeb194a 100644 --- a/comfy_api_nodes/nodes_magnific.py +++ b/comfy_api_nodes/nodes_magnific.py @@ -61,7 +61,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode): return IO.Schema( node_id="MagnificImageUpscalerCreativeNode", display_name="Magnific Image Upscale (Creative)", - category="api node/image/Magnific", + category="image/partner/Magnific", description="Prompt‑guided enhancement, stylization, and 2x/4x/8x/16x upscaling. " "Maximum output: 25.3 megapixels.", inputs=[ @@ -240,7 +240,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode): return IO.Schema( node_id="MagnificImageUpscalerPreciseV2Node", display_name="Magnific Image Upscale (Precise V2)", - category="api node/image/Magnific", + category="image/partner/Magnific", description="High-fidelity upscaling with fine control over sharpness, grain, and detail. " "Maximum output: 10060×10060 pixels.", inputs=[ @@ -400,7 +400,7 @@ class MagnificImageStyleTransferNode(IO.ComfyNode): return IO.Schema( node_id="MagnificImageStyleTransferNode", display_name="Magnific Image Style Transfer", - category="api node/image/Magnific", + category="image/partner/Magnific", description="Transfer the style from a reference image to your input image.", inputs=[ IO.Image.Input("image", tooltip="The image to apply style transfer to."), @@ -549,7 +549,7 @@ class MagnificImageRelightNode(IO.ComfyNode): return IO.Schema( node_id="MagnificImageRelightNode", display_name="Magnific Image Relight", - category="api node/image/Magnific", + category="image/partner/Magnific", description="Relight an image with lighting adjustments and optional reference-based light transfer.", inputs=[ IO.Image.Input("image", tooltip="The image to relight."), @@ -789,7 +789,7 @@ class MagnificImageSkinEnhancerNode(IO.ComfyNode): return IO.Schema( node_id="MagnificImageSkinEnhancerNode", display_name="Magnific Image Skin Enhancer", - category="api node/image/Magnific", + category="image/partner/Magnific", description="Skin enhancement for portraits with multiple processing modes.", inputs=[ IO.Image.Input("image", tooltip="The portrait image to enhance."), diff --git a/comfy_api_nodes/nodes_meshy.py b/comfy_api_nodes/nodes_meshy.py index 3cf577f4a..4fb670404 100644 --- a/comfy_api_nodes/nodes_meshy.py +++ b/comfy_api_nodes/nodes_meshy.py @@ -33,7 +33,7 @@ class MeshyTextToModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyTextToModelNode", display_name="Meshy: Text to Model", - category="api node/3d/Meshy", + category="3d/partner/Meshy", inputs=[ IO.Combo.Input("model", options=["latest"]), IO.String.Input("prompt", multiline=True, default=""), @@ -145,7 +145,7 @@ class MeshyRefineNode(IO.ComfyNode): return IO.Schema( node_id="MeshyRefineNode", display_name="Meshy: Refine Draft Model", - category="api node/3d/Meshy", + category="3d/partner/Meshy", description="Refine a previously created draft model.", inputs=[ IO.Combo.Input("model", options=["latest"]), @@ -240,7 +240,7 @@ class MeshyImageToModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyImageToModelNode", display_name="Meshy: Image to Model", - category="api node/3d/Meshy", + category="3d/partner/Meshy", inputs=[ IO.Combo.Input("model", options=["latest"]), IO.Image.Input("image"), @@ -405,7 +405,7 @@ class MeshyMultiImageToModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyMultiImageToModelNode", display_name="Meshy: Multi-Image to Model", - category="api node/3d/Meshy", + category="3d/partner/Meshy", inputs=[ IO.Combo.Input("model", options=["latest"]), IO.Autogrow.Input( @@ -575,7 +575,7 @@ class MeshyRigModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyRigModelNode", display_name="Meshy: Rig Model", - category="api node/3d/Meshy", + category="3d/partner/Meshy", description="Provides a rigged character in standard formats. " "Auto-rigging is currently not suitable for untextured meshes, non-humanoid assets, " "or humanoid assets with unclear limb and body structure.", @@ -656,7 +656,7 @@ class MeshyAnimateModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyAnimateModelNode", display_name="Meshy: Animate Model", - category="api node/3d/Meshy", + category="3d/partner/Meshy", description="Apply a specific animation action to a previously rigged character.", inputs=[ IO.Custom("MESHY_RIGGED_TASK_ID").Input("rig_task_id"), @@ -722,7 +722,7 @@ class MeshyTextureNode(IO.ComfyNode): return IO.Schema( node_id="MeshyTextureNode", display_name="Meshy: Texture Model", - category="api node/3d/Meshy", + category="3d/partner/Meshy", inputs=[ IO.Combo.Input("model", options=["latest"]), IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"), diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index b5d0b461f..338584148 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -101,7 +101,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="MinimaxTextToVideoNode", display_name="MiniMax Text to Video", - category="api node/video/MiniMax", + category="video/partner/MiniMax", description="Generates videos synchronously based on a prompt, and optional parameters.", inputs=[ IO.String.Input( @@ -163,7 +163,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="MinimaxImageToVideoNode", display_name="MiniMax Image to Video", - category="api node/video/MiniMax", + category="video/partner/MiniMax", description="Generates videos synchronously based on an image and prompt, and optional parameters.", inputs=[ IO.Image.Input( @@ -230,7 +230,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode): return IO.Schema( node_id="MinimaxSubjectToVideoNode", display_name="MiniMax Subject to Video", - category="api node/video/MiniMax", + category="video/partner/MiniMax", description="Generates videos synchronously based on an image and prompt, and optional parameters.", inputs=[ IO.Image.Input( @@ -294,7 +294,7 @@ class MinimaxHailuoVideoNode(IO.ComfyNode): return IO.Schema( node_id="MinimaxHailuoVideoNode", display_name="MiniMax Hailuo Video", - category="api node/video/MiniMax", + category="video/partner/MiniMax", description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index a5a188634..48c739dfe 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -99,7 +99,7 @@ class OpenAIDalle2(IO.ComfyNode): return IO.Schema( node_id="OpenAIDalle2", display_name="OpenAI DALL·E 2", - category="api node/image/OpenAI", + category="image/partner/OpenAI", description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.", inputs=[ IO.String.Input( @@ -249,7 +249,7 @@ class OpenAIDalle3(IO.ComfyNode): return IO.Schema( node_id="OpenAIDalle3", display_name="OpenAI DALL·E 3", - category="api node/image/OpenAI", + category="image/partner/OpenAI", description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.", inputs=[ IO.String.Input( @@ -371,7 +371,7 @@ class OpenAIGPTImage1(IO.ComfyNode): return IO.Schema( node_id="OpenAIGPTImage1", display_name="OpenAI GPT Image 2", - category="api node/image/OpenAI", + category="image/partner/OpenAI", description="Generates images synchronously via OpenAI's GPT Image endpoint.", is_deprecated=True, inputs=[ @@ -695,7 +695,7 @@ class OpenAIGPTImageNodeV2(IO.ComfyNode): return IO.Schema( node_id="OpenAIGPTImageNodeV2", display_name="OpenAI GPT Image 2", - category="api node/image/OpenAI", + category="image/partner/OpenAI", description="Generates images via OpenAI's GPT Image endpoint.", inputs=[ IO.String.Input( @@ -962,7 +962,7 @@ class OpenAIChatNode(IO.ComfyNode): return IO.Schema( node_id="OpenAIChatNode", display_name="OpenAI ChatGPT", - category="api node/text/OpenAI", + category="text/partner/OpenAI", essentials_category="Text Generation", description="Generate text responses from an OpenAI model.", inputs=[ @@ -1201,7 +1201,7 @@ class OpenAIInputFiles(IO.ComfyNode): return IO.Schema( node_id="OpenAIInputFiles", display_name="OpenAI ChatGPT Input Files", - category="api node/text/OpenAI", + category="text/partner/OpenAI", description="Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes.", inputs=[ IO.Combo.Input( @@ -1248,7 +1248,7 @@ class OpenAIChatConfig(IO.ComfyNode): return IO.Schema( node_id="OpenAIChatConfig", display_name="OpenAI ChatGPT Advanced Options", - category="api node/text/OpenAI", + category="text/partner/OpenAI", description="Allows specifying advanced configuration options for the OpenAI Chat Nodes.", inputs=[ IO.Combo.Input( diff --git a/comfy_api_nodes/nodes_openrouter.py b/comfy_api_nodes/nodes_openrouter.py index 031301870..d2ebbef0d 100644 --- a/comfy_api_nodes/nodes_openrouter.py +++ b/comfy_api_nodes/nodes_openrouter.py @@ -265,7 +265,7 @@ class OpenRouterLLMNode(IO.ComfyNode): return IO.Schema( node_id="OpenRouterLLMNode", display_name="OpenRouter LLM", - category="api node/text/OpenRouter", + category="text/partner/OpenRouter", essentials_category="Text Generation", description=( "Generate text responses through OpenRouter. Routes to a curated set of popular " diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index e17a24ae7..3861cfedd 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -53,7 +53,7 @@ class PixverseTemplateNode(IO.ComfyNode): return IO.Schema( node_id="PixverseTemplateNode", display_name="PixVerse Template", - category="api node/video/PixVerse", + category="video/partner/PixVerse", inputs=[ IO.Combo.Input("template", options=list(pixverse_templates.keys())), ], @@ -74,7 +74,7 @@ class PixverseTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="PixverseTextToVideoNode", display_name="PixVerse Text to Video", - category="api node/video/PixVerse", + category="video/partner/PixVerse", description="Generates videos based on prompt and output_size.", inputs=[ IO.String.Input( @@ -192,7 +192,7 @@ class PixverseImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="PixverseImageToVideoNode", display_name="PixVerse Image to Video", - category="api node/video/PixVerse", + category="video/partner/PixVerse", description="Generates videos based on prompt and output_size.", inputs=[ IO.Image.Input("image"), @@ -310,7 +310,7 @@ class PixverseTransitionVideoNode(IO.ComfyNode): return IO.Schema( node_id="PixverseTransitionVideoNode", display_name="PixVerse Transition Video", - category="api node/video/PixVerse", + category="video/partner/PixVerse", description="Generates videos based on prompt and output_size.", inputs=[ IO.Image.Input("first_frame"), diff --git a/comfy_api_nodes/nodes_quiver.py b/comfy_api_nodes/nodes_quiver.py index 3269c0afe..ad045a7ef 100644 --- a/comfy_api_nodes/nodes_quiver.py +++ b/comfy_api_nodes/nodes_quiver.py @@ -62,7 +62,7 @@ class QuiverTextToSVGNode(IO.ComfyNode): return IO.Schema( node_id="QuiverTextToSVGNode", display_name="Quiver Text to SVG", - category="api node/image/Quiver", + category="image/partner/Quiver", description="Generate an SVG from a text prompt using Quiver AI.", inputs=[ IO.String.Input( @@ -177,7 +177,7 @@ class QuiverImageToSVGNode(IO.ComfyNode): return IO.Schema( node_id="QuiverImageToSVGNode", display_name="Quiver Image to SVG", - category="api node/image/Quiver", + category="image/partner/Quiver", description="Vectorize a raster image into SVG using Quiver AI.", inputs=[ IO.Image.Input( diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index c60cfbc4a..07387821d 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -178,7 +178,7 @@ class RecraftColorRGBNode(IO.ComfyNode): return IO.Schema( node_id="RecraftColorRGB", display_name="Recraft Color RGB", - category="api node/image/Recraft", + category="image/partner/Recraft", description="Create Recraft Color by choosing specific RGB values.", inputs=[ IO.Int.Input("r", default=0, min=0, max=255, tooltip="Red value of color."), @@ -204,7 +204,7 @@ class RecraftControlsNode(IO.ComfyNode): return IO.Schema( node_id="RecraftControls", display_name="Recraft Controls", - category="api node/image/Recraft", + category="image/partner/Recraft", description="Create Recraft Controls for customizing Recraft generation.", inputs=[ IO.Custom(RecraftIO.COLOR).Input("colors", optional=True), @@ -228,7 +228,7 @@ class RecraftStyleV3RealisticImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftStyleV3RealisticImage", display_name="Recraft Style - Realistic Image", - category="api node/image/Recraft", + category="image/partner/Recraft", description="Select realistic_image style and optional substyle.", inputs=[ IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), @@ -253,7 +253,7 @@ class RecraftStyleV3DigitalIllustrationNode(RecraftStyleV3RealisticImageNode): return IO.Schema( node_id="RecraftStyleV3DigitalIllustration", display_name="Recraft Style - Digital Illustration", - category="api node/image/Recraft", + category="image/partner/Recraft", description="Select realistic_image style and optional substyle.", inputs=[ IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), @@ -272,7 +272,7 @@ class RecraftStyleV3VectorIllustrationNode(RecraftStyleV3RealisticImageNode): return IO.Schema( node_id="RecraftStyleV3VectorIllustrationNode", display_name="Recraft Style - Realistic Image", - category="api node/image/Recraft", + category="image/partner/Recraft", description="Select realistic_image style and optional substyle.", inputs=[ IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), @@ -291,7 +291,7 @@ class RecraftStyleV3LogoRasterNode(RecraftStyleV3RealisticImageNode): return IO.Schema( node_id="RecraftStyleV3LogoRaster", display_name="Recraft Style - Logo Raster", - category="api node/image/Recraft", + category="image/partner/Recraft", description="Select realistic_image style and optional substyle.", inputs=[ IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE, include_none=False)), @@ -308,7 +308,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode): return IO.Schema( node_id="RecraftStyleV3InfiniteStyleLibrary", display_name="Recraft Style - Infinite Style Library", - category="api node/image/Recraft", + category="image/partner/Recraft", description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.", inputs=[ IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."), @@ -331,7 +331,7 @@ class RecraftCreateStyleNode(IO.ComfyNode): return IO.Schema( node_id="RecraftCreateStyleNode", display_name="Recraft Create Style", - category="api node/image/Recraft", + category="image/partner/Recraft", description="Create a custom style from reference images. " "Upload 1-5 images to use as style references. " "Total size of all images is limited to 5 MB.", @@ -400,7 +400,7 @@ class RecraftTextToImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftTextToImageNode", display_name="Recraft Text to Image", - category="api node/image/Recraft", + category="image/partner/Recraft", description="Generates images synchronously based on prompt and resolution.", inputs=[ IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), @@ -512,7 +512,7 @@ class RecraftImageToImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftImageToImageNode", display_name="Recraft Image to Image", - category="api node/image/Recraft", + category="image/partner/Recraft", description="Modify image based on prompt and strength.", inputs=[ IO.Image.Input("image"), @@ -630,7 +630,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode): return IO.Schema( node_id="RecraftImageInpaintingNode", display_name="Recraft Image Inpainting", - category="api node/image/Recraft", + category="image/partner/Recraft", description="Modify image based on prompt and mask.", inputs=[ IO.Image.Input("image"), @@ -732,7 +732,7 @@ class RecraftTextToVectorNode(IO.ComfyNode): return IO.Schema( node_id="RecraftTextToVectorNode", display_name="Recraft Text to Vector", - category="api node/image/Recraft", + category="image/partner/Recraft", description="Generates SVG synchronously based on prompt and resolution.", inputs=[ IO.String.Input("prompt", default="", tooltip="Prompt for the image generation.", multiline=True), @@ -832,7 +832,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftVectorizeImageNode", display_name="Recraft Vectorize Image", - category="api node/image/Recraft", + category="image/partner/Recraft", essentials_category="Image Tools", description="Generates SVG synchronously from an input image.", inputs=[ @@ -876,7 +876,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode): return IO.Schema( node_id="RecraftReplaceBackgroundNode", display_name="Recraft Replace Background", - category="api node/image/Recraft", + category="image/partner/Recraft", description="Replace background on image, based on provided prompt.", inputs=[ IO.Image.Input("image"), @@ -963,7 +963,7 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode): return IO.Schema( node_id="RecraftRemoveBackgroundNode", display_name="Recraft Remove Background", - category="api node/image/Recraft", + category="image/partner/Recraft", essentials_category="Image Tools", description="Remove background from image, and return processed image and mask.", inputs=[ @@ -1012,7 +1012,7 @@ class RecraftCrispUpscaleNode(IO.ComfyNode): return IO.Schema( node_id="RecraftCrispUpscaleNode", display_name="Recraft Crisp Upscale Image", - category="api node/image/Recraft", + category="image/partner/Recraft", description="Upscale image synchronously.\n" "Enhances a given raster image using ‘crisp upscale’ tool, " "increasing image resolution, making the image sharper and cleaner.", @@ -1058,7 +1058,7 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode): return IO.Schema( node_id="RecraftCreativeUpscaleNode", display_name="Recraft Creative Upscale Image", - category="api node/image/Recraft", + category="image/partner/Recraft", description="Upscale image synchronously.\n" "Enhances a given raster image using ‘creative upscale’ tool, " "boosting resolution with a focus on refining small details and faces.", @@ -1086,7 +1086,7 @@ class RecraftV4TextToImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftV4TextToImageNode", display_name="Recraft V4 Text to Image", - category="api node/image/Recraft", + category="image/partner/Recraft", description="Generates images using Recraft V4 or V4 Pro models.", inputs=[ IO.String.Input( @@ -1210,7 +1210,7 @@ class RecraftV4TextToVectorNode(IO.ComfyNode): return IO.Schema( node_id="RecraftV4TextToVectorNode", display_name="Recraft V4 Text to Vector", - category="api node/image/Recraft", + category="image/partner/Recraft", description="Generates SVG using Recraft V4 or V4 Pro models.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_reve.py b/comfy_api_nodes/nodes_reve.py index a87395394..2b15eadd7 100644 --- a/comfy_api_nodes/nodes_reve.py +++ b/comfy_api_nodes/nodes_reve.py @@ -109,7 +109,7 @@ class ReveImageCreateNode(IO.ComfyNode): return IO.Schema( node_id="ReveImageCreateNode", display_name="Reve Image Create", - category="api node/image/Reve", + category="image/partner/Reve", description="Generate images from text descriptions using Reve.", inputs=[ IO.String.Input( @@ -200,7 +200,7 @@ class ReveImageEditNode(IO.ComfyNode): return IO.Schema( node_id="ReveImageEditNode", display_name="Reve Image Edit", - category="api node/image/Reve", + category="image/partner/Reve", description="Edit images using natural language instructions with Reve.", inputs=[ IO.Image.Input("image", tooltip="The image to edit."), @@ -300,7 +300,7 @@ class ReveImageRemixNode(IO.ComfyNode): return IO.Schema( node_id="ReveImageRemixNode", display_name="Reve Image Remix", - category="api node/image/Reve", + category="image/partner/Reve", description="Combine reference images with text prompts to create new images using Reve.", inputs=[ IO.Autogrow.Input( diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index 2df5a3e13..e14955661 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -230,7 +230,7 @@ class Rodin3D_Regular(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Regular", display_name="Rodin 3D Generate - Regular Generate", - category="api node/3d/Rodin", + category="3d/partner/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -289,7 +289,7 @@ class Rodin3D_Detail(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Detail", display_name="Rodin 3D Generate - Detail Generate", - category="api node/3d/Rodin", + category="3d/partner/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -348,7 +348,7 @@ class Rodin3D_Smooth(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Smooth", display_name="Rodin 3D Generate - Smooth Generate", - category="api node/3d/Rodin", + category="3d/partner/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -406,7 +406,7 @@ class Rodin3D_Sketch(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Sketch", display_name="Rodin 3D Generate - Sketch Generate", - category="api node/3d/Rodin", + category="3d/partner/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -468,7 +468,7 @@ class Rodin3D_Gen2(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Gen2", display_name="Rodin 3D Generate - Gen-2 Generate", - category="api node/3d/Rodin", + category="3d/partner/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -941,7 +941,7 @@ class Rodin3D_Gen25_Image(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Gen25_Image", display_name="Rodin 3D Gen-2.5 - Image to 3D", - category="api node/3d/Rodin", + category="3d/partner/Rodin", description=( "Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. " "Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost." @@ -1035,7 +1035,7 @@ class Rodin3D_Gen25_Text(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Gen25_Text", display_name="Rodin 3D Gen-2.5 - Text to 3D", - category="api node/3d/Rodin", + category="3d/partner/Rodin", description=( "Generate a 3D model from a text prompt via Rodin Gen-2.5. " "Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost." diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index 573170ba2..7357c733e 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -140,7 +140,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): return IO.Schema( node_id="RunwayImageToVideoNodeGen3a", display_name="Runway Image to Video (Gen3a Turbo)", - category="api node/video/Runway", + category="video/partner/Runway", description="Generate a video from a single starting frame using Gen3a Turbo model. " "Before diving in, review these best practices to ensure that " "your input selections will set your generation up for success: " @@ -234,7 +234,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): return IO.Schema( node_id="RunwayImageToVideoNodeGen4", display_name="Runway Image to Video (Gen4 Turbo)", - category="api node/video/Runway", + category="video/partner/Runway", description="Generate a video from a single starting frame using Gen4 Turbo model. " "Before diving in, review these best practices to ensure that " "your input selections will set your generation up for success: " @@ -329,7 +329,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="RunwayFirstLastFrameNode", display_name="Runway First-Last-Frame to Video", - category="api node/video/Runway", + category="video/partner/Runway", description="Upload first and last keyframes, draft a prompt, and generate a video. " "More complex transitions, such as cases where the Last frame is completely different " "from the First frame, may benefit from the longer 10s duration. " @@ -440,7 +440,7 @@ class RunwayTextToImageNode(IO.ComfyNode): return IO.Schema( node_id="RunwayTextToImageNode", display_name="Runway Text to Image", - category="api node/image/Runway", + category="image/partner/Runway", description="Generate an image from a text prompt using Runway's Gen 4 model. " "You can also include reference image to guide the generation.", inputs=[ diff --git a/comfy_api_nodes/nodes_sonilo.py b/comfy_api_nodes/nodes_sonilo.py index 5518f5902..bc31a0074 100644 --- a/comfy_api_nodes/nodes_sonilo.py +++ b/comfy_api_nodes/nodes_sonilo.py @@ -34,7 +34,7 @@ class SoniloVideoToMusic(IO.ComfyNode): return IO.Schema( node_id="SoniloVideoToMusic", display_name="Sonilo Video to Music", - category="api node/audio/Sonilo", + category="audio/partner/Sonilo", description="Generate music from video content using Sonilo's AI model. " "Analyzes the video and creates matching music.", inputs=[ @@ -99,7 +99,7 @@ class SoniloTextToMusic(IO.ComfyNode): return IO.Schema( node_id="SoniloTextToMusic", display_name="Sonilo Text to Music", - category="api node/audio/Sonilo", + category="audio/partner/Sonilo", description="Generate music from a text prompt using Sonilo's AI model. " "Leave duration at 0 to let the model infer it from the prompt.", inputs=[ diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index c1d485188..83cfca495 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -34,7 +34,7 @@ class OpenAIVideoSora2(IO.ComfyNode): return IO.Schema( node_id="OpenAIVideoSora2", display_name="OpenAI Sora - Video (DEPRECATED)", - category="api node/video/Sora", + category="video/partner/Sora", description=( "OpenAI video and audio generation.\n\n" "DEPRECATION NOTICE: OpenAI will stop serving the Sora v2 API in September 2026. " diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 906d8ff35..a1753d647 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -62,7 +62,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode): return IO.Schema( node_id="StabilityStableImageUltraNode", display_name="Stability AI Stable Image Ultra", - category="api node/image/Stability AI", + category="image/partner/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.String.Input( @@ -197,7 +197,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode): return IO.Schema( node_id="StabilityStableImageSD_3_5Node", display_name="Stability AI Stable Diffusion 3.5 Image", - category="api node/image/Stability AI", + category="image/partner/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.String.Input( @@ -354,7 +354,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode): return IO.Schema( node_id="StabilityUpscaleConservativeNode", display_name="Stability AI Upscale Conservative", - category="api node/image/Stability AI", + category="image/partner/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("image"), @@ -457,7 +457,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode): return IO.Schema( node_id="StabilityUpscaleCreativeNode", display_name="Stability AI Upscale Creative", - category="api node/image/Stability AI", + category="image/partner/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("image"), @@ -578,7 +578,7 @@ class StabilityUpscaleFastNode(IO.ComfyNode): return IO.Schema( node_id="StabilityUpscaleFastNode", display_name="Stability AI Upscale Fast", - category="api node/image/Stability AI", + category="image/partner/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("image"), @@ -630,7 +630,7 @@ class StabilityTextToAudio(IO.ComfyNode): return IO.Schema( node_id="StabilityTextToAudio", display_name="Stability AI Text To Audio", - category="api node/audio/Stability AI", + category="audio/partner/Stability AI", essentials_category="Audio", description=cleandoc(cls.__doc__ or ""), inputs=[ @@ -708,7 +708,7 @@ class StabilityAudioToAudio(IO.ComfyNode): return IO.Schema( node_id="StabilityAudioToAudio", display_name="Stability AI Audio To Audio", - category="api node/audio/Stability AI", + category="audio/partner/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Combo.Input( @@ -802,7 +802,7 @@ class StabilityAudioInpaint(IO.ComfyNode): return IO.Schema( node_id="StabilityAudioInpaint", display_name="Stability AI Audio Inpaint", - category="api node/audio/Stability AI", + category="audio/partner/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Combo.Input( diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index e79c16d3c..d0906ee44 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -52,7 +52,7 @@ class TopazImageEnhance(IO.ComfyNode): return IO.Schema( node_id="TopazImageEnhance", display_name="Topaz Image Enhance", - category="api node/image/Topaz", + category="image/partner/Topaz", description="Industry-standard upscaling and image enhancement.", inputs=[ IO.Combo.Input("model", options=["Reimagine"]), @@ -235,7 +235,7 @@ class TopazVideoEnhance(IO.ComfyNode): return IO.Schema( node_id="TopazVideoEnhance", display_name="Topaz Video Enhance (Legacy)", - category="api node/video/Topaz", + category="video/partner/Topaz", description="Breathe new life into video with powerful upscaling and recovery technology.", inputs=[ IO.Video.Input("video"), @@ -475,7 +475,7 @@ class TopazVideoEnhanceV2(IO.ComfyNode): return IO.Schema( node_id="TopazVideoEnhanceV2", display_name="Topaz Video Enhance", - category="api node/video/Topaz", + category="video/partner/Topaz", description="Breathe new life into video with powerful upscaling and recovery technology.", inputs=[ IO.Video.Input("video"), diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index d6501dee4..6ee674a18 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -80,7 +80,7 @@ class TripoTextToModelNode(IO.ComfyNode): return IO.Schema( node_id="TripoTextToModelNode", display_name="Tripo: Text to Model", - category="api node/3d/Tripo", + category="3d/partner/Tripo", inputs=[ IO.String.Input("prompt", multiline=True), IO.String.Input("negative_prompt", multiline=True, optional=True), @@ -195,7 +195,7 @@ class TripoImageToModelNode(IO.ComfyNode): return IO.Schema( node_id="TripoImageToModelNode", display_name="Tripo: Image to Model", - category="api node/3d/Tripo", + category="3d/partner/Tripo", inputs=[ IO.Image.Input("image"), IO.Combo.Input( @@ -323,7 +323,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode): return IO.Schema( node_id="TripoMultiviewToModelNode", display_name="Tripo: Multiview to Model", - category="api node/3d/Tripo", + category="3d/partner/Tripo", inputs=[ IO.Image.Input("image"), IO.Image.Input("image_left", optional=True), @@ -461,7 +461,7 @@ class TripoTextureNode(IO.ComfyNode): return IO.Schema( node_id="TripoTextureNode", display_name="Tripo: Texture model", - category="api node/3d/Tripo", + category="3d/partner/Tripo", inputs=[ IO.Custom("MODEL_TASK_ID").Input("model_task_id"), IO.Boolean.Input("texture", default=True, optional=True), @@ -528,7 +528,7 @@ class TripoRefineNode(IO.ComfyNode): return IO.Schema( node_id="TripoRefineNode", display_name="Tripo: Refine Draft model", - category="api node/3d/Tripo", + category="3d/partner/Tripo", description="Refine a draft model created by v1.4 Tripo models only.", inputs=[ IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"), @@ -568,7 +568,7 @@ class TripoRigNode(IO.ComfyNode): return IO.Schema( node_id="TripoRigNode", display_name="Tripo: Rig model", - category="api node/3d/Tripo", + category="3d/partner/Tripo", inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")], outputs=[ IO.String.Output(display_name="model_file"), # for backward compatibility only @@ -605,7 +605,7 @@ class TripoRetargetNode(IO.ComfyNode): return IO.Schema( node_id="TripoRetargetNode", display_name="Tripo: Retarget rigged model", - category="api node/3d/Tripo", + category="3d/partner/Tripo", inputs=[ IO.Custom("RIG_TASK_ID").Input("original_model_task_id"), IO.Combo.Input( @@ -670,7 +670,7 @@ class TripoConversionNode(IO.ComfyNode): return IO.Schema( node_id="TripoConversionNode", display_name="Tripo: Convert model", - category="api node/3d/Tripo", + category="3d/partner/Tripo", inputs=[ IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"), IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]), diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 2ff75d9b2..068862397 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -45,7 +45,7 @@ class VeoVideoGenerationNode(IO.ComfyNode): return IO.Schema( node_id="VeoVideoGenerationNode", display_name="Google Veo 2 Video Generation", - category="api node/video/Veo", + category="video/partner/Veo", description="Generates videos from text prompts using Google's Veo 2 API", inputs=[ IO.String.Input( @@ -256,7 +256,7 @@ class Veo3VideoGenerationNode(IO.ComfyNode): return IO.Schema( node_id="Veo3VideoGenerationNode", display_name="Google Veo 3 Video Generation", - category="api node/video/Veo", + category="video/partner/Veo", description="Generates videos from text prompts using Google's Veo 3 API", inputs=[ IO.String.Input( @@ -468,7 +468,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="Veo3FirstLastFrameNode", display_name="Google Veo 3 First-Last-Frame to Video", - category="api node/video/Veo", + category="video/partner/Veo", description="Generate video using prompt and first and last frames.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 8d90cefeb..16f6113de 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -71,7 +71,7 @@ class ViduTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduTextToVideoNode", display_name="Vidu Text To Video Generation", - category="api node/video/Vidu", + category="video/partner/Vidu", description="Generate video from a text prompt", inputs=[ IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), @@ -169,7 +169,7 @@ class ViduImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduImageToVideoNode", display_name="Vidu Image To Video Generation", - category="api node/video/Vidu", + category="video/partner/Vidu", description="Generate video from image and optional prompt", inputs=[ IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), @@ -273,7 +273,7 @@ class ViduReferenceVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduReferenceVideoNode", display_name="Vidu Reference To Video Generation", - category="api node/video/Vidu", + category="video/partner/Vidu", description="Generate video from multiple images and a prompt", inputs=[ IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), @@ -388,7 +388,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduStartEndToVideoNode", display_name="Vidu Start End To Video Generation", - category="api node/video/Vidu", + category="video/partner/Vidu", description="Generate a video from start and end frames and a prompt", inputs=[ IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), @@ -492,7 +492,7 @@ class Vidu2TextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu2TextToVideoNode", display_name="Vidu2 Text-to-Video Generation", - category="api node/video/Vidu", + category="video/partner/Vidu", description="Generate video from a text prompt", inputs=[ IO.Combo.Input("model", options=["viduq2"]), @@ -584,7 +584,7 @@ class Vidu2ImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu2ImageToVideoNode", display_name="Vidu2 Image-to-Video Generation", - category="api node/video/Vidu", + category="video/partner/Vidu", description="Generate a video from an image and an optional prompt.", inputs=[ IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]), @@ -714,7 +714,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu2ReferenceVideoNode", display_name="Vidu2 Reference-to-Video Generation", - category="api node/video/Vidu", + category="video/partner/Vidu", description="Generate a video from multiple reference images and a prompt.", inputs=[ IO.Combo.Input("model", options=["viduq2"]), @@ -849,7 +849,7 @@ class Vidu2StartEndToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu2StartEndToVideoNode", display_name="Vidu2 Start/End Frame-to-Video Generation", - category="api node/video/Vidu", + category="video/partner/Vidu", description="Generate a video from a start frame, an end frame, and a prompt.", inputs=[ IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]), @@ -969,7 +969,7 @@ class ViduExtendVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduExtendVideoNode", display_name="Vidu Video Extension", - category="api node/video/Vidu", + category="video/partner/Vidu", description="Extend an existing video by generating additional frames.", inputs=[ IO.DynamicCombo.Input( @@ -1138,7 +1138,7 @@ class ViduMultiFrameVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduMultiFrameVideoNode", display_name="Vidu Multi-Frame Video Generation", - category="api node/video/Vidu", + category="video/partner/Vidu", description="Generate a video with multiple keyframe transitions.", inputs=[ IO.Combo.Input("model", options=["viduq2-pro", "viduq2-turbo"]), @@ -1284,7 +1284,7 @@ class Vidu3TextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu3TextToVideoNode", display_name="Vidu Q3 Text-to-Video Generation", - category="api node/video/Vidu", + category="video/partner/Vidu", description="Generate video from a text prompt.", inputs=[ IO.DynamicCombo.Input( @@ -1429,7 +1429,7 @@ class Vidu3ImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu3ImageToVideoNode", display_name="Vidu Q3 Image-to-Video Generation", - category="api node/video/Vidu", + category="video/partner/Vidu", description="Generate a video from an image and an optional prompt.", inputs=[ IO.DynamicCombo.Input( @@ -1571,7 +1571,7 @@ class Vidu3StartEndToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu3StartEndToVideoNode", display_name="Vidu Q3 Start/End Frame-to-Video Generation", - category="api node/video/Vidu", + category="video/partner/Vidu", description="Generate a video from a start frame, an end frame, and a prompt.", inputs=[ IO.DynamicCombo.Input( diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 68061bb5c..a235dc387 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -61,7 +61,7 @@ class WanTextToImageApi(IO.ComfyNode): return IO.Schema( node_id="WanTextToImageApi", display_name="Wan Text to Image", - category="api node/image/Wan", + category="image/partner/Wan", description="Generates an image based on a text prompt.", inputs=[ IO.Combo.Input( @@ -184,7 +184,7 @@ class WanImageToImageApi(IO.ComfyNode): return IO.Schema( node_id="WanImageToImageApi", display_name="Wan Image to Image", - category="api node/image/Wan", + category="image/partner/Wan", description="Generates an image from one or two input images and a text prompt. " "The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).", inputs=[ @@ -312,7 +312,7 @@ class WanTextToVideoApi(IO.ComfyNode): return IO.Schema( node_id="WanTextToVideoApi", display_name="Wan Text to Video", - category="api node/video/Wan", + category="video/partner/Wan", description="Generates a video based on a text prompt.", inputs=[ IO.Combo.Input( @@ -495,7 +495,7 @@ class WanImageToVideoApi(IO.ComfyNode): return IO.Schema( node_id="WanImageToVideoApi", display_name="Wan Image to Video", - category="api node/video/Wan", + category="video/partner/Wan", description="Generates a video from the first frame and a text prompt.", inputs=[ IO.Combo.Input( @@ -674,7 +674,7 @@ class WanReferenceVideoApi(IO.ComfyNode): return IO.Schema( node_id="WanReferenceVideoApi", display_name="Wan Reference to Video", - category="api node/video/Wan", + category="video/partner/Wan", description="Use the character and voice from input videos, combined with a prompt, " "to generate a new video that maintains character consistency.", inputs=[ @@ -828,7 +828,7 @@ class Wan2TextToVideoApi(IO.ComfyNode): return IO.Schema( node_id="Wan2TextToVideoApi", display_name="Wan 2.7 Text to Video", - category="api node/video/Wan", + category="video/partner/Wan", description="Generates a video based on a text prompt using the Wan 2.7 model.", inputs=[ IO.DynamicCombo.Input( @@ -981,7 +981,7 @@ class Wan2ImageToVideoApi(IO.ComfyNode): return IO.Schema( node_id="Wan2ImageToVideoApi", display_name="Wan 2.7 Image to Video", - category="api node/video/Wan", + category="video/partner/Wan", description="Generate a video from a first-frame image, with optional last-frame image and audio.", inputs=[ IO.DynamicCombo.Input( @@ -1152,7 +1152,7 @@ class Wan2VideoContinuationApi(IO.ComfyNode): return IO.Schema( node_id="Wan2VideoContinuationApi", display_name="Wan 2.7 Video Continuation", - category="api node/video/Wan", + category="video/partner/Wan", description="Continue a video from where it left off, with optional last-frame control.", inputs=[ IO.DynamicCombo.Input( @@ -1319,7 +1319,7 @@ class Wan2VideoEditApi(IO.ComfyNode): return IO.Schema( node_id="Wan2VideoEditApi", display_name="Wan 2.7 Video Edit", - category="api node/video/Wan", + category="video/partner/Wan", description="Edit a video using text instructions, reference images, or style transfer.", inputs=[ IO.DynamicCombo.Input( @@ -1477,7 +1477,7 @@ class Wan2ReferenceVideoApi(IO.ComfyNode): return IO.Schema( node_id="Wan2ReferenceVideoApi", display_name="Wan 2.7 Reference to Video", - category="api node/video/Wan", + category="video/partner/Wan", description="Generate a video featuring a person or object from reference materials. " "Supports single-character performances and multi-character interactions.", inputs=[ @@ -1651,7 +1651,7 @@ class HappyHorseTextToVideoApi(IO.ComfyNode): return IO.Schema( node_id="HappyHorseTextToVideoApi", display_name="HappyHorse Text to Video", - category="api node/video/Wan", + category="video/partner/Wan", description="Generates a video based on a text prompt using the HappyHorse model.", inputs=[ IO.DynamicCombo.Input( @@ -1775,7 +1775,7 @@ class HappyHorseImageToVideoApi(IO.ComfyNode): return IO.Schema( node_id="HappyHorseImageToVideoApi", display_name="HappyHorse Image to Video", - category="api node/video/Wan", + category="video/partner/Wan", description="Generate a video from a first-frame image using the HappyHorse model.", inputs=[ IO.DynamicCombo.Input( @@ -1905,7 +1905,7 @@ class HappyHorseVideoEditApi(IO.ComfyNode): return IO.Schema( node_id="HappyHorseVideoEditApi", display_name="HappyHorse Video Edit", - category="api node/video/Wan", + category="video/partner/Wan", description="Edit a video using text instructions or reference images with the HappyHorse model. " "Output duration is 3-15s and matches the input video; inputs longer than 15s are truncated.", inputs=[ @@ -2046,7 +2046,7 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode): return IO.Schema( node_id="HappyHorseReferenceVideoApi", display_name="HappyHorse Reference to Video", - category="api node/video/Wan", + category="video/partner/Wan", description="Generate a video featuring a person or object from reference materials with the HappyHorse " "model. Supports single-character performances and multi-character interactions.", inputs=[ diff --git a/comfy_api_nodes/nodes_wavespeed.py b/comfy_api_nodes/nodes_wavespeed.py index 65e45f60a..a250015c3 100644 --- a/comfy_api_nodes/nodes_wavespeed.py +++ b/comfy_api_nodes/nodes_wavespeed.py @@ -27,7 +27,7 @@ class WavespeedFlashVSRNode(IO.ComfyNode): return IO.Schema( node_id="WavespeedFlashVSRNode", display_name="FlashVSR Video Upscale", - category="api node/video/WaveSpeed", + category="video/partner/WaveSpeed", description="Fast, high-quality video upscaler that " "boosts resolution and restores clarity for low-resolution or blurry footage.", inputs=[ @@ -98,7 +98,7 @@ class WavespeedImageUpscaleNode(IO.ComfyNode): return IO.Schema( node_id="WavespeedImageUpscaleNode", display_name="WaveSpeed Image Upscale", - category="api node/image/WaveSpeed", + category="image/partner/WaveSpeed", description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.", inputs=[ IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]), diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 052301c33..57c501724 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -86,7 +86,7 @@ class _PollUIState: _RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] -QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"] +QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait", "in_queue"] async def sync_op( diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index 247d9ae8a..044077b18 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -11,7 +11,7 @@ class TextEncodeAceStepAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="TextEncodeAceStepAudio", - category="conditioning", + category="model/conditioning", inputs=[ IO.Clip.Input("clip"), IO.String.Input("tags", multiline=True, dynamic_prompts=True), @@ -33,7 +33,7 @@ class TextEncodeAceStepAudio15(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="TextEncodeAceStepAudio1.5", - category="conditioning", + category="model/conditioning", inputs=[ IO.Clip.Input("clip"), IO.String.Input("tags", multiline=True, dynamic_prompts=True), @@ -67,7 +67,7 @@ class EmptyAceStepLatentAudio(IO.ComfyNode): return IO.Schema( node_id="EmptyAceStepLatentAudio", display_name="Empty Ace Step 1.0 Latent Audio", - category="latent/audio", + category="model/latent/audio", inputs=[ IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1), IO.Int.Input( @@ -90,7 +90,7 @@ class EmptyAceStep15LatentAudio(IO.ComfyNode): return IO.Schema( node_id="EmptyAceStep1.5LatentAudio", display_name="Empty Ace Step 1.5 Latent Audio", - category="latent/audio", + category="model/latent/audio", inputs=[ IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01), IO.Int.Input( diff --git a/comfy_extras/nodes_advanced_samplers.py b/comfy_extras/nodes_advanced_samplers.py index 20717ca38..77a561e30 100644 --- a/comfy_extras/nodes_advanced_samplers.py +++ b/comfy_extras/nodes_advanced_samplers.py @@ -45,7 +45,7 @@ class SamplerLCMUpscale(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="SamplerLCMUpscale", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01, advanced=True), io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1, advanced=True), @@ -91,7 +91,7 @@ class SamplerLCM(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="SamplerLCM", - category="sampling/samplers", + category="model/sampling/samplers", description=("LCM sampler with tunable per-step noise. s_noise is a multiplier on the model's training noise scale"), inputs=[ io.Float.Input("s_noise", default=1.0, min=0.0, max=64.0, step=0.01, diff --git a/comfy_extras/nodes_align_your_steps.py b/comfy_extras/nodes_align_your_steps.py index 307f41337..f89a809bb 100644 --- a/comfy_extras/nodes_align_your_steps.py +++ b/comfy_extras/nodes_align_your_steps.py @@ -29,7 +29,7 @@ class AlignYourStepsScheduler(io.ComfyNode): return io.Schema( node_id="AlignYourStepsScheduler", search_aliases=["AYS scheduler"], - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]), io.Int.Input("steps", default=10, min=1, max=10000), diff --git a/comfy_extras/nodes_apg.py b/comfy_extras/nodes_apg.py index fd561d360..4a352038a 100644 --- a/comfy_extras/nodes_apg.py +++ b/comfy_extras/nodes_apg.py @@ -16,7 +16,7 @@ class APG(io.ComfyNode): return io.Schema( node_id="APG", display_name="Adaptive Projected Guidance", - category="sampling/custom_sampling", + category="model/sampling/custom_sampling", inputs=[ io.Model.Input("model"), io.Float.Input( diff --git a/comfy_extras/nodes_ar_video.py b/comfy_extras/nodes_ar_video.py index 1a15facfa..c22359eb2 100644 --- a/comfy_extras/nodes_ar_video.py +++ b/comfy_extras/nodes_ar_video.py @@ -19,7 +19,7 @@ class EmptyARVideoLatent(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="EmptyARVideoLatent", - category="latent/video", + category="model/latent/video", inputs=[ io.Int.Input("width", default=832, min=16, max=8192, step=16), io.Int.Input("height", default=480, min=16, max=8192, step=16), @@ -53,7 +53,7 @@ class SamplerARVideo(io.ComfyNode): return io.Schema( node_id="SamplerARVideo", display_name="Sampler AR Video", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Int.Input( "num_frame_per_block", @@ -85,7 +85,7 @@ class ARVideoI2V(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ARVideoI2V", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Model.Input("model"), io.Vae.Input("vae"), diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index f09a8a874..ff078f74c 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -16,7 +16,7 @@ class EmptyLatentAudio(IO.ComfyNode): return IO.Schema( node_id="EmptyLatentAudio", display_name="Empty Latent Audio", - category="latent/audio", + category="model/latent/audio", essentials_category="Audio", inputs=[ IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1), @@ -41,7 +41,7 @@ class ConditioningStableAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ConditioningStableAudio", - category="conditioning", + category="model/conditioning", inputs=[ IO.Conditioning.Input("positive"), IO.Conditioning.Input("negative"), @@ -70,7 +70,7 @@ class VAEEncodeAudio(IO.ComfyNode): node_id="VAEEncodeAudio", search_aliases=["audio to latent"], display_name="VAE Encode Audio", - category="latent/audio", + category="model/latent/audio", inputs=[ IO.Audio.Input("audio"), IO.Vae.Input("vae"), @@ -115,7 +115,7 @@ class VAEDecodeAudio(IO.ComfyNode): node_id="VAEDecodeAudio", search_aliases=["latent to audio"], display_name="VAE Decode Audio", - category="latent/audio", + category="model/latent/audio", inputs=[ IO.Latent.Input("samples"), IO.Vae.Input("vae"), @@ -137,7 +137,7 @@ class VAEDecodeAudioTiled(IO.ComfyNode): node_id="VAEDecodeAudioTiled", search_aliases=["latent to audio"], display_name="VAE Decode Audio (Tiled)", - category="latent/audio", + category="model/latent/audio", inputs=[ IO.Latent.Input("samples"), IO.Vae.Input("vae"), diff --git a/comfy_extras/nodes_audio_encoder.py b/comfy_extras/nodes_audio_encoder.py index 6a85da89b..2ae30d321 100644 --- a/comfy_extras/nodes_audio_encoder.py +++ b/comfy_extras/nodes_audio_encoder.py @@ -11,7 +11,7 @@ class AudioEncoderLoader(io.ComfyNode): return io.Schema( node_id="AudioEncoderLoader", display_name="Load Audio Encoder", - category="loaders", + category="model/loaders", inputs=[ io.Combo.Input( "audio_encoder_name", @@ -36,7 +36,7 @@ class AudioEncoderEncode(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="AudioEncoderEncode", - category="conditioning", + category="model/conditioning", inputs=[ io.AudioEncoder.Input("audio_encoder"), io.Audio.Input("audio"), diff --git a/comfy_extras/nodes_bg_removal.py b/comfy_extras/nodes_bg_removal.py index 793fd802b..9dc9ad854 100644 --- a/comfy_extras/nodes_bg_removal.py +++ b/comfy_extras/nodes_bg_removal.py @@ -11,7 +11,7 @@ class LoadBackgroundRemovalModel(IO.ComfyNode): return IO.Schema( node_id="LoadBackgroundRemovalModel", display_name="Load Background Removal Model", - category="loaders", + category="model/loaders", inputs=[ IO.Combo.Input("bg_removal_name", options=sorted(files), tooltip="The model used to remove backgrounds from images"), ], diff --git a/comfy_extras/nodes_camera_trajectory.py b/comfy_extras/nodes_camera_trajectory.py index 34b78e81b..13a1448f4 100644 --- a/comfy_extras/nodes_camera_trajectory.py +++ b/comfy_extras/nodes_camera_trajectory.py @@ -153,7 +153,7 @@ class WanCameraEmbedding(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanCameraEmbedding", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Combo.Input( "camera_pose", diff --git a/comfy_extras/nodes_cfg.py b/comfy_extras/nodes_cfg.py index 4ebb4b51e..b585c560f 100644 --- a/comfy_extras/nodes_cfg.py +++ b/comfy_extras/nodes_cfg.py @@ -57,24 +57,55 @@ class CFGNorm(io.ComfyNode): inputs=[ io.Model.Input("model"), io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01), + io.Boolean.Input( + "pre_cfg", + default=False, + optional=True, + tooltip=( + "If true, rescale the combined noise BEFORE the sampler's CFG combine, " + "without clamping (can amplify). Matches the norm-scaled CFG used by " + "models like Lens. Default false keeps the original post-CFG x0-space " + "attenuate-only behavior." + ), + ), ], outputs=[io.Model.Output(display_name="patched_model")], is_experimental=True, ) @classmethod - def execute(cls, model, strength) -> io.NodeOutput: + def execute(cls, model, strength, pre_cfg=False) -> io.NodeOutput: m = model.clone() - def cfg_norm(args): - cond_p = args['cond_denoised'] - pred_text_ = args["denoised"] + if pre_cfg: + def cfg_norm_pre(args): + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + comb = uncond + cond_scale * (cond - uncond) + cond_norm = torch.linalg.vector_norm(cond, dim=1, keepdim=True) + comb_norm = torch.linalg.vector_norm(comb, dim=1, keepdim=True) + rescale = torch.where( + comb_norm > 0, + cond_norm / comb_norm.clamp_min(1e-12), + torch.ones_like(comb_norm), + ) + rescaled = comb * rescale + # strength blends back toward standard linear CFG (1.0 = full rescale). + if strength != 1.0: + rescaled = strength * rescaled + (1.0 - strength) * comb + return rescaled + m.set_model_sampler_cfg_function(cfg_norm_pre) + else: + def cfg_norm(args): + cond_p = args['cond_denoised'] + pred_text_ = args["denoised"] - norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True) - norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True) - scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0) - return pred_text_ * scale * strength + norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True) + norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True) + scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0) + return pred_text_ * scale * strength - m.set_model_sampler_post_cfg_function(cfg_norm) + m.set_model_sampler_post_cfg_function(cfg_norm) return io.NodeOutput(m) diff --git a/comfy_extras/nodes_chroma_radiance.py b/comfy_extras/nodes_chroma_radiance.py index 509436062..ca427e5cb 100644 --- a/comfy_extras/nodes_chroma_radiance.py +++ b/comfy_extras/nodes_chroma_radiance.py @@ -13,7 +13,7 @@ class EmptyChromaRadianceLatentImage(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="EmptyChromaRadianceLatentImage", - category="latent/chroma_radiance", + category="model/latent/chroma_radiance", inputs=[ io.Int.Input(id="width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), io.Int.Input(id="height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), @@ -33,7 +33,7 @@ class ChromaRadianceOptions(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="ChromaRadianceOptions", - category="model_patches/chroma_radiance", + category="model/patch/chroma_radiance", description="Allows setting advanced options for the Chroma Radiance model.", inputs=[ io.Model.Input(id="model"), diff --git a/comfy_extras/nodes_color.py b/comfy_extras/nodes_color.py index 80ba121cd..01a05035e 100644 --- a/comfy_extras/nodes_color.py +++ b/comfy_extras/nodes_color.py @@ -8,7 +8,7 @@ class ColorToRGBInt(io.ComfyNode): return io.Schema( node_id="ColorToRGBInt", display_name="Color to RGB Int", - category="utils", + category="utilities", description="Convert a color to a RGB integer value.", inputs=[ io.Color.Input("color"), diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index 24729c3a7..d9e32b9d9 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -9,7 +9,7 @@ class ContextWindowsManualNode(io.ComfyNode): return io.Schema( node_id="ContextWindowsManual", display_name="Context Windows (Manual)", - category="model_patches", + category="model/patch", description="Manually set context windows.", inputs=[ io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), diff --git a/comfy_extras/nodes_controlnet.py b/comfy_extras/nodes_controlnet.py index 847cb0bdf..17d965405 100644 --- a/comfy_extras/nodes_controlnet.py +++ b/comfy_extras/nodes_controlnet.py @@ -9,7 +9,7 @@ class SetUnionControlNetType(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SetUnionControlNetType", - category="conditioning/controlnet", + category="model/conditioning/controlnet", inputs=[ io.ControlNet.Input("control_net"), io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())), @@ -39,7 +39,7 @@ class ControlNetInpaintingAliMamaApply(io.ComfyNode): return io.Schema( node_id="ControlNetInpaintingAliMamaApply", search_aliases=["masked controlnet"], - category="conditioning/controlnet", + category="model/conditioning/controlnet", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), diff --git a/comfy_extras/nodes_cosmos.py b/comfy_extras/nodes_cosmos.py index 7dd129d19..d754ab442 100644 --- a/comfy_extras/nodes_cosmos.py +++ b/comfy_extras/nodes_cosmos.py @@ -13,7 +13,7 @@ class EmptyCosmosLatentVideo(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="EmptyCosmosLatentVideo", - category="latent/video", + category="model/latent/video", inputs=[ io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16), io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16), @@ -45,7 +45,7 @@ class CosmosImageToVideoLatent(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="CosmosImageToVideoLatent", - category="conditioning/inpaint", + category="model/conditioning/inpaint", inputs=[ io.Vae.Input("vae"), io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16), @@ -88,7 +88,7 @@ class CosmosPredict2ImageToVideoLatent(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="CosmosPredict2ImageToVideoLatent", - category="conditioning/inpaint", + category="model/conditioning/inpaint", inputs=[ io.Vae.Input("vae"), io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), diff --git a/comfy_extras/nodes_curve.py b/comfy_extras/nodes_curve.py index 099453131..aa2d94bb6 100644 --- a/comfy_extras/nodes_curve.py +++ b/comfy_extras/nodes_curve.py @@ -11,7 +11,7 @@ class CurveEditor(io.ComfyNode): return io.Schema( node_id="CurveEditor", display_name="Curve Editor", - category="utils", + category="utilities", inputs=[ io.Curve.Input("curve"), io.Histogram.Input("histogram", optional=True), @@ -38,7 +38,7 @@ class ImageHistogram(io.ComfyNode): return io.Schema( node_id="ImageHistogram", display_name="Image Histogram", - category="utils", + category="utilities", inputs=[ io.Image.Input("image"), ], diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 10b56b91c..c3346bf09 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -17,7 +17,7 @@ class BasicScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="BasicScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Model.Input("model"), io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES), @@ -47,7 +47,7 @@ class KarrasScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="KarrasScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=10000), io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), @@ -69,7 +69,7 @@ class ExponentialScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ExponentialScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=10000), io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), @@ -90,7 +90,7 @@ class PolyexponentialScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="PolyexponentialScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=10000), io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), @@ -112,7 +112,7 @@ class LaplaceScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LaplaceScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=10000), io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), @@ -136,7 +136,7 @@ class SDTurboScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SDTurboScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Model.Input("model"), io.Int.Input("steps", default=1, min=1, max=10), @@ -160,7 +160,7 @@ class BetaSamplingScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="BetaSamplingScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Model.Input("model"), io.Int.Input("steps", default=20, min=1, max=10000), @@ -182,7 +182,7 @@ class VPScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="VPScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=10000), io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), #TODO: fix default values @@ -204,7 +204,7 @@ class SplitSigmas(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SplitSigmas", - category="sampling/sigmas", + category="model/sampling/sigmas", inputs=[ io.Sigmas.Input("sigmas"), io.Int.Input("step", default=0, min=0, max=10000), @@ -228,7 +228,7 @@ class SplitSigmasDenoise(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SplitSigmasDenoise", - category="sampling/sigmas", + category="model/sampling/sigmas", inputs=[ io.Sigmas.Input("sigmas"), io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), @@ -254,7 +254,7 @@ class FlipSigmas(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="FlipSigmas", - category="sampling/sigmas", + category="model/sampling/sigmas", inputs=[io.Sigmas.Input("sigmas")], outputs=[io.Sigmas.Output()] ) @@ -276,7 +276,7 @@ class SetFirstSigma(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SetFirstSigma", - category="sampling/sigmas", + category="model/sampling/sigmas", inputs=[ io.Sigmas.Input("sigmas"), io.Float.Input("sigma", default=136.0, min=0.0, max=20000.0, step=0.001, round=False), @@ -298,7 +298,7 @@ class ExtendIntermediateSigmas(io.ComfyNode): return io.Schema( node_id="ExtendIntermediateSigmas", search_aliases=["interpolate sigmas"], - category="sampling/sigmas", + category="model/sampling/sigmas", inputs=[ io.Sigmas.Input("sigmas"), io.Int.Input("steps", default=2, min=1, max=100), @@ -351,7 +351,7 @@ class SamplingPercentToSigma(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplingPercentToSigma", - category="sampling/sigmas", + category="model/sampling/sigmas", inputs=[ io.Model.Input("model"), io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001), @@ -379,7 +379,7 @@ class KSamplerSelect(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="KSamplerSelect", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[io.Combo.Input("sampler_name", options=comfy.samplers.SAMPLER_NAMES)], outputs=[io.Sampler.Output()] ) @@ -396,7 +396,7 @@ class SamplerDPMPP_3M_SDE(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerDPMPP_3M_SDE", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), @@ -421,7 +421,7 @@ class SamplerDPMPP_2M_SDE(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerDPMPP_2M_SDE", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Combo.Input("solver_type", options=['midpoint', 'heun']), io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), @@ -448,7 +448,7 @@ class SamplerDPMPP_SDE(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerDPMPP_SDE", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), @@ -474,7 +474,7 @@ class SamplerDPMPP_2S_Ancestral(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerDPMPP_2S_Ancestral", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), @@ -494,7 +494,7 @@ class SamplerEulerAncestral(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerEulerAncestral", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), @@ -515,7 +515,7 @@ class SamplerEulerAncestralCFGPP(io.ComfyNode): return io.Schema( node_id="SamplerEulerAncestralCFGPP", display_name="SamplerEulerAncestralCFG++", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Float.Input("eta", default=1.0, min=0.0, max=1.0, step=0.01, round=False), io.Float.Input("s_noise", default=1.0, min=0.0, max=10.0, step=0.01, round=False), @@ -537,7 +537,7 @@ class SamplerLMS(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerLMS", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[io.Int.Input("order", default=4, min=1, max=100, advanced=True)], outputs=[io.Sampler.Output()] ) @@ -554,7 +554,7 @@ class SamplerDPMAdaptative(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerDPMAdaptative", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Int.Input("order", default=3, min=2, max=3, advanced=True), io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False, advanced=True), @@ -585,7 +585,7 @@ class SamplerER_SDE(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerER_SDE", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]), io.Int.Input("max_stage", default=3, min=1, max=3, advanced=True), @@ -623,7 +623,7 @@ class SamplerSASolver(io.ComfyNode): return io.Schema( node_id="SamplerSASolver", search_aliases=["sde"], - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Model.Input("model"), io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False, advanced=True), @@ -668,7 +668,7 @@ class SamplerSEEDS2(io.ComfyNode): return io.Schema( node_id="SamplerSEEDS2", search_aliases=["sde", "exp heun"], - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Combo.Input("solver_type", options=["phi_1", "phi_2"]), io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength", advanced=True), @@ -727,7 +727,7 @@ class SamplerCustom(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerCustom", - category="sampling/custom_sampling", + category="model/sampling/custom_sampling", inputs=[ io.Model.Input("model"), io.Boolean.Input("add_noise", default=True, advanced=True), @@ -795,7 +795,7 @@ class BasicGuider(io.ComfyNode): return io.Schema( node_id="BasicGuider", display_name="Basic Guider", - category="sampling/guiders", + category="model/sampling/guiders", inputs=[ io.Model.Input("model"), io.Conditioning.Input("conditioning"), @@ -817,7 +817,7 @@ class CFGGuider(io.ComfyNode): return io.Schema( node_id="CFGGuider", display_name="CFG Guider", - category="sampling/guiders", + category="model/sampling/guiders", inputs=[ io.Model.Input("model"), io.Conditioning.Input("positive"), @@ -872,7 +872,7 @@ class DualCFGGuider(io.ComfyNode): node_id="DualCFGGuider", search_aliases=["dual prompt guidance"], display_name="Dual CFG Guider", - category="sampling/guiders", + category="model/sampling/guiders", inputs=[ io.Model.Input("model"), io.Conditioning.Input("cond1"), @@ -900,7 +900,7 @@ class DisableNoise(io.ComfyNode): return io.Schema( node_id="DisableNoise", search_aliases=["zero noise"], - category="sampling/noise", + category="model/sampling/noise", inputs=[], outputs=[io.Noise.Output()] ) @@ -917,7 +917,7 @@ class RandomNoise(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="RandomNoise", - category="sampling/noise", + category="model/sampling/noise", inputs=[io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True)], outputs=[io.Noise.Output()] ) @@ -934,7 +934,7 @@ class SamplerCustomAdvanced(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerCustomAdvanced", - category="sampling/custom_sampling", + category="model/sampling/custom_sampling", inputs=[ io.Noise.Input("noise"), io.Guider.Input("guider"), diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 22f5ff203..104d16d91 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -157,7 +157,7 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode): return io.NodeOutput(output_tensor, captions) -def save_images_to_folder(image_list, output_dir, prefix="image"): +def save_images_to_folder(image_list, output_dir, prefix="image", overwrite=True): """Utility function to save a list of image tensors to disk. Args: @@ -197,7 +197,11 @@ def save_images_to_folder(image_list, output_dir, prefix="image"): raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}") # Save image - filename = f"{prefix}_{idx:05d}.png" + if overwrite: + filename = f"{prefix}_{idx:05d}.png" + else: + _, _, counter, _, resolved_prefix = folder_paths.get_save_image_path(prefix, output_dir) + filename = f"{resolved_prefix}_{counter:05}_{idx:05d}.png" filepath = os.path.join(output_dir, filename) img.save(filepath) saved_files.append(filename) @@ -230,19 +234,26 @@ class SaveImageDataSetToFolderNode(io.ComfyNode): tooltip="Prefix for saved image filenames.", advanced=True, ), + io.Combo.Input( + "mode", + default="overwrite", + options=["overwrite", "increment"], + tooltip="Whether to overwrite existing files or increment filenames to avoid overwriting." + ), ], outputs=[], is_deprecated=True, # This node is redundant and superseded by existing Save Image nodes where the target folder can be specified in the filename_prefix ) @classmethod - def execute(cls, images, folder_name, filename_prefix): + def execute(cls, images, folder_name, filename_prefix, mode): # Extract scalar values folder_name = folder_name[0] filename_prefix = filename_prefix[0] + mode = mode[0] output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) - saved_files = save_images_to_folder(images, output_dir, filename_prefix) + saved_files = save_images_to_folder(images, output_dir, filename_prefix, mode=='overwrite') logging.info(f"Saved {len(saved_files)} images to {output_dir}.") return io.NodeOutput() @@ -278,18 +289,25 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode): tooltip="Prefix for saved image filenames.", advanced=True, ), + io.Combo.Input( + "mode", + default="overwrite", + options=["overwrite", "increment"], + tooltip="Whether to overwrite existing files or increment filenames to avoid overwriting." + ), ], outputs=[], ) @classmethod - def execute(cls, images, folder_name, filename_prefix, texts=None): + def execute(cls, images, folder_name, filename_prefix, mode, texts=None): # Extract scalar values folder_name = folder_name[0] filename_prefix = filename_prefix[0] + mode = mode[0] output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) - saved_files = save_images_to_folder(images, output_dir, filename_prefix) + saved_files = save_images_to_folder(images, output_dir, filename_prefix, mode=='overwrite') # Save captions if texts: @@ -574,7 +592,7 @@ class TextProcessingNode(io.ComfyNode): return io.Schema( node_id=cls.node_id, display_name=cls.display_name or cls.node_id, - category="dataset/text", + category="text", is_experimental=True, is_input_list=is_group, # True for group, False for individual inputs=inputs, @@ -1208,7 +1226,7 @@ class ResolutionBucket(io.ComfyNode): node_id="ResolutionBucket", search_aliases=["bucket by resolution", "group by resolution", "batch by resolution"], display_name="Resolution Bucket", - category="training", + category="model/training", description="Group latents and conditionings into buckets", is_experimental=True, is_input_list=True, @@ -1302,7 +1320,7 @@ class MakeTrainingDataset(io.ComfyNode): node_id="MakeTrainingDataset", search_aliases=["encode dataset"], display_name="Make Training Dataset", - category="training", + category="model/training", description="Encode images with VAE and texts with CLIP to create a training dataset of latents and conditionings.", is_experimental=True, is_input_list=True, # images and texts as lists @@ -1390,7 +1408,7 @@ class SaveTrainingDataset(io.ComfyNode): node_id="SaveTrainingDataset", search_aliases=["export dataset", "save dataset"], display_name="Save Training Dataset", - category="training", + category="model/training", description="Save encoded training dataset (latents + conditioning) to disk for efficient loading during training.", is_experimental=True, is_output_node=True, @@ -1493,7 +1511,7 @@ class LoadTrainingDataset(io.ComfyNode): node_id="LoadTrainingDataset", search_aliases=["import dataset", "training data"], display_name="Load Training Dataset", - category="training", + category="model/training", description="Load encoded training dataset (latents + conditioning) from disk for use in training.", is_experimental=True, inputs=[ diff --git a/comfy_extras/nodes_eps.py b/comfy_extras/nodes_eps.py index 0fb3871c8..8c397f132 100644 --- a/comfy_extras/nodes_eps.py +++ b/comfy_extras/nodes_eps.py @@ -18,7 +18,7 @@ class EpsilonScaling(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Epsilon Scaling", - category="model_patches/unet", + category="model/patch/unet", inputs=[ io.Model.Input("model"), io.Float.Input( @@ -84,7 +84,7 @@ class TemporalScoreRescaling(io.ComfyNode): return io.Schema( node_id="TemporalScoreRescaling", display_name="TSR - Temporal Score Rescaling", - category="model_patches/unet", + category="model/patch/unet", inputs=[ io.Model.Input("model"), io.Float.Input( diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index 997f21c09..afc663b22 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -40,7 +40,7 @@ class EmptyFlux2LatentImage(io.ComfyNode): return io.Schema( node_id="EmptyFlux2LatentImage", display_name="Empty Flux 2 Latent", - category="latent", + category="model/latent", inputs=[ io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), @@ -215,7 +215,7 @@ class Flux2Scheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Flux2Scheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=4096), io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1), diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py index 9dd34cfb8..4d5bca17e 100644 --- a/comfy_extras/nodes_frame_interpolation.py +++ b/comfy_extras/nodes_frame_interpolation.py @@ -19,7 +19,7 @@ class FrameInterpolationModelLoader(io.ComfyNode): return io.Schema( node_id="FrameInterpolationModelLoader", display_name="Load Frame Interpolation Model", - category="loaders", + category="model/loaders", inputs=[ io.Combo.Input("model_name", options=folder_paths.get_filename_list("frame_interpolation"), tooltip="Select a frame interpolation model to load. Models must be placed in the 'frame_interpolation' folder."), diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py index 248efdef3..ccbd1fd90 100644 --- a/comfy_extras/nodes_freelunch.py +++ b/comfy_extras/nodes_freelunch.py @@ -29,7 +29,7 @@ class FreeU(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="FreeU", - category="model_patches/unet", + category="model/patch/unet", inputs=[ IO.Model.Input("model"), IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01, advanced=True), @@ -76,7 +76,7 @@ class FreeU_V2(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="FreeU_V2", - category="model_patches/unet", + category="model/patch/unet", inputs=[ IO.Model.Input("model"), IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01, advanced=True), diff --git a/comfy_extras/nodes_gits.py b/comfy_extras/nodes_gits.py index 0b7666524..434a24387 100644 --- a/comfy_extras/nodes_gits.py +++ b/comfy_extras/nodes_gits.py @@ -340,7 +340,7 @@ class GITSScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="GITSScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05, advanced=True), io.Int.Input("steps", default=10, min=2, max=1000), diff --git a/comfy_extras/nodes_hidream_o1.py b/comfy_extras/nodes_hidream_o1.py index f393745f6..8648d2e26 100644 --- a/comfy_extras/nodes_hidream_o1.py +++ b/comfy_extras/nodes_hidream_o1.py @@ -14,7 +14,7 @@ class EmptyHiDreamO1LatentImage(io.ComfyNode): return io.Schema( node_id="EmptyHiDreamO1LatentImage", display_name="Empty HiDream-O1 Latent Image", - category="latent/image", + category="model/latent/image", description=( "Empty pixel-space latent for HiDream-O1-Image. The model was " "trained at ~4 megapixels; lower resolutions go off-distribution " @@ -47,7 +47,7 @@ class HiDreamO1ReferenceImages(io.ComfyNode): return io.Schema( node_id="HiDreamO1ReferenceImages", display_name="HiDream-O1 Reference Images", - category="conditioning/image", + category="model/conditioning/image", description=( "Attach 1-10 reference images to conditioning, one for edit instruction" "or multiple for subject-driven personalization." diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 9e4873be5..16fff12af 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -41,7 +41,7 @@ class EmptyHunyuanLatentVideo(io.ComfyNode): return io.Schema( node_id="EmptyHunyuanLatentVideo", display_name="Empty HunyuanVideo 1.0 Latent", - category="latent/video", + category="model/latent/video", inputs=[ io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), @@ -81,7 +81,7 @@ class HunyuanVideo15ImageToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="HunyuanVideo15ImageToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -132,7 +132,7 @@ class HunyuanVideo15SuperResolution(io.ComfyNode): return io.Schema( node_id="HunyuanVideo15SuperResolution", display_name="Hunyuan Video 1.5 Super Resolution", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -178,7 +178,7 @@ class LatentUpscaleModelLoader(io.ComfyNode): return io.Schema( node_id="LatentUpscaleModelLoader", display_name="Load Latent Upscale Model", - category="loaders", + category="model/loaders", inputs=[ io.Combo.Input("model_name", options=folder_paths.get_filename_list("latent_upscale_models")), ], @@ -227,7 +227,7 @@ class HunyuanVideo15LatentUpscaleWithModel(io.ComfyNode): return io.Schema( node_id="HunyuanVideo15LatentUpscaleWithModel", display_name="Hunyuan Video 15 Latent Upscale With Model", - category="latent", + category="model/latent", inputs=[ io.LatentUpscaleModel.Input("model"), io.Latent.Input("samples"), @@ -308,7 +308,7 @@ class HunyuanImageToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="HunyuanImageToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Vae.Input("vae"), @@ -359,7 +359,7 @@ class EmptyHunyuanImageLatent(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="EmptyHunyuanImageLatent", - category="latent", + category="model/latent", inputs=[ io.Int.Input("width", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32), io.Int.Input("height", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32), @@ -384,7 +384,7 @@ class HunyuanRefinerLatent(io.ComfyNode): return io.Schema( node_id="HunyuanRefinerLatent", display_name="Hunyuan Latent Refiner", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index bcd3f9198..60e530626 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -12,7 +12,7 @@ class EmptyLatentHunyuan3Dv2(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="EmptyLatentHunyuan3Dv2", - category="latent/3d", + category="model/latent/3d", inputs=[ IO.Int.Input("resolution", default=3072, min=1, max=8192), IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), @@ -35,7 +35,7 @@ class Hunyuan3Dv2Conditioning(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="Hunyuan3Dv2Conditioning", - category="conditioning/3d_models", + category="model/conditioning/3d_models", inputs=[ IO.ClipVisionOutput.Input("clip_vision_output"), ], @@ -60,7 +60,7 @@ class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="Hunyuan3Dv2ConditioningMultiView", - category="conditioning/3d_models", + category="model/conditioning/3d_models", inputs=[ IO.ClipVisionOutput.Input("front", optional=True), IO.ClipVisionOutput.Input("left", optional=True), @@ -97,7 +97,7 @@ class VAEDecodeHunyuan3D(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="VAEDecodeHunyuan3D", - category="latent/3d", + category="model/latent/3d", inputs=[ IO.Latent.Input("samples"), IO.Vae.Input("vae"), diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index 44a9c6f97..2d3f1bd05 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -103,7 +103,7 @@ class HypernetworkLoader(IO.ComfyNode): return IO.Schema( node_id="HypernetworkLoader", display_name="Load Hypernetwork", - category="loaders", + category="model/loaders", inputs=[ IO.Model.Input("model"), IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")), diff --git a/comfy_extras/nodes_hypertile.py b/comfy_extras/nodes_hypertile.py index 354d96db1..2a96416be 100644 --- a/comfy_extras/nodes_hypertile.py +++ b/comfy_extras/nodes_hypertile.py @@ -27,7 +27,7 @@ class HyperTile(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="HyperTile", - category="model_patches/unet", + category="model/patch/unet", inputs=[ io.Model.Input("model"), io.Int.Input("tile_size", default=256, min=1, max=2048, advanced=True), diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index fe6008aa3..469a7be55 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -95,7 +95,7 @@ class BoundingBox(IO.ComfyNode): return IO.Schema( node_id="PrimitiveBoundingBox", display_name="Bounding Box", - category="utils/primitive", + category="utilities/primitive", inputs=[ IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION), IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION), diff --git a/comfy_extras/nodes_ip2p.py b/comfy_extras/nodes_ip2p.py index 78f29915d..9c80834f0 100644 --- a/comfy_extras/nodes_ip2p.py +++ b/comfy_extras/nodes_ip2p.py @@ -9,7 +9,7 @@ class InstructPixToPixConditioning(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="InstructPixToPixConditioning", - category="conditioning/instructpix2pix", + category="model/conditioning/instructpix2pix", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py index 346c50cde..015965498 100644 --- a/comfy_extras/nodes_kandinsky5.py +++ b/comfy_extras/nodes_kandinsky5.py @@ -13,7 +13,7 @@ class Kandinsky5ImageToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Kandinsky5ImageToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -71,7 +71,7 @@ class NormalizeVideoLatentStart(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="NormalizeVideoLatentStart", - category="conditioning/video_models", + category="model/conditioning/video_models", description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.", inputs=[ io.Latent.Input("latent"), diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 8bb368dec..32da9e8ac 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -22,7 +22,7 @@ class LatentAdd(io.ComfyNode): return io.Schema( node_id="LatentAdd", search_aliases=["combine latents", "sum latents"], - category="latent/advanced", + category="model/latent/advanced", inputs=[ io.Latent.Input("samples1"), io.Latent.Input("samples2"), @@ -49,7 +49,7 @@ class LatentSubtract(io.ComfyNode): return io.Schema( node_id="LatentSubtract", search_aliases=["difference latent", "remove features"], - category="latent/advanced", + category="model/latent/advanced", inputs=[ io.Latent.Input("samples1"), io.Latent.Input("samples2"), @@ -76,7 +76,7 @@ class LatentMultiply(io.ComfyNode): return io.Schema( node_id="LatentMultiply", search_aliases=["scale latent", "amplify latent", "latent gain"], - category="latent/advanced", + category="model/latent/advanced", inputs=[ io.Latent.Input("samples"), io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01), @@ -100,7 +100,7 @@ class LatentInterpolate(io.ComfyNode): return io.Schema( node_id="LatentInterpolate", search_aliases=["blend latent", "mix latent", "lerp latent", "transition"], - category="latent/advanced", + category="model/latent/advanced", inputs=[ io.Latent.Input("samples1"), io.Latent.Input("samples2"), @@ -139,7 +139,7 @@ class LatentConcat(io.ComfyNode): return io.Schema( node_id="LatentConcat", search_aliases=["join latents", "stitch latents"], - category="latent/advanced", + category="model/latent/advanced", inputs=[ io.Latent.Input("samples1"), io.Latent.Input("samples2"), @@ -179,7 +179,7 @@ class LatentCut(io.ComfyNode): return io.Schema( node_id="LatentCut", search_aliases=["crop latent", "slice latent", "extract region"], - category="latent/advanced", + category="model/latent/advanced", inputs=[ io.Latent.Input("samples"), io.Combo.Input("dim", options=["x", "y", "t"]), @@ -220,7 +220,7 @@ class LatentCutToBatch(io.ComfyNode): return io.Schema( node_id="LatentCutToBatch", search_aliases=["slice to batch", "split latent", "tile latent"], - category="latent/advanced", + category="model/latent/advanced", inputs=[ io.Latent.Input("samples"), io.Combo.Input("dim", options=["t", "x", "y"]), @@ -262,7 +262,7 @@ class LatentBatch(io.ComfyNode): return io.Schema( node_id="LatentBatch", search_aliases=["combine latents", "merge latents", "join latents"], - category="latent/batch", + category="model/latent/batch", is_deprecated=True, inputs=[ io.Latent.Input("samples1"), @@ -290,7 +290,7 @@ class LatentBatchSeedBehavior(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentBatchSeedBehavior", - category="latent/advanced", + category="model/latent/advanced", inputs=[ io.Latent.Input("samples"), io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"), @@ -319,7 +319,7 @@ class LatentApplyOperation(io.ComfyNode): return io.Schema( node_id="LatentApplyOperation", search_aliases=["transform latent"], - category="latent/advanced/operations", + category="model/latent/advanced/operations", is_experimental=True, inputs=[ io.Latent.Input("samples"), @@ -343,7 +343,7 @@ class LatentApplyOperationCFG(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentApplyOperationCFG", - category="latent/advanced/operations", + category="model/latent/advanced/operations", is_experimental=True, inputs=[ io.Model.Input("model"), @@ -375,7 +375,7 @@ class LatentOperationTonemapReinhard(io.ComfyNode): return io.Schema( node_id="LatentOperationTonemapReinhard", search_aliases=["hdr latent"], - category="latent/advanced/operations", + category="model/latent/advanced/operations", is_experimental=True, inputs=[ io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01), @@ -410,7 +410,7 @@ class LatentOperationSharpen(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentOperationSharpen", - category="latent/advanced/operations", + category="model/latent/advanced/operations", is_experimental=True, inputs=[ io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1, advanced=True), @@ -447,7 +447,7 @@ class ReplaceVideoLatentFrames(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ReplaceVideoLatentFrames", - category="latent/batch", + category="model/latent/batch", inputs=[ io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."), io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."), diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 9112bdd0a..9c27c0191 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -34,7 +34,7 @@ class Load3D(IO.ComfyNode): essentials_category="Basics", is_experimental=True, inputs=[ - IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model), + IO.Combo.Input("model_file", options=["none"] + sorted(files), upload=IO.UploadType.model), IO.Load3D.Input("image"), IO.Int.Input("width", default=1024, min=1, max=4096, step=1), IO.Int.Input("height", default=1024, min=1, max=4096, step=1), @@ -68,8 +68,12 @@ class Load3D(IO.ComfyNode): video = InputImpl.VideoFromFile(recording_video_path) - file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file)) - return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video, file_3d) + file_3d = None + mesh_path = "" + if model_file and model_file != "none": + file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file)) + mesh_path = model_file + return IO.NodeOutput(output_image, output_mask, mesh_path, normal_image, image['camera_info'], video, file_3d) process = execute # TODO: remove diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py index 92507f1fc..95f6ab848 100644 --- a/comfy_extras/nodes_logic.py +++ b/comfy_extras/nodes_logic.py @@ -13,7 +13,7 @@ class NotNode(io.ComfyNode): return io.Schema( node_id="ComfyNotNode", display_name="Not", - category="utils/logic", + category="utilities/logic", description="Logical NOT operation. Returns true if the value is falsy. Uses Python's rules for truthiness.", search_aliases=["invert", "toggle", "negate", "flip boolean"], inputs=[ @@ -40,7 +40,7 @@ class AndNode(io.ComfyNode): return io.Schema( node_id="ComfyAndNode", display_name="And", - category="utils/logic", + category="utilities/logic", description="Logical AND operation. Returns true if all of the values are truthy. Uses Python's rules for truthiness.", search_aliases=["all", "every"], inputs=[ @@ -67,7 +67,7 @@ class OrNode(io.ComfyNode): return io.Schema( node_id="ComfyOrNode", display_name="Or", - category="utils/logic", + category="utilities/logic", description="Logical OR operation. Returns true if any of the values are truthy. Uses Python's rules for truthiness.", search_aliases=["any", "some"], inputs=[ @@ -90,7 +90,7 @@ class SwitchNode(io.ComfyNode): return io.Schema( node_id="ComfySwitchNode", display_name="Switch", - category="utils/logic", + category="utilities/logic", is_experimental=True, inputs=[ io.Boolean.Input("switch"), @@ -121,7 +121,7 @@ class SoftSwitchNode(io.ComfyNode): return io.Schema( node_id="ComfySoftSwitchNode", display_name="Soft Switch", - category="utils/logic", + category="utilities/logic", is_experimental=True, inputs=[ io.Boolean.Input("switch"), @@ -176,7 +176,7 @@ class CustomComboNode(io.ComfyNode): return io.Schema( node_id="CustomCombo", display_name="Custom Combo", - category="utils", + category="utilities", is_experimental=True, inputs=[io.Combo.Input("choice", options=[])], outputs=[ @@ -211,7 +211,7 @@ class DCTestNode(io.ComfyNode): return io.Schema( node_id="DCTestNode", display_name="DCTest", - category="utils/logic", + category="utilities/logic", is_output_node=True, inputs=[io.DynamicCombo.Input("combo", options=[ io.DynamicCombo.Option("option1", [io.String.Input("string")]), @@ -249,7 +249,7 @@ class AutogrowNamesTestNode(io.ComfyNode): return io.Schema( node_id="AutogrowNamesTestNode", display_name="AutogrowNamesTest", - category="utils/logic", + category="utilities/logic", inputs=[ _io.Autogrow.Input("autogrow", template=template) ], @@ -269,7 +269,7 @@ class AutogrowPrefixTestNode(io.ComfyNode): return io.Schema( node_id="AutogrowPrefixTestNode", display_name="AutogrowPrefixTest", - category="utils/logic", + category="utilities/logic", inputs=[ _io.Autogrow.Input("autogrow", template=template) ], @@ -288,7 +288,7 @@ class ComboOutputTestNode(io.ComfyNode): return io.Schema( node_id="ComboOptionTestNode", display_name="ComboOptionTest", - category="utils/logic", + category="utilities/logic", inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]), io.Combo.Input("combo2", options=["option4", "option5", "option6"])], outputs=[io.Combo.Output(), io.Combo.Output()], @@ -305,7 +305,7 @@ class ConvertStringToComboNode(io.ComfyNode): node_id="ConvertStringToComboNode", search_aliases=["string to dropdown", "text to combo"], display_name="Convert String to Combo", - category="utils/logic", + category="utilities/logic", inputs=[io.String.Input("string")], outputs=[io.Combo.Output()], ) @@ -321,7 +321,7 @@ class InvertBooleanNode(io.ComfyNode): node_id="InvertBooleanNode", search_aliases=["not", "toggle", "negate", "flip boolean"], display_name="Invert Boolean", - category="utils/logic", + category="utilities/logic", inputs=[io.Boolean.Input("boolean")], outputs=[io.Boolean.Output()], ) diff --git a/comfy_extras/nodes_lora_debug.py b/comfy_extras/nodes_lora_debug.py index 937a0fbfb..3f68064e5 100644 --- a/comfy_extras/nodes_lora_debug.py +++ b/comfy_extras/nodes_lora_debug.py @@ -30,7 +30,7 @@ class LoraLoaderBypass: OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.") FUNCTION = "load_lora" - CATEGORY = "loaders" + CATEGORY = "model/loaders" DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios." EXPERIMENTAL = True diff --git a/comfy_extras/nodes_lotus.py b/comfy_extras/nodes_lotus.py index 9f62ba2bf..9fe4c5c7b 100644 --- a/comfy_extras/nodes_lotus.py +++ b/comfy_extras/nodes_lotus.py @@ -10,7 +10,7 @@ class LotusConditioning(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LotusConditioning", - category="conditioning/lotus", + category="model/conditioning/lotus", inputs=[], outputs=[io.Conditioning.Output(display_name="conditioning")], ) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 51cf7951f..6d6078abe 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -25,7 +25,7 @@ class GetICLoRAParameters(io.ComfyNode): display_name="Get IC-LoRA Parameters", description="Extracts IC-LoRA parameters from the safetensors metadata of a LoRA-loaded " "model and outputs them for LTXVAddGuide (eg. reference_downscale_factor).", - category="conditioning/video_models", + category="model/conditioning/video_models", search_aliases=["ic-lora", "ic lora", "iclora", "downscale factor", "reference downscale"], inputs=[ io.Model.Input( @@ -62,7 +62,7 @@ class EmptyLTXVLatentVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="EmptyLTXVLatentVideo", - category="latent/video/ltxv", + category="model/latent/video/ltxv", inputs=[ io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32), io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32), @@ -86,7 +86,7 @@ class LTXVImgToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LTXVImgToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -131,7 +131,7 @@ class LTXVImgToVideoInplace(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LTXVImgToVideoInplace", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Vae.Input("vae"), io.Image.Input("image"), @@ -226,10 +226,20 @@ def get_noise_mask(latent): noise_mask = noise_mask.clone() return noise_mask -def get_keyframe_idxs(cond): +def get_keyframe_idxs(cond, latent_shape=None): keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None) if keyframe_idxs is None: return None, 0 + # Get number of keyframes from latent_shape or guide_attention_entries if available + if latent_shape is not None and len(latent_shape) == 5: + tokens_per_frame = latent_shape[-2] * latent_shape[-1] + num_keyframes = keyframe_idxs.shape[2] // tokens_per_frame + return keyframe_idxs, num_keyframes + entries = conditioning_get_any_value(cond, "guide_attention_entries", None) + if entries: + num_keyframes = sum(e["latent_shape"][0] for e in entries) + return keyframe_idxs, num_keyframes + # fallback, may under-count if keyframes share t-start # keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0] return keyframe_idxs, num_keyframes @@ -241,7 +251,7 @@ class LTXVAddGuide(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LTXVAddGuide", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -322,9 +332,9 @@ class LTXVAddGuide(io.ComfyNode): return factor @classmethod - def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors): + def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors, latent_shape=None): time_scale_factor, _, _ = scale_factors - _, num_keyframes = get_keyframe_idxs(cond) + _, num_keyframes = get_keyframe_idxs(cond, latent_shape) latent_count = latent_length - num_keyframes frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0) if guide_length > 1 and frame_idx != 0: @@ -436,7 +446,7 @@ class LTXVAddGuide(io.ComfyNode): num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1 resolved_frame_idx = frame_idx if frame_idx < 0: - _, num_keyframes = get_keyframe_idxs(positive) + _, num_keyframes = get_keyframe_idxs(positive, latent_image.shape) resolved_frame_idx = max((latent_length - num_keyframes - 1) * time_scale_factor + 1 + frame_idx, 0) causal_fix = resolved_frame_idx == 0 or num_frames_to_keep == 1 @@ -454,7 +464,7 @@ class LTXVAddGuide(io.ComfyNode): if latent_downscale_factor > 1: t, guide_mask = cls.dilate_latent(t, latent_downscale_factor) - frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors) + frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors, latent_shape=latent_image.shape) assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence." positive, negative, latent_image, noise_mask = cls.append_keyframe( @@ -488,7 +498,7 @@ class LTXVCropGuides(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LTXVCropGuides", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -506,7 +516,7 @@ class LTXVCropGuides(io.ComfyNode): latent_image = latent["samples"].clone() noise_mask = get_noise_mask(latent) - _, num_keyframes = get_keyframe_idxs(positive) + _, num_keyframes = get_keyframe_idxs(positive, latent_image.shape) if num_keyframes == 0: return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},) @@ -532,7 +542,7 @@ class LTXVConditioning(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LTXVConditioning", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -601,7 +611,7 @@ class LTXVScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LTXVScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=10000), io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01), @@ -736,7 +746,7 @@ class LTXVConcatAVLatent(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LTXVConcatAVLatent", - category="latent/video/ltxv", + category="model/latent/video/ltxv", inputs=[ io.Latent.Input("video_latent"), io.Latent.Input("audio_latent"), @@ -771,7 +781,7 @@ class LTXVSeparateAVLatent(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LTXVSeparateAVLatent", - category="latent/video/ltxv", + category="model/latent/video/ltxv", description="LTXV Separate AV Latent", inputs=[ io.Latent.Input("av_latent"), @@ -804,7 +814,7 @@ class LTXVReferenceAudio(io.ComfyNode): return io.Schema( node_id="LTXVReferenceAudio", display_name="LTXV Reference Audio (ID-LoRA)", - category="conditioning/audio", + category="model/conditioning/audio", description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).", inputs=[ io.Model.Input("model"), diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index 51ddf584a..052186083 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -12,7 +12,7 @@ class LTXVAudioVAELoader(io.ComfyNode): return io.Schema( node_id="LTXVAudioVAELoader", display_name="Load LTXV Audio VAE", - category="loaders", + category="model/loaders", inputs=[ io.Combo.Input( "ckpt_name", @@ -40,7 +40,7 @@ class LTXVAudioVAEEncode(VAEEncodeAudio): return io.Schema( node_id="LTXVAudioVAEEncode", display_name="LTXV Audio VAE Encode", - category="latent/audio", + category="model/latent/audio", inputs=[ io.Audio.Input("audio", tooltip="The audio to be encoded."), io.Vae.Input( @@ -63,7 +63,7 @@ class LTXVAudioVAEDecode(io.ComfyNode): return io.Schema( node_id="LTXVAudioVAEDecode", display_name="LTXV Audio VAE Decode", - category="latent/audio", + category="model/latent/audio", inputs=[ io.Latent.Input("samples", tooltip="The latent to be decoded."), io.Vae.Input( @@ -96,7 +96,7 @@ class LTXVEmptyLatentAudio(io.ComfyNode): return io.Schema( node_id="LTXVEmptyLatentAudio", display_name="LTXV Empty Latent Audio", - category="latent/audio", + category="model/latent/audio", inputs=[ io.Int.Input( "frames_number", diff --git a/comfy_extras/nodes_lt_upsampler.py b/comfy_extras/nodes_lt_upsampler.py index f99ba13fb..be9a36e69 100644 --- a/comfy_extras/nodes_lt_upsampler.py +++ b/comfy_extras/nodes_lt_upsampler.py @@ -1,32 +1,32 @@ from comfy import model_management +from comfy_api.latest import ComfyExtension, IO +from typing_extensions import override import math -class LTXVLatentUpsampler: + +class LTXVLatentUpsampler(IO.ComfyNode): """ Upsamples a video latent by a factor of 2. """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "samples": ("LATENT",), - "upscale_model": ("LATENT_UPSCALE_MODEL",), - "vae": ("VAE",), - } - } + def define_schema(cls): + return IO.Schema( + node_id="LTXVLatentUpsampler", + category="model/latent/video", + is_experimental=True, + inputs=[ + IO.Latent.Input("samples"), + IO.LatentUpscaleModel.Input("upscale_model"), + IO.Vae.Input("vae"), + ], + outputs=[ + IO.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "upsample_latent" - CATEGORY = "latent/video" - EXPERIMENTAL = True - - def upsample_latent( - self, - samples: dict, - upscale_model, - vae, - ) -> tuple: + @classmethod + def execute(cls, samples, upscale_model, vae) -> IO.NodeOutput: """ Upsample the input latent using the provided model. @@ -34,7 +34,6 @@ class LTXVLatentUpsampler: samples (dict): Input latent samples upscale_model (LatentUpsampler): Loaded upscale model vae: VAE model for normalization - auto_tiling (bool): Whether to automatically tile the input for processing Returns: tuple: Tuple containing the upsampled latent @@ -67,9 +66,16 @@ class LTXVLatentUpsampler: return_dict = samples.copy() return_dict["samples"] = upsampled_latents return_dict.pop("noise_mask", None) - return (return_dict,) + return IO.NodeOutput(return_dict) + + upsample_latent = execute # TODO: remove -NODE_CLASS_MAPPINGS = { - "LTXVLatentUpsampler": LTXVLatentUpsampler, -} +class LTXVLatentUpsamplerExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [LTXVLatentUpsampler] + + +async def comfy_entrypoint() -> LTXVLatentUpsamplerExtension: + return LTXVLatentUpsamplerExtension() diff --git a/comfy_extras/nodes_lumina2.py b/comfy_extras/nodes_lumina2.py index b35ab8b7d..c060a86a0 100644 --- a/comfy_extras/nodes_lumina2.py +++ b/comfy_extras/nodes_lumina2.py @@ -81,7 +81,7 @@ class CLIPTextEncodeLumina2(io.ComfyNode): node_id="CLIPTextEncodeLumina2", search_aliases=["lumina prompt"], display_name="CLIP Text Encode for Lumina2", - category="conditioning", + category="model/conditioning", description="Encodes a system prompt and a user prompt using a CLIP model into an embedding " "that can be used to guide the diffusion model towards generating specific images.", inputs=[ diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index d15f1f4e7..52484697a 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -53,7 +53,7 @@ class LatentCompositeMasked(IO.ComfyNode): return IO.Schema( node_id="LatentCompositeMasked", search_aliases=["overlay latent", "layer latent", "paste latent", "inpaint latent"], - category="latent", + category="model/latent", inputs=[ IO.Latent.Input("destination"), IO.Latent.Input("source"), diff --git a/comfy_extras/nodes_math.py b/comfy_extras/nodes_math.py index 0040d1a92..873ee7b51 100644 --- a/comfy_extras/nodes_math.py +++ b/comfy_extras/nodes_math.py @@ -69,7 +69,7 @@ class MathExpressionNode(io.ComfyNode): return io.Schema( node_id="ComfyMathExpression", display_name="Math Expression", - category="utils", + category="utilities", search_aliases=[ "expression", "formula", "calculate", "calculator", "eval", "math", diff --git a/comfy_extras/nodes_mediapipe.py b/comfy_extras/nodes_mediapipe.py index 32dc22de3..343d88dbb 100644 --- a/comfy_extras/nodes_mediapipe.py +++ b/comfy_extras/nodes_mediapipe.py @@ -205,7 +205,7 @@ class LoadMediaPipeFaceLandmarker(io.ComfyNode): node_id="LoadMediaPipeFaceLandmarker", search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"], display_name="Load Face Detection Model (MediaPipe)", - category="loaders", + category="model/loaders", inputs=[ io.Combo.Input("model_name", options=folder_paths.get_filename_list("detection"), tooltip="Face detection model from models/detection/."), diff --git a/comfy_extras/nodes_mochi.py b/comfy_extras/nodes_mochi.py index d750194fc..3dcea6ab3 100644 --- a/comfy_extras/nodes_mochi.py +++ b/comfy_extras/nodes_mochi.py @@ -10,7 +10,7 @@ class EmptyMochiLatentVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="EmptyMochiLatentVideo", - category="latent/video", + category="model/latent/video", inputs=[ io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py index 24d47a903..817542452 100644 --- a/comfy_extras/nodes_model_downscale.py +++ b/comfy_extras/nodes_model_downscale.py @@ -10,7 +10,7 @@ class PatchModelAddDownscale(io.ComfyNode): return io.Schema( node_id="PatchModelAddDownscale", display_name="PatchModelAddDownscale (Kohya Deep Shrink)", - category="model_patches/unet", + category="model/patch/unet", inputs=[ io.Model.Input("model"), io.Int.Input("block_number", default=3, min=1, max=32, step=1, advanced=True), diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 748559a6b..bdccbf8c4 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -548,7 +548,7 @@ class USOStyleReference: FUNCTION = "apply_patch" EXPERIMENTAL = True - CATEGORY = "advanced/model_patches/flux" + CATEGORY = "model/patch/flux" def apply_patch(self, model, model_patch, clip_vision_output): encoded_image = torch.stack((clip_vision_output.all_hidden_states[:, -20], clip_vision_output.all_hidden_states[:, -11], clip_vision_output.penultimate_hidden_states)) @@ -594,7 +594,7 @@ class SUPIRApply(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="SUPIRApply", - category="model_patches/supir", + category="model/patch/supir", is_experimental=True, inputs=[ io.Model.Input("model"), diff --git a/comfy_extras/nodes_moge.py b/comfy_extras/nodes_moge.py index 79aec5d7f..422949531 100644 --- a/comfy_extras/nodes_moge.py +++ b/comfy_extras/nodes_moge.py @@ -78,7 +78,7 @@ class LoadMoGeModel(io.ComfyNode): return io.Schema( node_id="LoadMoGeModel", display_name="Load MoGe Model", - category="loaders", + category="model/loaders", inputs=[ io.Combo.Input("model_name", options=folder_paths.get_filename_list("geometry_estimation")), ], @@ -104,7 +104,7 @@ class MoGePanoramaInference(io.ComfyNode): node_id="MoGePanoramaInference", search_aliases=["moge", "panorama", "depth", "geometry", "depth estimation", "geometry estimation"], display_name="Run MoGe Panorama Inference", - category="image/geometry_estimation", + category="image/geometry estimation", description="Run MoGe on an equirectangular panorama by splitting it into 12 perspective views, running inference on each, and merging the results into a single depth map.", inputs=[ MoGeModelType.Input("moge_model"), @@ -226,7 +226,7 @@ class MoGeInference(io.ComfyNode): search_aliases=["moge", "depth", "geometry", "depth estimation", "geometry estimation"], display_name="Run MoGe Inference", description="Run MoGe on a single image to estimate depth and geometry.", - category="image/geometry_estimation", + category="image/geometry estimation", inputs=[ MoGeModelType.Input("moge_model"), io.Image.Input("image"), @@ -283,7 +283,7 @@ class MoGeRender(io.ComfyNode): search_aliases=["moge", "render", "geometry", "depth", "normal"], display_name="Render MoGe Geometry", description="Render a depth map or normal map from geometry data", - category="image/geometry_estimation", + category="image/geometry estimation", inputs=[ MoGeGeometry.Input("moge_geometry"), io.Combo.Input("output", options=["depth", "depth_colored", "normal_opengl", "normal_directx", "mask"], default="depth", @@ -350,7 +350,7 @@ class MoGePointMapToMesh(io.ComfyNode): search_aliases=["moge", "mesh", "geometry", "point map"], display_name="Convert MoGe Point Map to Mesh", description="Convert a MoGe point map into a 3D mesh.", - category="image/geometry_estimation", + category="image/geometry estimation", inputs=[ MoGeGeometry.Input("moge_geometry"), io.Int.Input("batch_index", default=0, min=0, max=4096, diff --git a/comfy_extras/nodes_number_convert.py b/comfy_extras/nodes_number_convert.py index 01593b6e6..d7e557e95 100644 --- a/comfy_extras/nodes_number_convert.py +++ b/comfy_extras/nodes_number_convert.py @@ -20,7 +20,7 @@ class NumberConvertNode(io.ComfyNode): return io.Schema( node_id="ComfyNumberConvert", display_name="Convert Number", - category="utils", + category="utilities", search_aliases=[ "int to float", "float to int", "number convert", "int2float", "float2int", "cast", "parse number", diff --git a/comfy_extras/nodes_optimalsteps.py b/comfy_extras/nodes_optimalsteps.py index 5beeaa7db..19629790f 100644 --- a/comfy_extras/nodes_optimalsteps.py +++ b/comfy_extras/nodes_optimalsteps.py @@ -31,7 +31,7 @@ class OptimalStepsScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="OptimalStepsScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Combo.Input("model_type", options=["FLUX", "Wan", "Chroma"]), io.Int.Input("steps", default=20, min=3, max=1000), diff --git a/comfy_extras/nodes_pag.py b/comfy_extras/nodes_pag.py index 79fea5f0c..c875e1e06 100644 --- a/comfy_extras/nodes_pag.py +++ b/comfy_extras/nodes_pag.py @@ -15,7 +15,7 @@ class PerturbedAttentionGuidance(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="PerturbedAttentionGuidance", - category="model_patches/unet", + category="model/patch/unet", inputs=[ io.Model.Input("model"), io.Float.Input("scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01), diff --git a/comfy_extras/nodes_pid.py b/comfy_extras/nodes_pid.py new file mode 100644 index 000000000..811b9ae8e --- /dev/null +++ b/comfy_extras/nodes_pid.py @@ -0,0 +1,55 @@ +"""PiD (Pixel Diffusion Decoder) node""" + +import torch +from typing_extensions import override + +import node_helpers +import comfy.latent_formats +from comfy_api.latest import ComfyExtension, io + + +class PiDConditioning(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="PiDConditioning", + display_name="PiD Conditioning", + category="advanced/conditioning", + description=( + "Attaches a latent and a degrade_sigma scalar to a CONDITIONING for PiD decoding/upscaling" + ), + inputs=[ + io.Conditioning.Input("positive"), + io.Latent.Input("latent", tooltip="latent (from VAEEncode or a KSampler)."), + io.Combo.Input("latent_format", options=["flux", "sd3"], default="flux", + tooltip="Flux1 and Flux2 latents auto-detected from channel dim, sd3 has to be selected manually."), + io.Float.Input( + "degrade_sigma", default=0.0, min=0.0, max=1.0, step=0.01, + tooltip="0 = clean latent. Increase to denoise corrupted latent outputs.", + ), + ], + outputs=[io.Conditioning.Output()], + ) + + @classmethod + def execute(cls, positive, latent, latent_format: str, degrade_sigma: float) -> io.NodeOutput: + samples = latent["samples"] + if latent_format == "flux": + fmt_cls = comfy.latent_formats.Flux2 if samples.shape[1] == 128 else comfy.latent_formats.Flux + else: + fmt_cls = comfy.latent_formats.SD3 + lq_latent = fmt_cls().process_in(samples) + sigma_t = torch.tensor([float(degrade_sigma)], dtype=torch.float32) + return io.NodeOutput(node_helpers.conditioning_set_values( + positive, {"lq_latent": lq_latent, "degrade_sigma": sigma_t}, + )) + + +class PiDExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [PiDConditioning] + + +async def comfy_entrypoint() -> PiDExtension: + return PiDExtension() diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index a25db277c..3e440433e 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -616,7 +616,7 @@ class BatchLatentsNode(io.ComfyNode): node_id="BatchLatentsNode", search_aliases=["combine latents", "stack latents", "merge latents"], display_name="Batch Latents", - category="latent", + category="model/latent", inputs=[ io.Autogrow.Input("latents", template=autogrow_template) ], diff --git a/comfy_extras/nodes_preview_any.py b/comfy_extras/nodes_preview_any.py index 17e25d514..1070a69d0 100644 --- a/comfy_extras/nodes_preview_any.py +++ b/comfy_extras/nodes_preview_any.py @@ -16,7 +16,7 @@ class PreviewAny(): FUNCTION = "main" OUTPUT_NODE = True - CATEGORY = "utils" + CATEGORY = "utilities" SEARCH_ALIASES = ["show output", "inspect", "debug", "print value", "show text"] def main(self, source=None): diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index 33373266b..c44b09098 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -11,7 +11,7 @@ class String(io.ComfyNode): node_id="PrimitiveString", search_aliases=["text", "string", "text box", "prompt"], display_name="Text String", - category="utils/primitive", + category="utilities/primitive", inputs=[ io.String.Input("value"), ], @@ -30,7 +30,7 @@ class StringMultiline(io.ComfyNode): node_id="PrimitiveStringMultiline", search_aliases=["text", "string", "text multiline", "string multiline", "text box", "prompt"], display_name="Text String (Multiline)", - category="utils/primitive", + category="utilities/primitive", essentials_category="Basics", inputs=[ io.String.Input("value", multiline=True), @@ -49,7 +49,7 @@ class Int(io.ComfyNode): return io.Schema( node_id="PrimitiveInt", display_name="Int", - category="utils/primitive", + category="utilities/primitive", inputs=[ io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed), ], @@ -67,7 +67,7 @@ class Float(io.ComfyNode): return io.Schema( node_id="PrimitiveFloat", display_name="Float", - category="utils/primitive", + category="utilities/primitive", inputs=[ io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize, step=0.1), ], @@ -85,7 +85,7 @@ class Boolean(io.ComfyNode): return io.Schema( node_id="PrimitiveBoolean", display_name="Boolean", - category="utils/primitive", + category="utilities/primitive", inputs=[ io.Boolean.Input("value"), ], diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index fde8fac9a..5b92814a4 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -112,7 +112,7 @@ class EmptyQwenImageLayeredLatentImage(io.ComfyNode): return io.Schema( node_id="EmptyQwenImageLayeredLatentImage", display_name="Empty Qwen Image Layered Latent", - category="latent/qwen", + category="model/latent/qwen", inputs=[ io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16), io.Int.Input("height", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16), diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py index 5f4e82aef..2185385f0 100644 --- a/comfy_extras/nodes_rebatch.py +++ b/comfy_extras/nodes_rebatch.py @@ -10,7 +10,7 @@ class LatentRebatch(io.ComfyNode): return io.Schema( node_id="RebatchLatents", display_name="Rebatch Latents", - category="latent/batch", + category="model/latent/batch", is_input_list=True, inputs=[ io.Latent.Input("latents"), diff --git a/comfy_extras/nodes_resolution.py b/comfy_extras/nodes_resolution.py index 1628038cc..dc405291c 100644 --- a/comfy_extras/nodes_resolution.py +++ b/comfy_extras/nodes_resolution.py @@ -35,7 +35,7 @@ class ResolutionSelector(io.ComfyNode): return io.Schema( node_id="ResolutionSelector", display_name="Resolution Selector", - category="utils", + category="utilities", description="Calculate width and height from aspect ratio and megapixel target. Useful for setting up Empty Latent Image dimensions.", inputs=[ io.Combo.Input( diff --git a/comfy_extras/nodes_rope.py b/comfy_extras/nodes_rope.py index 918ddc02b..808eee29b 100644 --- a/comfy_extras/nodes_rope.py +++ b/comfy_extras/nodes_rope.py @@ -7,7 +7,7 @@ class ScaleROPE(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ScaleROPE", - category="advanced/model_patches", + category="model/patch", description="Scale and shift the ROPE of the model.", is_experimental=True, inputs=[ diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index 6655c1ba7..38cbf117b 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -41,7 +41,7 @@ class EmptySD3LatentImage(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="EmptySD3LatentImage", - category="latent/sd3", + category="model/latent/sd3", inputs=[ io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), @@ -113,7 +113,7 @@ class ControlNetApplySD3(io.ComfyNode): return io.Schema( node_id="ControlNetApplySD3", display_name="Apply Controlnet with VAE", - category="conditioning/controlnet", + category="model/conditioning/controlnet", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), diff --git a/comfy_extras/nodes_sdupscale.py b/comfy_extras/nodes_sdupscale.py index 5877719d3..ea283e971 100644 --- a/comfy_extras/nodes_sdupscale.py +++ b/comfy_extras/nodes_sdupscale.py @@ -9,7 +9,7 @@ class SD_4XUpscale_Conditioning(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SD_4XUpscale_Conditioning", - category="conditioning/upscale_diffusion", + category="model/conditioning/upscale_diffusion", inputs=[ io.Image.Input("images"), io.Conditioning.Input("positive"), diff --git a/comfy_extras/nodes_stable3d.py b/comfy_extras/nodes_stable3d.py index 829c837a1..8a6e5b726 100644 --- a/comfy_extras/nodes_stable3d.py +++ b/comfy_extras/nodes_stable3d.py @@ -27,7 +27,7 @@ class StableZero123_Conditioning(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StableZero123_Conditioning", - category="conditioning/3d_models", + category="model/conditioning/3d_models", inputs=[ io.ClipVision.Input("clip_vision"), io.Image.Input("init_image"), @@ -65,7 +65,7 @@ class StableZero123_Conditioning_Batched(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StableZero123_Conditioning_Batched", - category="conditioning/3d_models", + category="model/conditioning/3d_models", inputs=[ io.ClipVision.Input("clip_vision"), io.Image.Input("init_image"), @@ -112,7 +112,7 @@ class SV3D_Conditioning(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SV3D_Conditioning", - category="conditioning/3d_models", + category="model/conditioning/3d_models", inputs=[ io.ClipVision.Input("clip_vision"), io.Image.Input("init_image"), diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py index 0dc6c9fcd..e55f248ae 100644 --- a/comfy_extras/nodes_stable_cascade.py +++ b/comfy_extras/nodes_stable_cascade.py @@ -29,7 +29,7 @@ class StableCascade_EmptyLatentImage(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StableCascade_EmptyLatentImage", - category="latent/stable_cascade", + category="model/latent/stable_cascade", inputs=[ io.Int.Input("width", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8), io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8), @@ -58,7 +58,7 @@ class StableCascade_StageC_VAEEncode(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StableCascade_StageC_VAEEncode", - category="latent/stable_cascade", + category="model/latent/stable_cascade", inputs=[ io.Image.Input("image"), io.Vae.Input("vae"), @@ -93,7 +93,7 @@ class StableCascade_StageB_Conditioning(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StableCascade_StageB_Conditioning", - category="conditioning/stable_cascade", + category="model/conditioning/stable_cascade", inputs=[ io.Conditioning.Input("conditioning"), io.Latent.Input("stage_c"), diff --git a/comfy_extras/nodes_tomesd.py b/comfy_extras/nodes_tomesd.py index 87bf29b8f..3667fac3a 100644 --- a/comfy_extras/nodes_tomesd.py +++ b/comfy_extras/nodes_tomesd.py @@ -151,7 +151,7 @@ class TomePatchModel(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="TomePatchModel", - category="model_patches/unet", + category="model/patch/unet", inputs=[ io.Model.Input("model"), io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01), diff --git a/comfy_extras/nodes_toolkit.py b/comfy_extras/nodes_toolkit.py index 0548a0cf8..9f709bbe3 100644 --- a/comfy_extras/nodes_toolkit.py +++ b/comfy_extras/nodes_toolkit.py @@ -13,7 +13,7 @@ class CreateList(io.ComfyNode): return io.Schema( node_id="CreateList", display_name="Create List", - category="utils", + category="utilities", is_input_list=True, search_aliases=["Image Iterator", "Text Iterator", "Iterator"], inputs=[io.Autogrow.Input("inputs", template=template_autogrow)], diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index e9871369b..046eeaaf5 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -951,7 +951,7 @@ class TrainLoraNode(io.ComfyNode): return io.Schema( node_id="TrainLoraNode", display_name="Train LoRA", - category="training", + category="model/training", is_experimental=True, is_input_list=True, # All inputs become lists inputs=[ @@ -1309,7 +1309,7 @@ class LoraModelLoader(io.ComfyNode): return io.Schema( node_id="LoraModelLoader", display_name="Load LoRA Model", - category="loaders", + category="model/loaders", is_experimental=True, inputs=[ io.Model.Input( @@ -1405,7 +1405,7 @@ class LossGraphNode(io.ComfyNode): node_id="LossGraphNode", search_aliases=["training chart", "training visualization", "plot loss"], display_name="Plot Loss Graph", - category="training", + category="model/training", is_experimental=True, is_output_node=True, inputs=[ diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index d3ee3f1c1..1cf5a5d01 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -22,7 +22,7 @@ class UpscaleModelLoader(io.ComfyNode): return io.Schema( node_id="UpscaleModelLoader", display_name="Load Upscale Model", - category="loaders", + category="model/loaders", inputs=[ io.Combo.Input("model_name", options=folder_paths.get_filename_list("upscale_models")), ], diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py index 8f19895a1..0d6cae6a8 100644 --- a/comfy_extras/nodes_video_model.py +++ b/comfy_extras/nodes_video_model.py @@ -15,7 +15,7 @@ class ImageOnlyCheckpointLoader: RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE") FUNCTION = "load_checkpoint" - CATEGORY = "loaders" + CATEGORY = "model/loaders" def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) @@ -41,7 +41,7 @@ class SVD_img2vid_Conditioning: FUNCTION = "encode" - CATEGORY = "conditioning/video_models" + CATEGORY = "model/conditioning/video_models" def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level): output = clip_vision.encode_image(init_image) @@ -65,7 +65,7 @@ class VideoLinearCFGGuidance: RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "sampling/guiders" + CATEGORY = "model/sampling/guiders" def patch(self, model, min_cfg): def linear_cfg(args): @@ -89,7 +89,7 @@ class VideoTriangleCFGGuidance: RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "sampling/guiders" + CATEGORY = "model/sampling/guiders" def patch(self, model, min_cfg): def linear_cfg(args): @@ -138,7 +138,7 @@ class ConditioningSetAreaPercentageVideo: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def append(self, conditioning, width, height, temporal, x, y, z, strength): c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", temporal, height, width, z, y, x), diff --git a/comfy_extras/nodes_void.py b/comfy_extras/nodes_void.py index be724371a..b43154b8d 100644 --- a/comfy_extras/nodes_void.py +++ b/comfy_extras/nodes_void.py @@ -58,7 +58,7 @@ class OpticalFlowLoader(io.ComfyNode): return io.Schema( node_id="OpticalFlowLoader", display_name="Load Optical Flow Model", - category="loaders", + category="model/loaders", inputs=[ io.Combo.Input( "model_name", @@ -175,7 +175,7 @@ class VOIDInpaintConditioning(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="VOIDInpaintConditioning", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -288,7 +288,7 @@ class VOIDWarpedNoise(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="VOIDWarpedNoise", - category="latent/video", + category="model/latent/video", inputs=[ OpticalFlow.Input( "optical_flow", @@ -393,7 +393,7 @@ class VOIDWarpedNoiseSource(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="VOIDWarpedNoiseSource", - category="sampling/noise", + category="model/sampling/noise", inputs=[ io.Latent.Input("warped_noise", tooltip="Warped noise latent from VOIDWarpedNoise"), @@ -455,7 +455,7 @@ class VOIDSampler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="VOIDSampler", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[], outputs=[io.Sampler.Output()], ) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index e50bfcd2c..67d3a8443 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -18,7 +18,7 @@ class WanImageToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanImageToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -66,7 +66,7 @@ class WanFunControlToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanFunControlToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -119,7 +119,7 @@ class Wan22FunControlToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Wan22FunControlToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -184,7 +184,7 @@ class WanFirstLastFrameToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanFirstLastFrameToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -256,7 +256,7 @@ class WanFunInpaintToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanFunInpaintToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -288,7 +288,7 @@ class WanVaceToVideo(io.ComfyNode): return io.Schema( node_id="WanVaceToVideo", search_aliases=["video conditioning", "video control"], - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -375,7 +375,7 @@ class TrimVideoLatent(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="TrimVideoLatent", - category="latent/video", + category="model/latent/video", inputs=[ io.Latent.Input("samples"), io.Int.Input("trim_amount", default=0, min=0, max=99999), @@ -398,7 +398,7 @@ class WanCameraImageToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanCameraImageToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -452,7 +452,7 @@ class WanPhantomSubjectToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanPhantomSubjectToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -707,7 +707,7 @@ class WanTrackToVideo(io.ComfyNode): return io.Schema( node_id="WanTrackToVideo", search_aliases=["motion tracking", "trajectory video", "point tracking", "keypoint animation"], - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -951,7 +951,7 @@ class WanSoundImageToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanSoundImageToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -984,7 +984,7 @@ class WanSoundImageToVideoExtend(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanSoundImageToVideoExtend", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -1046,7 +1046,7 @@ class WanHuMoImageToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanHuMoImageToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -1112,7 +1112,7 @@ class WanAnimateToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanAnimateToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -1252,7 +1252,7 @@ class Wan22ImageToVideoLatent(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Wan22ImageToVideoLatent", - category="conditioning/inpaint", + category="model/conditioning/inpaint", inputs=[ io.Vae.Input("vae"), io.Int.Input("width", default=1280, min=32, max=nodes.MAX_RESOLUTION, step=32), @@ -1302,7 +1302,7 @@ class WanInfiniteTalkToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanInfiniteTalkToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.DynamicCombo.Input("mode", options=[ io.DynamicCombo.Option("single_speaker", []), @@ -1461,7 +1461,7 @@ class WanSCAILToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanSCAILToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), diff --git a/comfy_extras/nodes_wandancer.py b/comfy_extras/nodes_wandancer.py index fc005ed4c..a96885745 100644 --- a/comfy_extras/nodes_wandancer.py +++ b/comfy_extras/nodes_wandancer.py @@ -713,7 +713,7 @@ class WanDancerEncodeAudio(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanDancerEncodeAudio", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Audio.Input("audio"), io.Int.Input("video_frames", default=149, min=1, max=nodes.MAX_RESOLUTION, step=4), @@ -787,7 +787,7 @@ class WanDancerVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanDancerVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), diff --git a/comfy_extras/nodes_wanmove.py b/comfy_extras/nodes_wanmove.py index 5acae03eb..2db064922 100644 --- a/comfy_extras/nodes_wanmove.py +++ b/comfy_extras/nodes_wanmove.py @@ -247,7 +247,7 @@ class WanMoveVisualizeTracks(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanMoveVisualizeTracks", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Image.Input("images"), io.Tracks.Input("tracks", optional=True), @@ -283,7 +283,7 @@ class WanMoveTracksFromCoords(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanMoveTracksFromCoords", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.String.Input("track_coords", force_input=True, default="[]", optional=True), io.Mask.Input("track_mask", optional=True), @@ -325,7 +325,7 @@ class GenerateTracks(io.ComfyNode): return io.Schema( node_id="GenerateTracks", search_aliases=["motion paths", "camera movement", "trajectory"], - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Int.Input("width", default=832, min=16, max=4096, step=16), io.Int.Input("height", default=480, min=16, max=4096, step=16), @@ -434,7 +434,7 @@ class WanMoveConcatTrack(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanMoveConcatTrack", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Tracks.Input("tracks_1"), io.Tracks.Input("tracks_2", optional=True), @@ -463,7 +463,7 @@ class WanMoveTrackToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanMoveTrackToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), diff --git a/nodes.py b/nodes.py index a3d5af27f..a5d1b16cd 100644 --- a/nodes.py +++ b/nodes.py @@ -72,7 +72,7 @@ class CLIPTextEncode(ComfyNodeABC): OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",) FUNCTION = "encode" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images." SEARCH_ALIASES = ["text", "prompt", "text prompt", "positive prompt", "negative prompt", "encode text", "text encoder", "encode prompt"] @@ -91,7 +91,7 @@ class ConditioningCombine: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "combine" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" SEARCH_ALIASES = ["combine", "merge conditioning", "combine prompts", "merge prompts", "mix prompts", "add prompt"] def combine(self, conditioning_1, conditioning_2): @@ -108,7 +108,7 @@ class ConditioningAverage : RETURN_TYPES = ("CONDITIONING",) FUNCTION = "addWeighted" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_strength): out = [] @@ -147,7 +147,7 @@ class ConditioningConcat: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "concat" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def concat(self, conditioning_to, conditioning_from): out = [] @@ -180,7 +180,7 @@ class ConditioningSetArea: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def append(self, conditioning, width, height, x, y, strength): c = node_helpers.conditioning_set_values(conditioning, {"area": (height // 8, width // 8, y // 8, x // 8), @@ -201,7 +201,7 @@ class ConditioningSetAreaPercentage: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def append(self, conditioning, width, height, x, y, strength): c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", height, width, y, x), @@ -218,7 +218,7 @@ class ConditioningSetAreaStrength: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def append(self, conditioning, strength): c = node_helpers.conditioning_set_values(conditioning, {"strength": strength}) @@ -238,7 +238,7 @@ class ConditioningSetMask: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def append(self, conditioning, mask, set_cond_area, strength): set_area_to_bounds = False @@ -307,7 +307,7 @@ class VAEDecode: OUTPUT_TOOLTIPS = ("The decoded image.",) FUNCTION = "decode" - CATEGORY = "latent" + CATEGORY = "model/latent" DESCRIPTION = "Decodes latent images back into pixel space images." SEARCH_ALIASES = ["decode", "decode latent", "latent to image", "render latent"] @@ -375,7 +375,7 @@ class VAEEncode: RETURN_TYPES = ("LATENT",) FUNCTION = "encode" - CATEGORY = "latent" + CATEGORY = "model/latent" SEARCH_ALIASES = ["encode", "encode image", "image to latent"] def encode(self, vae, pixels): @@ -410,7 +410,7 @@ class VAEEncodeForInpaint: RETURN_TYPES = ("LATENT",) FUNCTION = "encode" - CATEGORY = "latent/inpaint" + CATEGORY = "model/latent/inpaint" def encode(self, vae, pixels, mask, grow_mask_by=6): downscale_ratio = vae.spacial_compression_encode() @@ -459,7 +459,7 @@ class InpaintModelConditioning: RETURN_NAMES = ("positive", "negative", "latent") FUNCTION = "encode" - CATEGORY = "conditioning/inpaint" + CATEGORY = "model/conditioning/inpaint" def encode(self, positive, negative, pixels, vae, mask, noise_mask=True): x = (pixels.shape[1] // 8) * 8 @@ -619,7 +619,7 @@ class CheckpointLoaderSimple: "The VAE model used for encoding and decoding images to and from latent space.") FUNCTION = "load_checkpoint" - CATEGORY = "loaders" + CATEGORY = "model/loaders" DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents." SEARCH_ALIASES = ["load model", "checkpoint", "model loader", "load checkpoint", "ckpt", "model"] @@ -665,7 +665,7 @@ class unCLIPCheckpointLoader: RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION") FUNCTION = "load_checkpoint" - CATEGORY = "loaders" + CATEGORY = "model/loaders" def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) @@ -681,7 +681,7 @@ class CLIPSetLastLayer: RETURN_TYPES = ("CLIP",) FUNCTION = "set_last_layer" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def set_last_layer(self, clip, stop_at_clip_layer): clip = clip.clone() @@ -710,7 +710,7 @@ class LoraLoader: OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.") FUNCTION = "load_lora" - CATEGORY = "loaders" + CATEGORY = "model/loaders" DESCRIPTION = "This LoRA loader is used to modify both diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together." SEARCH_ALIASES = ["lora", "load lora", "apply lora", "lora loader", "lora model"] @@ -810,7 +810,7 @@ class VAELoader: RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" - CATEGORY = "loaders" + CATEGORY = "model/loaders" #TODO: scale factor? def load_vae(self, vae_name): @@ -852,7 +852,7 @@ class ControlNetLoader: RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" - CATEGORY = "loaders" + CATEGORY = "model/loaders" SEARCH_ALIASES = ["controlnet", "control net", "cn", "load controlnet", "controlnet loader"] def load_controlnet(self, control_net_name): @@ -871,7 +871,7 @@ class DiffControlNetLoader: RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" - CATEGORY = "loaders" + CATEGORY = "model/loaders" def load_controlnet(self, model, control_net_name): controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name) @@ -891,7 +891,7 @@ class ControlNetApply: FUNCTION = "apply_controlnet" DEPRECATED = True - CATEGORY = "conditioning/controlnet" + CATEGORY = "model/conditioning/controlnet" def apply_controlnet(self, conditioning, control_net, image, strength): if strength == 0: @@ -929,7 +929,7 @@ class ControlNetApplyAdvanced: RETURN_NAMES = ("positive", "negative") FUNCTION = "apply_controlnet" - CATEGORY = "conditioning/controlnet" + CATEGORY = "model/conditioning/controlnet" SEARCH_ALIASES = ["controlnet", "apply controlnet", "use controlnet", "control net"] def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]): @@ -990,7 +990,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -1000,7 +1000,7 @@ class CLIPLoader: CATEGORY = "advanced/loaders" - DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B" + DESCRIPTION = "[Recipes]\n\nstable_diffusion: clip-l\nstable_cascade: clip-g\nsd3: t5 xxl/ clip-g / clip-l\nstable_audio: t5 base\nmochi: t5 xxl\ncogvideox: t5 xxl (226-token padding)\ncosmos: old t5 xxl\nlumina2: gemma 2 2B\nwan: umt5 xxl\n hidream: llama-3.1 (Recommend) or t5\nomnigen2: qwen vl 2.5 3B\nlens: gpt-oss-20b\n pixeldit: gemma 2 2B elm" def load_clip(self, clip_name, type="stable_diffusion", device="default"): clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) @@ -1051,7 +1051,7 @@ class CLIPVisionLoader: RETURN_TYPES = ("CLIP_VISION",) FUNCTION = "load_clip" - CATEGORY = "loaders" + CATEGORY = "model/loaders" def load_clip(self, clip_name): clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name) @@ -1070,7 +1070,7 @@ class CLIPVisionEncode: RETURN_TYPES = ("CLIP_VISION_OUTPUT",) FUNCTION = "encode" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def encode(self, clip_vision, image, crop): crop_image = True @@ -1087,7 +1087,7 @@ class StyleModelLoader: RETURN_TYPES = ("STYLE_MODEL",) FUNCTION = "load_style_model" - CATEGORY = "loaders" + CATEGORY = "model/loaders" def load_style_model(self, style_model_name): style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name) @@ -1109,7 +1109,7 @@ class StyleModelApply: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_stylemodel" - CATEGORY = "conditioning/style_model" + CATEGORY = "model/conditioning/style_model" def apply_stylemodel(self, conditioning, style_model, clip_vision_output, strength, strength_type): cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0) @@ -1169,7 +1169,7 @@ class unCLIPConditioning: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_adm" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation): if strength == 0: @@ -1186,7 +1186,7 @@ class GLIGENLoader: RETURN_TYPES = ("GLIGEN",) FUNCTION = "load_gligen" - CATEGORY = "loaders" + CATEGORY = "model/loaders" def load_gligen(self, gligen_name): gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name) @@ -1208,7 +1208,7 @@ class GLIGENTextBoxApply: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" - CATEGORY = "conditioning/gligen" + CATEGORY = "model/conditioning/gligen" def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y): c = [] @@ -1238,7 +1238,7 @@ class EmptyLatentImage: OUTPUT_TOOLTIPS = ("The empty latent image batch.",) FUNCTION = "generate" - CATEGORY = "latent" + CATEGORY = "model/latent" DESCRIPTION = "Create a new batch of empty latent images to be denoised via sampling." SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"] @@ -1259,7 +1259,7 @@ class LatentFromBatch: RETURN_TYPES = ("LATENT",) FUNCTION = "frombatch" - CATEGORY = "latent/batch" + CATEGORY = "model/latent/batch" def frombatch(self, samples, batch_index, length): s = samples.copy() @@ -1294,7 +1294,7 @@ class RepeatLatentBatch: RETURN_TYPES = ("LATENT",) FUNCTION = "repeat" - CATEGORY = "latent/batch" + CATEGORY = "model/latent/batch" def repeat(self, samples, amount): s = samples.copy() @@ -1326,7 +1326,7 @@ class LatentUpscale: RETURN_TYPES = ("LATENT",) FUNCTION = "upscale" - CATEGORY = "latent" + CATEGORY = "model/latent" def upscale(self, samples, upscale_method, width, height, crop): if width == 0 and height == 0: @@ -1359,7 +1359,7 @@ class LatentUpscaleBy: RETURN_TYPES = ("LATENT",) FUNCTION = "upscale" - CATEGORY = "latent" + CATEGORY = "model/latent" def upscale(self, samples, upscale_method, scale_by): s = samples.copy() @@ -1377,7 +1377,7 @@ class LatentRotate: RETURN_TYPES = ("LATENT",) FUNCTION = "rotate" - CATEGORY = "latent/transform" + CATEGORY = "model/latent/transform" def rotate(self, samples, rotation): s = samples.copy() @@ -1403,7 +1403,7 @@ class LatentFlip: RETURN_TYPES = ("LATENT",) FUNCTION = "flip" - CATEGORY = "latent/transform" + CATEGORY = "model/latent/transform" def flip(self, samples, flip_method): s = samples.copy() @@ -1428,7 +1428,7 @@ class LatentComposite: RETURN_TYPES = ("LATENT",) FUNCTION = "composite" - CATEGORY = "latent" + CATEGORY = "model/latent" def composite(self, samples_to, samples_from, x, y, composite_method="normal", feather=0): x = x // 8 @@ -1515,7 +1515,7 @@ class LatentCrop: RETURN_TYPES = ("LATENT",) FUNCTION = "crop" - CATEGORY = "latent/transform" + CATEGORY = "model/latent/transform" def crop(self, samples, width, height, x, y): s = samples.copy() @@ -1545,7 +1545,7 @@ class SetLatentNoiseMask: RETURN_TYPES = ("LATENT",) FUNCTION = "set_mask" - CATEGORY = "latent/inpaint" + CATEGORY = "model/latent/inpaint" def set_mask(self, samples, mask): s = samples.copy() @@ -1599,7 +1599,7 @@ class KSampler: OUTPUT_TOOLTIPS = ("The denoised latent.",) FUNCTION = "sample" - CATEGORY = "sampling" + CATEGORY = "model/sampling" DESCRIPTION = "Uses the provided model, positive and negative conditioning to denoise the latent image." SEARCH_ALIASES = ["sampler", "sample", "generate", "denoise", "diffuse", "txt2img", "img2img"] @@ -1629,7 +1629,7 @@ class KSamplerAdvanced: RETURN_TYPES = ("LATENT",) FUNCTION = "sample" - CATEGORY = "sampling" + CATEGORY = "model/sampling" def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0): force_full_denoise = True @@ -2442,6 +2442,7 @@ async def init_builtin_extra_nodes(): "nodes_context_windows.py", "nodes_qwen.py", "nodes_chroma_radiance.py", + "nodes_pid.py", "nodes_model_patch.py", "nodes_easycache.py", "nodes_audio_encoder.py", diff --git a/openapi.yaml b/openapi.yaml index 502e518c7..f801a39d9 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -275,7 +275,10 @@ paths: responses: "200": description: Queue updated - + content: + application/json: + schema: + $ref: "#/components/schemas/QueueManageResponse" '400': description: Invalid request parameters content: @@ -3092,18 +3095,34 @@ paths: application/json: schema: type: object - required: - - asset_ids properties: + job_ids: + type: array + items: + type: string + description: Job IDs whose associated assets should all be included in the ZIP bundle. asset_ids: type: array items: type: string format: uuid - description: IDs of assets to export + description: Asset IDs to include in the ZIP bundle. Additive to assets associated with provided job IDs. export_name: type: string description: Name for the export archive + naming_strategy: + type: string + enum: [group_by_job_id, preserve, asset_id, group_by_job_time] + default: group_by_job_time + description: "Strategy for naming files in the ZIP: group by job ID, preserve original names, use the asset ID, or group by job creation time." + job_asset_name_filters: + type: object + additionalProperties: + type: array + minItems: 1 + items: + type: string + description: Optional per-job asset name filters. When provided for a job ID, only assets whose name matches one of the listed names are included. responses: "202": description: Export task accepted @@ -3575,10 +3594,7 @@ paths: content: application/json: schema: - type: array - items: - $ref: "#/components/schemas/HubLabel" - + $ref: "#/components/schemas/HubLabelListResponse" '400': description: Bad request (e.g. invalid type parameter) content: @@ -7466,6 +7482,25 @@ components: type: string description: Array of prompt IDs to delete from queue + QueueManageResponse: + type: object + x-runtime: [cloud] + description: >- + [cloud-only] Result of a queue mutation. The Cloud runtime returns which + items were deleted and whether the queue was cleared; local ComfyUI + returns an empty 200 body. + properties: + deleted: + type: array + nullable: true + items: + type: string + description: Prompt IDs that were deleted from the queue. + cleared: + type: boolean + nullable: true + description: Whether the queue was cleared. + # ------------------------------------------------------------------- # History # ------------------------------------------------------------------- @@ -7546,6 +7581,16 @@ components: outputs_count: type: integer description: Total number of output files + workflow_id: + type: string + nullable: true + x-runtime: [cloud] + description: "[cloud-only] UUID of the Cloud workflow entity this job is associated with. Local ComfyUI returns null." + execution_error: + x-runtime: [cloud] + description: "[cloud-only] Detailed execution error from ComfyUI for failed jobs. Absent on local ComfyUI." + allOf: + - $ref: "#/components/schemas/ExecutionError" JobDetailResponse: type: object @@ -10433,6 +10478,19 @@ components: - custom_node description: Label category. + HubLabelListResponse: + type: object + x-runtime: [cloud] + description: '[cloud-only] Response wrapper for the available Hub label catalog.' + required: + - labels + properties: + labels: + type: array + items: + $ref: '#/components/schemas/HubLabelInfo' + description: Available labels, optionally filtered by type. + HubProfileSummary: type: object x-runtime: [cloud] diff --git a/requirements.txt b/requirements.txt index 2ca6d8929..651315cb2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ comfyui-frontend-package==1.44.19 -comfyui-workflow-templates==0.9.82 +comfyui-workflow-templates==0.9.85 comfyui-embedded-docs==0.5.1 torch torchsde @@ -21,8 +21,8 @@ psutil alembic SQLAlchemy>=2.0.0 filelock -av>=14.2.0 -comfy-kitchen>=0.2.8 +av>=16.0.0 +comfy-kitchen==0.2.9 comfy-aimdo==0.4.5 requests simpleeval>=1.0.0