Merge branch 'master' into seedvr2-native-support

This commit is contained in:
John Pollock 2026-05-28 21:44:34 -05:00 committed by GitHub
commit 7431bef672
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
132 changed files with 3980 additions and 742 deletions

View File

@ -0,0 +1,24 @@
name: Detect Unreviewed Merge
# SOC 2 compliance — reusable workflow lives in Comfy-Org/github-workflows,
# tracking issues are filed in Comfy-Org/unreviewed-merges.
on:
push:
branches: [master]
concurrency:
group: detect-unreviewed-merge-${{ github.sha }}
cancel-in-progress: false
permissions:
contents: read
pull-requests: read
jobs:
detect:
uses: Comfy-Org/github-workflows/.github/workflows/detect-unreviewed-merge.yml@4d9cb6b87f953bb7cd69954280e1465fb9bd2040 # v1
with:
approval-mode: latest-per-reviewer
secrets:
UNREVIEWED_MERGES_TOKEN: ${{ secrets.UNREVIEWED_MERGES_TOKEN }}

View File

@ -1,5 +1,20 @@
import logging
import torch
_CK_STOCHASTIC_ROUNDING_AVAILABLE = False
try:
import comfy_kitchen as ck
_ck_stochastic_rounding_fp8 = ck.stochastic_rounding_fp8
_CK_STOCHASTIC_ROUNDING_AVAILABLE = True
except (AttributeError, ImportError):
logging.warning("comfy_kitchen does not support stochastic FP8 rounding, please update comfy_kitchen.")
if not _CK_STOCHASTIC_ROUNDING_AVAILABLE:
def _ck_stochastic_rounding_fp8(value, rng, dtype):
raise NotImplementedError("comfy_kitchen does not support stochastic FP8 rounding")
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
mantissa_scaled = torch.where(
normal_mask,
@ -57,6 +72,10 @@ def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device)
generator.manual_seed(seed)
if _CK_STOCHASTIC_ROUNDING_AVAILABLE:
rng = torch.randint(0, 256, value.size(), dtype=torch.uint8, layout=value.layout, device=value.device, generator=generator)
return _ck_stochastic_rounding_fp8(value, rng, dtype)
output = torch.empty_like(value, dtype=dtype)
num_slices = max(1, (value.numel() / (4096 * 4096)))
slice_size = max(1, round(value.shape[0] / num_slices))

View File

@ -804,13 +804,15 @@ class ZImagePixelSpace(ChromaRadiance):
"""
pass
class HiDreamO1Pixel(ChromaRadiance):
"""Pixel-space latent format for HiDream-O1.
No VAE model patches/unpatches raw RGB internally with patch_size=32.
"""
pass
class PixelDiTPixel(ChromaRadiance):
pass
class CogVideoX(LatentFormat):
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).

View File

@ -433,11 +433,11 @@ class Attention(nn.Module):
if self.differential:
q, q_diff = q.unbind(dim=1)
k, k_diff = k.unbind(dim=1)
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, transformer_options=transformer_options)
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
out = out - out_diff
else:
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
out = self.to_out(out)

View File

@ -138,11 +138,11 @@ class Attention(nn.Module):
k_diff = _apply_rotary_pos_emb(k_diff.float(), freqs).to(k_dtype)
if self.differential:
out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
- optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True))
out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)
- optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True, low_precision_attention=False))
del q, k, v, q_diff, k_diff
else:
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)
del q, k, v
return self.to_out(out)

510
comfy/ldm/lens/model.py Normal file
View File

@ -0,0 +1,510 @@
"""Lens denoising transformer (DiT)"""
from __future__ import annotations
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ldm.flux.layers
import comfy.patcher_extension
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
from comfy.ldm.modules.attention import optimized_attention
def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor:
return comfy.ldm.flux.layers.timestep_embedding(t, dim)
def _lens_position_ids(
frame: int, height: int, width: int, text_seq_len: int,
scale_rope: bool = True, device=None,
) -> torch.Tensor:
"""Lens axial (frame, h, w) position ids for joint image + text sequence.
With ``scale_rope=True`` h/w are centered around 0 (negative + positive
halves) and text starts at ``max(h//2, w//2)``. Result shape ``[seq, 3]``;
caller adds a batch dim for ``EmbedND``.
"""
if scale_rope:
h_pos = torch.cat([torch.arange(-(height - height // 2), 0, device=device),
torch.arange(0, height // 2, device=device)])
w_pos = torch.cat([torch.arange(-(width - width // 2), 0, device=device),
torch.arange(0, width // 2, device=device)])
text_start = max(height // 2, width // 2)
else:
h_pos = torch.arange(height, device=device)
w_pos = torch.arange(width, device=device)
text_start = max(height, width)
f_pos = torch.arange(frame, device=device)
img_ids = torch.zeros(frame, height, width, 3, device=device)
img_ids[..., 0] = f_pos[:, None, None]
img_ids[..., 1] = h_pos[None, :, None]
img_ids[..., 2] = w_pos[None, None, :]
img_ids = img_ids.reshape(-1, 3)
# Text positions replicate across all 3 axes (matches original packing).
txt_pos = torch.arange(text_start, text_start + text_seq_len, device=device).float()
txt_ids = txt_pos[:, None].expand(text_seq_len, 3)
return torch.cat([img_ids, txt_ids], dim=0)
class _TimestepEmbedder(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None) -> None:
super().__init__()
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear_1(x)
x = F.silu(x)
return self.linear_2(x)
class LensTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim: int, dtype=None, device=None, operations=None) -> None:
super().__init__()
self.timestep_embedder = _TimestepEmbedder(256, embedding_dim, dtype=dtype, device=device, operations=operations)
def forward(self, timestep: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
proj = _lens_time_proj(timestep, 256)
return self.timestep_embedder(proj.to(dtype=hidden_states.dtype))
class GateMLP(nn.Module):
"""SwiGLU MLP."""
def __init__(self, dim: int, hidden_dim: int, dtype=None, device=None, operations=None) -> None:
super().__init__()
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
def forward(self, x):
return self.w2(F.silu(self.w1(x), inplace=True).mul_(self.w3(x)))
class LensJointAttention(nn.Module):
"""Joint image+text attention with fused QKV per stream."""
def __init__(
self,
query_dim: int,
added_kv_proj_dim: int,
dim_head: int = 64,
heads: int = 8,
out_dim: Optional[int] = None,
eps: float = 1e-5,
dtype=None,
device=None,
operations=None,
) -> None:
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.heads = self.inner_dim // dim_head
self.dim_head = dim_head
self.out_dim = out_dim if out_dim is not None else query_dim
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
self.img_qkv = operations.Linear(query_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
self.txt_qkv = operations.Linear(added_kv_proj_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
# ModuleList([Linear, Identity]) for state-dict key compatibility.
self.to_out = nn.ModuleList([
operations.Linear(self.inner_dim, self.out_dim, bias=True, dtype=dtype, device=device),
nn.Identity(),
])
self.to_add_out = operations.Linear(self.inner_dim, query_dim, bias=True, dtype=dtype, device=device)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
bsz, seq_img, _ = hidden_states.shape
seq_txt = encoder_hidden_states.shape[1]
# image stream
img_qkv = self.img_qkv(hidden_states).view(bsz, seq_img, 3, self.heads, self.dim_head)
img_q, img_k, img_v = img_qkv.unbind(dim=2)
img_q = self.norm_q(img_q)
img_k = self.norm_k(img_k)
del img_qkv
# text stream
txt_qkv = self.txt_qkv(encoder_hidden_states).view(bsz, seq_txt, 3, self.heads, self.dim_head)
txt_q, txt_k, txt_v = txt_qkv.unbind(dim=2)
txt_q = self.norm_added_q(txt_q)
txt_k = self.norm_added_k(txt_k)
# [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks
q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2)
del img_q, txt_q
k = torch.cat([img_k, txt_k], dim=1).transpose(1, 2)
del img_k, txt_k
v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2)
del img_v, txt_v
q, k = apply_rope(q, k, freqs_cis)
if attention_mask is not None:
expected = (bsz, 1, 1, seq_img + seq_txt)
if attention_mask.shape != expected:
raise ValueError(
f"attention_mask must be {expected}, got {tuple(attention_mask.shape)}"
)
attention_mask = attention_mask.to(q.dtype)
out = optimized_attention(
q, k, v, self.heads, mask=attention_mask, skip_reshape=True,
transformer_options=transformer_options,
)
img_out = self.to_out[1](self.to_out[0](out[:, :seq_img, :]))
txt_out = self.to_add_out(out[:, seq_img:, :])
return img_out, txt_out
class LensTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
eps: float = 1e-6,
rms_norm: bool = True,
dtype=None,
device=None,
operations=None,
) -> None:
super().__init__()
self.attn = LensJointAttention(
query_dim=dim,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
eps=1e-5,
dtype=dtype,
device=device,
operations=operations,
)
if rms_norm:
NormCls = operations.RMSNorm
norm_kwargs = {}
else:
NormCls = operations.LayerNorm
norm_kwargs = {"elementwise_affine": False}
mlp_hidden = int(dim / 3 * 8)
# Sequential(SiLU, Linear) so state-dict lands at img_mod.1.{weight,bias}.
self.img_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.img_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.img_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.img_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
self.txt_mod = nn.Sequential(
nn.SiLU(),
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
)
self.txt_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.txt_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
self.txt_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
@staticmethod
def _modulate(x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shift, scale, gate = mod_params.chunk(3, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
img_mod1, img_mod2 = self.img_mod(temb).chunk(2, dim=-1)
txt_mod1, txt_mod2 = self.txt_mod(temb).chunk(2, dim=-1)
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
img_attn, txt_attn = self.attn(
hidden_states=img_modulated,
encoder_hidden_states=txt_modulated,
freqs_cis=freqs_cis,
attention_mask=attention_mask,
transformer_options=transformer_options,
)
hidden_states = hidden_states + img_gate1 * img_attn
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
return encoder_hidden_states, hidden_states
class _AdaLayerNormContinuousNoAffine(nn.Module):
"""AdaLayerNormContinuous(elementwise_affine=False).
The reference uses ``scale, shift = chunk(2)`` (scale first) opposite
to Flux's ``LastLayer``.
"""
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, eps: float = 1e-6,
dtype=None, device=None, operations=None) -> None:
super().__init__()
self.linear = operations.Linear(
conditioning_embedding_dim, embedding_dim * 2, bias=True, dtype=dtype, device=device
)
self.eps = eps
self.embedding_dim = embedding_dim
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
emb = self.linear(F.silu(conditioning))
scale, shift = torch.chunk(emb, 2, dim=-1)
x = F.layer_norm(x, (self.embedding_dim,), None, None, self.eps)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class LensTransformer2DModel(nn.Module):
"""Lens dual-stream MMDiT (48 blocks, inner_dim=1536, multi-layer text)."""
def __init__(
self,
patch_size: int = 2,
in_channels: int = 128,
out_channels: Optional[int] = 32,
num_layers: int = 48,
attention_head_dim: int = 64,
num_attention_heads: int = 24,
enc_hidden_dim: int = 2880,
axes_dims_rope: Tuple[int, int, int] = (8, 28, 28),
rms_norm: bool = True,
multi_layer_encoder_feature: bool = True,
selected_layer_index: Tuple[int, ...] = (5, 11, 17, 23),
image_model=None, # unused; accepted for detection-side configs.
dtype=None,
device=None,
operations=None,
) -> None:
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = out_channels if out_channels is not None else in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.multi_layer_encoder_feature = multi_layer_encoder_feature
self.selected_layer_index = list(selected_layer_index)
self.dtype = dtype
self.pos_embed = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
self.time_text_embed = LensTimestepProjEmbeddings(
embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations
)
if self.multi_layer_encoder_feature:
self.txt_norm = nn.ModuleList(
[operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
for _ in self.selected_layer_index]
)
self.txt_in = operations.Linear(
enc_hidden_dim * len(self.selected_layer_index),
self.inner_dim, bias=True, dtype=dtype, device=device,
)
else:
self.txt_norm = operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
self.txt_in = operations.Linear(enc_hidden_dim, self.inner_dim, bias=True, dtype=dtype, device=device)
self.img_in = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
self.transformer_blocks = nn.ModuleList([
LensTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
eps=1e-6,
rms_norm=rms_norm,
dtype=dtype, device=device, operations=operations,
)
for _ in range(num_layers)
])
self.norm_out = _AdaLayerNormContinuousNoAffine(
self.inner_dim, self.inner_dim, eps=1e-6,
dtype=dtype, device=device, operations=operations,
)
self.proj_out = operations.Linear(
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True,
dtype=dtype, device=device,
)
def forward(self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None, **kwargs) -> torch.Tensor:
if transformer_options is None:
transformer_options = {}
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward, self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
).execute(x, timestep, context, attention_mask, transformer_options, **kwargs)
def _forward(
self,
x: torch.Tensor,
timestep: torch.Tensor,
context: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
transformer_options: Optional[Dict[str, Any]] = None,
control: Optional[Dict[str, Any]] = None,
**kwargs,
) -> torch.Tensor:
"""ComfyUI bridge: ``(x[B,128,h,w], t[B], context[B,S,L*H], mask[B,S])``."""
if transformer_options is None:
transformer_options = {}
transformer_options = transformer_options.copy()
patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
B, C, h, w = x.shape
hidden_states = x.permute(0, 2, 3, 1).reshape(B, h * w, C)
if self.multi_layer_encoder_feature:
L = len(self.selected_layer_index)
enc_dim = context.shape[-1] // L
encoder_hidden_states = list(
context.reshape(B, -1, L, enc_dim).unbind(dim=2)
)
text_seq_len = encoder_hidden_states[0].shape[1]
else:
encoder_hidden_states = context
text_seq_len = context.shape[1]
if attention_mask is None:
attention_mask = torch.ones(
(B, text_seq_len), dtype=torch.bool, device=x.device
)
img_len = h * w
joint_mask = self._build_joint_attention_mask(attention_mask, img_len)
hidden_states = self.img_in(hidden_states)
timestep = timestep.to(hidden_states.dtype)
if self.multi_layer_encoder_feature:
normed = [self.txt_norm[i](encoder_hidden_states[i]) for i in range(L)]
encoder_hidden_states = torch.cat(normed, dim=-1)
else:
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if "post_input" in patches:
for p in patches["post_input"]:
out = p({
"img": hidden_states,
"txt": encoder_hidden_states,
"transformer_options": transformer_options,
})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
temb = self.time_text_embed(timestep, hidden_states)
ids = _lens_position_ids(1, h, w, text_seq_len, device=hidden_states.device).unsqueeze(0)
freqs_cis = self.pos_embed(ids)
transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["txt"], out["img"] = block(
hidden_states=args["img"],
encoder_hidden_states=args["txt"],
temb=args["vec"],
freqs_cis=args["pe"],
attention_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options"),
)
return out
out = blocks_replace[("double_block", i)](
{
"img": hidden_states,
"txt": encoder_hidden_states,
"vec": temb,
"pe": freqs_cis,
"attn_mask": joint_mask,
"transformer_options": transformer_options,
},
{"original_block": block_wrap},
)
encoder_hidden_states = out["txt"]
hidden_states = out["img"]
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
freqs_cis=freqs_cis,
attention_mask=joint_mask,
transformer_options=transformer_options,
)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({
"img": hidden_states,
"txt": encoder_hidden_states,
"x": x,
"block_index": i,
"transformer_options": transformer_options,
})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
if control is not None:
control_i = control.get("input")
if control_i is not None and i < len(control_i):
add = control_i[i]
if add is not None:
hidden_states[:, :add.shape[1]] += add
hidden_states = self.norm_out(hidden_states, temb)
out = self.proj_out(hidden_states)
return out.reshape(B, h, w, C).permute(0, 3, 1, 2).contiguous()
@staticmethod
def _build_joint_attention_mask(text_mask: torch.Tensor, img_len: int) -> torch.Tensor:
if text_mask.dtype != torch.bool:
text_mask = text_mask.bool()
bsz = text_mask.shape[0]
img_ones = torch.ones((bsz, img_len), dtype=torch.bool, device=text_mask.device)
joint = torch.cat([img_ones, text_mask], dim=1)
additive = torch.zeros_like(joint, dtype=torch.float32)
additive.masked_fill_(~joint, torch.finfo(torch.float32).min)
return additive[:, None, None, :]

View File

@ -211,7 +211,7 @@ class TimestepEmbedder(nn.Module):
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None, max_period=10000):
super().__init__()
if output_size is None:
output_size = hidden_size
@ -221,9 +221,10 @@ class TimestepEmbedder(nn.Module):
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
def forward(self, t, dtype, **kwargs):
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
t_freq = timestep_embedding(t, self.frequency_embedding_size, max_period=self.max_period).to(dtype)
t_emb = self.mlp(t_freq)
return t_emb

239
comfy/ldm/pixeldit/model.py Normal file
View File

@ -0,0 +1,239 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ldm.common_dit
import comfy.patcher_extension
from comfy.ldm.flux.math import apply_rope, rope
from comfy.ldm.hidream.model import FeedForwardSwiGLU
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
from .modules import (
FinalLayer,
PatchTokenEmbedder,
PiTBlock,
PixelTokenEmbedder,
apply_adaln_,
precompute_freqs_cis_2d,
)
class MMDiTJointAttention(nn.Module):
"""Joint MMDiT attention with separate Q/K/V/proj for image and text streams.
RoPE is applied to each stream before concatenation so each stream uses its own
2D/1D positional encoding. Concat order is [text, image] (text first).
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv_x = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.qkv_y = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.proj_x = operations.Linear(dim, dim, dtype=dtype, device=device)
self.proj_y = operations.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x, y, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
B, Nx, _ = x.shape
_, Ny, _ = y.shape
H = self.num_heads
D = self.head_dim
qkv_x = self.qkv_x(x).reshape(B, Nx, 3, H, D).permute(2, 0, 3, 1, 4)
qx, kx, vx = qkv_x.unbind(0)
qx = self.q_norm_x(qx)
kx = self.k_norm_x(kx)
qkv_y = self.qkv_y(y).reshape(B, Ny, 3, H, D).permute(2, 0, 3, 1, 4)
qy, ky, vy = qkv_y.unbind(0)
qy = self.q_norm_y(qy)
ky = self.k_norm_y(ky)
qx, kx = apply_rope(qx, kx, pos_img[None, None])
if pos_txt is not None:
qy, ky = apply_rope(qy, ky, pos_txt[None, None])
q_joint = torch.cat([qy, qx], dim=2)
k_joint = torch.cat([ky, kx], dim=2)
v_joint = torch.cat([vy, vx], dim=2)
out_joint = optimized_attention(
q_joint, k_joint, v_joint, H,
mask=attn_mask, skip_reshape=True, skip_output_reshape=True,
transformer_options=transformer_options,
)
out_y = out_joint[:, :, :Ny, :].transpose(1, 2).reshape(B, Ny, H * D)
out_x = out_joint[:, :, Ny:, :].transpose(1, 2).reshape(B, Nx, H * D)
return self.proj_x(out_x), self.proj_y(out_y)
class MMDiTBlockT2I(nn.Module):
def __init__(self, hidden_size, groups, mlp_ratio=4.0, dtype=None, device=None, operations=None):
super().__init__()
self.norm_x1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.norm_y1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.attn = MMDiTJointAttention(hidden_size, num_heads=groups, qkv_bias=False, dtype=dtype, device=device, operations=operations)
self.norm_x2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.norm_y2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_x = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations)
self.mlp_y = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations)
self.adaLN_modulation_img = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device))
self.adaLN_modulation_txt = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device))
def forward(self, x, y, c, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1)
shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1)
x_norm = apply_adaln_(self.norm_x1(x), shift_msa_x, scale_msa_x)
y_norm = apply_adaln_(self.norm_y1(y), shift_msa_y, scale_msa_y)
attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options)
x = torch.addcmul(x, gate_msa_x, attn_x)
y = torch.addcmul(y, gate_msa_y, attn_y)
x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln_(self.norm_x2(x), shift_mlp_x, scale_mlp_x)))
y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln_(self.norm_y2(y), shift_mlp_y, scale_mlp_y)))
return x, y
class PixDiT_T2I(nn.Module):
"""PixelDiT T2I model. Hardcoded for the released 1024px Stage-3 checkpoint
(also runs at 512px when fed the appropriate latent size and flow_shift).
Forward:
x: [B, 3, H, W] pixel-space input (no VAE)
timesteps:[B] in [0, 1000] (ComfyUI flow sampling convention)
context: [B, Ltxt, 2304] Gemma-2-2b-it hidden states (chi_prompt prepended)
Returns flow-matching velocity [B, 3, H, W].
"""
def __init__(
self,
in_channels=3,
num_groups=24,
hidden_size=1536,
pixel_hidden_size=16,
pixel_attn_hidden_size=1152,
pixel_num_groups=16,
patch_depth=14,
pixel_depth=2,
patch_size=16,
txt_embed_dim=2304,
txt_max_length=300,
use_text_rope=True,
text_rope_theta=10000.0,
image_model=None,
dtype=None,
device=None,
operations=None,
pixel_mlp_chunks=2,
):
super().__init__()
self.dtype = dtype
self.in_channels = in_channels
self.out_channels = in_channels
self.hidden_size = hidden_size
self.num_groups = num_groups
self.patch_depth = patch_depth
self.pixel_depth = pixel_depth
self.patch_size = patch_size
self.pixel_hidden_size = pixel_hidden_size
self.pixel_attn_hidden_size = pixel_attn_hidden_size
self.pixel_num_groups = pixel_num_groups
self.txt_embed_dim = txt_embed_dim
self.txt_max_length = txt_max_length
self.use_text_rope = use_text_rope
self.text_rope_theta = text_rope_theta
self.pixel_embedder = PixelTokenEmbedder(self.in_channels, self.pixel_hidden_size, dtype=dtype, device=device, operations=operations)
self.s_embedder = PatchTokenEmbedder(self.in_channels * self.patch_size ** 2, self.hidden_size, bias=True, dtype=dtype, device=device, operations=operations)
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations, max_period=10)
self.y_embedder = PatchTokenEmbedder(self.txt_embed_dim, self.hidden_size, bias=True, use_norm=True, dtype=dtype, device=device, operations=operations)
self.y_pos_embedding = nn.Parameter(torch.empty(1, self.txt_max_length, self.hidden_size, dtype=dtype, device=device))
self.patch_blocks = nn.ModuleList([
MMDiTBlockT2I(self.hidden_size, self.num_groups,
dtype=dtype, device=device, operations=operations)
for _ in range(self.patch_depth)
])
self.pixel_blocks = nn.ModuleList([
PiTBlock(
self.pixel_hidden_size,
self.hidden_size,
patch_size=self.patch_size,
num_heads=self.num_groups,
attn_hidden_size=self.pixel_attn_hidden_size,
attn_num_heads=self.pixel_num_groups,
dtype=dtype, device=device, operations=operations,
mlp_chunks=pixel_mlp_chunks,
)
for _ in range(self.pixel_depth)
])
self.final_layer = FinalLayer(self.pixel_hidden_size, self.out_channels, dtype=dtype, device=device, operations=operations)
def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
return precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width, device=device, dtype=dtype, **rope_opts)
def _fetch_text_pos(self, length, device, dtype):
return rope(torch.arange(length, dtype=torch.float32, device=device).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0).to(dtype=dtype)
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
def _pre_patch_block(self, s, i, **kwargs):
"""Hook for subclasses to inject per-block state into the patch stream (e.g. PiD's LQ gate)."""
return s
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
H_orig, W_orig = x.shape[2], x.shape[3]
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
B, _, H, W = x.shape
Hs = H // self.patch_size
Ws = W // self.patch_size
L = Hs * Ws
pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {}))
x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size)
if context is None or context.dim() != 3:
raise ValueError("PixDiT_T2I requires context (text embeddings) of shape [B, L, D]")
Ltxt = min(context.shape[1], self.txt_max_length)
y = context[:, :Ltxt, :]
y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size)
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb) # y_pos_embedding is a raw nn.Parameter
condition = F.silu(t_emb)
pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None
s = self.s_embedder(x_patches)
for i, blk in enumerate(self.patch_blocks):
s = self._pre_patch_block(s, i, **kwargs)
s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options)
s = F.silu(t_emb + s)
s_cond = s.view(B * L, self.hidden_size)
x_pixels = self.pixel_embedder(x, patch_size=self.patch_size)
for blk in self.pixel_blocks:
x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None, transformer_options=transformer_options)
x_pixels = self.final_layer(x_pixels)
C_out = self.out_channels
P2 = self.patch_size * self.patch_size
x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).reshape(B, C_out * P2, L)
out = F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size)
return out[:, :, :H_orig, :W_orig]

View File

@ -0,0 +1,187 @@
import torch
import torch.nn as nn
from comfy.ldm.flux.math import apply_rope, rope
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, get_1d_sincos_pos_embed_from_grid_torch
def apply_adaln_(x, shift, scale):
return x.addcmul_(x, scale).add_(shift)
def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0,
ref_grid_h=None, ref_grid_w=None,
scale_x=1.0, scale_y=1.0, shift_x=0.0, shift_y=0.0,
device=None, dtype=torch.float32, **kwargs):
"""2D RoPE with x/y axis frequencies interleaved at stride 2 across head dim.
rope_options:
scale_x / scale_y multiply the position range (RoPE extrapolation).
shift_x / shift_y offset the position origin (tiled / regional inference).
With ref_grid_h/w set, also applies NTK-aware per-axis theta scaling
(rope_mode='ntk_aware'): theta_axis = theta * (current/ref)^(dim_axis/(dim_axis-2)).
Returns Flux-format rotation matrices of shape [H*W, dim/2, 2, 2].
Layout of head-dim pairs: [x_0, y_0, x_1, y_1, ..., x_{dim/4-1}, y_{dim/4-1}].
"""
dim_axis = dim // 2
if ref_grid_h is not None and dim_axis > 2:
h_ntk = (height / ref_grid_h) ** (dim_axis / (dim_axis - 2))
w_ntk = (width / ref_grid_w) ** (dim_axis / (dim_axis - 2))
else:
h_ntk = w_ntk = 1.0
x_lin = torch.linspace(shift_x, scale * scale_x + shift_x, width, device=device)
y_lin = torch.linspace(shift_y, scale * scale_y + shift_y, height, device=device)
y_grid, x_grid = torch.meshgrid(y_lin, x_lin, indexing="ij")
x_rope = rope(x_grid.reshape(1, -1), dim_axis, theta * w_ntk).squeeze(0)
y_rope = rope(y_grid.reshape(1, -1), dim_axis, theta * h_ntk).squeeze(0)
out = torch.stack([x_rope, y_rope], dim=2).reshape(height * width, dim // 2, 2, 2)
return out.to(dtype=dtype)
def get_2d_sincos_pos_embed(embed_dim, height, width, device=None, dtype=torch.float32):
"""Standard 2D sin/cos absolute positional embedding (ViT-style).
first half encodes W-coordinates, second half H.
"""
assert embed_dim % 4 == 0
grid_h = torch.arange(height, dtype=torch.float32, device=device)
grid_w = torch.arange(width, dtype=torch.float32, device=device)
grid_y, grid_x = torch.meshgrid(grid_h, grid_w, indexing="ij")
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_x.reshape(-1), device=device)
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_y.reshape(-1), device=device)
return torch.cat([emb_w, emb_h], dim=1).to(dtype=dtype)
class RotaryAttention(nn.Module):
"""Single-stream self-attention with rotary positional encoding (used inside PiTBlock)."""
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x, pos, mask=None, transformer_options={}):
B, N, C = x.shape
H = self.num_heads
D = self.head_dim
qkv = self.qkv(x).reshape(B, N, 3, H, D).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = apply_rope(self.q_norm(q), self.k_norm(k), pos[None, None])
x = optimized_attention(q, k, v, H, mask=mask, skip_reshape=True, transformer_options=transformer_options)
return self.proj(x)
class FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
def forward(self, x):
return self.linear(self.norm(x))
class PatchTokenEmbedder(nn.Module):
"""Linear projection used both for patchified-image tokens and text-feature tokens."""
def __init__(self, in_chans, embed_dim, use_norm=False, bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.proj = operations.Linear(in_chans, embed_dim, bias=bias, dtype=dtype, device=device)
self.norm = operations.RMSNorm(embed_dim, eps=1e-6, dtype=dtype, device=device) if use_norm else nn.Identity()
def forward(self, x):
return self.norm(self.proj(x))
class PixelTokenEmbedder(nn.Module):
"""Pixel-level embedder: lifts each RGB pixel to hidden_size and packs into per-patch sequences."""
def __init__(self, in_channels, hidden_size_output, dtype=None, device=None, operations=None):
super().__init__()
self.in_channels = in_channels
self.hidden_size_output = hidden_size_output
self.proj = operations.Linear(self.in_channels, self.hidden_size_output, bias=True, dtype=dtype, device=device)
def forward(self, inputs, patch_size):
B, _, H, W = inputs.shape
Hs, Ws = H // patch_size, W // patch_size
P2 = patch_size * patch_size
x = inputs.permute(0, 2, 3, 1).contiguous()
x = self.proj(x)
pos_full = get_2d_sincos_pos_embed(self.hidden_size_output, H, W, device=x.device, dtype=x.dtype).view(H, W, self.hidden_size_output)
x = x + pos_full.unsqueeze(0)
x = x.view(B, Hs, patch_size, Ws, patch_size, self.hidden_size_output)
return x.permute(0, 1, 3, 2, 4, 5).reshape(B * Hs * Ws, P2, self.hidden_size_output)
class PiTBlock(nn.Module):
"""Pixel-level transformer block.
Compresses each patch's P^2 pixel tokens → 1 attention token via a linear,
runs global self-attention across patches with 2D RoPE, then expands back to P^2 tokens.
Conditioning is per-pixel adaLN from the patch-level features.
"""
def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0,
attn_hidden_size=None, attn_num_heads=None, dtype=None, device=None, operations=None, mlp_chunks=1):
super().__init__()
self.pixel_dim = pixel_hidden_size
self.context_dim = patch_hidden_size
self.attn_dim = attn_hidden_size if attn_hidden_size is not None else patch_hidden_size
self.num_heads = attn_num_heads if attn_num_heads is not None else num_heads
assert self.attn_dim % self.num_heads == 0
p2 = patch_size * patch_size
self.compress_to_attn = operations.Linear(p2 * self.pixel_dim, self.attn_dim, bias=True, dtype=dtype, device=device)
self.expand_from_attn = operations.Linear(self.attn_dim, p2 * self.pixel_dim, bias=True, dtype=dtype, device=device)
self.norm1 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
self.attn = RotaryAttention(self.attn_dim, num_heads=self.num_heads, qkv_bias=False, dtype=dtype, device=device, operations=operations)
self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio), dtype=dtype, device=device, operations=operations)
self.adaLN_modulation_msa = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
self.adaLN_modulation_mlp = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
self._rope_fn = precompute_freqs_cis_2d
self.mlp_chunks = max(1, int(mlp_chunks))
def _fetch_pos(self, height, width, device, dtype, **rope_opts):
return self._rope_fn(self.attn_dim // self.num_heads, height, width, device=device, dtype=dtype, **rope_opts)
def forward(self, x, s_cond, image_height, image_width, patch_size, mask=None, transformer_options={}):
BL, P2, _ = x.shape
Hs, Ws = image_height // patch_size, image_width // patch_size
L = Hs * Ws
B = BL // L
# Attention path uses only msa params; compute, use, free before mlp params allocate.
msa_params = self.adaLN_modulation_msa(s_cond).view(BL, P2, 3 * self.pixel_dim)
shift_msa, scale_msa, gate_msa = msa_params.chunk(3, dim=-1)
x_norm = apply_adaln_(self.norm1(x), shift_msa, scale_msa)
x_flat = x_norm.view(BL, P2 * self.pixel_dim)
x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim)
pos_comp = self._fetch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {}))
attn_out = self.attn(x_comp, pos_comp, mask=mask, transformer_options=transformer_options)
attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim))
attn_exp = attn_flat.view(BL, P2, self.pixel_dim)
x = torch.addcmul(x, gate_msa, attn_exp)
del msa_params, shift_msa, scale_msa, gate_msa
mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim)
shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1)
gate_mlp = gate_mlp.contiguous() # detach from mlp_params so the del below frees shift+scale storage before the MLP
mlp_input = apply_adaln_(self.norm2(x), shift_mlp, scale_mlp)
del mlp_params, shift_mlp, scale_mlp
# MLP in chunks since the peak memory usage is huge here
chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks
for s in range(0, BL, chunk_size):
e = min(s + chunk_size, BL)
x[s:e].addcmul_(gate_mlp[s:e], self.mlp(mlp_input[s:e]))
return x

227
comfy/ldm/pixeldit/pid.py Normal file
View File

@ -0,0 +1,227 @@
"""PiD — Pixel Diffusion Decoder. Decodes a Flux/SD3/Flux2/Z-Image latent
directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I
body + LQ projection branch injected before each MMDiT patch block.
"""
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from .model import PixDiT_T2I
from .modules import precompute_freqs_cis_2d
class SigmaAwareGatePerTokenPerDim(nn.Module):
"""gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq.
Trained init gives ~0.88 gate at sigma=0, ~0.05 at sigma=1.
"""
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.content_proj = operations.Linear(dim * 2, dim, dtype=dtype, device=device)
self.log_alpha = nn.Parameter(torch.empty((), dtype=dtype, device=device))
def forward(self, x: torch.Tensor, lq: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
content_logit = self.content_proj(torch.cat([x, lq], dim=-1))
# log_alpha is a raw nn.Parameter -> doesn't auto-cast under dynamic VRAM.
log_alpha = self.log_alpha.to(device=x.device, dtype=torch.float32)
sigma_offset = -log_alpha.exp() * sigma.float().view(-1, 1, 1)
gate = torch.sigmoid(content_logit + sigma_offset)
return x + (gate * lq).to(x.dtype)
class ResBlock(nn.Module):
"""Pre-activation ResNet block: GN -> SiLU -> Conv -> GN -> SiLU -> Conv + skip."""
def __init__(self, channels: int, num_groups: int = 4, dtype=None, device=None, operations=None):
super().__init__()
self.block = nn.Sequential(
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.block(x)
class LQProjection2D(nn.Module):
"""LQ latent -> per-block patch-aligned features for controlnet-style injection."""
def __init__(
self,
latent_channels: int,
hidden_dim: int = 512,
out_dim: int = 1536,
patch_size: int = 16,
sr_scale: int = 4,
latent_spatial_down_factor: int = 8,
num_res_blocks: int = 4,
num_outputs: int = 7,
interval: int = 2,
dtype=None, device=None, operations=None,
):
super().__init__()
self.latent_channels = latent_channels
self.hidden_dim = hidden_dim
self.out_dim = out_dim
self.patch_size = patch_size
self.sr_scale = sr_scale
self.latent_spatial_down_factor = latent_spatial_down_factor
self.num_outputs = num_outputs
self.interval = interval
z_to_patch_ratio = (sr_scale * latent_spatial_down_factor) / patch_size
self.z_to_patch_ratio = z_to_patch_ratio
if z_to_patch_ratio >= 1:
self.latent_fold_factor = 0
latent_proj_in_ch = latent_channels
else:
fold_factor = int(1 / z_to_patch_ratio)
assert fold_factor * z_to_patch_ratio == 1.0
self.latent_fold_factor = fold_factor
latent_proj_in_ch = latent_channels * fold_factor * fold_factor
layers = [
operations.Conv2d(latent_proj_in_ch, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
]
for _ in range(num_res_blocks):
layers.append(ResBlock(hidden_dim, dtype=dtype, device=device, operations=operations))
self.latent_proj = nn.Sequential(*layers)
self.output_heads = nn.ModuleList(
[operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) for _ in range(num_outputs)]
)
self.gate_modules = nn.ModuleList(
[SigmaAwareGatePerTokenPerDim(out_dim, dtype=dtype, device=device, operations=operations)
for _ in range(num_outputs)]
)
def is_gate_active(self, block_idx: int) -> bool:
return block_idx % self.interval == 0
def output_index(self, block_idx: int) -> int:
return block_idx // self.interval
def gate(self, x: torch.Tensor, lq_feature: torch.Tensor, sigma: torch.Tensor, out_idx: int) -> torch.Tensor:
return self.gate_modules[out_idx](x, lq_feature, sigma)
def _align_latent_to_patch_grid(self, lq_latent: torch.Tensor, pH: int, pW: int) -> torch.Tensor:
B, z_dim = lq_latent.shape[:2]
if self.z_to_patch_ratio >= 1:
if lq_latent.shape[2] != pH or lq_latent.shape[3] != pW:
z_aligned = F.interpolate(lq_latent, size=(pH, pW), mode="nearest")
else:
z_aligned = lq_latent
else:
f = self.latent_fold_factor
zH_expected, zW_expected = pH * f, pW * f
if lq_latent.shape[2] != zH_expected or lq_latent.shape[3] != zW_expected:
lq_latent = F.interpolate(lq_latent, size=(zH_expected, zW_expected), mode="nearest")
z_aligned = lq_latent.reshape(B, z_dim, pH, f, pW, f).permute(0, 1, 3, 5, 2, 4)
z_aligned = z_aligned.reshape(B, z_dim * f * f, pH, pW)
return self.latent_proj(z_aligned)
def forward(self, lq_latent: torch.Tensor, target_pH: int, target_pW: int) -> List[torch.Tensor]:
feat = self._align_latent_to_patch_grid(lq_latent, target_pH, target_pW)
B, C, H, W = feat.shape
tokens = feat.permute(0, 2, 3, 1).contiguous().view(B, H * W, C)
return [head(tokens) for head in self.output_heads]
class PidNet(PixDiT_T2I):
"""PixDiT_T2I + LQ injection (one sigma-gated feature inserted before each patch block)."""
def __init__(
self,
lq_latent_channels: int = 16,
lq_hidden_dim: int = 512,
lq_num_res_blocks: int = 4,
lq_interval: int = 2,
sr_scale: int = 4,
latent_spatial_down_factor: int = 8,
rope_ref_h: int = 1024, # NTK ref resolution in PIXEL units: 1024px / patch=16 -> grid_ref=64.
rope_ref_w: int = 1024,
image_model=None,
dtype=None, device=None, operations=None,
**pixdit_kwargs,
):
super().__init__(dtype=dtype, device=device, operations=operations, **pixdit_kwargs)
self.rope_ref_grid_h = rope_ref_h // self.patch_size
self.rope_ref_grid_w = rope_ref_w // self.patch_size
# Parent's PiTBlocks were built with plain RoPE — swap in NTK-aware.
def _pit_rope_fn(head_dim, h, w, device=None, dtype=torch.float32, **rope_opts):
return precompute_freqs_cis_2d(head_dim, h, w, ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w, device=device, dtype=dtype, **rope_opts)
for blk in self.pixel_blocks:
blk._rope_fn = _pit_rope_fn
num_lq_outputs = (self.patch_depth + lq_interval - 1) // lq_interval
self.lq_proj = LQProjection2D(
latent_channels=lq_latent_channels,
hidden_dim=lq_hidden_dim,
out_dim=self.hidden_size,
patch_size=self.patch_size,
sr_scale=sr_scale,
latent_spatial_down_factor=latent_spatial_down_factor,
num_res_blocks=lq_num_res_blocks,
num_outputs=num_lq_outputs,
interval=lq_interval,
dtype=dtype,
device=device,
operations=operations,
)
def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
return precompute_freqs_cis_2d(
self.hidden_size // self.num_groups,
height, width,
ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w,
device=device, dtype=dtype, **rope_opts,
)
def _pre_patch_block(self, s, i, pid_lq_features, pid_degrade_sigma, **kwargs):
if not self.lq_proj.is_gate_active(i):
return s
out_idx = self.lq_proj.output_index(i)
if out_idx >= len(pid_lq_features):
return s
return self.lq_proj.gate(s, pid_lq_features[out_idx], pid_degrade_sigma, out_idx)
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, lq_latent=None, degrade_sigma=None, **kwargs):
if lq_latent is None:
raise ValueError("PidNet requires lq_latent — attach via PiDConditioning")
expected_c = self.lq_proj.latent_channels
if lq_latent.shape[1] != expected_c:
raise ValueError(
f"Input latent has {lq_latent.shape[1]} channels, this model variant expects {expected_c}. "
f"Flux1/SD3 = 16 channels, Flux2 = 128 channels."
)
B = x.shape[0]
# Match the backbone's pad_to_patch_size (round up) so the LQ grid lines up with the patch stream.
Hs = -(-x.shape[2] // self.patch_size)
Ws = -(-x.shape[3] // self.patch_size)
degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1)
if degrade_sigma.numel() == 1 and B > 1:
degrade_sigma = degrade_sigma.expand(B).contiguous()
lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws)
return super()._forward(
x, timesteps,
context=context, attention_mask=attention_mask,
transformer_options=transformer_options,
pid_lq_features=lq_features,
pid_degrade_sigma=degrade_sigma,
**kwargs,
)

View File

@ -35,6 +35,7 @@ import comfy.ldm.hydit.models
import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders
import comfy.ldm.flux.model
import comfy.ldm.lens.model
import comfy.ldm.lightricks.model
import comfy.ldm.hunyuan_video.model
import comfy.ldm.cosmos.model
@ -48,6 +49,8 @@ import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.chroma_radiance.model
import comfy.ldm.pixeldit.model
import comfy.ldm.pixeldit.pid
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.seedvr.model
@ -1070,6 +1073,27 @@ class Flux2(Flux):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class Lens(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(
model_config, model_type, device=device,
unet_model=comfy.ldm.lens.model.LensTransformer2DModel,
)
def encode_adm(self, **kwargs):
return None # Lens has no pooled/ADM conditioning.
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
return out
class GenmoMochi(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
@ -1387,6 +1411,53 @@ class ZImagePixelSpace(Lumina2):
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
self.memory_usage_factor_conds = ("ref_latents",)
class PixelDiTT2I(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device,
unet_model=comfy.ldm.pixeldit.model.PixDiT_T2I)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out["attention_mask"] = comfy.conds.CONDRegular(attention_mask)
return out
class PiD(PixelDiTT2I):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
BaseModel.__init__(self, model_config, model_type, device=device,
unet_model=comfy.ldm.pixeldit.pid.PidNet)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
lq_latent = kwargs.get("lq_latent", None)
if lq_latent is not None:
out["lq_latent"] = comfy.conds.CONDRegular(lq_latent)
degrade_sigma = kwargs.get("degrade_sigma", None)
if degrade_sigma is not None:
out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma)
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "lq_latent" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
lq = cond_value.cond
dim = window.dim
if dim >= lq.ndim:
return None
lq_proj = self.diffusion_model.lq_proj
ratio = lq_proj.sr_scale * lq_proj.latent_spatial_down_factor
# Map x window indices -> lq indices (deduplicated, sorted, in-bounds).
lq_size = lq.size(dim)
lq_indices = sorted({i // ratio for i in window.index_list if 0 <= i // ratio < lq_size})
if not lq_indices:
return None
idx = tuple([slice(None)] * dim + [lq_indices])
return cond_value._copy_with(lq[idx].to(device))
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)

View File

@ -463,6 +463,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
return dit_config
# PiD (Pixel Diffusion Decoder). Must check BEFORE plain PixelDiT_T2I.
_lq_w_key = '{}lq_proj.latent_proj.0.weight'.format(key_prefix)
if _lq_w_key in state_dict_keys:
in_ch = int(state_dict[_lq_w_key].shape[1])
_gate_prefix = '{}lq_proj.gate_modules.'.format(key_prefix)
num_gates = len({k[len(_gate_prefix):].split('.')[0]
for k in state_dict_keys if k.startswith(_gate_prefix)})
dit_config = {"image_model": "pid",
"lq_latent_channels": in_ch,
"latent_spatial_down_factor": 16 if in_ch >= 64 else 8}
if num_gates > 0:
dit_config["lq_interval"] = (14 + num_gates - 1) // num_gates
return dit_config
if '{}core.pixel_embedder.proj.weight'.format(key_prefix) in state_dict_keys: # PixelDiT T2I
return {"image_model": "pixeldit_t2i"}
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
dit_config = {}
dit_config["image_model"] = "lumina2"
@ -805,6 +822,30 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["timestep_scale"] = 1000.0
return dit_config
if '{}transformer_blocks.0.attn.norm_added_q.weight'.format(key_prefix) in state_dict_keys \
and '{}transformer_blocks.0.img_mlp.w1.weight'.format(key_prefix) in state_dict_keys: # Lens
img_in_w = state_dict['{}img_in.weight'.format(key_prefix)]
proj_out_w = state_dict['{}proj_out.weight'.format(key_prefix)]
multi_layer = '{}txt_norm.0.weight'.format(key_prefix) in state_dict_keys
if multi_layer:
enc_hidden_dim = state_dict['{}txt_norm.0.weight'.format(key_prefix)].shape[0]
# Indices are TE-side; the DiT just consumes L layers in order.
selected_layer_index = tuple(range(count_blocks(state_dict_keys, '{}txt_norm.'.format(key_prefix) + '{}.')))
else:
enc_hidden_dim = state_dict['{}txt_norm.weight'.format(key_prefix)].shape[0]
selected_layer_index = (0,)
return {
"image_model": "lens",
"in_channels": img_in_w.shape[1],
"out_channels": proj_out_w.shape[0] // 4, # patch_size ** 2 (=2² default)
"num_layers": count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.'),
"num_attention_heads": img_in_w.shape[0] // 64, # // attention_head_dim default
"enc_hidden_dim": enc_hidden_dim,
"multi_layer_encoder_feature": multi_layer,
"selected_layer_index": selected_layer_index,
}
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
dit_config = {}
dit_config["image_model"] = "qwen_image"

View File

@ -18,6 +18,7 @@
import torch
import logging
import contextlib
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy.float
@ -1047,6 +1048,144 @@ class QuantLinearFunc(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None
# Quantized-weight module helpers
def _quantized_apply(module, fn, recurse=True):
"""Re-wrap Parameters after fn so .to()/.cuda() propagate through QuantizedTensor weights."""
if recurse:
for child in module.children():
child._apply(fn)
for key, param in module._parameters.items():
if param is None:
continue
p = fn(param)
if (not torch.is_inference_mode_enabled()) and p.is_inference():
p = p.clone()
module.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in module._buffers.items():
if buf is not None:
module._buffers[key] = fn(buf)
return module
def _load_quantized_module(module, super_load, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs, load_extra_params=False):
"""Shared _load_from_state_dict body for quantized-weight modules.
Pops weight (+ scales, +/- extras), populates module.weight as a Parameter
or Parameter-wrapped QuantizedTensor, then calls super_load and strips
consumed keys from missing_keys. Reads compute_dtype from factory_kwargs
and disabled formats from module._disabled_formats.
"""
device = module.factory_kwargs["device"]
compute_dtype = module.factory_kwargs["dtype"]
disabled_formats = module._disabled_formats
layer_name = prefix.rstrip('.')
weight = state_dict.pop(f"{prefix}weight", None)
if weight is None:
logging.warning(f"Missing weight for layer {layer_name}")
module.weight = None
return
manually_loaded_keys = [f"{prefix}weight"]
def pop_scale(name, dtype=None):
key = f"{prefix}{name}"
v = state_dict.pop(key, None)
if v is not None:
v = v.to(device=device)
if dtype is not None:
v = v.view(dtype=dtype)
manually_loaded_keys.append(key)
return v
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
module.weight = torch.nn.Parameter(weight.to(device=device, dtype=compute_dtype), requires_grad=False)
else:
module.quant_format = layer_conf.get("format", None)
module._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
if not module._full_precision_mm:
module._full_precision_mm = module._full_precision_mm_config
if module.quant_format in disabled_formats:
module._full_precision_mm = True
if module.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[module.quant_format]
module.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(module.layout_type)
# Per-format scales; fp8 dtype views handle both legacy uint8-on-disk and native fp8.
if module.quant_format in ("float8_e4m3fn", "float8_e5m2"):
scales = {"scale": pop_scale("weight_scale")}
elif module.quant_format == "mxfp8":
bs = pop_scale("weight_scale", torch.float8_e8m0fnu)
if bs is None:
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
scales = {"scale": bs}
elif module.quant_format == "nvfp4":
ts = pop_scale("weight_scale_2")
bs = pop_scale("weight_scale", torch.float8_e4m3fn)
if ts is None or bs is None:
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
scales = {"scale": ts, "block_scale": bs}
else:
raise ValueError(f"Unsupported quantization format: {module.quant_format}")
params = layout_cls.Params(**scales, orig_dtype=compute_dtype, orig_shape=module._orig_shape)
module.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), module.layout_type, params),
requires_grad=False,
)
if load_extra_params:
for param_name in qconfig["parameters"]:
if param_name in {"weight_scale", "weight_scale_2"}:
continue
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
module.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super_load(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)
def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extra_quant_params=()):
"""Shared state_dict body. extra_quant_conf merges into the comfy_quant JSON;
extra_quant_params names attributes written as additional top-level keys."""
if not hasattr(module, 'weight'):
logging.warning(f"Warning: state dict on uninitialized op {prefix}")
return sd
bias = getattr(module, 'bias', None)
if bias is not None:
sd[f"{prefix}bias"] = bias
if module.weight is None:
return sd
if isinstance(module.weight, QuantizedTensor):
sd.update(module.weight.state_dict(f"{prefix}weight"))
quant_conf = {"format": module.quant_format}
if getattr(module, '_full_precision_mm_config', False):
quant_conf["full_precision_matrix_mult"] = True
if extra_quant_conf:
quant_conf.update(extra_quant_conf)
sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
for name in extra_quant_params:
value = getattr(module, name, None)
if value is not None:
sd[f"{prefix}{name}"] = value
else:
sd[f"{prefix}weight"] = module.weight
return sd
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast):
@ -1056,21 +1195,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
_disabled = disabled
class Linear(torch.nn.Module, CastWeightBiasOp):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
_disabled_formats = disabled
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
super().__init__()
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
# self.factory_kwargs = {"device": device, "dtype": dtype}
self.in_features = in_features
self.out_features = out_features
self._orig_shape = (out_features, in_features)
if bias:
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
else:
@ -1083,151 +1217,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def reset_parameters(self):
return None
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
key = f"{prefix}{param_name}"
value = state_dict.pop(key, None)
if value is not None:
value = value.to(device=device)
if dtype is not None:
value = value.view(dtype=dtype)
manually_loaded_keys.append(key)
return value
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
device = self.factory_kwargs["device"]
layer_name = prefix.rstrip('.')
weight_key = f"{prefix}weight"
weight = state_dict.pop(weight_key, None)
if weight is None:
logging.warning(f"Missing weight for layer {layer_name}")
self.weight = None
return
manually_loaded_keys = [weight_key]
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
if layer_conf is None:
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
else:
self.quant_format = layer_conf.get("format", None)
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
if not self._full_precision_mm:
self._full_precision_mm = self._full_precision_mm_config
if self.quant_format in MixedPrecisionOps._disabled:
self._full_precision_mm = True
if self.quant_format is None:
raise ValueError(f"Unknown quantization format for layer {layer_name}")
qconfig = QUANT_ALGOS[self.quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(self.layout_type)
# Load format-specific parameters
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
# FP8: single tensor scale
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
params = layout_cls.Params(
scale=scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "mxfp8":
# MXFP8: E8M0 block scales stored as uint8 in safetensors
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.uint8)
if block_scale is None:
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
block_scale = block_scale.view(torch.float8_e8m0fnu)
params = layout_cls.Params(
scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "nvfp4":
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.float8_e4m3fn)
if tensor_scale is None or block_scale is None:
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
params = layout_cls.Params(
scale=tensor_scale,
block_scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
else:
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
requires_grad=False
)
for param_name in qconfig["parameters"]:
if param_name in {"weight_scale", "weight_scale_2"}:
continue # Already handled above
param_key = f"{prefix}{param_name}"
_v = state_dict.pop(param_key, None)
if _v is None:
continue
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
manually_loaded_keys.append(param_key)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for key in manually_loaded_keys:
if key in missing_keys:
missing_keys.remove(key)
def _load_from_state_dict(self, *args):
_load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=True)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
if destination is not None:
sd = destination
else:
sd = {}
if not hasattr(self, 'weight'):
logging.warning("Warning: state dict on uninitialized op {}".format(prefix))
return sd
if self.bias is not None:
sd["{}bias".format(prefix)] = self.bias
if self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
if self._full_precision_mm_config:
quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
input_scale = getattr(self, 'input_scale', None)
if input_scale is not None:
sd["{}input_scale".format(prefix)] = input_scale
else:
sd["{}weight".format(prefix)] = self.weight
return sd
sd = destination if destination is not None else {}
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_params=("input_scale",))
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
@ -1317,25 +1312,126 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.weight = torch.nn.Parameter(weight, requires_grad=False)
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
if recurse:
for module in self.children():
module._apply(fn)
return _quantized_apply(self, fn, recurse)
for key, param in self._parameters.items():
if param is None:
continue
p = fn(param)
if (not torch.is_inference_mode_enabled()) and p.is_inference():
p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
class MoEExperts(torch.nn.Module, CastWeightBiasOp):
"""Container for E quantized expert weights, indexed via expert_weight(i).
The bank lives on self.weight as a single 3D tensor either a
compute_dtype Parameter or a Parameter wrapping a QuantizedTensor
with leading expert dim.
State-dict layout matches mixed_precision_ops.Linear with a leading
expert dim:
{prefix}.weight quant data (storage_t), leading dim = E
{prefix}.weight_scale block / per-tensor scale
{prefix}.weight_scale_2 [E] or scalar NVFP4 only
{prefix}.bias [E, out_features] optional, compute_dtype
{prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}}
Without comfy_quant the weight loads as a plain compute_dtype 3D Parameter [E, out, in].
"""
_disabled_formats = disabled
def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
super().__init__()
self.num_experts = num_experts
self.in_features = in_features
self.out_features = out_features
self._orig_shape = (num_experts, out_features, in_features)
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
if bias:
self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs))
else:
self.register_parameter("bias", None)
# Populated by _load_from_state_dict:
self.weight = None
self.quant_format = None
self.layout_type = None
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
self._full_precision_mm_config = False
self._resident_bank = None
def reset_parameters(self):
return None
def _apply(self, fn, recurse=True):
return _quantized_apply(self, fn, recurse)
def _load_from_state_dict(self, *args):
_load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=False)
def expert_weight(self, i: int):
"""Expert i's weight (Tensor or per-expert QuantizedTensor view)."""
if isinstance(self.weight, QuantizedTensor):
return self._expert_qt_from(self.weight, i)
return self.weight[i]
@contextlib.contextmanager
def bank_resident(self, input):
"""Cast the whole bank once; expert_linear inside reuses the cast.
Not re-entrant do not nest calls on the same instance.
"""
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
self._resident_bank = (weight, bias)
try:
yield self
finally:
self._resident_bank = None
uncast_bias_weight(self, weight, bias, offload_stream)
def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor:
"""Linear against expert i's weight (with optional bias)."""
resident = getattr(self, "_resident_bank", None)
if resident is not None:
weight, bias = resident
return self._expert_linear_impl(input, weight, bias, i)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
try:
return self._expert_linear_impl(input, weight, bias, i)
finally:
uncast_bias_weight(self, weight, bias, offload_stream)
def _expert_linear_impl(self, input, weight, bias, i):
if isinstance(weight, QuantizedTensor):
qw = self._expert_qt_from(weight, i)
else:
qw = weight[i]
b = cast_to_input(bias[i], input, copy=False) if bias is not None else None
if isinstance(qw, QuantizedTensor):
use_fast = (
not self._full_precision_mm
and qw.layout_cls.supports_fast_matmul()
and input.dim() == 2
)
if use_fast:
qin = QuantizedTensor.from_float(input, self.layout_type)
return torch.nn.functional.linear(qin, qw, b)
out = input @ qw.dequantize().t()
return out + b if b is not None else out
return torch.nn.functional.linear(input, qw, b)
def _expert_qt_from(self, weight: QuantizedTensor, i: int) -> QuantizedTensor:
"""Build a per-expert QuantizedTensor by indexing into a resident bank."""
params = weight._params
kwargs = {
"scale": params.scale[i] if params.scale.dim() else params.scale,
"orig_dtype": params.orig_dtype,
"orig_shape": (self.out_features, self.in_features),
}
if hasattr(params, "block_scale"): # NVFP4
kwargs["block_scale"] = params.block_scale[i]
return QuantizedTensor(weight._qdata[i], weight._layout_cls, type(params)(**kwargs))
def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = destination if destination is not None else {}
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_conf={"num_experts": self.num_experts})
class Embedding(manual_cast.Embedding):
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
weight_key = f"{prefix}weight"
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
@ -1343,14 +1439,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
# Only fp8 makes sense for embeddings (per-row dequant via index select).
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
quant_format = layer_conf.get("format") if layer_conf is not None else None
manually_loaded_keys = []
if quant_format in ("float8_e4m3fn", "float8_e5m2") and weight_key in state_dict:
self.quant_format = quant_format
qconfig = QUANT_ALGOS[quant_format]
self.layout_type = qconfig["comfy_tensor_layout"]
layout_cls = get_layout_class(self.layout_type)
weight = state_dict.pop(weight_key)
manually_loaded_keys = [weight_key]
manually_loaded_keys.append(weight_key)
scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(scale_key, None)
@ -1366,35 +1464,19 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
requires_grad=False)
elif layer_conf is not None:
# Unsupported format — restore the marker so it round-trips; fall through to default load.
state_dict[f"{prefix}comfy_quant"] = torch.tensor(
list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for k in manually_loaded_keys:
if k in missing_keys:
missing_keys.remove(k)
else:
if layer_conf is not None:
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for k in manually_loaded_keys:
if k in missing_keys:
missing_keys.remove(k)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
if destination is not None:
sd = destination
else:
sd = {}
if not hasattr(self, 'weight') or self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
else:
sd["{}weight".format(prefix)] = self.weight
return sd
sd = destination if destination is not None else {}
return _quantized_weight_state_dict(self, sd, prefix)
def forward_comfy_cast_weights(self, input, out_dtype=None):
weight = self.weight

View File

@ -51,6 +51,7 @@ import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
import comfy.text_encoders.lumina2
import comfy.text_encoders.pixeldit
import comfy.text_encoders.wan
import comfy.text_encoders.hidream
import comfy.text_encoders.ace
@ -70,6 +71,7 @@ import comfy.text_encoders.ernie
import comfy.text_encoders.gemma4
import comfy.text_encoders.cogvideo
import comfy.text_encoders.sa3
import comfy.text_encoders.gpt_oss
import comfy.model_patcher
import comfy.lora
@ -1463,6 +1465,8 @@ class CLIPType(Enum):
FLUX2 = 25
LONGCAT_IMAGE = 26
COGVIDEOX = 27
LENS = 28
PIXELDIT = 29
@ -1515,6 +1519,7 @@ class TEModel(Enum):
GEMMA_4_E2B = 30
GEMMA_4_31B = 31
T5_GEMMA = 32
GPT_OSS_20B = 33
def detect_te_model(sd):
@ -1556,6 +1561,9 @@ def detect_te_model(sd):
else:
return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
# Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512).
if "layers.0.self_attn.sinks" in sd and "layers.0.mlp.experts.gate_up_proj.weight" in sd:
return TEModel.GPT_OSS_20B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
if weight.shape[0] == 256:
@ -1702,8 +1710,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.tokenizer = variant.tokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_2_2B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
if clip_type == CLIPType.PIXELDIT:
clip_target.clip = comfy.text_encoders.pixeldit.pixeldit_te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer
else:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_4B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
@ -1738,6 +1750,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
elif te_model == TEModel.GPT_OSS_20B:
clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.QWEN3_4B:
if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2:
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b")

View File

@ -30,6 +30,7 @@ import comfy.text_encoders.longcat_image
import comfy.text_encoders.ernie
import comfy.text_encoders.cogvideo
import comfy.text_encoders.hidream_o1
import comfy.text_encoders.pixeldit
from . import supported_models_base
from . import latent_formats
@ -829,6 +830,50 @@ class Flux2(Flux):
return None
class Lens(supported_models_base.BASE):
"""Microsoft Lens (3.8B dual-stream MMDiT, GPT-OSS-20B text features, Flux2 VAE)."""
unet_config = {
"image_model": "lens",
}
sampling_settings = {
"shift": 1.829, # Default mu for 1440x1440 (and any seq_len > 4300
}
unet_extra_config = {}
latent_format = latent_formats.Flux2
memory_usage_factor = 4.0
supported_inference_dtypes = [torch.bfloat16, torch.float32] # fp16 causes NaNs
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def __init__(self, unet_config):
super().__init__(unet_config)
def get_model(self, state_dict, prefix="", device=None):
return model_base.Lens(self, model_type=model_base.ModelType.FLUX, device=device)
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
for hint in ("gpt_oss.transformer.", ""):
full_prefix = "{}{}".format(pref, hint)
if "{}layers.0.self_attn.sinks".format(full_prefix) in state_dict:
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix)
return supported_models_base.ClipTarget(
comfy.text_encoders.gpt_oss.LensTokenizer,
comfy.text_encoders.gpt_oss.lens_te(**detect),
)
return supported_models_base.ClipTarget(
comfy.text_encoders.gpt_oss.LensTokenizer,
comfy.text_encoders.gpt_oss.lens_te(),
)
class GenmoMochi(supported_models_base.BASE):
unet_config = {
"image_model": "mochi_preview",
@ -1159,6 +1204,72 @@ class ZImagePixelSpace(ZImage):
def get_model(self, state_dict, prefix="", device=None):
return model_base.ZImagePixelSpace(self, device=device)
class PixelDiTT2I(supported_models_base.BASE):
unet_config = {
"image_model": "pixeldit_t2i",
}
unet_extra_config = {}
sampling_settings = {
"shift": 4.0, # 1024px stage 3 default; 2.0 for 512px
}
latent_format = latent_formats.PixelDiTPixel
memory_usage_factor = 0.04
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
return model_base.PixelDiTT2I(self, device=device)
def process_unet_state_dict(self, state_dict):
# pixel_dim from pixel_embedder.proj.weight = (pixel_dim, in_channels); p2 derived per-weight from total // (6 * pixel_dim).
pixel_dim = next(v for k, v in state_dict.items() if k.endswith("pixel_embedder.proj.weight")).shape[0]
out = {}
marker = ".adaLN_modulation.0."
for k, v in state_dict.items():
if k.startswith("_repa_projector") or k.startswith("net_ema."):
continue
if k.startswith("core."):
k = k[len("core."):]
elif k.startswith("net."):
k = k[len("net."):]
if "pixel_blocks." in k and marker in k:
# Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM
p2 = v.shape[0] // (6 * pixel_dim)
trail = v.shape[1:] # () for bias, (in_dim,) for weight
vv = v.view(p2, 6, pixel_dim, *trail)
base, suffix = k.split(marker)
out[f"{base}.adaLN_modulation_msa.{suffix}"] = vv[:, 0:3].reshape(3 * p2 * pixel_dim, *trail).contiguous()
out[f"{base}.adaLN_modulation_mlp.{suffix}"] = vv[:, 3:6].reshape(3 * p2 * pixel_dim, *trail).contiguous()
else:
out[k] = v
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(
comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer,
comfy.text_encoders.pixeldit.PixelDiTGemma2TE,
)
class PiD(PixelDiTT2I):
unet_config = {
"image_model": "pid",
}
sampling_settings = {
"shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0]
}
memory_usage_factor = 0.04
def get_model(self, state_dict, prefix="", device=None):
return model_base.PiD(self, device=device)
class WAN21_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "wan2.1",
@ -2097,6 +2208,8 @@ models = [
CosmosI2VPredict2,
ZImagePixelSpace,
ZImage,
PiD,
PixelDiTT2I,
Lumina2,
WAN22_T2V,
WAN21_CausalAR_T2V,
@ -2125,6 +2238,7 @@ models = [
Omnigen2,
QwenImage,
Flux2,
Lens,
Kandinsky5Image,
Kandinsky5,
Anima,

View File

@ -0,0 +1,600 @@
"""GPT-OSS text encoder for Lens."""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Any, List, Optional, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
from comfy import sd1_clip
from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device
from comfy.text_encoders.llama import RMSNorm, apply_rope
@dataclass
class GptOss20BConfig:
vocab_size: int = 201088
hidden_size: int = 2880
intermediate_size: int = 2880
num_hidden_layers: int = 24
num_attention_heads: int = 64
num_key_value_heads: int = 8
head_dim: int = 64
num_local_experts: int = 32
num_experts_per_tok: int = 4
sliding_window: int = 128
original_max_position_embeddings: int = 4096
rope_theta: float = 150000.0
rope_factor: float = 32.0
rope_beta_fast: float = 32.0
rope_beta_slow: float = 1.0
rope_truncate: bool = False
rms_norm_eps: float = 1e-5
attention_bias: bool = True
layer_types: Optional[List[str]] = None
moe_alpha: float = 1.702
moe_limit: float = 7.0
def __post_init__(self):
if self.layer_types is None:
self.layer_types = [
"sliding_attention" if (i + 1) % 2 else "full_attention"
for i in range(self.num_hidden_layers)
]
def _yarn_inv_freq(head_dim: int, base: float, factor: float, beta_fast: float, beta_slow: float,
original_max_position_embeddings: int, truncate: bool, device=None) -> tuple[torch.Tensor, float]:
"""YARN inv_freq + attention scaling (matches transformers)."""
dim = head_dim
def find_correction_dim(num_rotations: float) -> float:
return (dim * math.log(original_max_position_embeddings / (num_rotations * 2 * math.pi))) / (
2 * math.log(base)
)
def find_correction_range() -> tuple[float, float]:
low = find_correction_dim(beta_fast)
high = find_correction_dim(beta_slow)
if truncate:
low = math.floor(low)
high = math.ceil(high)
return max(low, 0), min(high, dim - 1)
def linear_ramp_factor(min_: float, max_: float, n: int) -> torch.Tensor:
if min_ == max_:
max_ += 0.001
linear = (torch.arange(n, dtype=torch.float32, device=device) - min_) / (max_ - min_)
return torch.clamp(linear, 0, 1)
def get_mscale(scale: float) -> float:
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
attention_scaling = get_mscale(factor)
pos_freqs = base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
low, high = find_correction_range()
extrap_factor = 1 - linear_ramp_factor(low, high, dim // 2)
inv_freq = inv_freq_interpolation * (1 - extrap_factor) + inv_freq_extrapolation * extrap_factor
return inv_freq, attention_scaling
def _build_freqs_cis(inv_freq: torch.Tensor, attention_scaling: float, position_ids: torch.Tensor, dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
inv_freq_e = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
pos_e = position_ids[:, None, :].float()
freqs = (inv_freq_e @ pos_e).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = (emb.cos() * attention_scaling).to(dtype).unsqueeze(1)
sin = (emb.sin() * attention_scaling).to(dtype).unsqueeze(1)
sin_split = sin.shape[-1] // 2
return cos, sin[..., :sin_split], -sin[..., sin_split:]
def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sinks: torch.Tensor,
attention_mask: Optional[torch.Tensor], num_heads: int, num_kv_groups: int) -> torch.Tensor:
"""Attention with per-head sinks.
Sinks add a learned term to each row's softmax denominator but contribute
nothing to the output. We fake this by appending one zero k/v position and
putting the sink logit in the mask at that column.
"""
if num_kv_groups > 1 and not TORCH_HAS_GQA:
k = k.repeat_interleave(num_kv_groups, dim=1)
v = v.repeat_interleave(num_kv_groups, dim=1)
B, _, S_q, D = q.shape
H_kv = k.shape[1]
S_kv = k.shape[-2]
k = torch.cat([k, k.new_zeros(B, H_kv, 1, D)], dim=-2)
v = torch.cat([v, v.new_zeros(B, H_kv, 1, D)], dim=-2)
sinks_col = sinks.to(q.dtype).view(1, num_heads, 1, 1).expand(B, num_heads, S_q, 1)
if attention_mask is not None:
mask_left = attention_mask[..., :S_kv].expand(B, num_heads, S_q, S_kv)
else:
mask_left = q.new_zeros(B, num_heads, S_q, S_kv)
mask = torch.cat([mask_left, sinks_col], dim=-1)
op = optimized_attention_for_device(q.device, mask=True, small_input=True)
return op(q, k, v, num_heads, mask=mask, skip_reshape=True, enable_gqa=True)
class GptOssAttention(nn.Module):
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
super().__init__()
self.layer_idx = layer_idx
self.layer_type = config.layer_types[layer_idx]
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.head_dim = config.head_dim
self.hidden_size = config.hidden_size
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
bias = config.attention_bias
self.q_proj = ops.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
self.o_proj = ops.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=bias, device=device, dtype=dtype)
self.sinks = nn.Parameter(torch.empty(self.num_heads, device=device, dtype=dtype))
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], freqs_cis) -> torch.Tensor:
B, S, _ = hidden_states.shape
q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
q, k = apply_rope(q, k, freqs_cis)
out = _attention_with_sinks(q, k, v, self.sinks, attention_mask, self.num_heads, self.num_kv_groups)
return self.o_proj(out)
# Mixture of Experts
class GptOssTopKRouter(nn.Module):
def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
super().__init__()
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_local_experts
self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype))
self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype))
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
weight = comfy.ops.cast_to_input(self.weight, hidden_states, copy=False)
bias = comfy.ops.cast_to_input(self.bias, hidden_states, copy=False)
logits = F.linear(hidden_states, weight, bias)
top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1)
# Softmax over top-k slice only
scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype)
return scores, top_idx
class GptOssExperts(nn.Module):
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
super().__init__()
self.num_experts = config.num_local_experts
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.alpha = config.moe_alpha
self.limit = config.moe_limit
E = self.num_experts
H = self.hidden_size
I = self.intermediate_size
self.gate_up_proj = ops.MoEExperts(num_experts=E, in_features=H, out_features=2 * I, bias=True, device=device, dtype=dtype)
self.down_proj = ops.MoEExperts(num_experts=E, in_features=I, out_features=H, bias=True, device=device, dtype=dtype)
def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
gate = gate_up[..., ::2]
up = gate_up[..., 1::2]
gate = gate.clamp(max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
return torch.addcmul(glu, up, glu)
def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
N = hidden_states.shape[0]
top_k = router_indices.shape[-1]
H = hidden_states.shape[-1]
per_pair = torch.zeros((N * top_k, H), dtype=hidden_states.dtype, device=hidden_states.device)
expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
with self.gate_up_proj.bank_resident(hidden_states) as gate_up_bank, \
self.down_proj.bank_resident(hidden_states) as down_bank:
for ei in expert_hit:
expert_idx = int(ei.item())
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current = hidden_states[token_idx]
gate_up = gate_up_bank.expert_linear(current, expert_idx)
gated = self._apply_gate(gate_up)
expert_out = down_bank.expert_linear(gated, expert_idx)
weighted = expert_out * routing_weights[token_idx, top_k_pos, None]
flat_idx = token_idx * top_k + top_k_pos
per_pair[flat_idx] = weighted.to(per_pair.dtype)
return per_pair.view(N, top_k, H).sum(dim=1)
class GptOssMLP(nn.Module):
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
super().__init__()
self.router = GptOssTopKRouter(config, device=device, dtype=dtype)
self.experts = GptOssExperts(config, device=device, dtype=dtype, ops=ops)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
B, S, H = hidden_states.shape
flat = hidden_states.reshape(-1, H)
scores, idx = self.router(flat)
out = self.experts(flat, idx, scores)
return out.reshape(B, S, H)
# Decoder layer + model
class GptOssDecoderLayer(nn.Module):
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
super().__init__()
self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops)
self.mlp = GptOssMLP(config, device=device, dtype=dtype, ops=ops)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
self.layer_type = config.layer_types[layer_idx]
def forward(self, x: torch.Tensor, attention_masks: dict[str, Optional[torch.Tensor]], freqs_cis) -> torch.Tensor:
residual = x
x = self.input_layernorm(x)
x = self.self_attn(x, attention_masks[self.layer_type], freqs_cis)
x = residual + x
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
x = residual + x
return x
def _make_full_causal_mask(B: int, S: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
neg = torch.finfo(dtype).min
mask = torch.full((S, S), neg, dtype=dtype, device=device).triu_(1)
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
if key_padding_mask is not None:
kp = key_padding_mask.to(dtype=dtype)
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
mask = mask + kp
return mask
def _make_sliding_causal_mask(B: int, S: int, window: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
neg = torch.finfo(dtype).min
i = torch.arange(S, device=device).view(-1, 1)
j = torch.arange(S, device=device).view(1, -1)
keep = (j <= i) & (j > i - window)
mask = torch.where(keep, torch.zeros((), dtype=dtype, device=device), torch.full((), neg, dtype=dtype, device=device))
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
if key_padding_mask is not None:
kp = key_padding_mask.to(dtype=dtype)
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
mask = mask + kp
return mask
class GptOssModel(nn.Module):
"""GPT-OSS decoder with multi-layer hidden-state capture + early exit."""
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
super().__init__()
self.config = config
self.dtype = dtype
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList(
[
GptOssDecoderLayer(config, i, device=device, dtype=dtype, ops=ops)
for i in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
# Always build on CPU so the buffer survives meta-device construction.
inv_freq, attn_scaling = _yarn_inv_freq(
head_dim=config.head_dim,
base=config.rope_theta,
factor=config.rope_factor,
beta_fast=config.rope_beta_fast,
beta_slow=config.rope_beta_slow,
original_max_position_embeddings=config.original_max_position_embeddings,
truncate=config.rope_truncate,
device=torch.device("cpu"),
)
self.register_buffer("rope_inv_freq", inv_freq, persistent=False)
self.rope_attention_scaling = float(attn_scaling)
@property
def num_layers(self) -> int:
return self.config.num_hidden_layers
def get_input_embeddings(self):
return self.embed_tokens
def _build_attention_masks(self, B: int, S: int, attention_mask: Optional[torch.Tensor], dtype: torch.dtype, device,
) -> dict[str, torch.Tensor]:
full = _make_full_causal_mask(B, S, attention_mask, dtype, device)
masks = {"full_attention": full}
if any(t == "sliding_attention" for t in self.config.layer_types):
masks["sliding_attention"] = _make_sliding_causal_mask(
B, S, self.config.sliding_window, attention_mask, dtype, device
)
return masks
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None,
capture_layers: Optional[Sequence[int]] = None) -> dict[str, Any]:
B, S = input_ids.shape
device = input_ids.device
dtype = self.dtype
hidden_states = self.embed_tokens(input_ids, out_dtype=dtype)
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
freqs_cis = _build_freqs_cis(self.rope_inv_freq.to(device=device), self.rope_attention_scaling, position_ids, dtype)
attn_masks = self._build_attention_masks(B, S, attention_mask, dtype, device)
capture_layers = list(capture_layers) if capture_layers else None
if capture_layers:
max_layer = max(capture_layers)
wanted = {idx: pos for pos, idx in enumerate(capture_layers)}
captured: List[Optional[torch.Tensor]] = [None] * len(capture_layers)
else:
max_layer = self.config.num_hidden_layers - 1
wanted = None
captured = None
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, attn_masks, freqs_cis)
if wanted is not None and i in wanted:
captured[wanted[i]] = hidden_states
if i >= max_layer:
break
if captured is not None:
return {"hidden_states": captured}
return {"last_hidden_state": self.norm(hidden_states)}
# Lens chat-template constants (verbatim from the reference pipeline).
_LENS_CHAT_SYSTEM = (
"Describe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background."
)
_LENS_CHAT_ASSISTANT_THINKING = "Need to generate one image according to the description."
LENS_TXT_OFFSET = 97
LENS_SELECTED_LAYERS = (5, 11, 17, 23)
LENS_MAX_TOKENS = 512
# The reference GPT-OSS Harmony template injects today's date here
_LENS_CHAT_DATE = "2026-05-23"
def _lens_render_chat(prompt: str) -> str:
"""Render the Lens prompt in GPT-OSS Harmony format."""
return (
f"<|start|>system<|message|>"
f"You are ChatGPT, a large language model trained by OpenAI.\n"
f"Knowledge cutoff: 2024-06\n"
f"Current date: {_LENS_CHAT_DATE}\n\n"
f"Reasoning: medium\n\n"
f"# Valid channels: analysis, commentary, final. "
f"Channel must be included for every message.<|end|>"
f"<|start|>developer<|message|># Instructions\n\n"
f"{_LENS_CHAT_SYSTEM}\n\n<|end|>"
f"<|start|>user<|message|>{prompt}<|end|>"
f"<|start|>assistant<|channel|>analysis<|message|>"
f"{_LENS_CHAT_ASSISTANT_THINKING}<|end|>"
f"<|start|>assistant<|channel|>final<|message|>"
)
# GPT-OSS-20B fixed token IDs (from the tokenizer's added-tokens table).
_LENS_PAD_TOKEN_ID = 199999 # <|endoftext|>
class _GptOssRawTokenizer:
"""Raw ``tokenizers.Tokenizer`` wrapper.
The tokenizer JSON ships as a byte tensor inside the encoder checkpoint
(``tokenizer_json`` key) rather than as a committed file. Extracted
it in ``sd.py`` and passes it here via ``tokenizer_data``.
"""
def __init__(self, tokenizer_json_bytes=None, **kwargs):
from tokenizers import Tokenizer
if isinstance(tokenizer_json_bytes, torch.Tensor):
tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist())
if tokenizer_json_bytes is None:
raise ValueError(
"Lens tokenizer requires the ``tokenizer_json`` byte tensor in the "
"encoder state dict. Re-bundle the encoder via bundle_te.py so it "
"embeds the tokenizer."
)
self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8"))
@classmethod
def from_pretrained(cls, tokenizer_data, **kwargs):
return cls(tokenizer_json_bytes=tokenizer_data, **kwargs)
def __call__(self, text):
return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids}
def get_vocab(self):
return self.tokenizer.get_vocab()
def convert_tokens_to_ids(self, tokens):
return [self.tokenizer.token_to_id(t) for t in tokens]
def decode(self, ids, **kwargs):
return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False))
class LensGptOssTokenizer(sd1_clip.SDTokenizer):
tokenizer_json_data = None
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_json = tokenizer_data.get("tokenizer_json", None)
self.tokenizer_json_data = tokenizer_json
super().__init__(
tokenizer_json,
embedding_directory=embedding_directory,
pad_with_end=False,
embedding_size=2880,
embedding_key="gpt_oss",
tokenizer_class=_GptOssRawTokenizer,
has_start_token=False,
has_end_token=False,
pad_to_max_length=False,
max_length=99999999,
min_length=1,
pad_left=False,
disable_weights=True,
tokenizer_data=tokenizer_data,
)
self.pad_token_id = _LENS_PAD_TOKEN_ID
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
# Empty prompt -> empty list; encode_token_weights returns zeros (uncond).
if not text or not text.strip():
return [[]]
rendered = _lens_render_chat(text)
ids = self.tokenizer(rendered)["input_ids"]
if len(ids) > LENS_MAX_TOKENS:
ids = ids[:LENS_MAX_TOKENS]
return [[(int(t), 1.0) for t in ids]]
def state_dict(self):
if self.tokenizer_json_data is not None:
return {"tokenizer_json": self.tokenizer_json_data}
return {}
class LensTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(
embedding_directory=embedding_directory,
tokenizer_data=tokenizer_data,
name="gpt_oss",
tokenizer=LensGptOssTokenizer,
)
class LensGptOssClipModel(nn.Module):
"""SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor)."""
def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs):
super().__init__()
model_options = dict(model_options or {})
operations = model_options.get("custom_operations")
if operations is None:
quant_config = model_options.get("quantization_metadata") or {}
operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
self.operations = operations
cfg_overrides = model_options.get("gpt_oss_config", {})
self.config = GptOss20BConfig(**cfg_overrides)
self.selected_layers = tuple(model_options.get("selected_layers", LENS_SELECTED_LAYERS))
self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET))
self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations)
self.num_layers = self.config.num_hidden_layers
self.dtype = dtype
self.execution_device = None
self._pad_token_id = _LENS_PAD_TOKEN_ID
def set_clip_options(self, options):
self.execution_device = options.get("execution_device", self.execution_device)
def reset_clip_options(self):
self.execution_device = None
def _gather_tokens(self, token_weight_pairs):
ids_list = [[int(t[0]) for t in batch] for batch in token_weight_pairs]
pad_id = self._pad_token_id
max_len = max(len(x) for x in ids_list)
device = self.execution_device
ids = torch.full((len(ids_list), max_len), pad_id, dtype=torch.long, device=device)
mask = torch.zeros((len(ids_list), max_len), dtype=torch.long, device=device)
for i, x in enumerate(ids_list):
ids[i, : len(x)] = torch.tensor(x, dtype=torch.long, device=device)
mask[i, : len(x)] = 1
return ids, mask
def encode_token_weights(self, token_weight_pairs):
# Empty negative: emit zero-length features + zero mask
if all(len(batch) == 0 for batch in token_weight_pairs):
device = self.execution_device
B = len(token_weight_pairs)
L = len(self.selected_layers)
H = self.config.hidden_size
flat = torch.zeros(B, 0, L * H, dtype=self.dtype, device=device)
mask = torch.zeros(B, 0, dtype=torch.long, device=device)
return flat, None, {"attention_mask": mask, "num_layers_stacked": L}
input_ids, attn_mask = self._gather_tokens(token_weight_pairs)
out = self.transformer(input_ids, attention_mask=attn_mask, capture_layers=self.selected_layers)
layers = out["hidden_states"] # list of L × [B, S, H]
stacked = torch.stack(layers, dim=2) # [B, S, L, H]
offset = self.txt_offset
if stacked.shape[1] > offset:
stacked = stacked[:, offset:].contiguous()
mask_trim = attn_mask[:, offset:]
else:
stacked = stacked[:, :0]
mask_trim = attn_mask[:, :0]
B, S, L, H = stacked.shape
flat = stacked.reshape(B, S, L * H)
extra = {"attention_mask": mask_trim, "num_layers_stacked": L}
return flat, None, extra
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=True)
class LensTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {})
def lens_te(dtype_llama=None, llama_quantization_metadata=None):
class LensTEModel_(LensTEModel):
def __init__(self, device="cpu", dtype=None, model_options=None):
mo = dict(model_options or {})
if llama_quantization_metadata is not None:
mo["quantization_metadata"] = llama_quantization_metadata
if dtype is None and dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=mo)
return LensTEModel_

View File

@ -0,0 +1,104 @@
import torch
from comfy import sd1_clip
from .lumina2 import Gemma2BTokenizer, LuminaModel
import comfy.text_encoders.llama
class PixelDiTGemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(
device=device, layer=layer, layer_idx=layer_idx,
textmodel_json_config={}, dtype=dtype,
special_tokens={"start": 2, "pad": 0},
layer_norm_hidden_state=False,
model_class=comfy.text_encoders.llama.Gemma2_2B,
enable_attention_masks=attention_mask,
return_attention_masks=attention_mask,
model_options=model_options,
)
_PIXELDIT_CHI_PROMPT = (
'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions '
"suitable for image generation. Evaluate the level of detail in the user prompt:\n"
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, "
"and spatial relationships to create vivid and concrete scenes.\n"
"- If the prompt is already detailed, refine and enhance the existing details slightly without "
"overcomplicating.\n"
"Here are examples of how to transform or refine prompts:\n"
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, "
"sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n"
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring "
"glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus "
"passing by towering glass skyscrapers.\n"
"Please generate only the enhanced description for the prompt below and avoid including any "
"additional commentary or evaluations:\n"
"User Prompt: "
)
_PIXELDIT_MAX_LENGTH = 300
_PIXELDIT_CHI_PROMPT_DETECT_PREFIX = 'Given a user prompt, generate an "Enhanced prompt"'
class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
name="gemma2_2b", tokenizer=Gemma2BTokenizer)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
if not text.strip():
return super().tokenize_with_weights("", return_word_ids=return_word_ids, disable_weights=True, min_length=_PIXELDIT_MAX_LENGTH)
chi_token_count = len(self.gemma2_2b.tokenizer(_PIXELDIT_CHI_PROMPT)["input_ids"])
combined = text if text.startswith(_PIXELDIT_CHI_PROMPT_DETECT_PREFIX) else _PIXELDIT_CHI_PROMPT + text
max_length_all = chi_token_count + _PIXELDIT_MAX_LENGTH - 2
out = super().tokenize_with_weights(combined, return_word_ids=return_word_ids,
disable_weights=True, min_length=max_length_all)
out["gemma2_2b"] = [out["gemma2_2b"][0][:max_length_all]]
return out
def untokenize(self, token_weight_pair):
return self.gemma2_2b.untokenize(token_weight_pair)
def state_dict(self):
return self.gemma2_2b.state_dict()
class PixelDiTGemma2TE(LuminaModel):
# PixelDiT's select_index: keep BOS + last 299 embeddings of the padded sequence.
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="gemma2_2b",
clip_model=PixelDiTGemma2_2BModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
result = super().encode_token_weights(token_weight_pairs)
cond, pooled = result[0], result[1]
extra = result[2] if len(result) > 2 else None
if cond.shape[1] > _PIXELDIT_MAX_LENGTH:
cond = torch.cat([cond[:, :1], cond[:, -(_PIXELDIT_MAX_LENGTH - 1):]], dim=1)
if extra is not None and "attention_mask" in extra:
am = extra["attention_mask"]
extra["attention_mask"] = torch.cat([am[..., :1], am[..., -(_PIXELDIT_MAX_LENGTH - 1):]], dim=-1)
if extra is not None:
return cond, pooled, extra
return cond, pooled
def pixeldit_te(dtype_llama=None, llama_quantization_metadata=None):
class PixelDiTTE_(PixelDiTGemma2TE):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return PixelDiTTE_

View File

@ -762,10 +762,17 @@ class Accumulation(ComfyTypeIO):
@comfytype(io_type="LOAD3D_CAMERA")
class Load3DCamera(ComfyTypeIO):
class CameraInfo(TypedDict):
position: dict[str, float | int]
target: dict[str, float | int]
zoom: int
cameraType: str
# Coordinate system: right-handed, Y-up, camera looks down -Z
position: dict[str, float | int] # scene units
target: dict[str, float | int] # scene units; OrbitControls focus point
zoom: float | int # dimensionless, 1 = 100%
cameraType: str # 'perspective' | 'orthographic'
quaternion: NotRequired[dict[str, float | int]] # normalized, dimensionless; camera world rotation
fov: NotRequired[float | int] # degrees, vertical FOV (perspective only)
aspect: NotRequired[float | int] # width / height (perspective only)
near: NotRequired[float | int] # scene units
far: NotRequired[float | int] # scene units
frustum: NotRequired[dict[str, float | int]] # orthographic only: {left, right, top, bottom} in scene units
Type = CameraInfo

View File

@ -0,0 +1,32 @@
from pydantic import BaseModel, Field
class CreateSwitchXRequest(BaseModel):
generation_type: str = Field(...)
source_uri: str = Field(...)
alpha_mode: str = Field(...)
prompt: str | None = Field(None, max_length=2000)
reference_image_uri: str | None = Field(None)
alpha_uri: str | None = Field(None)
max_resolution: int = Field(1080)
callback_url: str | None = Field(None)
idempotency_key: str | None = Field(None, max_length=256, min_length=1)
class SwitchXOutputUrls(BaseModel):
render: str | None = Field(None)
source: str | None = Field(None)
alpha: str | None = Field(None)
class SwitchXStatusResponse(BaseModel):
id: str = Field(...)
status: str = Field(...)
progress: int | None = Field(None)
generation_type: str | None = Field(None)
alpha_mode: str | None = Field(None)
output: SwitchXOutputUrls | None = Field(None)
error: str | None = Field(None)
created_at: str | None = Field(None)
modified_at: str | None = Field(None)
completed_at: str | None = Field(None)

View File

@ -158,8 +158,9 @@ class SeedanceCreateAssetResponse(BaseModel):
class SeedanceVirtualLibraryCreateAssetRequest(BaseModel):
url: str = Field(..., description="Publicly accessible URL of the image asset to upload.")
url: str = Field(..., description="Publicly accessible URL of the asset to upload.")
hash: str = Field(..., description="Dedup key. Re-submitting the same hash returns the existing asset id.")
asset_type: str | None = Field(None, description="BytePlus asset type. Defaults to Image server-side when omitted.")
# Dollars per 1K tokens, keyed by (model_id, has_video_input).

View File

@ -0,0 +1,46 @@
"""Pydantic models for the Krea image-generation API."""
from pydantic import BaseModel, Field
class KreaMoodboard(BaseModel):
id: str = Field(...)
strength: float = Field(default=0.35, ge=-0.5, le=1.5)
class KreaImageStyleReference(BaseModel):
strength: float = Field(..., ge=-2.0, le=2.0)
url: str | None = Field(default=None)
class KreaGenerateImageRequest(BaseModel):
prompt: str = Field(...)
aspect_ratio: str = Field(...)
resolution: str = Field(...)
seed: int | None = Field(default=None)
creativity: str = Field(default="medium")
moodboards: list[KreaMoodboard] | None = Field(default=None)
image_style_references: list[KreaImageStyleReference] | None = Field(default=None)
class KreaJobResult(BaseModel):
urls: list[str] | None = Field(default=None)
style_id: str | None = Field(default=None)
class KreaJob(BaseModel):
job_id: str = Field(...)
status: str = Field(...)
created_at: str = Field(...)
completed_at: str | None = Field(default=None)
result: KreaJobResult | None = Field(default=None)
class KreaAssetResponse(BaseModel):
id: str = Field(...)
image_url: str = Field(...)
uploaded_at: str = Field(...)
width: float | None = Field(default=None)
height: float | None = Field(default=None)
size_bytes: float | None = Field(default=None)
mime_type: str | None = Field(default=None)

View File

@ -155,7 +155,7 @@ class ClaudeNode(IO.ComfyNode):
return IO.Schema(
node_id="ClaudeNode",
display_name="Anthropic Claude",
category="api node/text/Anthropic",
category="text/partner/Anthropic",
essentials_category="Text Generation",
description="Generate text responses with Anthropic's Claude models. "
"Provide a text prompt and optionally one or more images for multimodal context.",

View File

@ -0,0 +1,404 @@
from fractions import Fraction
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl, Types
from comfy_api_nodes.apis.beeble import (
CreateSwitchXRequest,
SwitchXStatusResponse,
)
from comfy_api_nodes.util import (
ApiEndpoint,
bytesio_to_image_tensor,
convert_mask_to_image,
download_url_as_bytesio,
download_url_to_image_tensor,
download_url_to_video_output,
downscale_image_tensor,
downscale_video_to_max_pixels,
poll_op,
sync_op,
upload_image_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
validate_video_frame_count,
)
_MAX_PIXELS = 2_770_000
_MAX_FRAMES = 240
_MAX_PROMPT_LEN = 2000
def _validate_inputs(prompt: str | None, reference_image: Input.Image | None) -> str | None:
"""Beeble requires at least one of prompt or reference_image. Returns the cleaned prompt."""
cleaned = prompt.strip() if prompt else ""
if not cleaned and reference_image is None:
raise ValueError("At least one of 'prompt' or 'reference_image' must be provided.")
if cleaned:
validate_string(cleaned, strip_whitespace=False, max_length=_MAX_PROMPT_LEN)
return cleaned or None
async def _upload_mask_as_image(
cls: type[IO.ComfyNode],
mask: Input.Image,
*,
wait_label: str,
) -> str:
"""Encode a single-frame MASK (H, W) or (1, H, W) as a PNG and upload."""
if mask.dim() == 2:
mask = mask.unsqueeze(0)
image = convert_mask_to_image(mask[:1])
return await upload_image_to_comfyapi(
cls,
image,
mime_type="image/png",
wait_label=wait_label,
total_pixels=_MAX_PIXELS,
)
async def _upload_mask_batch_as_video(
cls: type[IO.ComfyNode],
mask: Input.Image,
*,
frame_rate: Fraction,
source_frame_count: int,
wait_label: str,
) -> str:
"""Encode a MASK batch (N, H, W) as a grayscale H.264 MP4 at frame_rate and upload.
The matte is always downscaled to the pixel budget so it stays within Beeble's limit and
keeps the same dimensions as the (similarly downscaled) source both use the same algorithm
from the same starting dimensions, and downscaling is a no-op when already within budget.
"""
if mask.dim() == 2:
mask = mask.unsqueeze(0)
if mask.shape[0] != source_frame_count:
raise ValueError(
f"Custom alpha video frame count ({mask.shape[0]}) does not match the "
f"source video frame count ({source_frame_count}). The Beeble API requires "
"one mask per source frame."
)
images = downscale_image_tensor(convert_mask_to_image(mask), _MAX_PIXELS)
alpha_video = InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=None, frame_rate=frame_rate))
return await upload_video_to_comfyapi(cls, alpha_video, wait_label=wait_label)
def _alpha_mode_input(*, video: bool) -> IO.DynamicCombo.Input:
"""Build the alpha_mode DynamicCombo with mode-specific extra inputs."""
select_keyframe_tooltip = (
"First-frame keyframe mask. Beeble propagates this across the video." if video else "Grayscale keyframe mask."
)
custom_tooltip = (
"Per-frame grayscale mask covering the entire video. "
"Must have the same frame count as the source. "
"Connect a MASK output from SAM3_TrackToMask or similar."
if video
else "Grayscale mask to apply."
)
return IO.DynamicCombo.Input(
"alpha_mode",
tooltip=(
"Controls how SwitchX decides what to keep vs. regenerate. "
"'auto' isolates the main subject automatically. "
"'fill' regenerates the entire frame while preserving geometry. "
"'select' propagates a first-frame keyframe across the clip. "
"'custom' uses a per-frame alpha matte you provide."
),
options=[
IO.DynamicCombo.Option("auto", []),
IO.DynamicCombo.Option("fill", []),
IO.DynamicCombo.Option(
"select",
[IO.Mask.Input("alpha_keyframe", tooltip=select_keyframe_tooltip)],
),
IO.DynamicCombo.Option(
"custom",
[IO.Mask.Input("alpha_mask", tooltip=custom_tooltip)],
),
],
)
def _common_inputs(*, source: IO.Input, video: bool) -> list[IO.Input]:
return [
source,
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip=(
"Text description of the desired output (max 2000 chars). "
"At least one of 'prompt' or 'reference_image' is required."
),
),
IO.Image.Input(
"reference_image",
optional=True,
tooltip=(
"Reference image whose look (background, lighting, costume) the result "
"should adopt. At least one of 'reference_image' or 'prompt' is required."
),
),
_alpha_mode_input(video=video),
IO.Combo.Input(
"max_resolution",
options=["1080p", "720p"],
default="1080p",
tooltip="Maximum output resolution.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip=(
"Seed controls whether the node should re-run; " "results are non-deterministic regardless of seed."
),
),
]
async def _submit_and_poll(
cls: type[IO.ComfyNode],
request: CreateSwitchXRequest,
) -> SwitchXStatusResponse:
initial = await sync_op(
cls,
ApiEndpoint(path="/proxy/beeble/v1/switchx/generations", method="POST"),
response_model=SwitchXStatusResponse,
data=request,
)
return await poll_op(
cls,
ApiEndpoint(path=f"/proxy/beeble/v1/switchx/generations/{initial.id}"),
response_model=SwitchXStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
)
def _require_output_url(response: SwitchXStatusResponse, name: str) -> str:
if response.output is None or getattr(response.output, name) is None:
raise RuntimeError(f"Beeble job {response.id} completed without a {name!r} output URL.")
return getattr(response.output, name)
def _alpha_url(response: SwitchXStatusResponse, mode: str) -> str | None:
"""URL of the alpha matte, or None when the mode produces no separate matte.
'fill' selects the whole frame, so Beeble writes no alpha asset even though the status
response still returns a (dangling) signed URL for it fetching it 403s with S3
AccessDenied. The other three modes ('auto', 'custom', 'select') all produce a real,
downloadable matte.
"""
if mode == "fill" or response.output is None:
return None
return response.output.alpha
class BeebleSwitchXVideoEdit(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="BeebleSwitchXVideoEdit",
display_name="Beeble SwitchX Video Edit",
category="video/partner/Beeble",
description=(
"Edit a video with Beeble SwitchX. Switches anything in the scene (background, "
"lighting, costume) while preserving the original subject's pixels and motion. "
"Provide a reference image and/or text prompt to describe the new look. "
"Max 240 frames, max ~2.77MP per frame."
),
inputs=_common_inputs(source=IO.Video.Input("video"), video=True),
outputs=[
IO.Video.Output(display_name="video"),
IO.Video.Output(
display_name="alpha",
tooltip="The alpha matte Beeble used. Empty for 'fill' mode, which has no separate matte.",
),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["max_resolution"]),
expr="""
(
$rate := widgets.max_resolution = "1080p" ? 0.429 : 0.143;
{"type":"usd","usd": $rate, "format":{"suffix":"/30 frames"}}
)
""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
prompt: str,
alpha_mode: dict,
max_resolution: str,
seed: int,
reference_image: Input.Image | None = None,
) -> IO.NodeOutput:
cleaned_prompt = _validate_inputs(prompt, reference_image)
validate_video_frame_count(video, max_frame_count=_MAX_FRAMES)
video = downscale_video_to_max_pixels(video, _MAX_PIXELS)
mode = alpha_mode["alpha_mode"]
alpha_uri: str | None = None
if mode == "select":
alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_keyframe"], wait_label="Uploading keyframe")
elif mode == "custom":
alpha_uri = await _upload_mask_batch_as_video(
cls,
alpha_mode["alpha_mask"],
frame_rate=video.get_frame_rate(),
source_frame_count=video.get_frame_count(),
wait_label="Uploading alpha video",
)
source_uri = await upload_video_to_comfyapi(cls, video, wait_label="Uploading source")
reference_uri: str | None = None
if reference_image is not None:
reference_uri = await upload_image_to_comfyapi(
cls,
reference_image,
mime_type="image/png",
wait_label="Uploading reference",
total_pixels=_MAX_PIXELS,
)
request = CreateSwitchXRequest(
generation_type="video",
source_uri=source_uri,
alpha_mode=mode,
prompt=cleaned_prompt,
reference_image_uri=reference_uri,
alpha_uri=alpha_uri,
max_resolution=1080 if max_resolution == "1080p" else 720,
)
response = await _submit_and_poll(cls, request)
render = await download_url_to_video_output(_require_output_url(response, "render"))
alpha = None
if (alpha_url := _alpha_url(response, mode)) is not None:
alpha = await download_url_to_video_output(alpha_url)
return IO.NodeOutput(render, alpha)
class BeebleSwitchXImageEdit(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="BeebleSwitchXImageEdit",
display_name="Beeble SwitchX Image Edit",
category="image/partner/Beeble",
description=(
"Edit a single image with Beeble SwitchX. Switches anything in the scene "
"(background, lighting, costume) while preserving the original subject's pixels. "
"Provide a reference image and/or text prompt to describe the new look. "
"Max ~2.77MP."
),
inputs=_common_inputs(source=IO.Image.Input("image"), video=False),
outputs=[
IO.Image.Output(display_name="image"),
IO.Mask.Output(
display_name="alpha",
tooltip="The alpha matte Beeble used. Empty for 'fill' mode, which has no separate matte.",
),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["max_resolution"]),
expr="""
(
$rate := widgets.max_resolution = "1080p" ? 0.429 : 0.143;
{"type":"usd","usd": $rate}
)
""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
prompt: str,
alpha_mode: dict,
max_resolution: str,
seed: int,
reference_image: Input.Image | None = None,
) -> IO.NodeOutput:
cleaned_prompt = _validate_inputs(prompt, reference_image)
image = downscale_image_tensor(image, _MAX_PIXELS)
mode = alpha_mode["alpha_mode"]
alpha_uri: str | None = None
if mode == "select":
alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_keyframe"], wait_label="Uploading keyframe")
elif mode == "custom":
alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_mask"], wait_label="Uploading alpha")
source_uri = await upload_image_to_comfyapi(
cls,
image,
mime_type="image/png",
wait_label="Uploading source",
total_pixels=None,
)
reference_uri: str | None = None
if reference_image is not None:
reference_uri = await upload_image_to_comfyapi(
cls,
reference_image,
mime_type="image/png",
wait_label="Uploading reference",
total_pixels=_MAX_PIXELS,
)
request = CreateSwitchXRequest(
generation_type="image",
source_uri=source_uri,
alpha_mode=mode,
prompt=cleaned_prompt,
reference_image_uri=reference_uri,
alpha_uri=alpha_uri,
max_resolution=1080 if max_resolution == "1080p" else 720,
)
response = await _submit_and_poll(cls, request)
render = await download_url_to_image_tensor(_require_output_url(response, "render"))
alpha_mask = None
if (alpha_url := _alpha_url(response, mode)) is not None:
alpha_image = bytesio_to_image_tensor(await download_url_as_bytesio(alpha_url), mode="L")
alpha_mask = alpha_image.squeeze(-1) if alpha_image.dim() == 4 else alpha_image
return IO.NodeOutput(render, alpha_mask)
class BeebleExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
BeebleSwitchXVideoEdit,
BeebleSwitchXImageEdit,
]
async def comfy_entrypoint() -> BeebleExtension:
return BeebleExtension()

View File

@ -42,7 +42,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
return IO.Schema(
node_id="FluxProUltraImageNode",
display_name="Flux 1.1 [pro] Ultra Image",
category="api node/image/BFL",
category="image/partner/BFL",
description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.",
inputs=[
IO.String.Input(
@ -160,7 +160,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
return IO.Schema(
node_id=cls.NODE_ID,
display_name=cls.DISPLAY_NAME,
category="api node/image/BFL",
category="image/partner/BFL",
description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.",
inputs=[
IO.String.Input(
@ -282,7 +282,7 @@ class FluxProExpandNode(IO.ComfyNode):
return IO.Schema(
node_id="FluxProExpandNode",
display_name="Flux.1 Expand Image",
category="api node/image/BFL",
category="image/partner/BFL",
description="Outpaints image based on prompt.",
inputs=[
IO.Image.Input("image"),
@ -419,7 +419,7 @@ class FluxProFillNode(IO.ComfyNode):
return IO.Schema(
node_id="FluxProFillNode",
display_name="Flux.1 Fill Image",
category="api node/image/BFL",
category="image/partner/BFL",
description="Inpaints image based on mask and prompt.",
inputs=[
IO.Image.Input("image"),
@ -545,7 +545,7 @@ class Flux2ProImageNode(IO.ComfyNode):
return IO.Schema(
node_id=cls.NODE_ID,
display_name=cls.DISPLAY_NAME,
category="api node/image/BFL",
category="image/partner/BFL",
description="Generates images synchronously based on prompt and resolution.",
inputs=[
IO.String.Input(
@ -716,7 +716,7 @@ class Flux2ImageNode(IO.ComfyNode):
return IO.Schema(
node_id="Flux2ImageNode",
display_name="Flux.2 Image",
category="api node/image/BFL",
category="image/partner/BFL",
description="Generate images via Flux.2 [pro] or Flux.2 [max] from a prompt and optional reference images.",
inputs=[
IO.String.Input(

View File

@ -31,7 +31,7 @@ class BriaImageEditNode(IO.ComfyNode):
return IO.Schema(
node_id="BriaImageEditNode",
display_name="Bria FIBO Image Edit",
category="api node/image/Bria",
category="image/partner/Bria",
description="Edit images using Bria latest model",
inputs=[
IO.Combo.Input("model", options=["FIBO"]),
@ -169,7 +169,7 @@ class BriaRemoveImageBackground(IO.ComfyNode):
return IO.Schema(
node_id="BriaRemoveImageBackground",
display_name="Bria Remove Image Background",
category="api node/image/Bria",
category="image/partner/Bria",
description="Remove the background from an image using Bria RMBG 2.0.",
inputs=[
IO.Image.Input("image"),
@ -245,7 +245,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
return IO.Schema(
node_id="BriaRemoveVideoBackground",
display_name="Bria Remove Video Background",
category="api node/video/Bria",
category="video/partner/Bria",
description="Remove the background from a video using Bria. ",
inputs=[
IO.Video.Input("video"),

View File

@ -2,11 +2,12 @@ import hashlib
import logging
import math
import re
from io import BytesIO
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.bytedance import (
RECOMMENDED_PRESETS,
RECOMMENDED_PRESETS_SEEDREAM_4,
@ -43,6 +44,7 @@ from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
download_url_to_video_output,
downscale_image_tensor_by_max_side,
downscale_video_to_max_pixels,
get_number_of_images,
image_tensor_pair_to_batch,
@ -121,6 +123,14 @@ def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: st
)
def _prepare_seedance_image(image: Input.Image) -> Input.Image:
"""Auto-downscale a Seedance image input to the per-side limits, then validate it."""
validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
image = downscale_image_tensor_by_max_side(image, max_side=6000)
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
return image
async def _resolve_reference_assets(
cls: type[IO.ComfyNode],
asset_ids: list[str],
@ -308,6 +318,26 @@ async def _seedance_virtual_library_upload_image_asset(
return f"asset://{create_resp.asset_id}"
async def _seedance_virtual_library_upload_video_asset(
cls: type[IO.ComfyNode],
video: Input.Video,
*,
wait_label: str = "Uploading video",
) -> str:
buf = BytesIO()
video.save_to(buf, format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264)
video_hash = hashlib.sha256(buf.getbuffer()).hexdigest()
public_url = await upload_video_to_comfyapi(cls, video, wait_label=wait_label)
create_resp = await sync_op(
cls,
ApiEndpoint(path="/proxy/seedance/virtual-library/assets", method="POST"),
response_model=SeedanceCreateAssetResponse,
data=SeedanceVirtualLibraryCreateAssetRequest(url=public_url, hash=video_hash, asset_type="Video"),
)
await _wait_for_asset_active(cls, create_resp.asset_id, group_id="virtual-library")
return f"asset://{create_resp.asset_id}"
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
@ -338,7 +368,7 @@ class ByteDanceImageNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceImageNode",
display_name="ByteDance Image",
category="api node/image/ByteDance",
category="image/partner/ByteDance",
description="Generate images using ByteDance models via api based on prompt",
inputs=[
IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
@ -462,7 +492,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceSeedreamNode",
display_name="ByteDance Seedream 4.5 & 5.0",
category="api node/image/ByteDance",
category="image/partner/ByteDance",
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
inputs=[
IO.Combo.Input(
@ -724,7 +754,7 @@ class ByteDanceSeedreamNodeV2(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceSeedreamNodeV2",
display_name="ByteDance Seedream 4.5 & 5.0",
category="api node/image/ByteDance",
category="image/partner/ByteDance",
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
inputs=[
IO.String.Input(
@ -890,7 +920,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceTextToVideoNode",
display_name="ByteDance Text to Video",
category="api node/video/ByteDance",
category="video/partner/ByteDance",
description="Generate video using ByteDance models via api based on prompt",
inputs=[
IO.Combo.Input(
@ -1018,7 +1048,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceImageToVideoNode",
display_name="ByteDance Image to Video",
category="api node/video/ByteDance",
category="video/partner/ByteDance",
description="Generate video using ByteDance models via api based on image and prompt",
inputs=[
IO.Combo.Input(
@ -1155,7 +1185,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceFirstLastFrameNode",
display_name="ByteDance First-Last-Frame to Video",
category="api node/video/ByteDance",
category="video/partner/ByteDance",
description="Generate video using prompt and first and last frames.",
inputs=[
IO.Combo.Input(
@ -1303,7 +1333,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceImageReferenceNode",
display_name="ByteDance Reference Images to Video",
category="api node/video/ByteDance",
category="video/partner/ByteDance",
description="Generate video using prompt and reference images.",
inputs=[
IO.Combo.Input(
@ -1546,7 +1576,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDance2TextToVideoNode",
display_name="ByteDance Seedance 2.0 Text to Video",
category="api node/video/ByteDance",
category="video/partner/ByteDance",
description="Generate video using Seedance 2.0 models based on a text prompt.",
inputs=[
IO.DynamicCombo.Input(
@ -1647,7 +1677,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDance2FirstLastFrameNode",
display_name="ByteDance Seedance 2.0 First-Last-Frame to Video",
category="api node/video/ByteDance",
category="video/partner/ByteDance",
description="Generate video using Seedance 2.0 from a first frame image and optional last frame image.",
inputs=[
IO.DynamicCombo.Input(
@ -1760,6 +1790,11 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
if last_frame is not None and last_frame_asset_id:
raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.")
if first_frame is not None:
first_frame = _prepare_seedance_image(first_frame)
if last_frame is not None:
last_frame = _prepare_seedance_image(last_frame)
asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a]
image_assets: dict[str, str] = {}
if asset_ids_to_resolve:
@ -1866,7 +1901,7 @@ def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16
),
IO.Boolean.Input(
"auto_downscale",
default=False,
default=True,
optional=True,
tooltip="Automatically downscale reference videos that exceed the model's pixel budget "
"for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.",
@ -1909,7 +1944,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDance2ReferenceNode",
display_name="ByteDance Seedance 2.0 Reference to Video",
category="api node/video/ByteDance",
category="video/partner/ByteDance",
description="Generate, edit, or extend video using Seedance 2.0 with reference images, "
"videos, and audio. Supports multimodal reference, video editing, and video extension.",
inputs=[
@ -2034,6 +2069,9 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
f"(audios={len(reference_audios)}, audio assets={len(reference_audio_assets)}). Maximum is 3."
)
for key in reference_images:
reference_images[key] = _prepare_seedance_image(reference_images[key])
model_id = SEEDANCE_MODELS[model["model"]]
has_video_input = total_videos > 0
@ -2106,7 +2144,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
content.append(
TaskVideoContent(
video_url=TaskVideoContentUrl(
url=await upload_video_to_comfyapi(
url=await _seedance_virtual_library_upload_video_asset(
cls,
reference_videos[key],
wait_label=f"Uploading video {i}",
@ -2203,7 +2241,7 @@ class ByteDanceCreateImageAsset(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceCreateImageAsset",
display_name="ByteDance Create Image Asset",
category="api node/image/ByteDance",
category="image/partner/ByteDance",
description=(
"Create a Seedance 2.0 personal image asset. Uploads the input image and "
"registers it in the given asset group. If group_id is empty, runs a real-person "
@ -2270,7 +2308,7 @@ class ByteDanceCreateVideoAsset(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceCreateVideoAsset",
display_name="ByteDance Create Video Asset",
category="api node/video/ByteDance",
category="video/partner/ByteDance",
description=(
"Create a Seedance 2.0 personal video asset. Uploads the input video and "
"registers it in the given asset group. If group_id is empty, runs a real-person "

View File

@ -144,7 +144,7 @@ class ByteDanceSeedNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceSeedNode",
display_name="ByteDance Seed",
category="api node/text/ByteDance",
category="text/partner/ByteDance",
essentials_category="Text Generation",
description="Generate text responses with ByteDance's Seed 2.0 models. "
"Provide a text prompt and optionally one or more images or videos for multimodal context.",

View File

@ -69,7 +69,7 @@ class ElevenLabsSpeechToText(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsSpeechToText",
display_name="ElevenLabs Speech to Text",
category="api node/audio/ElevenLabs",
category="audio/partner/ElevenLabs",
description="Transcribe audio to text. "
"Supports automatic language detection, speaker diarization, and audio event tagging.",
inputs=[
@ -210,7 +210,7 @@ class ElevenLabsVoiceSelector(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsVoiceSelector",
display_name="ElevenLabs Voice Selector",
category="api node/audio/ElevenLabs",
category="audio/partner/ElevenLabs",
description="Select a predefined ElevenLabs voice for text-to-speech generation.",
inputs=[
IO.Combo.Input(
@ -239,7 +239,7 @@ class ElevenLabsTextToSpeech(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsTextToSpeech",
display_name="ElevenLabs Text to Speech",
category="api node/audio/ElevenLabs",
category="audio/partner/ElevenLabs",
description="Convert text to speech.",
inputs=[
IO.Custom(ELEVENLABS_VOICE).Input(
@ -414,7 +414,7 @@ class ElevenLabsAudioIsolation(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsAudioIsolation",
display_name="ElevenLabs Voice Isolation",
category="api node/audio/ElevenLabs",
category="audio/partner/ElevenLabs",
description="Remove background noise from audio, isolating vocals or speech.",
inputs=[
IO.Audio.Input(
@ -459,7 +459,7 @@ class ElevenLabsTextToSoundEffects(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsTextToSoundEffects",
display_name="ElevenLabs Text to Sound Effects",
category="api node/audio/ElevenLabs",
category="audio/partner/ElevenLabs",
description="Generate sound effects from text descriptions.",
inputs=[
IO.String.Input(
@ -555,7 +555,7 @@ class ElevenLabsInstantVoiceClone(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsInstantVoiceClone",
display_name="ElevenLabs Instant Voice Clone",
category="api node/audio/ElevenLabs",
category="audio/partner/ElevenLabs",
description="Create a cloned voice from audio samples. "
"Provide 1-8 audio recordings of the voice to clone.",
inputs=[
@ -658,7 +658,7 @@ class ElevenLabsSpeechToSpeech(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsSpeechToSpeech",
display_name="ElevenLabs Speech to Speech",
category="api node/audio/ElevenLabs",
category="audio/partner/ElevenLabs",
description="Transform speech from one voice to another while preserving the original content and emotion.",
inputs=[
IO.Custom(ELEVENLABS_VOICE).Input(
@ -793,7 +793,7 @@ class ElevenLabsTextToDialogue(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsTextToDialogue",
display_name="ElevenLabs Text to Dialogue",
category="api node/audio/ElevenLabs",
category="audio/partner/ElevenLabs",
description="Generate multi-speaker dialogue from text. Each dialogue entry has its own text and voice.",
inputs=[
IO.Float.Input(

View File

@ -300,7 +300,7 @@ class GeminiNode(IO.ComfyNode):
return IO.Schema(
node_id="GeminiNode",
display_name="Google Gemini",
category="api node/text/Gemini",
category="text/partner/Gemini",
description="Generate text responses with Google's Gemini AI model. "
"You can provide multiple types of inputs (text, images, audio, video) "
"as context for generating more relevant and meaningful responses.",
@ -541,7 +541,7 @@ class GeminiInputFiles(IO.ComfyNode):
return IO.Schema(
node_id="GeminiInputFiles",
display_name="Gemini Input Files",
category="api node/text/Gemini",
category="text/partner/Gemini",
description="Loads and prepares input files to include as inputs for Gemini LLM nodes. "
"The files will be read by the Gemini model when generating a response. "
"The contents of the text file count toward the token limit. "
@ -598,7 +598,7 @@ class GeminiImage(IO.ComfyNode):
return IO.Schema(
node_id="GeminiImageNode",
display_name="Nano Banana (Google Gemini Image)",
category="api node/image/Gemini",
category="image/partner/Gemini",
description="Edit images synchronously via Google API.",
inputs=[
IO.String.Input(
@ -731,7 +731,7 @@ class GeminiImage2(IO.ComfyNode):
return IO.Schema(
node_id="GeminiImage2Node",
display_name="Nano Banana Pro (Google Gemini Image)",
category="api node/image/Gemini",
category="image/partner/Gemini",
description="Generate or edit images synchronously via Google Vertex API.",
inputs=[
IO.String.Input(
@ -869,7 +869,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
return IO.Schema(
node_id="GeminiNanoBanana2",
display_name="Nano Banana 2",
category="api node/image/Gemini",
category="image/partner/Gemini",
description="Generate or edit images synchronously via Google Vertex API.",
inputs=[
IO.String.Input(
@ -1085,7 +1085,7 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
return IO.Schema(
node_id="GeminiNanoBanana2V2",
display_name="Nano Banana 2",
category="api node/image/Gemini",
category="image/partner/Gemini",
description="Generate or edit images synchronously via Google Vertex API.",
inputs=[
IO.String.Input(

View File

@ -49,7 +49,7 @@ class GrokImageNode(IO.ComfyNode):
return IO.Schema(
node_id="GrokImageNode",
display_name="Grok Image",
category="api node/image/Grok",
category="image/partner/Grok",
description="Generate images using Grok based on a text prompt",
inputs=[
IO.Combo.Input(
@ -224,7 +224,7 @@ class GrokImageEditNode(IO.ComfyNode):
return IO.Schema(
node_id="GrokImageEditNode",
display_name="Grok Image Edit",
category="api node/image/Grok",
category="image/partner/Grok",
description="Modify an existing image based on a text prompt",
inputs=[
IO.Combo.Input(
@ -366,7 +366,7 @@ class GrokImageEditNodeV2(IO.ComfyNode):
return IO.Schema(
node_id="GrokImageEditNodeV2",
display_name="Grok Image Edit",
category="api node/image/Grok",
category="image/partner/Grok",
description="Modify an existing image based on a text prompt",
inputs=[
IO.String.Input(
@ -503,7 +503,7 @@ class GrokVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="GrokVideoNode",
display_name="Grok Video",
category="api node/video/Grok",
category="video/partner/Grok",
description="Generate video from a prompt or an image",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
@ -615,7 +615,7 @@ class GrokVideoEditNode(IO.ComfyNode):
return IO.Schema(
node_id="GrokVideoEditNode",
display_name="Grok Video Edit",
category="api node/video/Grok",
category="video/partner/Grok",
description="Edit an existing video based on a text prompt.",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
@ -693,7 +693,7 @@ class GrokVideoReferenceNode(IO.ComfyNode):
return IO.Schema(
node_id="GrokVideoReferenceNode",
display_name="Grok Reference-to-Video",
category="api node/video/Grok",
category="video/partner/Grok",
description="Generate video guided by reference images as style and content references.",
inputs=[
IO.String.Input(
@ -826,7 +826,7 @@ class GrokVideoExtendNode(IO.ComfyNode):
return IO.Schema(
node_id="GrokVideoExtendNode",
display_name="Grok Video Extend",
category="api node/video/Grok",
category="video/partner/Grok",
description="Extend an existing video with a seamless continuation based on a text prompt.",
inputs=[
IO.String.Input(

View File

@ -71,7 +71,7 @@ class HitPawGeneralImageEnhance(IO.ComfyNode):
return IO.Schema(
node_id="HitPawGeneralImageEnhance",
display_name="HitPaw General Image Enhance",
category="api node/image/HitPaw",
category="image/partner/HitPaw",
description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. "
f"Maximum output: {MAX_MP_GENERATIVE} megapixels.",
inputs=[
@ -201,7 +201,7 @@ class HitPawVideoEnhance(IO.ComfyNode):
return IO.Schema(
node_id="HitPawVideoEnhance",
display_name="HitPaw Video Enhance",
category="api node/video/HitPaw",
category="video/partner/HitPaw",
description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. "
"Prices shown are per second of video.",
inputs=[

View File

@ -123,7 +123,7 @@ class TencentTextToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="TencentTextToModelNode",
display_name="Hunyuan3D: Text to Model",
category="api node/3d/Tencent",
category="3d/partner/Tencent",
essentials_category="3D",
inputs=[
IO.Combo.Input(
@ -242,7 +242,7 @@ class TencentImageToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="TencentImageToModelNode",
display_name="Hunyuan3D: Image(s) to Model",
category="api node/3d/Tencent",
category="3d/partner/Tencent",
essentials_category="3D",
inputs=[
IO.Combo.Input(
@ -415,7 +415,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
return IO.Schema(
node_id="TencentModelTo3DUVNode",
display_name="Hunyuan3D: Model to UV",
category="api node/3d/Tencent",
category="3d/partner/Tencent",
description="Perform UV unfolding on a 3D model to generate UV texture. "
"Input model must have less than 30000 faces.",
inputs=[
@ -505,7 +505,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
return IO.Schema(
node_id="Tencent3DTextureEditNode",
display_name="Hunyuan3D: 3D Texture Edit",
category="api node/3d/Tencent",
category="3d/partner/Tencent",
description="After inputting the 3D model, perform 3D model texture redrawing.",
inputs=[
IO.MultiType.Input(
@ -594,7 +594,7 @@ class Tencent3DPartNode(IO.ComfyNode):
return IO.Schema(
node_id="Tencent3DPartNode",
display_name="Hunyuan3D: 3D Part",
category="api node/3d/Tencent",
category="3d/partner/Tencent",
description="Automatically perform component identification and generation based on the model structure.",
inputs=[
IO.MultiType.Input(
@ -666,7 +666,7 @@ class TencentSmartTopologyNode(IO.ComfyNode):
return IO.Schema(
node_id="TencentSmartTopologyNode",
display_name="Hunyuan3D: Smart Topology",
category="api node/3d/Tencent",
category="3d/partner/Tencent",
description="Perform smart retopology on a 3D model. "
"Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.",
inputs=[

View File

@ -234,7 +234,7 @@ class IdeogramV1(IO.ComfyNode):
return IO.Schema(
node_id="IdeogramV1",
display_name="Ideogram V1",
category="api node/image/Ideogram",
category="image/partner/Ideogram",
description="Generates images using the Ideogram V1 model.",
inputs=[
IO.String.Input(
@ -360,7 +360,7 @@ class IdeogramV2(IO.ComfyNode):
return IO.Schema(
node_id="IdeogramV2",
display_name="Ideogram V2",
category="api node/image/Ideogram",
category="image/partner/Ideogram",
description="Generates images using the Ideogram V2 model.",
inputs=[
IO.String.Input(
@ -526,7 +526,7 @@ class IdeogramV3(IO.ComfyNode):
return IO.Schema(
node_id="IdeogramV3",
display_name="Ideogram V3",
category="api node/image/Ideogram",
category="image/partner/Ideogram",
description="Generates images using the Ideogram V3 model. "
"Supports both regular image generation from text prompts and image editing with mask.",
inputs=[

View File

@ -642,7 +642,7 @@ class KlingCameraControls(IO.ComfyNode):
return IO.Schema(
node_id="KlingCameraControls",
display_name="Kling Camera Controls",
category="api node/video/Kling",
category="video/partner/Kling",
description="Allows specifying configuration options for Kling Camera Controls and motion control effects.",
inputs=[
IO.Combo.Input("camera_control_type", options=KlingCameraControlType),
@ -762,7 +762,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingTextToVideoNode",
display_name="Kling Text to Video",
category="api node/video/Kling",
category="video/partner/Kling",
description="Kling Text to Video Node",
inputs=[
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -849,7 +849,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingOmniProTextToVideoNode",
display_name="Kling 3.0 Omni Text to Video",
category="api node/video/Kling",
category="video/partner/Kling",
description="Use text prompts to generate videos with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
@ -998,7 +998,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingOmniProFirstLastFrameNode",
display_name="Kling 3.0 Omni First-Last-Frame to Video",
category="api node/video/Kling",
category="video/partner/Kling",
description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
@ -1205,7 +1205,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingOmniProImageToVideoNode",
display_name="Kling 3.0 Omni Image to Video",
category="api node/video/Kling",
category="video/partner/Kling",
description="Use up to 7 reference images to generate a video with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
@ -1374,7 +1374,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingOmniProVideoToVideoNode",
display_name="Kling 3.0 Omni Video to Video",
category="api node/video/Kling",
category="video/partner/Kling",
description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
@ -1485,7 +1485,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingOmniProEditVideoNode",
display_name="Kling 3.0 Omni Edit Video",
category="api node/video/Kling",
category="video/partner/Kling",
essentials_category="Video Generation",
description="Edit an existing video with the latest model from Kling.",
inputs=[
@ -1593,7 +1593,7 @@ class OmniProImageNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingOmniProImageNode",
display_name="Kling 3.0 Omni Image",
category="api node/image/Kling",
category="image/partner/Kling",
description="Create or edit images with the latest model from Kling.",
inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-image-o1"]),
@ -1721,7 +1721,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingCameraControlT2VNode",
display_name="Kling Text to Video (Camera Control)",
category="api node/video/Kling",
category="video/partner/Kling",
description="Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text.",
inputs=[
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -1783,7 +1783,7 @@ class KlingImage2VideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingImage2VideoNode",
display_name="Kling Image(First Frame) to Video",
category="api node/video/Kling",
category="video/partner/Kling",
inputs=[
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -1882,7 +1882,7 @@ class KlingCameraControlI2VNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingCameraControlI2VNode",
display_name="Kling Image to Video (Camera Control)",
category="api node/video/Kling",
category="video/partner/Kling",
description="Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image.",
inputs=[
IO.Image.Input(
@ -1953,7 +1953,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingStartEndFrameNode",
display_name="Kling Start-End Frame to Video",
category="api node/video/Kling",
category="video/partner/Kling",
description="Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last.",
inputs=[
IO.Image.Input(
@ -2047,7 +2047,7 @@ class KlingVideoExtendNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingVideoExtendNode",
display_name="Kling Video Extend",
category="api node/video/Kling",
category="video/partner/Kling",
description="Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes.",
inputs=[
IO.String.Input(
@ -2128,7 +2128,7 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingDualCharacterVideoEffectNode",
display_name="Kling Dual Character Video Effects",
category="api node/video/Kling",
category="video/partner/Kling",
description="Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite.",
inputs=[
IO.Image.Input("image_left", tooltip="Left side image"),
@ -2218,7 +2218,7 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingSingleImageVideoEffectNode",
display_name="Kling Video Effects",
category="api node/video/Kling",
category="video/partner/Kling",
description="Achieve different special effects when generating a video based on the effect_scene.",
inputs=[
IO.Image.Input(
@ -2291,7 +2291,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingLipSyncAudioToVideoNode",
display_name="Kling Lip Sync Video with Audio",
category="api node/video/Kling",
category="video/partner/Kling",
essentials_category="Video Generation",
description="Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.",
inputs=[
@ -2343,7 +2343,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingLipSyncTextToVideoNode",
display_name="Kling Lip Sync Video with Text",
category="api node/video/Kling",
category="video/partner/Kling",
description="Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.",
inputs=[
IO.Video.Input("video"),
@ -2411,7 +2411,7 @@ class KlingVirtualTryOnNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingVirtualTryOnNode",
display_name="Kling Virtual Try On",
category="api node/image/Kling",
category="image/partner/Kling",
description="Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background.",
inputs=[
IO.Image.Input("human_image"),
@ -2478,7 +2478,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingImageGenerationNode",
display_name="Kling 3.0 Image",
category="api node/image/Kling",
category="image/partner/Kling",
description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.",
inputs=[
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -2615,7 +2615,7 @@ class TextToVideoWithAudio(IO.ComfyNode):
return IO.Schema(
node_id="KlingTextToVideoWithAudio",
display_name="Kling 2.6 Text to Video with Audio",
category="api node/video/Kling",
category="video/partner/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
@ -2683,7 +2683,7 @@ class ImageToVideoWithAudio(IO.ComfyNode):
return IO.Schema(
node_id="KlingImageToVideoWithAudio",
display_name="Kling 2.6 Image(First Frame) to Video with Audio",
category="api node/video/Kling",
category="video/partner/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.Image.Input("start_frame"),
@ -2753,7 +2753,7 @@ class MotionControl(IO.ComfyNode):
return IO.Schema(
node_id="KlingMotionControl",
display_name="Kling Motion Control",
category="api node/video/Kling",
category="video/partner/Kling",
inputs=[
IO.String.Input("prompt", multiline=True),
IO.Image.Input("reference_image"),
@ -2854,7 +2854,7 @@ class KlingVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingVideoNode",
display_name="Kling 3.0 Video",
category="api node/video/Kling",
category="video/partner/Kling",
description="Generate videos with Kling V3. "
"Supports text-to-video and image-to-video with optional storyboard multi-prompt and audio generation.",
inputs=[
@ -3077,7 +3077,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingFirstLastFrameNode",
display_name="Kling 3.0 First-Last-Frame to Video",
category="api node/video/Kling",
category="video/partner/Kling",
description="Generate videos with Kling V3 using first and last frames.",
inputs=[
IO.String.Input("prompt", multiline=True, default=""),
@ -3202,7 +3202,7 @@ class KlingAvatarNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingAvatarNode",
display_name="Kling Avatar 2.0",
category="api node/video/Kling",
category="video/partner/Kling",
description="Generate broadcast-style digital human videos from a single photo and an audio file.",
inputs=[
IO.Image.Input(

View File

@ -0,0 +1,290 @@
"""Krea image-generation nodes."""
import re
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.krea import (
KreaAssetResponse,
KreaGenerateImageRequest,
KreaImageStyleReference,
KreaJob,
KreaMoodboard,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
poll_op,
sync_op,
tensor_to_bytesio,
validate_string,
)
class KreaIO:
STYLE_REF = "KREA_STYLE_REF"
async def _upload_image_to_krea_assets(cls: type[IO.ComfyNode], image: Input.Image) -> str:
"""Upload an image to Krea's /assets endpoint and return the Krea-hosted image URL."""
img_io = tensor_to_bytesio(image, total_pixels=2048 * 2048, mime_type="image/png")
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/krea/assets", method="POST"),
response_model=KreaAssetResponse,
files=[("file", (img_io.name, img_io, "image/png"))],
content_type="multipart/form-data",
max_retries=1,
wait_label="Uploading reference",
)
return response.image_url
_MODEL_MEDIUM = "Krea 2 Medium"
_MODEL_LARGE = "Krea 2 Large"
_MODEL_ENDPOINTS: dict[str, str] = {
_MODEL_MEDIUM: "/proxy/krea/generate/image/krea/krea-2/medium",
_MODEL_LARGE: "/proxy/krea/generate/image/krea/krea-2/large",
}
_ASPECT_RATIOS = ["1:1", "4:3", "3:2", "16:9", "2.35:1", "4:5", "2:3", "9:16"]
_RESOLUTIONS = ["1K"]
_CREATIVITY_LEVELS = ["raw", "low", "medium", "high"]
_KREA_QUEUED_STATUSES = ["backlogged", "queued", "scheduled"]
_UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$")
def _krea_model_inputs() -> list:
"""Nested inputs shared by both Krea 2 Medium and Large under the DynamicCombo."""
return [
IO.Combo.Input(
"aspect_ratio",
options=_ASPECT_RATIOS,
tooltip="Output aspect ratio.",
),
IO.Combo.Input(
"resolution",
options=_RESOLUTIONS,
tooltip="Resolution scale.",
),
IO.Combo.Input(
"creativity",
options=_CREATIVITY_LEVELS,
default="medium",
tooltip="Prompt interpretation strength: raw stays closest to the prompt; high is most creative.",
),
IO.String.Input(
"moodboard_id",
default="",
tooltip="Optional Krea moodboard UUID (e.g. from the Krea website). "
"Leave empty to disable. Only one moodboard is supported per request.",
optional=True,
),
IO.Float.Input(
"moodboard_strength",
default=0.35,
min=-0.5,
max=1.5,
step=0.05,
tooltip="Moodboard influence; ignored when moodboard_id is empty.",
optional=True,
),
IO.Custom(KreaIO.STYLE_REF).Input(
"style_reference",
optional=True,
tooltip="Optional chain of style references (max 10) from Krea 2 Style Reference nodes.",
),
]
class Krea2ImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Krea2ImageNode",
display_name="Krea 2 Image",
category="image/partner/Krea",
description=(
"Generate images via Krea 2 — pick Medium (expressive illustrations) or "
"Large (expressive photorealism). Supports an optional moodboard and up "
"to 10 chained image style references."
),
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the image.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(_MODEL_MEDIUM, _krea_model_inputs()),
IO.DynamicCombo.Option(_MODEL_LARGE, _krea_model_inputs()),
],
tooltip="Krea 2 Medium is best for expressive illustrations; "
"Krea 2 Large is best for expressive photorealism.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Random seed for reproducibility.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model", "model.moodboard_id"],
inputs=["model.style_reference"],
),
expr="""
(
$isLarge := widgets.model = "krea 2 large";
$hasMoodboard := $length($lookup(widgets, "model.moodboard_id")) > 0;
$hasStyle := $lookup(inputs, "model.style_reference").connected;
$usd := $hasMoodboard
? ($isLarge ? 0.07 : 0.04)
: ($hasStyle
? ($isLarge ? 0.065 : 0.035)
: ($isLarge ? 0.06 : 0.03));
{"type":"usd","usd": $usd}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1)
model_choice = model["model"]
endpoint_path = _MODEL_ENDPOINTS.get(model_choice)
if endpoint_path is None:
raise ValueError(f"Unknown Krea 2 model: {model_choice!r}")
moodboards: list[KreaMoodboard] | None = None
mb_id = (model.get("moodboard_id") or "").strip()
if mb_id:
if not _UUID_RE.match(mb_id):
raise ValueError(f"moodboard_id must be a UUID (received {mb_id!r}); copy it from the Krea website.")
mb_strength = model.get("moodboard_strength")
moodboards = [KreaMoodboard(id=mb_id, strength=0.35 if mb_strength is None else float(mb_strength))]
style_reference = model.get("style_reference")
image_style_references: list[KreaImageStyleReference] | None = None
if style_reference:
if len(style_reference) > 10:
raise ValueError(f"Krea 2 accepts at most 10 image_style_references; received {len(style_reference)}.")
image_style_references = [
KreaImageStyleReference(url=ref["url"], strength=float(ref["strength"])) for ref in style_reference
]
initial = await sync_op(
cls,
ApiEndpoint(path=endpoint_path, method="POST"),
response_model=KreaJob,
data=KreaGenerateImageRequest(
prompt=prompt,
aspect_ratio=model["aspect_ratio"],
resolution=model["resolution"],
seed=seed,
creativity=model["creativity"],
moodboards=moodboards,
image_style_references=image_style_references,
),
)
job = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/krea/jobs/{initial.job_id}", method="GET"),
response_model=KreaJob,
status_extractor=lambda r: r.status,
queued_statuses=_KREA_QUEUED_STATUSES,
)
if not job.result or not job.result.urls:
raise RuntimeError(f"Krea 2 job {job.job_id} completed without any image URLs.")
image = await download_url_to_image_tensor(job.result.urls[0])
return IO.NodeOutput(image)
class Krea2StyleReferenceNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Krea2StyleReferenceNode",
display_name="Krea 2 Style Reference",
category="image/partner/Krea",
description=(
"Add an image style reference to a Krea 2 generation. Chain multiple Krea 2 "
"Style Reference nodes (max 10) and feed the final `style_reference` output "
"into Krea 2 Image. Each image is uploaded to ComfyAPI storage and passed as URL."
),
inputs=[
IO.Image.Input(
"image",
tooltip="Reference image whose style influences the generation.",
),
IO.Float.Input(
"strength",
default=1.0,
min=-2.0,
max=2.0,
step=0.05,
tooltip="Reference strength; negative values invert the style influence.",
),
IO.Custom(KreaIO.STYLE_REF).Input(
"style_reference",
optional=True,
tooltip="Optional incoming chain of style references; this node appends one more.",
),
],
outputs=[IO.Custom(KreaIO.STYLE_REF).Output(display_name="style_reference")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
)
@classmethod
async def execute(
cls,
image: Input.Image,
strength: float,
style_reference: list[dict] | None = None,
) -> IO.NodeOutput:
chain: list[dict] = list(style_reference) if style_reference else []
if len(chain) >= 10:
raise ValueError("Krea 2 accepts at most 10 image_style_references in one generation.")
url = await _upload_image_to_krea_assets(cls, image)
chain.append({"url": url, "strength": float(strength)})
return IO.NodeOutput(chain)
class KreaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
Krea2ImageNode,
Krea2StyleReferenceNode,
]
async def comfy_entrypoint() -> KreaExtension:
return KreaExtension()

View File

@ -50,7 +50,7 @@ class TextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="LtxvApiTextToVideo",
display_name="LTXV Text To Video",
category="api node/video/LTXV",
category="video/partner/LTXV",
description="Professional-quality videos with customizable duration and resolution.",
inputs=[
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
@ -127,7 +127,7 @@ class ImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="LtxvApiImageToVideo",
display_name="LTXV Image To Video",
category="api node/video/LTXV",
category="video/partner/LTXV",
description="Professional-quality videos with customizable duration and resolution based on start image.",
inputs=[
IO.Image.Input("image", tooltip="First frame to be used for the video."),

View File

@ -46,7 +46,7 @@ class LumaReferenceNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaReferenceNode",
display_name="Luma Reference",
category="api node/image/Luma",
category="image/partner/Luma",
description="Holds an image and weight for use with Luma Generate Image node.",
inputs=[
IO.Image.Input(
@ -85,7 +85,7 @@ class LumaConceptsNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaConceptsNode",
display_name="Luma Concepts",
category="api node/video/Luma",
category="video/partner/Luma",
description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.",
inputs=[
IO.Combo.Input(
@ -134,7 +134,7 @@ class LumaImageGenerationNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaImageNode",
display_name="Luma Text to Image",
category="api node/image/Luma",
category="image/partner/Luma",
description="Generates images synchronously based on prompt and aspect ratio.",
inputs=[
IO.String.Input(
@ -278,7 +278,7 @@ class LumaImageModifyNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaImageModifyNode",
display_name="Luma Image to Image",
category="api node/image/Luma",
category="image/partner/Luma",
description="Modifies images synchronously based on prompt and aspect ratio.",
inputs=[
IO.Image.Input(
@ -371,7 +371,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaVideoNode",
display_name="Luma Text to Video",
category="api node/video/Luma",
category="video/partner/Luma",
description="Generates videos synchronously based on prompt and output_size.",
inputs=[
IO.String.Input(
@ -472,7 +472,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaImageToVideoNode",
display_name="Luma Image to Video",
category="api node/video/Luma",
category="video/partner/Luma",
description="Generates videos synchronously based on prompt, input images, and output_size.",
inputs=[
IO.String.Input(
@ -724,7 +724,7 @@ class LumaImageNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaImageNode2",
display_name="Luma UNI-1 Image",
category="api node/image/Luma",
category="image/partner/Luma",
description="Generate images from text using the Luma UNI-1 model.",
inputs=[
IO.String.Input(
@ -853,7 +853,7 @@ class LumaImageEditNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaImageEditNode2",
display_name="Luma UNI-1 Image Edit",
category="api node/image/Luma",
category="image/partner/Luma",
description="Edit an existing image with a text prompt using the Luma UNI-1 model.",
inputs=[
IO.Image.Input(

View File

@ -61,7 +61,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
return IO.Schema(
node_id="MagnificImageUpscalerCreativeNode",
display_name="Magnific Image Upscale (Creative)",
category="api node/image/Magnific",
category="image/partner/Magnific",
description="Promptguided enhancement, stylization, and 2x/4x/8x/16x upscaling. "
"Maximum output: 25.3 megapixels.",
inputs=[
@ -240,7 +240,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
return IO.Schema(
node_id="MagnificImageUpscalerPreciseV2Node",
display_name="Magnific Image Upscale (Precise V2)",
category="api node/image/Magnific",
category="image/partner/Magnific",
description="High-fidelity upscaling with fine control over sharpness, grain, and detail. "
"Maximum output: 10060×10060 pixels.",
inputs=[
@ -400,7 +400,7 @@ class MagnificImageStyleTransferNode(IO.ComfyNode):
return IO.Schema(
node_id="MagnificImageStyleTransferNode",
display_name="Magnific Image Style Transfer",
category="api node/image/Magnific",
category="image/partner/Magnific",
description="Transfer the style from a reference image to your input image.",
inputs=[
IO.Image.Input("image", tooltip="The image to apply style transfer to."),
@ -549,7 +549,7 @@ class MagnificImageRelightNode(IO.ComfyNode):
return IO.Schema(
node_id="MagnificImageRelightNode",
display_name="Magnific Image Relight",
category="api node/image/Magnific",
category="image/partner/Magnific",
description="Relight an image with lighting adjustments and optional reference-based light transfer.",
inputs=[
IO.Image.Input("image", tooltip="The image to relight."),
@ -789,7 +789,7 @@ class MagnificImageSkinEnhancerNode(IO.ComfyNode):
return IO.Schema(
node_id="MagnificImageSkinEnhancerNode",
display_name="Magnific Image Skin Enhancer",
category="api node/image/Magnific",
category="image/partner/Magnific",
description="Skin enhancement for portraits with multiple processing modes.",
inputs=[
IO.Image.Input("image", tooltip="The portrait image to enhance."),

View File

@ -33,7 +33,7 @@ class MeshyTextToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyTextToModelNode",
display_name="Meshy: Text to Model",
category="api node/3d/Meshy",
category="3d/partner/Meshy",
inputs=[
IO.Combo.Input("model", options=["latest"]),
IO.String.Input("prompt", multiline=True, default=""),
@ -145,7 +145,7 @@ class MeshyRefineNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyRefineNode",
display_name="Meshy: Refine Draft Model",
category="api node/3d/Meshy",
category="3d/partner/Meshy",
description="Refine a previously created draft model.",
inputs=[
IO.Combo.Input("model", options=["latest"]),
@ -240,7 +240,7 @@ class MeshyImageToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyImageToModelNode",
display_name="Meshy: Image to Model",
category="api node/3d/Meshy",
category="3d/partner/Meshy",
inputs=[
IO.Combo.Input("model", options=["latest"]),
IO.Image.Input("image"),
@ -405,7 +405,7 @@ class MeshyMultiImageToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyMultiImageToModelNode",
display_name="Meshy: Multi-Image to Model",
category="api node/3d/Meshy",
category="3d/partner/Meshy",
inputs=[
IO.Combo.Input("model", options=["latest"]),
IO.Autogrow.Input(
@ -575,7 +575,7 @@ class MeshyRigModelNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyRigModelNode",
display_name="Meshy: Rig Model",
category="api node/3d/Meshy",
category="3d/partner/Meshy",
description="Provides a rigged character in standard formats. "
"Auto-rigging is currently not suitable for untextured meshes, non-humanoid assets, "
"or humanoid assets with unclear limb and body structure.",
@ -656,7 +656,7 @@ class MeshyAnimateModelNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyAnimateModelNode",
display_name="Meshy: Animate Model",
category="api node/3d/Meshy",
category="3d/partner/Meshy",
description="Apply a specific animation action to a previously rigged character.",
inputs=[
IO.Custom("MESHY_RIGGED_TASK_ID").Input("rig_task_id"),
@ -722,7 +722,7 @@ class MeshyTextureNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyTextureNode",
display_name="Meshy: Texture Model",
category="api node/3d/Meshy",
category="3d/partner/Meshy",
inputs=[
IO.Combo.Input("model", options=["latest"]),
IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"),

View File

@ -101,7 +101,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="MinimaxTextToVideoNode",
display_name="MiniMax Text to Video",
category="api node/video/MiniMax",
category="video/partner/MiniMax",
description="Generates videos synchronously based on a prompt, and optional parameters.",
inputs=[
IO.String.Input(
@ -163,7 +163,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="MinimaxImageToVideoNode",
display_name="MiniMax Image to Video",
category="api node/video/MiniMax",
category="video/partner/MiniMax",
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[
IO.Image.Input(
@ -230,7 +230,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="MinimaxSubjectToVideoNode",
display_name="MiniMax Subject to Video",
category="api node/video/MiniMax",
category="video/partner/MiniMax",
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[
IO.Image.Input(
@ -294,7 +294,7 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="MinimaxHailuoVideoNode",
display_name="MiniMax Hailuo Video",
category="api node/video/MiniMax",
category="video/partner/MiniMax",
description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
inputs=[
IO.String.Input(

View File

@ -99,7 +99,7 @@ class OpenAIDalle2(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIDalle2",
display_name="OpenAI DALL·E 2",
category="api node/image/OpenAI",
category="image/partner/OpenAI",
description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.",
inputs=[
IO.String.Input(
@ -249,7 +249,7 @@ class OpenAIDalle3(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIDalle3",
display_name="OpenAI DALL·E 3",
category="api node/image/OpenAI",
category="image/partner/OpenAI",
description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.",
inputs=[
IO.String.Input(
@ -371,7 +371,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIGPTImage1",
display_name="OpenAI GPT Image 2",
category="api node/image/OpenAI",
category="image/partner/OpenAI",
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
is_deprecated=True,
inputs=[
@ -695,7 +695,7 @@ class OpenAIGPTImageNodeV2(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIGPTImageNodeV2",
display_name="OpenAI GPT Image 2",
category="api node/image/OpenAI",
category="image/partner/OpenAI",
description="Generates images via OpenAI's GPT Image endpoint.",
inputs=[
IO.String.Input(
@ -962,7 +962,7 @@ class OpenAIChatNode(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIChatNode",
display_name="OpenAI ChatGPT",
category="api node/text/OpenAI",
category="text/partner/OpenAI",
essentials_category="Text Generation",
description="Generate text responses from an OpenAI model.",
inputs=[
@ -1201,7 +1201,7 @@ class OpenAIInputFiles(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIInputFiles",
display_name="OpenAI ChatGPT Input Files",
category="api node/text/OpenAI",
category="text/partner/OpenAI",
description="Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes.",
inputs=[
IO.Combo.Input(
@ -1248,7 +1248,7 @@ class OpenAIChatConfig(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIChatConfig",
display_name="OpenAI ChatGPT Advanced Options",
category="api node/text/OpenAI",
category="text/partner/OpenAI",
description="Allows specifying advanced configuration options for the OpenAI Chat Nodes.",
inputs=[
IO.Combo.Input(

View File

@ -265,7 +265,7 @@ class OpenRouterLLMNode(IO.ComfyNode):
return IO.Schema(
node_id="OpenRouterLLMNode",
display_name="OpenRouter LLM",
category="api node/text/OpenRouter",
category="text/partner/OpenRouter",
essentials_category="Text Generation",
description=(
"Generate text responses through OpenRouter. Routes to a curated set of popular "

View File

@ -53,7 +53,7 @@ class PixverseTemplateNode(IO.ComfyNode):
return IO.Schema(
node_id="PixverseTemplateNode",
display_name="PixVerse Template",
category="api node/video/PixVerse",
category="video/partner/PixVerse",
inputs=[
IO.Combo.Input("template", options=list(pixverse_templates.keys())),
],
@ -74,7 +74,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="PixverseTextToVideoNode",
display_name="PixVerse Text to Video",
category="api node/video/PixVerse",
category="video/partner/PixVerse",
description="Generates videos based on prompt and output_size.",
inputs=[
IO.String.Input(
@ -192,7 +192,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="PixverseImageToVideoNode",
display_name="PixVerse Image to Video",
category="api node/video/PixVerse",
category="video/partner/PixVerse",
description="Generates videos based on prompt and output_size.",
inputs=[
IO.Image.Input("image"),
@ -310,7 +310,7 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="PixverseTransitionVideoNode",
display_name="PixVerse Transition Video",
category="api node/video/PixVerse",
category="video/partner/PixVerse",
description="Generates videos based on prompt and output_size.",
inputs=[
IO.Image.Input("first_frame"),

View File

@ -62,7 +62,7 @@ class QuiverTextToSVGNode(IO.ComfyNode):
return IO.Schema(
node_id="QuiverTextToSVGNode",
display_name="Quiver Text to SVG",
category="api node/image/Quiver",
category="image/partner/Quiver",
description="Generate an SVG from a text prompt using Quiver AI.",
inputs=[
IO.String.Input(
@ -177,7 +177,7 @@ class QuiverImageToSVGNode(IO.ComfyNode):
return IO.Schema(
node_id="QuiverImageToSVGNode",
display_name="Quiver Image to SVG",
category="api node/image/Quiver",
category="image/partner/Quiver",
description="Vectorize a raster image into SVG using Quiver AI.",
inputs=[
IO.Image.Input(

View File

@ -178,7 +178,7 @@ class RecraftColorRGBNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftColorRGB",
display_name="Recraft Color RGB",
category="api node/image/Recraft",
category="image/partner/Recraft",
description="Create Recraft Color by choosing specific RGB values.",
inputs=[
IO.Int.Input("r", default=0, min=0, max=255, tooltip="Red value of color."),
@ -204,7 +204,7 @@ class RecraftControlsNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftControls",
display_name="Recraft Controls",
category="api node/image/Recraft",
category="image/partner/Recraft",
description="Create Recraft Controls for customizing Recraft generation.",
inputs=[
IO.Custom(RecraftIO.COLOR).Input("colors", optional=True),
@ -228,7 +228,7 @@ class RecraftStyleV3RealisticImageNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftStyleV3RealisticImage",
display_name="Recraft Style - Realistic Image",
category="api node/image/Recraft",
category="image/partner/Recraft",
description="Select realistic_image style and optional substyle.",
inputs=[
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
@ -253,7 +253,7 @@ class RecraftStyleV3DigitalIllustrationNode(RecraftStyleV3RealisticImageNode):
return IO.Schema(
node_id="RecraftStyleV3DigitalIllustration",
display_name="Recraft Style - Digital Illustration",
category="api node/image/Recraft",
category="image/partner/Recraft",
description="Select realistic_image style and optional substyle.",
inputs=[
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
@ -272,7 +272,7 @@ class RecraftStyleV3VectorIllustrationNode(RecraftStyleV3RealisticImageNode):
return IO.Schema(
node_id="RecraftStyleV3VectorIllustrationNode",
display_name="Recraft Style - Realistic Image",
category="api node/image/Recraft",
category="image/partner/Recraft",
description="Select realistic_image style and optional substyle.",
inputs=[
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
@ -291,7 +291,7 @@ class RecraftStyleV3LogoRasterNode(RecraftStyleV3RealisticImageNode):
return IO.Schema(
node_id="RecraftStyleV3LogoRaster",
display_name="Recraft Style - Logo Raster",
category="api node/image/Recraft",
category="image/partner/Recraft",
description="Select realistic_image style and optional substyle.",
inputs=[
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE, include_none=False)),
@ -308,7 +308,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
return IO.Schema(
node_id="RecraftStyleV3InfiniteStyleLibrary",
display_name="Recraft Style - Infinite Style Library",
category="api node/image/Recraft",
category="image/partner/Recraft",
description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.",
inputs=[
IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."),
@ -331,7 +331,7 @@ class RecraftCreateStyleNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftCreateStyleNode",
display_name="Recraft Create Style",
category="api node/image/Recraft",
category="image/partner/Recraft",
description="Create a custom style from reference images. "
"Upload 1-5 images to use as style references. "
"Total size of all images is limited to 5 MB.",
@ -400,7 +400,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftTextToImageNode",
display_name="Recraft Text to Image",
category="api node/image/Recraft",
category="image/partner/Recraft",
description="Generates images synchronously based on prompt and resolution.",
inputs=[
IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."),
@ -512,7 +512,7 @@ class RecraftImageToImageNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftImageToImageNode",
display_name="Recraft Image to Image",
category="api node/image/Recraft",
category="image/partner/Recraft",
description="Modify image based on prompt and strength.",
inputs=[
IO.Image.Input("image"),
@ -630,7 +630,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftImageInpaintingNode",
display_name="Recraft Image Inpainting",
category="api node/image/Recraft",
category="image/partner/Recraft",
description="Modify image based on prompt and mask.",
inputs=[
IO.Image.Input("image"),
@ -732,7 +732,7 @@ class RecraftTextToVectorNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftTextToVectorNode",
display_name="Recraft Text to Vector",
category="api node/image/Recraft",
category="image/partner/Recraft",
description="Generates SVG synchronously based on prompt and resolution.",
inputs=[
IO.String.Input("prompt", default="", tooltip="Prompt for the image generation.", multiline=True),
@ -832,7 +832,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftVectorizeImageNode",
display_name="Recraft Vectorize Image",
category="api node/image/Recraft",
category="image/partner/Recraft",
essentials_category="Image Tools",
description="Generates SVG synchronously from an input image.",
inputs=[
@ -876,7 +876,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftReplaceBackgroundNode",
display_name="Recraft Replace Background",
category="api node/image/Recraft",
category="image/partner/Recraft",
description="Replace background on image, based on provided prompt.",
inputs=[
IO.Image.Input("image"),
@ -963,7 +963,7 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftRemoveBackgroundNode",
display_name="Recraft Remove Background",
category="api node/image/Recraft",
category="image/partner/Recraft",
essentials_category="Image Tools",
description="Remove background from image, and return processed image and mask.",
inputs=[
@ -1012,7 +1012,7 @@ class RecraftCrispUpscaleNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftCrispUpscaleNode",
display_name="Recraft Crisp Upscale Image",
category="api node/image/Recraft",
category="image/partner/Recraft",
description="Upscale image synchronously.\n"
"Enhances a given raster image using crisp upscale tool, "
"increasing image resolution, making the image sharper and cleaner.",
@ -1058,7 +1058,7 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
return IO.Schema(
node_id="RecraftCreativeUpscaleNode",
display_name="Recraft Creative Upscale Image",
category="api node/image/Recraft",
category="image/partner/Recraft",
description="Upscale image synchronously.\n"
"Enhances a given raster image using creative upscale tool, "
"boosting resolution with a focus on refining small details and faces.",
@ -1086,7 +1086,7 @@ class RecraftV4TextToImageNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftV4TextToImageNode",
display_name="Recraft V4 Text to Image",
category="api node/image/Recraft",
category="image/partner/Recraft",
description="Generates images using Recraft V4 or V4 Pro models.",
inputs=[
IO.String.Input(
@ -1210,7 +1210,7 @@ class RecraftV4TextToVectorNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftV4TextToVectorNode",
display_name="Recraft V4 Text to Vector",
category="api node/image/Recraft",
category="image/partner/Recraft",
description="Generates SVG using Recraft V4 or V4 Pro models.",
inputs=[
IO.String.Input(

View File

@ -109,7 +109,7 @@ class ReveImageCreateNode(IO.ComfyNode):
return IO.Schema(
node_id="ReveImageCreateNode",
display_name="Reve Image Create",
category="api node/image/Reve",
category="image/partner/Reve",
description="Generate images from text descriptions using Reve.",
inputs=[
IO.String.Input(
@ -200,7 +200,7 @@ class ReveImageEditNode(IO.ComfyNode):
return IO.Schema(
node_id="ReveImageEditNode",
display_name="Reve Image Edit",
category="api node/image/Reve",
category="image/partner/Reve",
description="Edit images using natural language instructions with Reve.",
inputs=[
IO.Image.Input("image", tooltip="The image to edit."),
@ -300,7 +300,7 @@ class ReveImageRemixNode(IO.ComfyNode):
return IO.Schema(
node_id="ReveImageRemixNode",
display_name="Reve Image Remix",
category="api node/image/Reve",
category="image/partner/Reve",
description="Combine reference images with text prompts to create new images using Reve.",
inputs=[
IO.Autogrow.Input(

View File

@ -230,7 +230,7 @@ class Rodin3D_Regular(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Regular",
display_name="Rodin 3D Generate - Regular Generate",
category="api node/3d/Rodin",
category="3d/partner/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
@ -289,7 +289,7 @@ class Rodin3D_Detail(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Detail",
display_name="Rodin 3D Generate - Detail Generate",
category="api node/3d/Rodin",
category="3d/partner/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
@ -348,7 +348,7 @@ class Rodin3D_Smooth(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Smooth",
display_name="Rodin 3D Generate - Smooth Generate",
category="api node/3d/Rodin",
category="3d/partner/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
@ -406,7 +406,7 @@ class Rodin3D_Sketch(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Sketch",
display_name="Rodin 3D Generate - Sketch Generate",
category="api node/3d/Rodin",
category="3d/partner/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
@ -468,7 +468,7 @@ class Rodin3D_Gen2(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Gen2",
display_name="Rodin 3D Generate - Gen-2 Generate",
category="api node/3d/Rodin",
category="3d/partner/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
@ -941,7 +941,7 @@ class Rodin3D_Gen25_Image(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Gen25_Image",
display_name="Rodin 3D Gen-2.5 - Image to 3D",
category="api node/3d/Rodin",
category="3d/partner/Rodin",
description=(
"Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. "
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
@ -1035,7 +1035,7 @@ class Rodin3D_Gen25_Text(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Gen25_Text",
display_name="Rodin 3D Gen-2.5 - Text to 3D",
category="api node/3d/Rodin",
category="3d/partner/Rodin",
description=(
"Generate a 3D model from a text prompt via Rodin Gen-2.5. "
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."

View File

@ -140,7 +140,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
return IO.Schema(
node_id="RunwayImageToVideoNodeGen3a",
display_name="Runway Image to Video (Gen3a Turbo)",
category="api node/video/Runway",
category="video/partner/Runway",
description="Generate a video from a single starting frame using Gen3a Turbo model. "
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
@ -234,7 +234,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
return IO.Schema(
node_id="RunwayImageToVideoNodeGen4",
display_name="Runway Image to Video (Gen4 Turbo)",
category="api node/video/Runway",
category="video/partner/Runway",
description="Generate a video from a single starting frame using Gen4 Turbo model. "
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
@ -329,7 +329,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="RunwayFirstLastFrameNode",
display_name="Runway First-Last-Frame to Video",
category="api node/video/Runway",
category="video/partner/Runway",
description="Upload first and last keyframes, draft a prompt, and generate a video. "
"More complex transitions, such as cases where the Last frame is completely different "
"from the First frame, may benefit from the longer 10s duration. "
@ -440,7 +440,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
return IO.Schema(
node_id="RunwayTextToImageNode",
display_name="Runway Text to Image",
category="api node/image/Runway",
category="image/partner/Runway",
description="Generate an image from a text prompt using Runway's Gen 4 model. "
"You can also include reference image to guide the generation.",
inputs=[

View File

@ -34,7 +34,7 @@ class SoniloVideoToMusic(IO.ComfyNode):
return IO.Schema(
node_id="SoniloVideoToMusic",
display_name="Sonilo Video to Music",
category="api node/audio/Sonilo",
category="audio/partner/Sonilo",
description="Generate music from video content using Sonilo's AI model. "
"Analyzes the video and creates matching music.",
inputs=[
@ -99,7 +99,7 @@ class SoniloTextToMusic(IO.ComfyNode):
return IO.Schema(
node_id="SoniloTextToMusic",
display_name="Sonilo Text to Music",
category="api node/audio/Sonilo",
category="audio/partner/Sonilo",
description="Generate music from a text prompt using Sonilo's AI model. "
"Leave duration at 0 to let the model infer it from the prompt.",
inputs=[

View File

@ -34,7 +34,7 @@ class OpenAIVideoSora2(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIVideoSora2",
display_name="OpenAI Sora - Video (DEPRECATED)",
category="api node/video/Sora",
category="video/partner/Sora",
description=(
"OpenAI video and audio generation.\n\n"
"DEPRECATION NOTICE: OpenAI will stop serving the Sora v2 API in September 2026. "

View File

@ -62,7 +62,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
return IO.Schema(
node_id="StabilityStableImageUltraNode",
display_name="Stability AI Stable Image Ultra",
category="api node/image/Stability AI",
category="image/partner/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
@ -197,7 +197,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
return IO.Schema(
node_id="StabilityStableImageSD_3_5Node",
display_name="Stability AI Stable Diffusion 3.5 Image",
category="api node/image/Stability AI",
category="image/partner/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
@ -354,7 +354,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
return IO.Schema(
node_id="StabilityUpscaleConservativeNode",
display_name="Stability AI Upscale Conservative",
category="api node/image/Stability AI",
category="image/partner/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
@ -457,7 +457,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
return IO.Schema(
node_id="StabilityUpscaleCreativeNode",
display_name="Stability AI Upscale Creative",
category="api node/image/Stability AI",
category="image/partner/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
@ -578,7 +578,7 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
return IO.Schema(
node_id="StabilityUpscaleFastNode",
display_name="Stability AI Upscale Fast",
category="api node/image/Stability AI",
category="image/partner/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
@ -630,7 +630,7 @@ class StabilityTextToAudio(IO.ComfyNode):
return IO.Schema(
node_id="StabilityTextToAudio",
display_name="Stability AI Text To Audio",
category="api node/audio/Stability AI",
category="audio/partner/Stability AI",
essentials_category="Audio",
description=cleandoc(cls.__doc__ or ""),
inputs=[
@ -708,7 +708,7 @@ class StabilityAudioToAudio(IO.ComfyNode):
return IO.Schema(
node_id="StabilityAudioToAudio",
display_name="Stability AI Audio To Audio",
category="api node/audio/Stability AI",
category="audio/partner/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Combo.Input(
@ -802,7 +802,7 @@ class StabilityAudioInpaint(IO.ComfyNode):
return IO.Schema(
node_id="StabilityAudioInpaint",
display_name="Stability AI Audio Inpaint",
category="api node/audio/Stability AI",
category="audio/partner/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Combo.Input(

View File

@ -52,7 +52,7 @@ class TopazImageEnhance(IO.ComfyNode):
return IO.Schema(
node_id="TopazImageEnhance",
display_name="Topaz Image Enhance",
category="api node/image/Topaz",
category="image/partner/Topaz",
description="Industry-standard upscaling and image enhancement.",
inputs=[
IO.Combo.Input("model", options=["Reimagine"]),
@ -235,7 +235,7 @@ class TopazVideoEnhance(IO.ComfyNode):
return IO.Schema(
node_id="TopazVideoEnhance",
display_name="Topaz Video Enhance (Legacy)",
category="api node/video/Topaz",
category="video/partner/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[
IO.Video.Input("video"),
@ -475,7 +475,7 @@ class TopazVideoEnhanceV2(IO.ComfyNode):
return IO.Schema(
node_id="TopazVideoEnhanceV2",
display_name="Topaz Video Enhance",
category="api node/video/Topaz",
category="video/partner/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[
IO.Video.Input("video"),

View File

@ -80,7 +80,7 @@ class TripoTextToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoTextToModelNode",
display_name="Tripo: Text to Model",
category="api node/3d/Tripo",
category="3d/partner/Tripo",
inputs=[
IO.String.Input("prompt", multiline=True),
IO.String.Input("negative_prompt", multiline=True, optional=True),
@ -195,7 +195,7 @@ class TripoImageToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoImageToModelNode",
display_name="Tripo: Image to Model",
category="api node/3d/Tripo",
category="3d/partner/Tripo",
inputs=[
IO.Image.Input("image"),
IO.Combo.Input(
@ -323,7 +323,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoMultiviewToModelNode",
display_name="Tripo: Multiview to Model",
category="api node/3d/Tripo",
category="3d/partner/Tripo",
inputs=[
IO.Image.Input("image"),
IO.Image.Input("image_left", optional=True),
@ -461,7 +461,7 @@ class TripoTextureNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoTextureNode",
display_name="Tripo: Texture model",
category="api node/3d/Tripo",
category="3d/partner/Tripo",
inputs=[
IO.Custom("MODEL_TASK_ID").Input("model_task_id"),
IO.Boolean.Input("texture", default=True, optional=True),
@ -528,7 +528,7 @@ class TripoRefineNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoRefineNode",
display_name="Tripo: Refine Draft model",
category="api node/3d/Tripo",
category="3d/partner/Tripo",
description="Refine a draft model created by v1.4 Tripo models only.",
inputs=[
IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"),
@ -568,7 +568,7 @@ class TripoRigNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoRigNode",
display_name="Tripo: Rig model",
category="api node/3d/Tripo",
category="3d/partner/Tripo",
inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")],
outputs=[
IO.String.Output(display_name="model_file"), # for backward compatibility only
@ -605,7 +605,7 @@ class TripoRetargetNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoRetargetNode",
display_name="Tripo: Retarget rigged model",
category="api node/3d/Tripo",
category="3d/partner/Tripo",
inputs=[
IO.Custom("RIG_TASK_ID").Input("original_model_task_id"),
IO.Combo.Input(
@ -670,7 +670,7 @@ class TripoConversionNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoConversionNode",
display_name="Tripo: Convert model",
category="api node/3d/Tripo",
category="3d/partner/Tripo",
inputs=[
IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"),
IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]),

View File

@ -45,7 +45,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
return IO.Schema(
node_id="VeoVideoGenerationNode",
display_name="Google Veo 2 Video Generation",
category="api node/video/Veo",
category="video/partner/Veo",
description="Generates videos from text prompts using Google's Veo 2 API",
inputs=[
IO.String.Input(
@ -256,7 +256,7 @@ class Veo3VideoGenerationNode(IO.ComfyNode):
return IO.Schema(
node_id="Veo3VideoGenerationNode",
display_name="Google Veo 3 Video Generation",
category="api node/video/Veo",
category="video/partner/Veo",
description="Generates videos from text prompts using Google's Veo 3 API",
inputs=[
IO.String.Input(
@ -468,7 +468,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="Veo3FirstLastFrameNode",
display_name="Google Veo 3 First-Last-Frame to Video",
category="api node/video/Veo",
category="video/partner/Veo",
description="Generate video using prompt and first and last frames.",
inputs=[
IO.String.Input(

View File

@ -71,7 +71,7 @@ class ViduTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ViduTextToVideoNode",
display_name="Vidu Text To Video Generation",
category="api node/video/Vidu",
category="video/partner/Vidu",
description="Generate video from a text prompt",
inputs=[
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
@ -169,7 +169,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ViduImageToVideoNode",
display_name="Vidu Image To Video Generation",
category="api node/video/Vidu",
category="video/partner/Vidu",
description="Generate video from image and optional prompt",
inputs=[
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
@ -273,7 +273,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ViduReferenceVideoNode",
display_name="Vidu Reference To Video Generation",
category="api node/video/Vidu",
category="video/partner/Vidu",
description="Generate video from multiple images and a prompt",
inputs=[
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
@ -388,7 +388,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ViduStartEndToVideoNode",
display_name="Vidu Start End To Video Generation",
category="api node/video/Vidu",
category="video/partner/Vidu",
description="Generate a video from start and end frames and a prompt",
inputs=[
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
@ -492,7 +492,7 @@ class Vidu2TextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu2TextToVideoNode",
display_name="Vidu2 Text-to-Video Generation",
category="api node/video/Vidu",
category="video/partner/Vidu",
description="Generate video from a text prompt",
inputs=[
IO.Combo.Input("model", options=["viduq2"]),
@ -584,7 +584,7 @@ class Vidu2ImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu2ImageToVideoNode",
display_name="Vidu2 Image-to-Video Generation",
category="api node/video/Vidu",
category="video/partner/Vidu",
description="Generate a video from an image and an optional prompt.",
inputs=[
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
@ -714,7 +714,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu2ReferenceVideoNode",
display_name="Vidu2 Reference-to-Video Generation",
category="api node/video/Vidu",
category="video/partner/Vidu",
description="Generate a video from multiple reference images and a prompt.",
inputs=[
IO.Combo.Input("model", options=["viduq2"]),
@ -849,7 +849,7 @@ class Vidu2StartEndToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu2StartEndToVideoNode",
display_name="Vidu2 Start/End Frame-to-Video Generation",
category="api node/video/Vidu",
category="video/partner/Vidu",
description="Generate a video from a start frame, an end frame, and a prompt.",
inputs=[
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
@ -969,7 +969,7 @@ class ViduExtendVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ViduExtendVideoNode",
display_name="Vidu Video Extension",
category="api node/video/Vidu",
category="video/partner/Vidu",
description="Extend an existing video by generating additional frames.",
inputs=[
IO.DynamicCombo.Input(
@ -1138,7 +1138,7 @@ class ViduMultiFrameVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ViduMultiFrameVideoNode",
display_name="Vidu Multi-Frame Video Generation",
category="api node/video/Vidu",
category="video/partner/Vidu",
description="Generate a video with multiple keyframe transitions.",
inputs=[
IO.Combo.Input("model", options=["viduq2-pro", "viduq2-turbo"]),
@ -1284,7 +1284,7 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu3TextToVideoNode",
display_name="Vidu Q3 Text-to-Video Generation",
category="api node/video/Vidu",
category="video/partner/Vidu",
description="Generate video from a text prompt.",
inputs=[
IO.DynamicCombo.Input(
@ -1429,7 +1429,7 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu3ImageToVideoNode",
display_name="Vidu Q3 Image-to-Video Generation",
category="api node/video/Vidu",
category="video/partner/Vidu",
description="Generate a video from an image and an optional prompt.",
inputs=[
IO.DynamicCombo.Input(
@ -1571,7 +1571,7 @@ class Vidu3StartEndToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu3StartEndToVideoNode",
display_name="Vidu Q3 Start/End Frame-to-Video Generation",
category="api node/video/Vidu",
category="video/partner/Vidu",
description="Generate a video from a start frame, an end frame, and a prompt.",
inputs=[
IO.DynamicCombo.Input(

View File

@ -61,7 +61,7 @@ class WanTextToImageApi(IO.ComfyNode):
return IO.Schema(
node_id="WanTextToImageApi",
display_name="Wan Text to Image",
category="api node/image/Wan",
category="image/partner/Wan",
description="Generates an image based on a text prompt.",
inputs=[
IO.Combo.Input(
@ -184,7 +184,7 @@ class WanImageToImageApi(IO.ComfyNode):
return IO.Schema(
node_id="WanImageToImageApi",
display_name="Wan Image to Image",
category="api node/image/Wan",
category="image/partner/Wan",
description="Generates an image from one or two input images and a text prompt. "
"The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).",
inputs=[
@ -312,7 +312,7 @@ class WanTextToVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="WanTextToVideoApi",
display_name="Wan Text to Video",
category="api node/video/Wan",
category="video/partner/Wan",
description="Generates a video based on a text prompt.",
inputs=[
IO.Combo.Input(
@ -495,7 +495,7 @@ class WanImageToVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="WanImageToVideoApi",
display_name="Wan Image to Video",
category="api node/video/Wan",
category="video/partner/Wan",
description="Generates a video from the first frame and a text prompt.",
inputs=[
IO.Combo.Input(
@ -674,7 +674,7 @@ class WanReferenceVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="WanReferenceVideoApi",
display_name="Wan Reference to Video",
category="api node/video/Wan",
category="video/partner/Wan",
description="Use the character and voice from input videos, combined with a prompt, "
"to generate a new video that maintains character consistency.",
inputs=[
@ -828,7 +828,7 @@ class Wan2TextToVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="Wan2TextToVideoApi",
display_name="Wan 2.7 Text to Video",
category="api node/video/Wan",
category="video/partner/Wan",
description="Generates a video based on a text prompt using the Wan 2.7 model.",
inputs=[
IO.DynamicCombo.Input(
@ -981,7 +981,7 @@ class Wan2ImageToVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="Wan2ImageToVideoApi",
display_name="Wan 2.7 Image to Video",
category="api node/video/Wan",
category="video/partner/Wan",
description="Generate a video from a first-frame image, with optional last-frame image and audio.",
inputs=[
IO.DynamicCombo.Input(
@ -1152,7 +1152,7 @@ class Wan2VideoContinuationApi(IO.ComfyNode):
return IO.Schema(
node_id="Wan2VideoContinuationApi",
display_name="Wan 2.7 Video Continuation",
category="api node/video/Wan",
category="video/partner/Wan",
description="Continue a video from where it left off, with optional last-frame control.",
inputs=[
IO.DynamicCombo.Input(
@ -1319,7 +1319,7 @@ class Wan2VideoEditApi(IO.ComfyNode):
return IO.Schema(
node_id="Wan2VideoEditApi",
display_name="Wan 2.7 Video Edit",
category="api node/video/Wan",
category="video/partner/Wan",
description="Edit a video using text instructions, reference images, or style transfer.",
inputs=[
IO.DynamicCombo.Input(
@ -1477,7 +1477,7 @@ class Wan2ReferenceVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="Wan2ReferenceVideoApi",
display_name="Wan 2.7 Reference to Video",
category="api node/video/Wan",
category="video/partner/Wan",
description="Generate a video featuring a person or object from reference materials. "
"Supports single-character performances and multi-character interactions.",
inputs=[
@ -1651,7 +1651,7 @@ class HappyHorseTextToVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="HappyHorseTextToVideoApi",
display_name="HappyHorse Text to Video",
category="api node/video/Wan",
category="video/partner/Wan",
description="Generates a video based on a text prompt using the HappyHorse model.",
inputs=[
IO.DynamicCombo.Input(
@ -1775,7 +1775,7 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="HappyHorseImageToVideoApi",
display_name="HappyHorse Image to Video",
category="api node/video/Wan",
category="video/partner/Wan",
description="Generate a video from a first-frame image using the HappyHorse model.",
inputs=[
IO.DynamicCombo.Input(
@ -1905,7 +1905,7 @@ class HappyHorseVideoEditApi(IO.ComfyNode):
return IO.Schema(
node_id="HappyHorseVideoEditApi",
display_name="HappyHorse Video Edit",
category="api node/video/Wan",
category="video/partner/Wan",
description="Edit a video using text instructions or reference images with the HappyHorse model. "
"Output duration is 3-15s and matches the input video; inputs longer than 15s are truncated.",
inputs=[
@ -2046,7 +2046,7 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="HappyHorseReferenceVideoApi",
display_name="HappyHorse Reference to Video",
category="api node/video/Wan",
category="video/partner/Wan",
description="Generate a video featuring a person or object from reference materials with the HappyHorse "
"model. Supports single-character performances and multi-character interactions.",
inputs=[

View File

@ -27,7 +27,7 @@ class WavespeedFlashVSRNode(IO.ComfyNode):
return IO.Schema(
node_id="WavespeedFlashVSRNode",
display_name="FlashVSR Video Upscale",
category="api node/video/WaveSpeed",
category="video/partner/WaveSpeed",
description="Fast, high-quality video upscaler that "
"boosts resolution and restores clarity for low-resolution or blurry footage.",
inputs=[
@ -98,7 +98,7 @@ class WavespeedImageUpscaleNode(IO.ComfyNode):
return IO.Schema(
node_id="WavespeedImageUpscaleNode",
display_name="WaveSpeed Image Upscale",
category="api node/image/WaveSpeed",
category="image/partner/WaveSpeed",
description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.",
inputs=[
IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]),

View File

@ -86,7 +86,7 @@ class _PollUIState:
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait", "in_queue"]
async def sync_op(

View File

@ -11,7 +11,7 @@ class TextEncodeAceStepAudio(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TextEncodeAceStepAudio",
category="conditioning",
category="model/conditioning",
inputs=[
IO.Clip.Input("clip"),
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
@ -33,7 +33,7 @@ class TextEncodeAceStepAudio15(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TextEncodeAceStepAudio1.5",
category="conditioning",
category="model/conditioning",
inputs=[
IO.Clip.Input("clip"),
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
@ -67,7 +67,7 @@ class EmptyAceStepLatentAudio(IO.ComfyNode):
return IO.Schema(
node_id="EmptyAceStepLatentAudio",
display_name="Empty Ace Step 1.0 Latent Audio",
category="latent/audio",
category="model/latent/audio",
inputs=[
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
IO.Int.Input(
@ -90,7 +90,7 @@ class EmptyAceStep15LatentAudio(IO.ComfyNode):
return IO.Schema(
node_id="EmptyAceStep1.5LatentAudio",
display_name="Empty Ace Step 1.5 Latent Audio",
category="latent/audio",
category="model/latent/audio",
inputs=[
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
IO.Int.Input(

View File

@ -45,7 +45,7 @@ class SamplerLCMUpscale(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="SamplerLCMUpscale",
category="sampling/samplers",
category="model/sampling/samplers",
inputs=[
io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01, advanced=True),
io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1, advanced=True),
@ -91,7 +91,7 @@ class SamplerLCM(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="SamplerLCM",
category="sampling/samplers",
category="model/sampling/samplers",
description=("LCM sampler with tunable per-step noise. s_noise is a multiplier on the model's training noise scale"),
inputs=[
io.Float.Input("s_noise", default=1.0, min=0.0, max=64.0, step=0.01,

View File

@ -29,7 +29,7 @@ class AlignYourStepsScheduler(io.ComfyNode):
return io.Schema(
node_id="AlignYourStepsScheduler",
search_aliases=["AYS scheduler"],
category="sampling/schedulers",
category="model/sampling/schedulers",
inputs=[
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),
io.Int.Input("steps", default=10, min=1, max=10000),

View File

@ -16,7 +16,7 @@ class APG(io.ComfyNode):
return io.Schema(
node_id="APG",
display_name="Adaptive Projected Guidance",
category="sampling/custom_sampling",
category="model/sampling/custom_sampling",
inputs=[
io.Model.Input("model"),
io.Float.Input(

View File

@ -19,7 +19,7 @@ class EmptyARVideoLatent(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="EmptyARVideoLatent",
category="latent/video",
category="model/latent/video",
inputs=[
io.Int.Input("width", default=832, min=16, max=8192, step=16),
io.Int.Input("height", default=480, min=16, max=8192, step=16),
@ -53,7 +53,7 @@ class SamplerARVideo(io.ComfyNode):
return io.Schema(
node_id="SamplerARVideo",
display_name="Sampler AR Video",
category="sampling/samplers",
category="model/sampling/samplers",
inputs=[
io.Int.Input(
"num_frame_per_block",
@ -85,7 +85,7 @@ class ARVideoI2V(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ARVideoI2V",
category="conditioning/video_models",
category="model/conditioning/video_models",
inputs=[
io.Model.Input("model"),
io.Vae.Input("vae"),

View File

@ -16,7 +16,7 @@ class EmptyLatentAudio(IO.ComfyNode):
return IO.Schema(
node_id="EmptyLatentAudio",
display_name="Empty Latent Audio",
category="latent/audio",
category="model/latent/audio",
essentials_category="Audio",
inputs=[
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
@ -41,7 +41,7 @@ class ConditioningStableAudio(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="ConditioningStableAudio",
category="conditioning",
category="model/conditioning",
inputs=[
IO.Conditioning.Input("positive"),
IO.Conditioning.Input("negative"),
@ -70,7 +70,7 @@ class VAEEncodeAudio(IO.ComfyNode):
node_id="VAEEncodeAudio",
search_aliases=["audio to latent"],
display_name="VAE Encode Audio",
category="latent/audio",
category="model/latent/audio",
inputs=[
IO.Audio.Input("audio"),
IO.Vae.Input("vae"),
@ -115,7 +115,7 @@ class VAEDecodeAudio(IO.ComfyNode):
node_id="VAEDecodeAudio",
search_aliases=["latent to audio"],
display_name="VAE Decode Audio",
category="latent/audio",
category="model/latent/audio",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
@ -137,7 +137,7 @@ class VAEDecodeAudioTiled(IO.ComfyNode):
node_id="VAEDecodeAudioTiled",
search_aliases=["latent to audio"],
display_name="VAE Decode Audio (Tiled)",
category="latent/audio",
category="model/latent/audio",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),

View File

@ -11,7 +11,7 @@ class AudioEncoderLoader(io.ComfyNode):
return io.Schema(
node_id="AudioEncoderLoader",
display_name="Load Audio Encoder",
category="loaders",
category="model/loaders",
inputs=[
io.Combo.Input(
"audio_encoder_name",
@ -36,7 +36,7 @@ class AudioEncoderEncode(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="AudioEncoderEncode",
category="conditioning",
category="model/conditioning",
inputs=[
io.AudioEncoder.Input("audio_encoder"),
io.Audio.Input("audio"),

View File

@ -11,7 +11,7 @@ class LoadBackgroundRemovalModel(IO.ComfyNode):
return IO.Schema(
node_id="LoadBackgroundRemovalModel",
display_name="Load Background Removal Model",
category="loaders",
category="model/loaders",
inputs=[
IO.Combo.Input("bg_removal_name", options=sorted(files), tooltip="The model used to remove backgrounds from images"),
],

View File

@ -153,7 +153,7 @@ class WanCameraEmbedding(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="WanCameraEmbedding",
category="conditioning/video_models",
category="model/conditioning/video_models",
inputs=[
io.Combo.Input(
"camera_pose",

View File

@ -57,24 +57,55 @@ class CFGNorm(io.ComfyNode):
inputs=[
io.Model.Input("model"),
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01),
io.Boolean.Input(
"pre_cfg",
default=False,
optional=True,
tooltip=(
"If true, rescale the combined noise BEFORE the sampler's CFG combine, "
"without clamping (can amplify). Matches the norm-scaled CFG used by "
"models like Lens. Default false keeps the original post-CFG x0-space "
"attenuate-only behavior."
),
),
],
outputs=[io.Model.Output(display_name="patched_model")],
is_experimental=True,
)
@classmethod
def execute(cls, model, strength) -> io.NodeOutput:
def execute(cls, model, strength, pre_cfg=False) -> io.NodeOutput:
m = model.clone()
def cfg_norm(args):
cond_p = args['cond_denoised']
pred_text_ = args["denoised"]
if pre_cfg:
def cfg_norm_pre(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]
comb = uncond + cond_scale * (cond - uncond)
cond_norm = torch.linalg.vector_norm(cond, dim=1, keepdim=True)
comb_norm = torch.linalg.vector_norm(comb, dim=1, keepdim=True)
rescale = torch.where(
comb_norm > 0,
cond_norm / comb_norm.clamp_min(1e-12),
torch.ones_like(comb_norm),
)
rescaled = comb * rescale
# strength blends back toward standard linear CFG (1.0 = full rescale).
if strength != 1.0:
rescaled = strength * rescaled + (1.0 - strength) * comb
return rescaled
m.set_model_sampler_cfg_function(cfg_norm_pre)
else:
def cfg_norm(args):
cond_p = args['cond_denoised']
pred_text_ = args["denoised"]
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
return pred_text_ * scale * strength
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
return pred_text_ * scale * strength
m.set_model_sampler_post_cfg_function(cfg_norm)
m.set_model_sampler_post_cfg_function(cfg_norm)
return io.NodeOutput(m)

View File

@ -13,7 +13,7 @@ class EmptyChromaRadianceLatentImage(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="EmptyChromaRadianceLatentImage",
category="latent/chroma_radiance",
category="model/latent/chroma_radiance",
inputs=[
io.Int.Input(id="width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input(id="height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
@ -33,7 +33,7 @@ class ChromaRadianceOptions(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="ChromaRadianceOptions",
category="model_patches/chroma_radiance",
category="model/patch/chroma_radiance",
description="Allows setting advanced options for the Chroma Radiance model.",
inputs=[
io.Model.Input(id="model"),

View File

@ -8,7 +8,7 @@ class ColorToRGBInt(io.ComfyNode):
return io.Schema(
node_id="ColorToRGBInt",
display_name="Color to RGB Int",
category="utils",
category="utilities",
description="Convert a color to a RGB integer value.",
inputs=[
io.Color.Input("color"),

View File

@ -9,7 +9,7 @@ class ContextWindowsManualNode(io.ComfyNode):
return io.Schema(
node_id="ContextWindowsManual",
display_name="Context Windows (Manual)",
category="model_patches",
category="model/patch",
description="Manually set context windows.",
inputs=[
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),

View File

@ -9,7 +9,7 @@ class SetUnionControlNetType(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SetUnionControlNetType",
category="conditioning/controlnet",
category="model/conditioning/controlnet",
inputs=[
io.ControlNet.Input("control_net"),
io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())),
@ -39,7 +39,7 @@ class ControlNetInpaintingAliMamaApply(io.ComfyNode):
return io.Schema(
node_id="ControlNetInpaintingAliMamaApply",
search_aliases=["masked controlnet"],
category="conditioning/controlnet",
category="model/conditioning/controlnet",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),

View File

@ -13,7 +13,7 @@ class EmptyCosmosLatentVideo(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="EmptyCosmosLatentVideo",
category="latent/video",
category="model/latent/video",
inputs=[
io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16),
@ -45,7 +45,7 @@ class CosmosImageToVideoLatent(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="CosmosImageToVideoLatent",
category="conditioning/inpaint",
category="model/conditioning/inpaint",
inputs=[
io.Vae.Input("vae"),
io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16),
@ -88,7 +88,7 @@ class CosmosPredict2ImageToVideoLatent(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="CosmosPredict2ImageToVideoLatent",
category="conditioning/inpaint",
category="model/conditioning/inpaint",
inputs=[
io.Vae.Input("vae"),
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),

View File

@ -11,7 +11,7 @@ class CurveEditor(io.ComfyNode):
return io.Schema(
node_id="CurveEditor",
display_name="Curve Editor",
category="utils",
category="utilities",
inputs=[
io.Curve.Input("curve"),
io.Histogram.Input("histogram", optional=True),
@ -38,7 +38,7 @@ class ImageHistogram(io.ComfyNode):
return io.Schema(
node_id="ImageHistogram",
display_name="Image Histogram",
category="utils",
category="utilities",
inputs=[
io.Image.Input("image"),
],

View File

@ -17,7 +17,7 @@ class BasicScheduler(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="BasicScheduler",
category="sampling/schedulers",
category="model/sampling/schedulers",
inputs=[
io.Model.Input("model"),
io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES),
@ -47,7 +47,7 @@ class KarrasScheduler(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="KarrasScheduler",
category="sampling/schedulers",
category="model/sampling/schedulers",
inputs=[
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
@ -69,7 +69,7 @@ class ExponentialScheduler(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ExponentialScheduler",
category="sampling/schedulers",
category="model/sampling/schedulers",
inputs=[
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
@ -90,7 +90,7 @@ class PolyexponentialScheduler(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="PolyexponentialScheduler",
category="sampling/schedulers",
category="model/sampling/schedulers",
inputs=[
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
@ -112,7 +112,7 @@ class LaplaceScheduler(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LaplaceScheduler",
category="sampling/schedulers",
category="model/sampling/schedulers",
inputs=[
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
@ -136,7 +136,7 @@ class SDTurboScheduler(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SDTurboScheduler",
category="sampling/schedulers",
category="model/sampling/schedulers",
inputs=[
io.Model.Input("model"),
io.Int.Input("steps", default=1, min=1, max=10),
@ -160,7 +160,7 @@ class BetaSamplingScheduler(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="BetaSamplingScheduler",
category="sampling/schedulers",
category="model/sampling/schedulers",
inputs=[
io.Model.Input("model"),
io.Int.Input("steps", default=20, min=1, max=10000),
@ -182,7 +182,7 @@ class VPScheduler(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="VPScheduler",
category="sampling/schedulers",
category="model/sampling/schedulers",
inputs=[
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), #TODO: fix default values
@ -204,7 +204,7 @@ class SplitSigmas(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SplitSigmas",
category="sampling/sigmas",
category="model/sampling/sigmas",
inputs=[
io.Sigmas.Input("sigmas"),
io.Int.Input("step", default=0, min=0, max=10000),
@ -228,7 +228,7 @@ class SplitSigmasDenoise(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SplitSigmasDenoise",
category="sampling/sigmas",
category="model/sampling/sigmas",
inputs=[
io.Sigmas.Input("sigmas"),
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
@ -254,7 +254,7 @@ class FlipSigmas(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="FlipSigmas",
category="sampling/sigmas",
category="model/sampling/sigmas",
inputs=[io.Sigmas.Input("sigmas")],
outputs=[io.Sigmas.Output()]
)
@ -276,7 +276,7 @@ class SetFirstSigma(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SetFirstSigma",
category="sampling/sigmas",
category="model/sampling/sigmas",
inputs=[
io.Sigmas.Input("sigmas"),
io.Float.Input("sigma", default=136.0, min=0.0, max=20000.0, step=0.001, round=False),
@ -298,7 +298,7 @@ class ExtendIntermediateSigmas(io.ComfyNode):
return io.Schema(
node_id="ExtendIntermediateSigmas",
search_aliases=["interpolate sigmas"],
category="sampling/sigmas",
category="model/sampling/sigmas",
inputs=[
io.Sigmas.Input("sigmas"),
io.Int.Input("steps", default=2, min=1, max=100),
@ -351,7 +351,7 @@ class SamplingPercentToSigma(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SamplingPercentToSigma",
category="sampling/sigmas",
category="model/sampling/sigmas",
inputs=[
io.Model.Input("model"),
io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001),
@ -379,7 +379,7 @@ class KSamplerSelect(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="KSamplerSelect",
category="sampling/samplers",
category="model/sampling/samplers",
inputs=[io.Combo.Input("sampler_name", options=comfy.samplers.SAMPLER_NAMES)],
outputs=[io.Sampler.Output()]
)
@ -396,7 +396,7 @@ class SamplerDPMPP_3M_SDE(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SamplerDPMPP_3M_SDE",
category="sampling/samplers",
category="model/sampling/samplers",
inputs=[
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
@ -421,7 +421,7 @@ class SamplerDPMPP_2M_SDE(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SamplerDPMPP_2M_SDE",
category="sampling/samplers",
category="model/sampling/samplers",
inputs=[
io.Combo.Input("solver_type", options=['midpoint', 'heun']),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
@ -448,7 +448,7 @@ class SamplerDPMPP_SDE(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SamplerDPMPP_SDE",
category="sampling/samplers",
category="model/sampling/samplers",
inputs=[
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
@ -474,7 +474,7 @@ class SamplerDPMPP_2S_Ancestral(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SamplerDPMPP_2S_Ancestral",
category="sampling/samplers",
category="model/sampling/samplers",
inputs=[
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
@ -494,7 +494,7 @@ class SamplerEulerAncestral(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SamplerEulerAncestral",
category="sampling/samplers",
category="model/sampling/samplers",
inputs=[
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
@ -515,7 +515,7 @@ class SamplerEulerAncestralCFGPP(io.ComfyNode):
return io.Schema(
node_id="SamplerEulerAncestralCFGPP",
display_name="SamplerEulerAncestralCFG++",
category="sampling/samplers",
category="model/sampling/samplers",
inputs=[
io.Float.Input("eta", default=1.0, min=0.0, max=1.0, step=0.01, round=False),
io.Float.Input("s_noise", default=1.0, min=0.0, max=10.0, step=0.01, round=False),
@ -537,7 +537,7 @@ class SamplerLMS(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SamplerLMS",
category="sampling/samplers",
category="model/sampling/samplers",
inputs=[io.Int.Input("order", default=4, min=1, max=100, advanced=True)],
outputs=[io.Sampler.Output()]
)
@ -554,7 +554,7 @@ class SamplerDPMAdaptative(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SamplerDPMAdaptative",
category="sampling/samplers",
category="model/sampling/samplers",
inputs=[
io.Int.Input("order", default=3, min=2, max=3, advanced=True),
io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
@ -585,7 +585,7 @@ class SamplerER_SDE(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SamplerER_SDE",
category="sampling/samplers",
category="model/sampling/samplers",
inputs=[
io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]),
io.Int.Input("max_stage", default=3, min=1, max=3, advanced=True),
@ -623,7 +623,7 @@ class SamplerSASolver(io.ComfyNode):
return io.Schema(
node_id="SamplerSASolver",
search_aliases=["sde"],
category="sampling/samplers",
category="model/sampling/samplers",
inputs=[
io.Model.Input("model"),
io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False, advanced=True),
@ -668,7 +668,7 @@ class SamplerSEEDS2(io.ComfyNode):
return io.Schema(
node_id="SamplerSEEDS2",
search_aliases=["sde", "exp heun"],
category="sampling/samplers",
category="model/sampling/samplers",
inputs=[
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength", advanced=True),
@ -727,7 +727,7 @@ class SamplerCustom(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SamplerCustom",
category="sampling/custom_sampling",
category="model/sampling/custom_sampling",
inputs=[
io.Model.Input("model"),
io.Boolean.Input("add_noise", default=True, advanced=True),
@ -795,7 +795,7 @@ class BasicGuider(io.ComfyNode):
return io.Schema(
node_id="BasicGuider",
display_name="Basic Guider",
category="sampling/guiders",
category="model/sampling/guiders",
inputs=[
io.Model.Input("model"),
io.Conditioning.Input("conditioning"),
@ -817,7 +817,7 @@ class CFGGuider(io.ComfyNode):
return io.Schema(
node_id="CFGGuider",
display_name="CFG Guider",
category="sampling/guiders",
category="model/sampling/guiders",
inputs=[
io.Model.Input("model"),
io.Conditioning.Input("positive"),
@ -872,7 +872,7 @@ class DualCFGGuider(io.ComfyNode):
node_id="DualCFGGuider",
search_aliases=["dual prompt guidance"],
display_name="Dual CFG Guider",
category="sampling/guiders",
category="model/sampling/guiders",
inputs=[
io.Model.Input("model"),
io.Conditioning.Input("cond1"),
@ -900,7 +900,7 @@ class DisableNoise(io.ComfyNode):
return io.Schema(
node_id="DisableNoise",
search_aliases=["zero noise"],
category="sampling/noise",
category="model/sampling/noise",
inputs=[],
outputs=[io.Noise.Output()]
)
@ -917,7 +917,7 @@ class RandomNoise(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="RandomNoise",
category="sampling/noise",
category="model/sampling/noise",
inputs=[io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True)],
outputs=[io.Noise.Output()]
)
@ -934,7 +934,7 @@ class SamplerCustomAdvanced(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SamplerCustomAdvanced",
category="sampling/custom_sampling",
category="model/sampling/custom_sampling",
inputs=[
io.Noise.Input("noise"),
io.Guider.Input("guider"),

View File

@ -157,7 +157,7 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
return io.NodeOutput(output_tensor, captions)
def save_images_to_folder(image_list, output_dir, prefix="image"):
def save_images_to_folder(image_list, output_dir, prefix="image", overwrite=True):
"""Utility function to save a list of image tensors to disk.
Args:
@ -197,7 +197,11 @@ def save_images_to_folder(image_list, output_dir, prefix="image"):
raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}")
# Save image
filename = f"{prefix}_{idx:05d}.png"
if overwrite:
filename = f"{prefix}_{idx:05d}.png"
else:
_, _, counter, _, resolved_prefix = folder_paths.get_save_image_path(prefix, output_dir)
filename = f"{resolved_prefix}_{counter:05}_{idx:05d}.png"
filepath = os.path.join(output_dir, filename)
img.save(filepath)
saved_files.append(filename)
@ -230,19 +234,26 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
tooltip="Prefix for saved image filenames.",
advanced=True,
),
io.Combo.Input(
"mode",
default="overwrite",
options=["overwrite", "increment"],
tooltip="Whether to overwrite existing files or increment filenames to avoid overwriting."
),
],
outputs=[],
is_deprecated=True, # This node is redundant and superseded by existing Save Image nodes where the target folder can be specified in the filename_prefix
)
@classmethod
def execute(cls, images, folder_name, filename_prefix):
def execute(cls, images, folder_name, filename_prefix, mode):
# Extract scalar values
folder_name = folder_name[0]
filename_prefix = filename_prefix[0]
mode = mode[0]
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
saved_files = save_images_to_folder(images, output_dir, filename_prefix)
saved_files = save_images_to_folder(images, output_dir, filename_prefix, mode=='overwrite')
logging.info(f"Saved {len(saved_files)} images to {output_dir}.")
return io.NodeOutput()
@ -278,18 +289,25 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
tooltip="Prefix for saved image filenames.",
advanced=True,
),
io.Combo.Input(
"mode",
default="overwrite",
options=["overwrite", "increment"],
tooltip="Whether to overwrite existing files or increment filenames to avoid overwriting."
),
],
outputs=[],
)
@classmethod
def execute(cls, images, folder_name, filename_prefix, texts=None):
def execute(cls, images, folder_name, filename_prefix, mode, texts=None):
# Extract scalar values
folder_name = folder_name[0]
filename_prefix = filename_prefix[0]
mode = mode[0]
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
saved_files = save_images_to_folder(images, output_dir, filename_prefix)
saved_files = save_images_to_folder(images, output_dir, filename_prefix, mode=='overwrite')
# Save captions
if texts:
@ -574,7 +592,7 @@ class TextProcessingNode(io.ComfyNode):
return io.Schema(
node_id=cls.node_id,
display_name=cls.display_name or cls.node_id,
category="dataset/text",
category="text",
is_experimental=True,
is_input_list=is_group, # True for group, False for individual
inputs=inputs,
@ -1208,7 +1226,7 @@ class ResolutionBucket(io.ComfyNode):
node_id="ResolutionBucket",
search_aliases=["bucket by resolution", "group by resolution", "batch by resolution"],
display_name="Resolution Bucket",
category="training",
category="model/training",
description="Group latents and conditionings into buckets",
is_experimental=True,
is_input_list=True,
@ -1302,7 +1320,7 @@ class MakeTrainingDataset(io.ComfyNode):
node_id="MakeTrainingDataset",
search_aliases=["encode dataset"],
display_name="Make Training Dataset",
category="training",
category="model/training",
description="Encode images with VAE and texts with CLIP to create a training dataset of latents and conditionings.",
is_experimental=True,
is_input_list=True, # images and texts as lists
@ -1390,7 +1408,7 @@ class SaveTrainingDataset(io.ComfyNode):
node_id="SaveTrainingDataset",
search_aliases=["export dataset", "save dataset"],
display_name="Save Training Dataset",
category="training",
category="model/training",
description="Save encoded training dataset (latents + conditioning) to disk for efficient loading during training.",
is_experimental=True,
is_output_node=True,
@ -1493,7 +1511,7 @@ class LoadTrainingDataset(io.ComfyNode):
node_id="LoadTrainingDataset",
search_aliases=["import dataset", "training data"],
display_name="Load Training Dataset",
category="training",
category="model/training",
description="Load encoded training dataset (latents + conditioning) from disk for use in training.",
is_experimental=True,
inputs=[

View File

@ -18,7 +18,7 @@ class EpsilonScaling(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="Epsilon Scaling",
category="model_patches/unet",
category="model/patch/unet",
inputs=[
io.Model.Input("model"),
io.Float.Input(
@ -84,7 +84,7 @@ class TemporalScoreRescaling(io.ComfyNode):
return io.Schema(
node_id="TemporalScoreRescaling",
display_name="TSR - Temporal Score Rescaling",
category="model_patches/unet",
category="model/patch/unet",
inputs=[
io.Model.Input("model"),
io.Float.Input(

View File

@ -40,7 +40,7 @@ class EmptyFlux2LatentImage(io.ComfyNode):
return io.Schema(
node_id="EmptyFlux2LatentImage",
display_name="Empty Flux 2 Latent",
category="latent",
category="model/latent",
inputs=[
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
@ -215,7 +215,7 @@ class Flux2Scheduler(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="Flux2Scheduler",
category="sampling/schedulers",
category="model/sampling/schedulers",
inputs=[
io.Int.Input("steps", default=20, min=1, max=4096),
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),

View File

@ -19,7 +19,7 @@ class FrameInterpolationModelLoader(io.ComfyNode):
return io.Schema(
node_id="FrameInterpolationModelLoader",
display_name="Load Frame Interpolation Model",
category="loaders",
category="model/loaders",
inputs=[
io.Combo.Input("model_name", options=folder_paths.get_filename_list("frame_interpolation"),
tooltip="Select a frame interpolation model to load. Models must be placed in the 'frame_interpolation' folder."),

View File

@ -29,7 +29,7 @@ class FreeU(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="FreeU",
category="model_patches/unet",
category="model/patch/unet",
inputs=[
IO.Model.Input("model"),
IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01, advanced=True),
@ -76,7 +76,7 @@ class FreeU_V2(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="FreeU_V2",
category="model_patches/unet",
category="model/patch/unet",
inputs=[
IO.Model.Input("model"),
IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01, advanced=True),

View File

@ -340,7 +340,7 @@ class GITSScheduler(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="GITSScheduler",
category="sampling/schedulers",
category="model/sampling/schedulers",
inputs=[
io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05, advanced=True),
io.Int.Input("steps", default=10, min=2, max=1000),

View File

@ -14,7 +14,7 @@ class EmptyHiDreamO1LatentImage(io.ComfyNode):
return io.Schema(
node_id="EmptyHiDreamO1LatentImage",
display_name="Empty HiDream-O1 Latent Image",
category="latent/image",
category="model/latent/image",
description=(
"Empty pixel-space latent for HiDream-O1-Image. The model was "
"trained at ~4 megapixels; lower resolutions go off-distribution "
@ -47,7 +47,7 @@ class HiDreamO1ReferenceImages(io.ComfyNode):
return io.Schema(
node_id="HiDreamO1ReferenceImages",
display_name="HiDream-O1 Reference Images",
category="conditioning/image",
category="model/conditioning/image",
description=(
"Attach 1-10 reference images to conditioning, one for edit instruction"
"or multiple for subject-driven personalization."

View File

@ -41,7 +41,7 @@ class EmptyHunyuanLatentVideo(io.ComfyNode):
return io.Schema(
node_id="EmptyHunyuanLatentVideo",
display_name="Empty HunyuanVideo 1.0 Latent",
category="latent/video",
category="model/latent/video",
inputs=[
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
@ -81,7 +81,7 @@ class HunyuanVideo15ImageToVideo(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="HunyuanVideo15ImageToVideo",
category="conditioning/video_models",
category="model/conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
@ -132,7 +132,7 @@ class HunyuanVideo15SuperResolution(io.ComfyNode):
return io.Schema(
node_id="HunyuanVideo15SuperResolution",
display_name="Hunyuan Video 1.5 Super Resolution",
category="conditioning/video_models",
category="model/conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
@ -178,7 +178,7 @@ class LatentUpscaleModelLoader(io.ComfyNode):
return io.Schema(
node_id="LatentUpscaleModelLoader",
display_name="Load Latent Upscale Model",
category="loaders",
category="model/loaders",
inputs=[
io.Combo.Input("model_name", options=folder_paths.get_filename_list("latent_upscale_models")),
],
@ -227,7 +227,7 @@ class HunyuanVideo15LatentUpscaleWithModel(io.ComfyNode):
return io.Schema(
node_id="HunyuanVideo15LatentUpscaleWithModel",
display_name="Hunyuan Video 15 Latent Upscale With Model",
category="latent",
category="model/latent",
inputs=[
io.LatentUpscaleModel.Input("model"),
io.Latent.Input("samples"),
@ -308,7 +308,7 @@ class HunyuanImageToVideo(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="HunyuanImageToVideo",
category="conditioning/video_models",
category="model/conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Vae.Input("vae"),
@ -359,7 +359,7 @@ class EmptyHunyuanImageLatent(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="EmptyHunyuanImageLatent",
category="latent",
category="model/latent",
inputs=[
io.Int.Input("width", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("height", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
@ -384,7 +384,7 @@ class HunyuanRefinerLatent(io.ComfyNode):
return io.Schema(
node_id="HunyuanRefinerLatent",
display_name="Hunyuan Latent Refiner",
category="conditioning/video_models",
category="model/conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),

View File

@ -12,7 +12,7 @@ class EmptyLatentHunyuan3Dv2(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="EmptyLatentHunyuan3Dv2",
category="latent/3d",
category="model/latent/3d",
inputs=[
IO.Int.Input("resolution", default=3072, min=1, max=8192),
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
@ -35,7 +35,7 @@ class Hunyuan3Dv2Conditioning(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="Hunyuan3Dv2Conditioning",
category="conditioning/3d_models",
category="model/conditioning/3d_models",
inputs=[
IO.ClipVisionOutput.Input("clip_vision_output"),
],
@ -60,7 +60,7 @@ class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="Hunyuan3Dv2ConditioningMultiView",
category="conditioning/3d_models",
category="model/conditioning/3d_models",
inputs=[
IO.ClipVisionOutput.Input("front", optional=True),
IO.ClipVisionOutput.Input("left", optional=True),
@ -97,7 +97,7 @@ class VAEDecodeHunyuan3D(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="VAEDecodeHunyuan3D",
category="latent/3d",
category="model/latent/3d",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),

View File

@ -103,7 +103,7 @@ class HypernetworkLoader(IO.ComfyNode):
return IO.Schema(
node_id="HypernetworkLoader",
display_name="Load Hypernetwork",
category="loaders",
category="model/loaders",
inputs=[
IO.Model.Input("model"),
IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),

View File

@ -27,7 +27,7 @@ class HyperTile(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="HyperTile",
category="model_patches/unet",
category="model/patch/unet",
inputs=[
io.Model.Input("model"),
io.Int.Input("tile_size", default=256, min=1, max=2048, advanced=True),

View File

@ -95,7 +95,7 @@ class BoundingBox(IO.ComfyNode):
return IO.Schema(
node_id="PrimitiveBoundingBox",
display_name="Bounding Box",
category="utils/primitive",
category="utilities/primitive",
inputs=[
IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION),
IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION),

View File

@ -9,7 +9,7 @@ class InstructPixToPixConditioning(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="InstructPixToPixConditioning",
category="conditioning/instructpix2pix",
category="model/conditioning/instructpix2pix",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),

View File

@ -13,7 +13,7 @@ class Kandinsky5ImageToVideo(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="Kandinsky5ImageToVideo",
category="conditioning/video_models",
category="model/conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
@ -71,7 +71,7 @@ class NormalizeVideoLatentStart(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="NormalizeVideoLatentStart",
category="conditioning/video_models",
category="model/conditioning/video_models",
description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.",
inputs=[
io.Latent.Input("latent"),

View File

@ -22,7 +22,7 @@ class LatentAdd(io.ComfyNode):
return io.Schema(
node_id="LatentAdd",
search_aliases=["combine latents", "sum latents"],
category="latent/advanced",
category="model/latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
@ -49,7 +49,7 @@ class LatentSubtract(io.ComfyNode):
return io.Schema(
node_id="LatentSubtract",
search_aliases=["difference latent", "remove features"],
category="latent/advanced",
category="model/latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
@ -76,7 +76,7 @@ class LatentMultiply(io.ComfyNode):
return io.Schema(
node_id="LatentMultiply",
search_aliases=["scale latent", "amplify latent", "latent gain"],
category="latent/advanced",
category="model/latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
@ -100,7 +100,7 @@ class LatentInterpolate(io.ComfyNode):
return io.Schema(
node_id="LatentInterpolate",
search_aliases=["blend latent", "mix latent", "lerp latent", "transition"],
category="latent/advanced",
category="model/latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
@ -139,7 +139,7 @@ class LatentConcat(io.ComfyNode):
return io.Schema(
node_id="LatentConcat",
search_aliases=["join latents", "stitch latents"],
category="latent/advanced",
category="model/latent/advanced",
inputs=[
io.Latent.Input("samples1"),
io.Latent.Input("samples2"),
@ -179,7 +179,7 @@ class LatentCut(io.ComfyNode):
return io.Schema(
node_id="LatentCut",
search_aliases=["crop latent", "slice latent", "extract region"],
category="latent/advanced",
category="model/latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Combo.Input("dim", options=["x", "y", "t"]),
@ -220,7 +220,7 @@ class LatentCutToBatch(io.ComfyNode):
return io.Schema(
node_id="LatentCutToBatch",
search_aliases=["slice to batch", "split latent", "tile latent"],
category="latent/advanced",
category="model/latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Combo.Input("dim", options=["t", "x", "y"]),
@ -262,7 +262,7 @@ class LatentBatch(io.ComfyNode):
return io.Schema(
node_id="LatentBatch",
search_aliases=["combine latents", "merge latents", "join latents"],
category="latent/batch",
category="model/latent/batch",
is_deprecated=True,
inputs=[
io.Latent.Input("samples1"),
@ -290,7 +290,7 @@ class LatentBatchSeedBehavior(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentBatchSeedBehavior",
category="latent/advanced",
category="model/latent/advanced",
inputs=[
io.Latent.Input("samples"),
io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"),
@ -319,7 +319,7 @@ class LatentApplyOperation(io.ComfyNode):
return io.Schema(
node_id="LatentApplyOperation",
search_aliases=["transform latent"],
category="latent/advanced/operations",
category="model/latent/advanced/operations",
is_experimental=True,
inputs=[
io.Latent.Input("samples"),
@ -343,7 +343,7 @@ class LatentApplyOperationCFG(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentApplyOperationCFG",
category="latent/advanced/operations",
category="model/latent/advanced/operations",
is_experimental=True,
inputs=[
io.Model.Input("model"),
@ -375,7 +375,7 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
return io.Schema(
node_id="LatentOperationTonemapReinhard",
search_aliases=["hdr latent"],
category="latent/advanced/operations",
category="model/latent/advanced/operations",
is_experimental=True,
inputs=[
io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
@ -410,7 +410,7 @@ class LatentOperationSharpen(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LatentOperationSharpen",
category="latent/advanced/operations",
category="model/latent/advanced/operations",
is_experimental=True,
inputs=[
io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1, advanced=True),
@ -447,7 +447,7 @@ class ReplaceVideoLatentFrames(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="ReplaceVideoLatentFrames",
category="latent/batch",
category="model/latent/batch",
inputs=[
io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."),
io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."),

View File

@ -34,7 +34,7 @@ class Load3D(IO.ComfyNode):
essentials_category="Basics",
is_experimental=True,
inputs=[
IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model),
IO.Combo.Input("model_file", options=["none"] + sorted(files), upload=IO.UploadType.model),
IO.Load3D.Input("image"),
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
@ -68,8 +68,12 @@ class Load3D(IO.ComfyNode):
video = InputImpl.VideoFromFile(recording_video_path)
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video, file_3d)
file_3d = None
mesh_path = ""
if model_file and model_file != "none":
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
mesh_path = model_file
return IO.NodeOutput(output_image, output_mask, mesh_path, normal_image, image['camera_info'], video, file_3d)
process = execute # TODO: remove

View File

@ -13,7 +13,7 @@ class NotNode(io.ComfyNode):
return io.Schema(
node_id="ComfyNotNode",
display_name="Not",
category="utils/logic",
category="utilities/logic",
description="Logical NOT operation. Returns true if the value is falsy. Uses Python's rules for truthiness.",
search_aliases=["invert", "toggle", "negate", "flip boolean"],
inputs=[
@ -40,7 +40,7 @@ class AndNode(io.ComfyNode):
return io.Schema(
node_id="ComfyAndNode",
display_name="And",
category="utils/logic",
category="utilities/logic",
description="Logical AND operation. Returns true if all of the values are truthy. Uses Python's rules for truthiness.",
search_aliases=["all", "every"],
inputs=[
@ -67,7 +67,7 @@ class OrNode(io.ComfyNode):
return io.Schema(
node_id="ComfyOrNode",
display_name="Or",
category="utils/logic",
category="utilities/logic",
description="Logical OR operation. Returns true if any of the values are truthy. Uses Python's rules for truthiness.",
search_aliases=["any", "some"],
inputs=[
@ -90,7 +90,7 @@ class SwitchNode(io.ComfyNode):
return io.Schema(
node_id="ComfySwitchNode",
display_name="Switch",
category="utils/logic",
category="utilities/logic",
is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
@ -121,7 +121,7 @@ class SoftSwitchNode(io.ComfyNode):
return io.Schema(
node_id="ComfySoftSwitchNode",
display_name="Soft Switch",
category="utils/logic",
category="utilities/logic",
is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
@ -176,7 +176,7 @@ class CustomComboNode(io.ComfyNode):
return io.Schema(
node_id="CustomCombo",
display_name="Custom Combo",
category="utils",
category="utilities",
is_experimental=True,
inputs=[io.Combo.Input("choice", options=[])],
outputs=[
@ -211,7 +211,7 @@ class DCTestNode(io.ComfyNode):
return io.Schema(
node_id="DCTestNode",
display_name="DCTest",
category="utils/logic",
category="utilities/logic",
is_output_node=True,
inputs=[io.DynamicCombo.Input("combo", options=[
io.DynamicCombo.Option("option1", [io.String.Input("string")]),
@ -249,7 +249,7 @@ class AutogrowNamesTestNode(io.ComfyNode):
return io.Schema(
node_id="AutogrowNamesTestNode",
display_name="AutogrowNamesTest",
category="utils/logic",
category="utilities/logic",
inputs=[
_io.Autogrow.Input("autogrow", template=template)
],
@ -269,7 +269,7 @@ class AutogrowPrefixTestNode(io.ComfyNode):
return io.Schema(
node_id="AutogrowPrefixTestNode",
display_name="AutogrowPrefixTest",
category="utils/logic",
category="utilities/logic",
inputs=[
_io.Autogrow.Input("autogrow", template=template)
],
@ -288,7 +288,7 @@ class ComboOutputTestNode(io.ComfyNode):
return io.Schema(
node_id="ComboOptionTestNode",
display_name="ComboOptionTest",
category="utils/logic",
category="utilities/logic",
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
outputs=[io.Combo.Output(), io.Combo.Output()],
@ -305,7 +305,7 @@ class ConvertStringToComboNode(io.ComfyNode):
node_id="ConvertStringToComboNode",
search_aliases=["string to dropdown", "text to combo"],
display_name="Convert String to Combo",
category="utils/logic",
category="utilities/logic",
inputs=[io.String.Input("string")],
outputs=[io.Combo.Output()],
)
@ -321,7 +321,7 @@ class InvertBooleanNode(io.ComfyNode):
node_id="InvertBooleanNode",
search_aliases=["not", "toggle", "negate", "flip boolean"],
display_name="Invert Boolean",
category="utils/logic",
category="utilities/logic",
inputs=[io.Boolean.Input("boolean")],
outputs=[io.Boolean.Output()],
)

View File

@ -30,7 +30,7 @@ class LoraLoaderBypass:
OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.")
FUNCTION = "load_lora"
CATEGORY = "loaders"
CATEGORY = "model/loaders"
DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios."
EXPERIMENTAL = True

View File

@ -10,7 +10,7 @@ class LotusConditioning(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LotusConditioning",
category="conditioning/lotus",
category="model/conditioning/lotus",
inputs=[],
outputs=[io.Conditioning.Output(display_name="conditioning")],
)

View File

@ -25,7 +25,7 @@ class GetICLoRAParameters(io.ComfyNode):
display_name="Get IC-LoRA Parameters",
description="Extracts IC-LoRA parameters from the safetensors metadata of a LoRA-loaded "
"model and outputs them for LTXVAddGuide (eg. reference_downscale_factor).",
category="conditioning/video_models",
category="model/conditioning/video_models",
search_aliases=["ic-lora", "ic lora", "iclora", "downscale factor", "reference downscale"],
inputs=[
io.Model.Input(
@ -62,7 +62,7 @@ class EmptyLTXVLatentVideo(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="EmptyLTXVLatentVideo",
category="latent/video/ltxv",
category="model/latent/video/ltxv",
inputs=[
io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
@ -86,7 +86,7 @@ class LTXVImgToVideo(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LTXVImgToVideo",
category="conditioning/video_models",
category="model/conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
@ -131,7 +131,7 @@ class LTXVImgToVideoInplace(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LTXVImgToVideoInplace",
category="conditioning/video_models",
category="model/conditioning/video_models",
inputs=[
io.Vae.Input("vae"),
io.Image.Input("image"),
@ -226,10 +226,20 @@ def get_noise_mask(latent):
noise_mask = noise_mask.clone()
return noise_mask
def get_keyframe_idxs(cond):
def get_keyframe_idxs(cond, latent_shape=None):
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
if keyframe_idxs is None:
return None, 0
# Get number of keyframes from latent_shape or guide_attention_entries if available
if latent_shape is not None and len(latent_shape) == 5:
tokens_per_frame = latent_shape[-2] * latent_shape[-1]
num_keyframes = keyframe_idxs.shape[2] // tokens_per_frame
return keyframe_idxs, num_keyframes
entries = conditioning_get_any_value(cond, "guide_attention_entries", None)
if entries:
num_keyframes = sum(e["latent_shape"][0] for e in entries)
return keyframe_idxs, num_keyframes
# fallback, may under-count if keyframes share t-start
# keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start
num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0]
return keyframe_idxs, num_keyframes
@ -241,7 +251,7 @@ class LTXVAddGuide(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LTXVAddGuide",
category="conditioning/video_models",
category="model/conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
@ -322,9 +332,9 @@ class LTXVAddGuide(io.ComfyNode):
return factor
@classmethod
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors, latent_shape=None):
time_scale_factor, _, _ = scale_factors
_, num_keyframes = get_keyframe_idxs(cond)
_, num_keyframes = get_keyframe_idxs(cond, latent_shape)
latent_count = latent_length - num_keyframes
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
if guide_length > 1 and frame_idx != 0:
@ -436,7 +446,7 @@ class LTXVAddGuide(io.ComfyNode):
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
resolved_frame_idx = frame_idx
if frame_idx < 0:
_, num_keyframes = get_keyframe_idxs(positive)
_, num_keyframes = get_keyframe_idxs(positive, latent_image.shape)
resolved_frame_idx = max((latent_length - num_keyframes - 1) * time_scale_factor + 1 + frame_idx, 0)
causal_fix = resolved_frame_idx == 0 or num_frames_to_keep == 1
@ -454,7 +464,7 @@ class LTXVAddGuide(io.ComfyNode):
if latent_downscale_factor > 1:
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors, latent_shape=latent_image.shape)
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
positive, negative, latent_image, noise_mask = cls.append_keyframe(
@ -488,7 +498,7 @@ class LTXVCropGuides(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LTXVCropGuides",
category="conditioning/video_models",
category="model/conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
@ -506,7 +516,7 @@ class LTXVCropGuides(io.ComfyNode):
latent_image = latent["samples"].clone()
noise_mask = get_noise_mask(latent)
_, num_keyframes = get_keyframe_idxs(positive)
_, num_keyframes = get_keyframe_idxs(positive, latent_image.shape)
if num_keyframes == 0:
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
@ -532,7 +542,7 @@ class LTXVConditioning(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LTXVConditioning",
category="conditioning/video_models",
category="model/conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
@ -601,7 +611,7 @@ class LTXVScheduler(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LTXVScheduler",
category="sampling/schedulers",
category="model/sampling/schedulers",
inputs=[
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
@ -736,7 +746,7 @@ class LTXVConcatAVLatent(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LTXVConcatAVLatent",
category="latent/video/ltxv",
category="model/latent/video/ltxv",
inputs=[
io.Latent.Input("video_latent"),
io.Latent.Input("audio_latent"),
@ -771,7 +781,7 @@ class LTXVSeparateAVLatent(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="LTXVSeparateAVLatent",
category="latent/video/ltxv",
category="model/latent/video/ltxv",
description="LTXV Separate AV Latent",
inputs=[
io.Latent.Input("av_latent"),
@ -804,7 +814,7 @@ class LTXVReferenceAudio(io.ComfyNode):
return io.Schema(
node_id="LTXVReferenceAudio",
display_name="LTXV Reference Audio (ID-LoRA)",
category="conditioning/audio",
category="model/conditioning/audio",
description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).",
inputs=[
io.Model.Input("model"),

View File

@ -12,7 +12,7 @@ class LTXVAudioVAELoader(io.ComfyNode):
return io.Schema(
node_id="LTXVAudioVAELoader",
display_name="Load LTXV Audio VAE",
category="loaders",
category="model/loaders",
inputs=[
io.Combo.Input(
"ckpt_name",
@ -40,7 +40,7 @@ class LTXVAudioVAEEncode(VAEEncodeAudio):
return io.Schema(
node_id="LTXVAudioVAEEncode",
display_name="LTXV Audio VAE Encode",
category="latent/audio",
category="model/latent/audio",
inputs=[
io.Audio.Input("audio", tooltip="The audio to be encoded."),
io.Vae.Input(
@ -63,7 +63,7 @@ class LTXVAudioVAEDecode(io.ComfyNode):
return io.Schema(
node_id="LTXVAudioVAEDecode",
display_name="LTXV Audio VAE Decode",
category="latent/audio",
category="model/latent/audio",
inputs=[
io.Latent.Input("samples", tooltip="The latent to be decoded."),
io.Vae.Input(
@ -96,7 +96,7 @@ class LTXVEmptyLatentAudio(io.ComfyNode):
return io.Schema(
node_id="LTXVEmptyLatentAudio",
display_name="LTXV Empty Latent Audio",
category="latent/audio",
category="model/latent/audio",
inputs=[
io.Int.Input(
"frames_number",

View File

@ -1,32 +1,32 @@
from comfy import model_management
from comfy_api.latest import ComfyExtension, IO
from typing_extensions import override
import math
class LTXVLatentUpsampler:
class LTXVLatentUpsampler(IO.ComfyNode):
"""
Upsamples a video latent by a factor of 2.
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT",),
"upscale_model": ("LATENT_UPSCALE_MODEL",),
"vae": ("VAE",),
}
}
def define_schema(cls):
return IO.Schema(
node_id="LTXVLatentUpsampler",
category="model/latent/video",
is_experimental=True,
inputs=[
IO.Latent.Input("samples"),
IO.LatentUpscaleModel.Input("upscale_model"),
IO.Vae.Input("vae"),
],
outputs=[
IO.Latent.Output(),
],
)
RETURN_TYPES = ("LATENT",)
FUNCTION = "upsample_latent"
CATEGORY = "latent/video"
EXPERIMENTAL = True
def upsample_latent(
self,
samples: dict,
upscale_model,
vae,
) -> tuple:
@classmethod
def execute(cls, samples, upscale_model, vae) -> IO.NodeOutput:
"""
Upsample the input latent using the provided model.
@ -34,7 +34,6 @@ class LTXVLatentUpsampler:
samples (dict): Input latent samples
upscale_model (LatentUpsampler): Loaded upscale model
vae: VAE model for normalization
auto_tiling (bool): Whether to automatically tile the input for processing
Returns:
tuple: Tuple containing the upsampled latent
@ -67,9 +66,16 @@ class LTXVLatentUpsampler:
return_dict = samples.copy()
return_dict["samples"] = upsampled_latents
return_dict.pop("noise_mask", None)
return (return_dict,)
return IO.NodeOutput(return_dict)
upsample_latent = execute # TODO: remove
NODE_CLASS_MAPPINGS = {
"LTXVLatentUpsampler": LTXVLatentUpsampler,
}
class LTXVLatentUpsamplerExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [LTXVLatentUpsampler]
async def comfy_entrypoint() -> LTXVLatentUpsamplerExtension:
return LTXVLatentUpsamplerExtension()

View File

@ -81,7 +81,7 @@ class CLIPTextEncodeLumina2(io.ComfyNode):
node_id="CLIPTextEncodeLumina2",
search_aliases=["lumina prompt"],
display_name="CLIP Text Encode for Lumina2",
category="conditioning",
category="model/conditioning",
description="Encodes a system prompt and a user prompt using a CLIP model into an embedding "
"that can be used to guide the diffusion model towards generating specific images.",
inputs=[

View File

@ -53,7 +53,7 @@ class LatentCompositeMasked(IO.ComfyNode):
return IO.Schema(
node_id="LatentCompositeMasked",
search_aliases=["overlay latent", "layer latent", "paste latent", "inpaint latent"],
category="latent",
category="model/latent",
inputs=[
IO.Latent.Input("destination"),
IO.Latent.Input("source"),

View File

@ -69,7 +69,7 @@ class MathExpressionNode(io.ComfyNode):
return io.Schema(
node_id="ComfyMathExpression",
display_name="Math Expression",
category="utils",
category="utilities",
search_aliases=[
"expression", "formula", "calculate", "calculator",
"eval", "math",

Some files were not shown because too many files have changed in this diff Show More