mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-30 10:57:23 +08:00
Merge branch 'master' into seedvr2-native-support
This commit is contained in:
commit
7431bef672
24
.github/workflows/detect-unreviewed-merge.yml
vendored
Normal file
24
.github/workflows/detect-unreviewed-merge.yml
vendored
Normal file
@ -0,0 +1,24 @@
|
||||
name: Detect Unreviewed Merge
|
||||
|
||||
# SOC 2 compliance — reusable workflow lives in Comfy-Org/github-workflows,
|
||||
# tracking issues are filed in Comfy-Org/unreviewed-merges.
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master]
|
||||
|
||||
concurrency:
|
||||
group: detect-unreviewed-merge-${{ github.sha }}
|
||||
cancel-in-progress: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
detect:
|
||||
uses: Comfy-Org/github-workflows/.github/workflows/detect-unreviewed-merge.yml@4d9cb6b87f953bb7cd69954280e1465fb9bd2040 # v1
|
||||
with:
|
||||
approval-mode: latest-per-reviewer
|
||||
secrets:
|
||||
UNREVIEWED_MERGES_TOKEN: ${{ secrets.UNREVIEWED_MERGES_TOKEN }}
|
||||
@ -1,5 +1,20 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
_CK_STOCHASTIC_ROUNDING_AVAILABLE = False
|
||||
try:
|
||||
import comfy_kitchen as ck
|
||||
_ck_stochastic_rounding_fp8 = ck.stochastic_rounding_fp8
|
||||
_CK_STOCHASTIC_ROUNDING_AVAILABLE = True
|
||||
except (AttributeError, ImportError):
|
||||
logging.warning("comfy_kitchen does not support stochastic FP8 rounding, please update comfy_kitchen.")
|
||||
|
||||
if not _CK_STOCHASTIC_ROUNDING_AVAILABLE:
|
||||
def _ck_stochastic_rounding_fp8(value, rng, dtype):
|
||||
raise NotImplementedError("comfy_kitchen does not support stochastic FP8 rounding")
|
||||
|
||||
|
||||
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
||||
mantissa_scaled = torch.where(
|
||||
normal_mask,
|
||||
@ -57,6 +72,10 @@ def stochastic_rounding(value, dtype, seed=0):
|
||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||
generator = torch.Generator(device=value.device)
|
||||
generator.manual_seed(seed)
|
||||
if _CK_STOCHASTIC_ROUNDING_AVAILABLE:
|
||||
rng = torch.randint(0, 256, value.size(), dtype=torch.uint8, layout=value.layout, device=value.device, generator=generator)
|
||||
return _ck_stochastic_rounding_fp8(value, rng, dtype)
|
||||
|
||||
output = torch.empty_like(value, dtype=dtype)
|
||||
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||
|
||||
@ -804,13 +804,15 @@ class ZImagePixelSpace(ChromaRadiance):
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class HiDreamO1Pixel(ChromaRadiance):
|
||||
"""Pixel-space latent format for HiDream-O1.
|
||||
No VAE — model patches/unpatches raw RGB internally with patch_size=32.
|
||||
"""
|
||||
pass
|
||||
|
||||
class PixelDiTPixel(ChromaRadiance):
|
||||
pass
|
||||
|
||||
class CogVideoX(LatentFormat):
|
||||
"""Latent format for CogVideoX-2b (THUDM/CogVideoX-2b).
|
||||
|
||||
|
||||
@ -433,11 +433,11 @@ class Attention(nn.Module):
|
||||
if self.differential:
|
||||
q, q_diff = q.unbind(dim=1)
|
||||
k, k_diff = k.unbind(dim=1)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
|
||||
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, transformer_options=transformer_options)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
|
||||
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
|
||||
out = out - out_diff
|
||||
else:
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
|
||||
|
||||
out = self.to_out(out)
|
||||
|
||||
|
||||
@ -138,11 +138,11 @@ class Attention(nn.Module):
|
||||
k_diff = _apply_rotary_pos_emb(k_diff.float(), freqs).to(k_dtype)
|
||||
|
||||
if self.differential:
|
||||
out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
|
||||
- optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True))
|
||||
out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)
|
||||
- optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True, low_precision_attention=False))
|
||||
del q, k, v, q_diff, k_diff
|
||||
else:
|
||||
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
|
||||
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)
|
||||
del q, k, v
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
510
comfy/ldm/lens/model.py
Normal file
510
comfy/ldm/lens/model.py
Normal file
@ -0,0 +1,510 @@
|
||||
"""Lens denoising transformer (DiT)"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ldm.flux.layers
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
def _lens_time_proj(t: torch.Tensor, dim: int = 256) -> torch.Tensor:
|
||||
return comfy.ldm.flux.layers.timestep_embedding(t, dim)
|
||||
|
||||
|
||||
def _lens_position_ids(
|
||||
frame: int, height: int, width: int, text_seq_len: int,
|
||||
scale_rope: bool = True, device=None,
|
||||
) -> torch.Tensor:
|
||||
"""Lens axial (frame, h, w) position ids for joint image + text sequence.
|
||||
|
||||
With ``scale_rope=True`` h/w are centered around 0 (negative + positive
|
||||
halves) and text starts at ``max(h//2, w//2)``. Result shape ``[seq, 3]``;
|
||||
caller adds a batch dim for ``EmbedND``.
|
||||
"""
|
||||
if scale_rope:
|
||||
h_pos = torch.cat([torch.arange(-(height - height // 2), 0, device=device),
|
||||
torch.arange(0, height // 2, device=device)])
|
||||
w_pos = torch.cat([torch.arange(-(width - width // 2), 0, device=device),
|
||||
torch.arange(0, width // 2, device=device)])
|
||||
text_start = max(height // 2, width // 2)
|
||||
else:
|
||||
h_pos = torch.arange(height, device=device)
|
||||
w_pos = torch.arange(width, device=device)
|
||||
text_start = max(height, width)
|
||||
|
||||
f_pos = torch.arange(frame, device=device)
|
||||
img_ids = torch.zeros(frame, height, width, 3, device=device)
|
||||
img_ids[..., 0] = f_pos[:, None, None]
|
||||
img_ids[..., 1] = h_pos[None, :, None]
|
||||
img_ids[..., 2] = w_pos[None, None, :]
|
||||
img_ids = img_ids.reshape(-1, 3)
|
||||
|
||||
# Text positions replicate across all 3 axes (matches original packing).
|
||||
txt_pos = torch.arange(text_start, text_start + text_seq_len, device=device).float()
|
||||
txt_ids = txt_pos[:, None].expand(text_seq_len, 3)
|
||||
|
||||
return torch.cat([img_ids, txt_ids], dim=0)
|
||||
|
||||
|
||||
class _TimestepEmbedder(nn.Module):
|
||||
def __init__(self, in_channels: int, time_embed_dim: int, dtype=None, device=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.linear_1 = operations.Linear(in_channels, time_embed_dim, dtype=dtype, device=device)
|
||||
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear_1(x)
|
||||
x = F.silu(x)
|
||||
return self.linear_2(x)
|
||||
|
||||
|
||||
class LensTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim: int, dtype=None, device=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.timestep_embedder = _TimestepEmbedder(256, embedding_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, timestep: torch.Tensor, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
proj = _lens_time_proj(timestep, 256)
|
||||
return self.timestep_embedder(proj.to(dtype=hidden_states.dtype))
|
||||
|
||||
|
||||
class GateMLP(nn.Module):
|
||||
"""SwiGLU MLP."""
|
||||
|
||||
def __init__(self, dim: int, hidden_dim: int, dtype=None, device=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
||||
self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
|
||||
self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(F.silu(self.w1(x), inplace=True).mul_(self.w3(x)))
|
||||
|
||||
|
||||
class LensJointAttention(nn.Module):
|
||||
"""Joint image+text attention with fused QKV per stream."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
added_kv_proj_dim: int,
|
||||
dim_head: int = 64,
|
||||
heads: int = 8,
|
||||
out_dim: Optional[int] = None,
|
||||
eps: float = 1e-5,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
||||
self.heads = self.inner_dim // dim_head
|
||||
self.dim_head = dim_head
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
|
||||
self.norm_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
self.norm_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
self.norm_added_q = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
self.norm_added_k = operations.RMSNorm(dim_head, eps=eps, dtype=dtype, device=device)
|
||||
|
||||
self.img_qkv = operations.Linear(query_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
self.txt_qkv = operations.Linear(added_kv_proj_dim, 3 * self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
# ModuleList([Linear, Identity]) for state-dict key compatibility.
|
||||
self.to_out = nn.ModuleList([
|
||||
operations.Linear(self.inner_dim, self.out_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.Identity(),
|
||||
])
|
||||
self.to_add_out = operations.Linear(self.inner_dim, query_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
bsz, seq_img, _ = hidden_states.shape
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
# image stream
|
||||
img_qkv = self.img_qkv(hidden_states).view(bsz, seq_img, 3, self.heads, self.dim_head)
|
||||
img_q, img_k, img_v = img_qkv.unbind(dim=2)
|
||||
img_q = self.norm_q(img_q)
|
||||
img_k = self.norm_k(img_k)
|
||||
del img_qkv
|
||||
|
||||
# text stream
|
||||
txt_qkv = self.txt_qkv(encoder_hidden_states).view(bsz, seq_txt, 3, self.heads, self.dim_head)
|
||||
txt_q, txt_k, txt_v = txt_qkv.unbind(dim=2)
|
||||
txt_q = self.norm_added_q(txt_q)
|
||||
txt_k = self.norm_added_k(txt_k)
|
||||
|
||||
# [B, S, H, D] → [B, H, S, D] for attention, dels to avoid VRAM peaks
|
||||
q = torch.cat([img_q, txt_q], dim=1).transpose(1, 2)
|
||||
del img_q, txt_q
|
||||
k = torch.cat([img_k, txt_k], dim=1).transpose(1, 2)
|
||||
del img_k, txt_k
|
||||
v = torch.cat([img_v, txt_v], dim=1).transpose(1, 2)
|
||||
del img_v, txt_v
|
||||
|
||||
q, k = apply_rope(q, k, freqs_cis)
|
||||
|
||||
if attention_mask is not None:
|
||||
expected = (bsz, 1, 1, seq_img + seq_txt)
|
||||
if attention_mask.shape != expected:
|
||||
raise ValueError(
|
||||
f"attention_mask must be {expected}, got {tuple(attention_mask.shape)}"
|
||||
)
|
||||
attention_mask = attention_mask.to(q.dtype)
|
||||
|
||||
out = optimized_attention(
|
||||
q, k, v, self.heads, mask=attention_mask, skip_reshape=True,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
img_out = self.to_out[1](self.to_out[0](out[:, :seq_img, :]))
|
||||
txt_out = self.to_add_out(out[:, seq_img:, :])
|
||||
return img_out, txt_out
|
||||
|
||||
|
||||
class LensTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
eps: float = 1e-6,
|
||||
rms_norm: bool = True,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.attn = LensJointAttention(
|
||||
query_dim=dim,
|
||||
added_kv_proj_dim=dim,
|
||||
dim_head=attention_head_dim,
|
||||
heads=num_attention_heads,
|
||||
out_dim=dim,
|
||||
eps=1e-5,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
if rms_norm:
|
||||
NormCls = operations.RMSNorm
|
||||
norm_kwargs = {}
|
||||
else:
|
||||
NormCls = operations.LayerNorm
|
||||
norm_kwargs = {"elementwise_affine": False}
|
||||
|
||||
mlp_hidden = int(dim / 3 * 8)
|
||||
|
||||
# Sequential(SiLU, Linear) so state-dict lands at img_mod.1.{weight,bias}.
|
||||
self.img_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.img_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
||||
self.img_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
||||
self.img_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_mod = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.txt_norm1 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
||||
self.txt_norm2 = NormCls(dim, eps=eps, dtype=dtype, device=device, **norm_kwargs)
|
||||
self.txt_mlp = GateMLP(dim, mlp_hidden, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
@staticmethod
|
||||
def _modulate(x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
img_mod1, img_mod2 = self.img_mod(temb).chunk(2, dim=-1)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(temb).chunk(2, dim=-1)
|
||||
|
||||
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
|
||||
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
|
||||
|
||||
img_attn, txt_attn = self.attn(
|
||||
hidden_states=img_modulated,
|
||||
encoder_hidden_states=txt_modulated,
|
||||
freqs_cis=freqs_cis,
|
||||
attention_mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + img_gate1 * img_attn
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn
|
||||
|
||||
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
|
||||
hidden_states = hidden_states + img_gate2 * self.img_mlp(img_modulated2)
|
||||
|
||||
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
|
||||
encoder_hidden_states = encoder_hidden_states + txt_gate2 * self.txt_mlp(txt_modulated2)
|
||||
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class _AdaLayerNormContinuousNoAffine(nn.Module):
|
||||
"""AdaLayerNormContinuous(elementwise_affine=False).
|
||||
|
||||
The reference uses ``scale, shift = chunk(2)`` (scale first) — opposite
|
||||
to Flux's ``LastLayer``.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim: int, conditioning_embedding_dim: int, eps: float = 1e-6,
|
||||
dtype=None, device=None, operations=None) -> None:
|
||||
super().__init__()
|
||||
self.linear = operations.Linear(
|
||||
conditioning_embedding_dim, embedding_dim * 2, bias=True, dtype=dtype, device=device
|
||||
)
|
||||
self.eps = eps
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
|
||||
emb = self.linear(F.silu(conditioning))
|
||||
scale, shift = torch.chunk(emb, 2, dim=-1)
|
||||
x = F.layer_norm(x, (self.embedding_dim,), None, None, self.eps)
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
class LensTransformer2DModel(nn.Module):
|
||||
"""Lens dual-stream MMDiT (48 blocks, inner_dim=1536, multi-layer text)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 128,
|
||||
out_channels: Optional[int] = 32,
|
||||
num_layers: int = 48,
|
||||
attention_head_dim: int = 64,
|
||||
num_attention_heads: int = 24,
|
||||
enc_hidden_dim: int = 2880,
|
||||
axes_dims_rope: Tuple[int, int, int] = (8, 28, 28),
|
||||
rms_norm: bool = True,
|
||||
multi_layer_encoder_feature: bool = True,
|
||||
selected_layer_index: Tuple[int, ...] = (5, 11, 17, 23),
|
||||
image_model=None, # unused; accepted for detection-side configs.
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels if out_channels is not None else in_channels
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.multi_layer_encoder_feature = multi_layer_encoder_feature
|
||||
self.selected_layer_index = list(selected_layer_index)
|
||||
self.dtype = dtype
|
||||
|
||||
self.pos_embed = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
|
||||
self.time_text_embed = LensTimestepProjEmbeddings(
|
||||
embedding_dim=self.inner_dim, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
if self.multi_layer_encoder_feature:
|
||||
self.txt_norm = nn.ModuleList(
|
||||
[operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
|
||||
for _ in self.selected_layer_index]
|
||||
)
|
||||
self.txt_in = operations.Linear(
|
||||
enc_hidden_dim * len(self.selected_layer_index),
|
||||
self.inner_dim, bias=True, dtype=dtype, device=device,
|
||||
)
|
||||
else:
|
||||
self.txt_norm = operations.RMSNorm(enc_hidden_dim, eps=1e-5, dtype=dtype, device=device)
|
||||
self.txt_in = operations.Linear(enc_hidden_dim, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.img_in = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
LensTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
eps=1e-6,
|
||||
rms_norm=rms_norm,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.norm_out = _AdaLayerNormContinuousNoAffine(
|
||||
self.inner_dim, self.inner_dim, eps=1e-6,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
self.proj_out = operations.Linear(
|
||||
self.inner_dim, patch_size * patch_size * self.out_channels, bias=True,
|
||||
dtype=dtype, device=device,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, timestep: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[Dict[str, Any]] = None, **kwargs) -> torch.Tensor:
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward, self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
|
||||
).execute(x, timestep, context, attention_mask, transformer_options, **kwargs)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
transformer_options: Optional[Dict[str, Any]] = None,
|
||||
control: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""ComfyUI bridge: ``(x[B,128,h,w], t[B], context[B,S,L*H], mask[B,S])``."""
|
||||
if transformer_options is None:
|
||||
transformer_options = {}
|
||||
transformer_options = transformer_options.copy()
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
B, C, h, w = x.shape
|
||||
hidden_states = x.permute(0, 2, 3, 1).reshape(B, h * w, C)
|
||||
|
||||
if self.multi_layer_encoder_feature:
|
||||
L = len(self.selected_layer_index)
|
||||
enc_dim = context.shape[-1] // L
|
||||
encoder_hidden_states = list(
|
||||
context.reshape(B, -1, L, enc_dim).unbind(dim=2)
|
||||
)
|
||||
text_seq_len = encoder_hidden_states[0].shape[1]
|
||||
else:
|
||||
encoder_hidden_states = context
|
||||
text_seq_len = context.shape[1]
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(B, text_seq_len), dtype=torch.bool, device=x.device
|
||||
)
|
||||
|
||||
img_len = h * w
|
||||
joint_mask = self._build_joint_attention_mask(attention_mask, img_len)
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
|
||||
if self.multi_layer_encoder_feature:
|
||||
normed = [self.txt_norm[i](encoder_hidden_states[i]) for i in range(L)]
|
||||
encoder_hidden_states = torch.cat(normed, dim=-1)
|
||||
else:
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({
|
||||
"img": hidden_states,
|
||||
"txt": encoder_hidden_states,
|
||||
"transformer_options": transformer_options,
|
||||
})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
|
||||
temb = self.time_text_embed(timestep, hidden_states)
|
||||
ids = _lens_position_ids(1, h, w, text_seq_len, device=hidden_states.device).unsqueeze(0)
|
||||
freqs_cis = self.pos_embed(ids)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["txt"], out["img"] = block(
|
||||
hidden_states=args["img"],
|
||||
encoder_hidden_states=args["txt"],
|
||||
temb=args["vec"],
|
||||
freqs_cis=args["pe"],
|
||||
attention_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"),
|
||||
)
|
||||
return out
|
||||
out = blocks_replace[("double_block", i)](
|
||||
{
|
||||
"img": hidden_states,
|
||||
"txt": encoder_hidden_states,
|
||||
"vec": temb,
|
||||
"pe": freqs_cis,
|
||||
"attn_mask": joint_mask,
|
||||
"transformer_options": transformer_options,
|
||||
},
|
||||
{"original_block": block_wrap},
|
||||
)
|
||||
encoder_hidden_states = out["txt"]
|
||||
hidden_states = out["img"]
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
freqs_cis=freqs_cis,
|
||||
attention_mask=joint_mask,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({
|
||||
"img": hidden_states,
|
||||
"txt": encoder_hidden_states,
|
||||
"x": x,
|
||||
"block_index": i,
|
||||
"transformer_options": transformer_options,
|
||||
})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
|
||||
if control is not None:
|
||||
control_i = control.get("input")
|
||||
if control_i is not None and i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
hidden_states[:, :add.shape[1]] += add
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
out = self.proj_out(hidden_states)
|
||||
return out.reshape(B, h, w, C).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
@staticmethod
|
||||
def _build_joint_attention_mask(text_mask: torch.Tensor, img_len: int) -> torch.Tensor:
|
||||
if text_mask.dtype != torch.bool:
|
||||
text_mask = text_mask.bool()
|
||||
bsz = text_mask.shape[0]
|
||||
img_ones = torch.ones((bsz, img_len), dtype=torch.bool, device=text_mask.device)
|
||||
joint = torch.cat([img_ones, text_mask], dim=1)
|
||||
additive = torch.zeros_like(joint, dtype=torch.float32)
|
||||
additive.masked_fill_(~joint, torch.finfo(torch.float32).min)
|
||||
return additive[:, None, None, :]
|
||||
@ -211,7 +211,7 @@ class TimestepEmbedder(nn.Module):
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None, max_period=10000):
|
||||
super().__init__()
|
||||
if output_size is None:
|
||||
output_size = hidden_size
|
||||
@ -221,9 +221,10 @@ class TimestepEmbedder(nn.Module):
|
||||
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.max_period = max_period
|
||||
|
||||
def forward(self, t, dtype, **kwargs):
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size, max_period=self.max_period).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
239
comfy/ldm/pixeldit/model.py
Normal file
239
comfy/ldm/pixeldit/model.py
Normal file
@ -0,0 +1,239 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.flux.math import apply_rope, rope
|
||||
from comfy.ldm.hidream.model import FeedForwardSwiGLU
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||
|
||||
from .modules import (
|
||||
FinalLayer,
|
||||
PatchTokenEmbedder,
|
||||
PiTBlock,
|
||||
PixelTokenEmbedder,
|
||||
apply_adaln_,
|
||||
precompute_freqs_cis_2d,
|
||||
)
|
||||
|
||||
|
||||
class MMDiTJointAttention(nn.Module):
|
||||
"""Joint MMDiT attention with separate Q/K/V/proj for image and text streams.
|
||||
|
||||
RoPE is applied to each stream before concatenation so each stream uses its own
|
||||
2D/1D positional encoding. Concat order is [text, image] (text first).
|
||||
"""
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
|
||||
self.qkv_x = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.qkv_y = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
|
||||
self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.proj_x = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
self.proj_y = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, y, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
|
||||
B, Nx, _ = x.shape
|
||||
_, Ny, _ = y.shape
|
||||
H = self.num_heads
|
||||
D = self.head_dim
|
||||
|
||||
qkv_x = self.qkv_x(x).reshape(B, Nx, 3, H, D).permute(2, 0, 3, 1, 4)
|
||||
qx, kx, vx = qkv_x.unbind(0)
|
||||
qx = self.q_norm_x(qx)
|
||||
kx = self.k_norm_x(kx)
|
||||
|
||||
qkv_y = self.qkv_y(y).reshape(B, Ny, 3, H, D).permute(2, 0, 3, 1, 4)
|
||||
qy, ky, vy = qkv_y.unbind(0)
|
||||
qy = self.q_norm_y(qy)
|
||||
ky = self.k_norm_y(ky)
|
||||
|
||||
qx, kx = apply_rope(qx, kx, pos_img[None, None])
|
||||
if pos_txt is not None:
|
||||
qy, ky = apply_rope(qy, ky, pos_txt[None, None])
|
||||
|
||||
q_joint = torch.cat([qy, qx], dim=2)
|
||||
k_joint = torch.cat([ky, kx], dim=2)
|
||||
v_joint = torch.cat([vy, vx], dim=2)
|
||||
|
||||
out_joint = optimized_attention(
|
||||
q_joint, k_joint, v_joint, H,
|
||||
mask=attn_mask, skip_reshape=True, skip_output_reshape=True,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
out_y = out_joint[:, :, :Ny, :].transpose(1, 2).reshape(B, Ny, H * D)
|
||||
out_x = out_joint[:, :, Ny:, :].transpose(1, 2).reshape(B, Nx, H * D)
|
||||
|
||||
return self.proj_x(out_x), self.proj_y(out_y)
|
||||
|
||||
|
||||
class MMDiTBlockT2I(nn.Module):
|
||||
def __init__(self, hidden_size, groups, mlp_ratio=4.0, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_x1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm_y1 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = MMDiTJointAttention(hidden_size, num_heads=groups, qkv_bias=False, dtype=dtype, device=device, operations=operations)
|
||||
self.norm_x2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm_y2 = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.mlp_x = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp_y = FeedForwardSwiGLU(hidden_size, mlp_hidden_dim, multiple_of=1, dtype=dtype, device=device, operations=operations)
|
||||
self.adaLN_modulation_img = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
self.adaLN_modulation_txt = nn.Sequential(operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, y, c, pos_img, pos_txt=None, attn_mask=None, transformer_options={}):
|
||||
shift_msa_x, scale_msa_x, gate_msa_x, shift_mlp_x, scale_mlp_x, gate_mlp_x = self.adaLN_modulation_img(c).chunk(6, dim=-1)
|
||||
shift_msa_y, scale_msa_y, gate_msa_y, shift_mlp_y, scale_mlp_y, gate_mlp_y = self.adaLN_modulation_txt(c).chunk(6, dim=-1)
|
||||
|
||||
x_norm = apply_adaln_(self.norm_x1(x), shift_msa_x, scale_msa_x)
|
||||
y_norm = apply_adaln_(self.norm_y1(y), shift_msa_y, scale_msa_y)
|
||||
attn_x, attn_y = self.attn(x_norm, y_norm, pos_img, pos_txt, attn_mask, transformer_options=transformer_options)
|
||||
x = torch.addcmul(x, gate_msa_x, attn_x)
|
||||
y = torch.addcmul(y, gate_msa_y, attn_y)
|
||||
|
||||
x = torch.addcmul(x, gate_mlp_x, self.mlp_x(apply_adaln_(self.norm_x2(x), shift_mlp_x, scale_mlp_x)))
|
||||
y = torch.addcmul(y, gate_mlp_y, self.mlp_y(apply_adaln_(self.norm_y2(y), shift_mlp_y, scale_mlp_y)))
|
||||
return x, y
|
||||
|
||||
|
||||
class PixDiT_T2I(nn.Module):
|
||||
"""PixelDiT T2I model. Hardcoded for the released 1024px Stage-3 checkpoint
|
||||
(also runs at 512px when fed the appropriate latent size and flow_shift).
|
||||
|
||||
Forward:
|
||||
x: [B, 3, H, W] pixel-space input (no VAE)
|
||||
timesteps:[B] in [0, 1000] (ComfyUI flow sampling convention)
|
||||
context: [B, Ltxt, 2304] Gemma-2-2b-it hidden states (chi_prompt prepended)
|
||||
Returns flow-matching velocity [B, 3, H, W].
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
num_groups=24,
|
||||
hidden_size=1536,
|
||||
pixel_hidden_size=16,
|
||||
pixel_attn_hidden_size=1152,
|
||||
pixel_num_groups=16,
|
||||
patch_depth=14,
|
||||
pixel_depth=2,
|
||||
patch_size=16,
|
||||
txt_embed_dim=2304,
|
||||
txt_max_length=300,
|
||||
use_text_rope=True,
|
||||
text_rope_theta=10000.0,
|
||||
image_model=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
pixel_mlp_chunks=2,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_groups = num_groups
|
||||
self.patch_depth = patch_depth
|
||||
self.pixel_depth = pixel_depth
|
||||
self.patch_size = patch_size
|
||||
self.pixel_hidden_size = pixel_hidden_size
|
||||
self.pixel_attn_hidden_size = pixel_attn_hidden_size
|
||||
self.pixel_num_groups = pixel_num_groups
|
||||
self.txt_embed_dim = txt_embed_dim
|
||||
self.txt_max_length = txt_max_length
|
||||
self.use_text_rope = use_text_rope
|
||||
self.text_rope_theta = text_rope_theta
|
||||
|
||||
self.pixel_embedder = PixelTokenEmbedder(self.in_channels, self.pixel_hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.s_embedder = PatchTokenEmbedder(self.in_channels * self.patch_size ** 2, self.hidden_size, bias=True, dtype=dtype, device=device, operations=operations)
|
||||
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations, max_period=10)
|
||||
self.y_embedder = PatchTokenEmbedder(self.txt_embed_dim, self.hidden_size, bias=True, use_norm=True, dtype=dtype, device=device, operations=operations)
|
||||
self.y_pos_embedding = nn.Parameter(torch.empty(1, self.txt_max_length, self.hidden_size, dtype=dtype, device=device))
|
||||
|
||||
self.patch_blocks = nn.ModuleList([
|
||||
MMDiTBlockT2I(self.hidden_size, self.num_groups,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(self.patch_depth)
|
||||
])
|
||||
self.pixel_blocks = nn.ModuleList([
|
||||
PiTBlock(
|
||||
self.pixel_hidden_size,
|
||||
self.hidden_size,
|
||||
patch_size=self.patch_size,
|
||||
num_heads=self.num_groups,
|
||||
attn_hidden_size=self.pixel_attn_hidden_size,
|
||||
attn_num_heads=self.pixel_num_groups,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
mlp_chunks=pixel_mlp_chunks,
|
||||
)
|
||||
for _ in range(self.pixel_depth)
|
||||
])
|
||||
|
||||
self.final_layer = FinalLayer(self.pixel_hidden_size, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
|
||||
return precompute_freqs_cis_2d(self.hidden_size // self.num_groups, height, width, device=device, dtype=dtype, **rope_opts)
|
||||
|
||||
def _fetch_text_pos(self, length, device, dtype):
|
||||
return rope(torch.arange(length, dtype=torch.float32, device=device).reshape(1, -1), self.hidden_size // self.num_groups, self.text_rope_theta).squeeze(0).to(dtype=dtype)
|
||||
|
||||
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
|
||||
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
|
||||
|
||||
def _pre_patch_block(self, s, i, **kwargs):
|
||||
"""Hook for subclasses to inject per-block state into the patch stream (e.g. PiD's LQ gate)."""
|
||||
return s
|
||||
|
||||
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||
H_orig, W_orig = x.shape[2], x.shape[3]
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||
B, _, H, W = x.shape
|
||||
Hs = H // self.patch_size
|
||||
Ws = W // self.patch_size
|
||||
L = Hs * Ws
|
||||
|
||||
pos_img = self._fetch_patch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {}))
|
||||
x_patches = F.unfold(x, kernel_size=self.patch_size, stride=self.patch_size).transpose(1, 2)
|
||||
|
||||
t_emb = self.t_embedder(timesteps.view(-1), x.dtype).view(B, -1, self.hidden_size)
|
||||
|
||||
if context is None or context.dim() != 3:
|
||||
raise ValueError("PixDiT_T2I requires context (text embeddings) of shape [B, L, D]")
|
||||
Ltxt = min(context.shape[1], self.txt_max_length)
|
||||
y = context[:, :Ltxt, :]
|
||||
y_emb = self.y_embedder(y).view(B, Ltxt, self.hidden_size)
|
||||
y_emb = y_emb + self.y_pos_embedding[:, :Ltxt, :].to(y_emb) # y_pos_embedding is a raw nn.Parameter
|
||||
|
||||
condition = F.silu(t_emb)
|
||||
pos_txt = self._fetch_text_pos(Ltxt, x.device, x.dtype) if self.use_text_rope else None
|
||||
|
||||
s = self.s_embedder(x_patches)
|
||||
for i, blk in enumerate(self.patch_blocks):
|
||||
s = self._pre_patch_block(s, i, **kwargs)
|
||||
s, y_emb = blk(s, y_emb, condition, pos_img, pos_txt, None, transformer_options=transformer_options)
|
||||
s = F.silu(t_emb + s)
|
||||
|
||||
s_cond = s.view(B * L, self.hidden_size)
|
||||
x_pixels = self.pixel_embedder(x, patch_size=self.patch_size)
|
||||
for blk in self.pixel_blocks:
|
||||
x_pixels = blk(x_pixels, s_cond, H, W, self.patch_size, mask=None, transformer_options=transformer_options)
|
||||
|
||||
x_pixels = self.final_layer(x_pixels)
|
||||
C_out = self.out_channels
|
||||
P2 = self.patch_size * self.patch_size
|
||||
x_pixels = x_pixels.view(B, L, P2, C_out).permute(0, 3, 2, 1).reshape(B, C_out * P2, L)
|
||||
out = F.fold(x_pixels, (H, W), kernel_size=self.patch_size, stride=self.patch_size)
|
||||
return out[:, :, :H_orig, :W_orig]
|
||||
187
comfy/ldm/pixeldit/modules.py
Normal file
187
comfy/ldm/pixeldit/modules.py
Normal file
@ -0,0 +1,187 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.ldm.flux.math import apply_rope, rope
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, get_1d_sincos_pos_embed_from_grid_torch
|
||||
|
||||
|
||||
def apply_adaln_(x, shift, scale):
|
||||
return x.addcmul_(x, scale).add_(shift)
|
||||
|
||||
|
||||
def precompute_freqs_cis_2d(dim, height, width, theta=10000.0, scale=16.0,
|
||||
ref_grid_h=None, ref_grid_w=None,
|
||||
scale_x=1.0, scale_y=1.0, shift_x=0.0, shift_y=0.0,
|
||||
device=None, dtype=torch.float32, **kwargs):
|
||||
"""2D RoPE with x/y axis frequencies interleaved at stride 2 across head dim.
|
||||
|
||||
rope_options:
|
||||
scale_x / scale_y multiply the position range (RoPE extrapolation).
|
||||
shift_x / shift_y offset the position origin (tiled / regional inference).
|
||||
With ref_grid_h/w set, also applies NTK-aware per-axis theta scaling
|
||||
(rope_mode='ntk_aware'): theta_axis = theta * (current/ref)^(dim_axis/(dim_axis-2)).
|
||||
Returns Flux-format rotation matrices of shape [H*W, dim/2, 2, 2].
|
||||
Layout of head-dim pairs: [x_0, y_0, x_1, y_1, ..., x_{dim/4-1}, y_{dim/4-1}].
|
||||
"""
|
||||
dim_axis = dim // 2
|
||||
if ref_grid_h is not None and dim_axis > 2:
|
||||
h_ntk = (height / ref_grid_h) ** (dim_axis / (dim_axis - 2))
|
||||
w_ntk = (width / ref_grid_w) ** (dim_axis / (dim_axis - 2))
|
||||
else:
|
||||
h_ntk = w_ntk = 1.0
|
||||
|
||||
x_lin = torch.linspace(shift_x, scale * scale_x + shift_x, width, device=device)
|
||||
y_lin = torch.linspace(shift_y, scale * scale_y + shift_y, height, device=device)
|
||||
y_grid, x_grid = torch.meshgrid(y_lin, x_lin, indexing="ij")
|
||||
x_rope = rope(x_grid.reshape(1, -1), dim_axis, theta * w_ntk).squeeze(0)
|
||||
y_rope = rope(y_grid.reshape(1, -1), dim_axis, theta * h_ntk).squeeze(0)
|
||||
out = torch.stack([x_rope, y_rope], dim=2).reshape(height * width, dim // 2, 2, 2)
|
||||
return out.to(dtype=dtype)
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, height, width, device=None, dtype=torch.float32):
|
||||
"""Standard 2D sin/cos absolute positional embedding (ViT-style).
|
||||
|
||||
first half encodes W-coordinates, second half H.
|
||||
"""
|
||||
assert embed_dim % 4 == 0
|
||||
grid_h = torch.arange(height, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(width, dtype=torch.float32, device=device)
|
||||
grid_y, grid_x = torch.meshgrid(grid_h, grid_w, indexing="ij")
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_x.reshape(-1), device=device)
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_y.reshape(-1), device=device)
|
||||
return torch.cat([emb_w, emb_h], dim=1).to(dtype=dtype)
|
||||
|
||||
|
||||
class RotaryAttention(nn.Module):
|
||||
"""Single-stream self-attention with rotary positional encoding (used inside PiTBlock)."""
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.q_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.k_norm = operations.RMSNorm(self.head_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, pos, mask=None, transformer_options={}):
|
||||
B, N, C = x.shape
|
||||
H = self.num_heads
|
||||
D = self.head_dim
|
||||
qkv = self.qkv(x).reshape(B, N, 3, H, D).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
q, k = apply_rope(self.q_norm(q), self.k_norm(k), pos[None, None])
|
||||
x = optimized_attention(q, k, v, H, mask=mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
return self.proj(x)
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm = operations.RMSNorm(hidden_size, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(self.norm(x))
|
||||
|
||||
|
||||
class PatchTokenEmbedder(nn.Module):
|
||||
"""Linear projection used both for patchified-image tokens and text-feature tokens."""
|
||||
def __init__(self, in_chans, embed_dim, use_norm=False, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = operations.Linear(in_chans, embed_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.norm = operations.RMSNorm(embed_dim, eps=1e-6, dtype=dtype, device=device) if use_norm else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
return self.norm(self.proj(x))
|
||||
|
||||
|
||||
class PixelTokenEmbedder(nn.Module):
|
||||
"""Pixel-level embedder: lifts each RGB pixel to hidden_size and packs into per-patch sequences."""
|
||||
def __init__(self, in_channels, hidden_size_output, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_size_output = hidden_size_output
|
||||
self.proj = operations.Linear(self.in_channels, self.hidden_size_output, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, inputs, patch_size):
|
||||
B, _, H, W = inputs.shape
|
||||
Hs, Ws = H // patch_size, W // patch_size
|
||||
P2 = patch_size * patch_size
|
||||
x = inputs.permute(0, 2, 3, 1).contiguous()
|
||||
x = self.proj(x)
|
||||
pos_full = get_2d_sincos_pos_embed(self.hidden_size_output, H, W, device=x.device, dtype=x.dtype).view(H, W, self.hidden_size_output)
|
||||
x = x + pos_full.unsqueeze(0)
|
||||
x = x.view(B, Hs, patch_size, Ws, patch_size, self.hidden_size_output)
|
||||
return x.permute(0, 1, 3, 2, 4, 5).reshape(B * Hs * Ws, P2, self.hidden_size_output)
|
||||
|
||||
|
||||
class PiTBlock(nn.Module):
|
||||
"""Pixel-level transformer block.
|
||||
|
||||
Compresses each patch's P^2 pixel tokens → 1 attention token via a linear,
|
||||
runs global self-attention across patches with 2D RoPE, then expands back to P^2 tokens.
|
||||
Conditioning is per-pixel adaLN from the patch-level features.
|
||||
"""
|
||||
def __init__(self, pixel_hidden_size, patch_hidden_size, patch_size, num_heads, mlp_ratio=4.0,
|
||||
attn_hidden_size=None, attn_num_heads=None, dtype=None, device=None, operations=None, mlp_chunks=1):
|
||||
super().__init__()
|
||||
self.pixel_dim = pixel_hidden_size
|
||||
self.context_dim = patch_hidden_size
|
||||
self.attn_dim = attn_hidden_size if attn_hidden_size is not None else patch_hidden_size
|
||||
self.num_heads = attn_num_heads if attn_num_heads is not None else num_heads
|
||||
assert self.attn_dim % self.num_heads == 0
|
||||
|
||||
p2 = patch_size * patch_size
|
||||
self.compress_to_attn = operations.Linear(p2 * self.pixel_dim, self.attn_dim, bias=True, dtype=dtype, device=device)
|
||||
self.expand_from_attn = operations.Linear(self.attn_dim, p2 * self.pixel_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.norm1 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = RotaryAttention(self.attn_dim, num_heads=self.num_heads, qkv_bias=False, dtype=dtype, device=device, operations=operations)
|
||||
self.norm2 = operations.RMSNorm(self.pixel_dim, eps=1e-6, dtype=dtype, device=device)
|
||||
self.mlp = Mlp(self.pixel_dim, hidden_features=int(self.pixel_dim * mlp_ratio), dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.adaLN_modulation_msa = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation_mlp = operations.Linear(self.context_dim, 3 * self.pixel_dim * p2, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self._rope_fn = precompute_freqs_cis_2d
|
||||
self.mlp_chunks = max(1, int(mlp_chunks))
|
||||
|
||||
def _fetch_pos(self, height, width, device, dtype, **rope_opts):
|
||||
return self._rope_fn(self.attn_dim // self.num_heads, height, width, device=device, dtype=dtype, **rope_opts)
|
||||
|
||||
def forward(self, x, s_cond, image_height, image_width, patch_size, mask=None, transformer_options={}):
|
||||
BL, P2, _ = x.shape
|
||||
Hs, Ws = image_height // patch_size, image_width // patch_size
|
||||
L = Hs * Ws
|
||||
B = BL // L
|
||||
|
||||
# Attention path uses only msa params; compute, use, free before mlp params allocate.
|
||||
msa_params = self.adaLN_modulation_msa(s_cond).view(BL, P2, 3 * self.pixel_dim)
|
||||
shift_msa, scale_msa, gate_msa = msa_params.chunk(3, dim=-1)
|
||||
|
||||
x_norm = apply_adaln_(self.norm1(x), shift_msa, scale_msa)
|
||||
x_flat = x_norm.view(BL, P2 * self.pixel_dim)
|
||||
|
||||
x_comp = self.compress_to_attn(x_flat).view(B, L, self.attn_dim)
|
||||
pos_comp = self._fetch_pos(Hs, Ws, x.device, x.dtype, **(transformer_options.get("rope_options") or {}))
|
||||
attn_out = self.attn(x_comp, pos_comp, mask=mask, transformer_options=transformer_options)
|
||||
attn_flat = self.expand_from_attn(attn_out.view(B * L, self.attn_dim))
|
||||
attn_exp = attn_flat.view(BL, P2, self.pixel_dim)
|
||||
x = torch.addcmul(x, gate_msa, attn_exp)
|
||||
del msa_params, shift_msa, scale_msa, gate_msa
|
||||
|
||||
mlp_params = self.adaLN_modulation_mlp(s_cond).view(BL, P2, 3 * self.pixel_dim)
|
||||
shift_mlp, scale_mlp, gate_mlp = mlp_params.chunk(3, dim=-1)
|
||||
gate_mlp = gate_mlp.contiguous() # detach from mlp_params so the del below frees shift+scale storage before the MLP
|
||||
mlp_input = apply_adaln_(self.norm2(x), shift_mlp, scale_mlp)
|
||||
del mlp_params, shift_mlp, scale_mlp
|
||||
|
||||
# MLP in chunks since the peak memory usage is huge here
|
||||
chunk_size = (BL + self.mlp_chunks - 1) // self.mlp_chunks
|
||||
for s in range(0, BL, chunk_size):
|
||||
e = min(s + chunk_size, BL)
|
||||
x[s:e].addcmul_(gate_mlp[s:e], self.mlp(mlp_input[s:e]))
|
||||
return x
|
||||
227
comfy/ldm/pixeldit/pid.py
Normal file
227
comfy/ldm/pixeldit/pid.py
Normal file
@ -0,0 +1,227 @@
|
||||
"""PiD — Pixel Diffusion Decoder. Decodes a Flux/SD3/Flux2/Z-Image latent
|
||||
directly to a 4x-upscaled image in 4 distilled flow-matching steps. PixDiT_T2I
|
||||
body + LQ projection branch injected before each MMDiT patch block.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .model import PixDiT_T2I
|
||||
from .modules import precompute_freqs_cis_2d
|
||||
|
||||
|
||||
class SigmaAwareGatePerTokenPerDim(nn.Module):
|
||||
"""gate = sigmoid(content_proj(cat[x, lq]) - exp(log_alpha) * sigma); out = x + gate * lq.
|
||||
|
||||
Trained init gives ~0.88 gate at sigma=0, ~0.05 at sigma=1.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.content_proj = operations.Linear(dim * 2, dim, dtype=dtype, device=device)
|
||||
self.log_alpha = nn.Parameter(torch.empty((), dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: torch.Tensor, lq: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
||||
content_logit = self.content_proj(torch.cat([x, lq], dim=-1))
|
||||
# log_alpha is a raw nn.Parameter -> doesn't auto-cast under dynamic VRAM.
|
||||
log_alpha = self.log_alpha.to(device=x.device, dtype=torch.float32)
|
||||
sigma_offset = -log_alpha.exp() * sigma.float().view(-1, 1, 1)
|
||||
gate = torch.sigmoid(content_logit + sigma_offset)
|
||||
return x + (gate * lq).to(x.dtype)
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
"""Pre-activation ResNet block: GN -> SiLU -> Conv -> GN -> SiLU -> Conv + skip."""
|
||||
|
||||
def __init__(self, channels: int, num_groups: int = 4, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
operations.GroupNorm(num_groups, channels, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(channels, channels, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.block(x)
|
||||
|
||||
|
||||
class LQProjection2D(nn.Module):
|
||||
"""LQ latent -> per-block patch-aligned features for controlnet-style injection."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
latent_channels: int,
|
||||
hidden_dim: int = 512,
|
||||
out_dim: int = 1536,
|
||||
patch_size: int = 16,
|
||||
sr_scale: int = 4,
|
||||
latent_spatial_down_factor: int = 8,
|
||||
num_res_blocks: int = 4,
|
||||
num_outputs: int = 7,
|
||||
interval: int = 2,
|
||||
dtype=None, device=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.latent_channels = latent_channels
|
||||
self.hidden_dim = hidden_dim
|
||||
self.out_dim = out_dim
|
||||
self.patch_size = patch_size
|
||||
self.sr_scale = sr_scale
|
||||
self.latent_spatial_down_factor = latent_spatial_down_factor
|
||||
self.num_outputs = num_outputs
|
||||
self.interval = interval
|
||||
|
||||
z_to_patch_ratio = (sr_scale * latent_spatial_down_factor) / patch_size
|
||||
self.z_to_patch_ratio = z_to_patch_ratio
|
||||
if z_to_patch_ratio >= 1:
|
||||
self.latent_fold_factor = 0
|
||||
latent_proj_in_ch = latent_channels
|
||||
else:
|
||||
fold_factor = int(1 / z_to_patch_ratio)
|
||||
assert fold_factor * z_to_patch_ratio == 1.0
|
||||
self.latent_fold_factor = fold_factor
|
||||
latent_proj_in_ch = latent_channels * fold_factor * fold_factor
|
||||
|
||||
layers = [
|
||||
operations.Conv2d(latent_proj_in_ch, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, dtype=dtype, device=device),
|
||||
]
|
||||
for _ in range(num_res_blocks):
|
||||
layers.append(ResBlock(hidden_dim, dtype=dtype, device=device, operations=operations))
|
||||
self.latent_proj = nn.Sequential(*layers)
|
||||
|
||||
self.output_heads = nn.ModuleList(
|
||||
[operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) for _ in range(num_outputs)]
|
||||
)
|
||||
self.gate_modules = nn.ModuleList(
|
||||
[SigmaAwareGatePerTokenPerDim(out_dim, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_outputs)]
|
||||
)
|
||||
|
||||
def is_gate_active(self, block_idx: int) -> bool:
|
||||
return block_idx % self.interval == 0
|
||||
|
||||
def output_index(self, block_idx: int) -> int:
|
||||
return block_idx // self.interval
|
||||
|
||||
def gate(self, x: torch.Tensor, lq_feature: torch.Tensor, sigma: torch.Tensor, out_idx: int) -> torch.Tensor:
|
||||
return self.gate_modules[out_idx](x, lq_feature, sigma)
|
||||
|
||||
def _align_latent_to_patch_grid(self, lq_latent: torch.Tensor, pH: int, pW: int) -> torch.Tensor:
|
||||
B, z_dim = lq_latent.shape[:2]
|
||||
if self.z_to_patch_ratio >= 1:
|
||||
if lq_latent.shape[2] != pH or lq_latent.shape[3] != pW:
|
||||
z_aligned = F.interpolate(lq_latent, size=(pH, pW), mode="nearest")
|
||||
else:
|
||||
z_aligned = lq_latent
|
||||
else:
|
||||
f = self.latent_fold_factor
|
||||
zH_expected, zW_expected = pH * f, pW * f
|
||||
if lq_latent.shape[2] != zH_expected or lq_latent.shape[3] != zW_expected:
|
||||
lq_latent = F.interpolate(lq_latent, size=(zH_expected, zW_expected), mode="nearest")
|
||||
z_aligned = lq_latent.reshape(B, z_dim, pH, f, pW, f).permute(0, 1, 3, 5, 2, 4)
|
||||
z_aligned = z_aligned.reshape(B, z_dim * f * f, pH, pW)
|
||||
return self.latent_proj(z_aligned)
|
||||
|
||||
def forward(self, lq_latent: torch.Tensor, target_pH: int, target_pW: int) -> List[torch.Tensor]:
|
||||
feat = self._align_latent_to_patch_grid(lq_latent, target_pH, target_pW)
|
||||
B, C, H, W = feat.shape
|
||||
tokens = feat.permute(0, 2, 3, 1).contiguous().view(B, H * W, C)
|
||||
return [head(tokens) for head in self.output_heads]
|
||||
|
||||
|
||||
class PidNet(PixDiT_T2I):
|
||||
"""PixDiT_T2I + LQ injection (one sigma-gated feature inserted before each patch block)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lq_latent_channels: int = 16,
|
||||
lq_hidden_dim: int = 512,
|
||||
lq_num_res_blocks: int = 4,
|
||||
lq_interval: int = 2,
|
||||
sr_scale: int = 4,
|
||||
latent_spatial_down_factor: int = 8,
|
||||
rope_ref_h: int = 1024, # NTK ref resolution in PIXEL units: 1024px / patch=16 -> grid_ref=64.
|
||||
rope_ref_w: int = 1024,
|
||||
image_model=None,
|
||||
dtype=None, device=None, operations=None,
|
||||
**pixdit_kwargs,
|
||||
):
|
||||
super().__init__(dtype=dtype, device=device, operations=operations, **pixdit_kwargs)
|
||||
|
||||
self.rope_ref_grid_h = rope_ref_h // self.patch_size
|
||||
self.rope_ref_grid_w = rope_ref_w // self.patch_size
|
||||
|
||||
# Parent's PiTBlocks were built with plain RoPE — swap in NTK-aware.
|
||||
def _pit_rope_fn(head_dim, h, w, device=None, dtype=torch.float32, **rope_opts):
|
||||
return precompute_freqs_cis_2d(head_dim, h, w, ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w, device=device, dtype=dtype, **rope_opts)
|
||||
for blk in self.pixel_blocks:
|
||||
blk._rope_fn = _pit_rope_fn
|
||||
|
||||
num_lq_outputs = (self.patch_depth + lq_interval - 1) // lq_interval
|
||||
self.lq_proj = LQProjection2D(
|
||||
latent_channels=lq_latent_channels,
|
||||
hidden_dim=lq_hidden_dim,
|
||||
out_dim=self.hidden_size,
|
||||
patch_size=self.patch_size,
|
||||
sr_scale=sr_scale,
|
||||
latent_spatial_down_factor=latent_spatial_down_factor,
|
||||
num_res_blocks=lq_num_res_blocks,
|
||||
num_outputs=num_lq_outputs,
|
||||
interval=lq_interval,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
|
||||
def _fetch_patch_pos(self, height, width, device, dtype, **rope_opts):
|
||||
return precompute_freqs_cis_2d(
|
||||
self.hidden_size // self.num_groups,
|
||||
height, width,
|
||||
ref_grid_h=self.rope_ref_grid_h, ref_grid_w=self.rope_ref_grid_w,
|
||||
device=device, dtype=dtype, **rope_opts,
|
||||
)
|
||||
|
||||
def _pre_patch_block(self, s, i, pid_lq_features, pid_degrade_sigma, **kwargs):
|
||||
if not self.lq_proj.is_gate_active(i):
|
||||
return s
|
||||
out_idx = self.lq_proj.output_index(i)
|
||||
if out_idx >= len(pid_lq_features):
|
||||
return s
|
||||
return self.lq_proj.gate(s, pid_lq_features[out_idx], pid_degrade_sigma, out_idx)
|
||||
|
||||
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, lq_latent=None, degrade_sigma=None, **kwargs):
|
||||
if lq_latent is None:
|
||||
raise ValueError("PidNet requires lq_latent — attach via PiDConditioning")
|
||||
expected_c = self.lq_proj.latent_channels
|
||||
if lq_latent.shape[1] != expected_c:
|
||||
raise ValueError(
|
||||
f"Input latent has {lq_latent.shape[1]} channels, this model variant expects {expected_c}. "
|
||||
f"Flux1/SD3 = 16 channels, Flux2 = 128 channels."
|
||||
)
|
||||
B = x.shape[0]
|
||||
# Match the backbone's pad_to_patch_size (round up) so the LQ grid lines up with the patch stream.
|
||||
Hs = -(-x.shape[2] // self.patch_size)
|
||||
Ws = -(-x.shape[3] // self.patch_size)
|
||||
|
||||
degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1)
|
||||
if degrade_sigma.numel() == 1 and B > 1:
|
||||
degrade_sigma = degrade_sigma.expand(B).contiguous()
|
||||
|
||||
lq_features = self.lq_proj(lq_latent=lq_latent.to(x), target_pH=Hs, target_pW=Ws)
|
||||
|
||||
return super()._forward(
|
||||
x, timesteps,
|
||||
context=context, attention_mask=attention_mask,
|
||||
transformer_options=transformer_options,
|
||||
pid_lq_features=lq_features,
|
||||
pid_degrade_sigma=degrade_sigma,
|
||||
**kwargs,
|
||||
)
|
||||
@ -35,6 +35,7 @@ import comfy.ldm.hydit.models
|
||||
import comfy.ldm.audio.dit
|
||||
import comfy.ldm.audio.embedders
|
||||
import comfy.ldm.flux.model
|
||||
import comfy.ldm.lens.model
|
||||
import comfy.ldm.lightricks.model
|
||||
import comfy.ldm.hunyuan_video.model
|
||||
import comfy.ldm.cosmos.model
|
||||
@ -48,6 +49,8 @@ import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
import comfy.ldm.chroma_radiance.model
|
||||
import comfy.ldm.pixeldit.model
|
||||
import comfy.ldm.pixeldit.pid
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.seedvr.model
|
||||
@ -1070,6 +1073,27 @@ class Flux2(Flux):
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
|
||||
class Lens(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(
|
||||
model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.lens.model.LensTransformer2DModel,
|
||||
)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return None # Lens has no pooled/ADM conditioning.
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
return out
|
||||
|
||||
class GenmoMochi(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.genmo.joint_model.asymm_models_joint.AsymmDiTJoint)
|
||||
@ -1387,6 +1411,53 @@ class ZImagePixelSpace(Lumina2):
|
||||
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
|
||||
self.memory_usage_factor_conds = ("ref_latents",)
|
||||
|
||||
|
||||
class PixelDiTT2I(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.pixeldit.model.PixDiT_T2I)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out["attention_mask"] = comfy.conds.CONDRegular(attention_mask)
|
||||
return out
|
||||
|
||||
|
||||
class PiD(PixelDiTT2I):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
BaseModel.__init__(self, model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.pixeldit.pid.PidNet)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
lq_latent = kwargs.get("lq_latent", None)
|
||||
if lq_latent is not None:
|
||||
out["lq_latent"] = comfy.conds.CONDRegular(lq_latent)
|
||||
degrade_sigma = kwargs.get("degrade_sigma", None)
|
||||
if degrade_sigma is not None:
|
||||
out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma)
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "lq_latent" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
lq = cond_value.cond
|
||||
dim = window.dim
|
||||
if dim >= lq.ndim:
|
||||
return None
|
||||
lq_proj = self.diffusion_model.lq_proj
|
||||
ratio = lq_proj.sr_scale * lq_proj.latent_spatial_down_factor
|
||||
# Map x window indices -> lq indices (deduplicated, sorted, in-bounds).
|
||||
lq_size = lq.size(dim)
|
||||
lq_indices = sorted({i // ratio for i in window.index_list if 0 <= i // ratio < lq_size})
|
||||
if not lq_indices:
|
||||
return None
|
||||
idx = tuple([slice(None)] * dim + [lq_indices])
|
||||
return cond_value._copy_with(lq[idx].to(device))
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
|
||||
class WAN21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||
|
||||
@ -463,6 +463,23 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
|
||||
return dit_config
|
||||
|
||||
# PiD (Pixel Diffusion Decoder). Must check BEFORE plain PixelDiT_T2I.
|
||||
_lq_w_key = '{}lq_proj.latent_proj.0.weight'.format(key_prefix)
|
||||
if _lq_w_key in state_dict_keys:
|
||||
in_ch = int(state_dict[_lq_w_key].shape[1])
|
||||
_gate_prefix = '{}lq_proj.gate_modules.'.format(key_prefix)
|
||||
num_gates = len({k[len(_gate_prefix):].split('.')[0]
|
||||
for k in state_dict_keys if k.startswith(_gate_prefix)})
|
||||
dit_config = {"image_model": "pid",
|
||||
"lq_latent_channels": in_ch,
|
||||
"latent_spatial_down_factor": 16 if in_ch >= 64 else 8}
|
||||
if num_gates > 0:
|
||||
dit_config["lq_interval"] = (14 + num_gates - 1) // num_gates
|
||||
return dit_config
|
||||
|
||||
if '{}core.pixel_embedder.proj.weight'.format(key_prefix) in state_dict_keys: # PixelDiT T2I
|
||||
return {"image_model": "pixeldit_t2i"}
|
||||
|
||||
if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys and '{}noise_refiner.0.attention.k_norm.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "lumina2"
|
||||
@ -805,6 +822,30 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["timestep_scale"] = 1000.0
|
||||
return dit_config
|
||||
|
||||
if '{}transformer_blocks.0.attn.norm_added_q.weight'.format(key_prefix) in state_dict_keys \
|
||||
and '{}transformer_blocks.0.img_mlp.w1.weight'.format(key_prefix) in state_dict_keys: # Lens
|
||||
img_in_w = state_dict['{}img_in.weight'.format(key_prefix)]
|
||||
proj_out_w = state_dict['{}proj_out.weight'.format(key_prefix)]
|
||||
multi_layer = '{}txt_norm.0.weight'.format(key_prefix) in state_dict_keys
|
||||
if multi_layer:
|
||||
enc_hidden_dim = state_dict['{}txt_norm.0.weight'.format(key_prefix)].shape[0]
|
||||
# Indices are TE-side; the DiT just consumes L layers in order.
|
||||
selected_layer_index = tuple(range(count_blocks(state_dict_keys, '{}txt_norm.'.format(key_prefix) + '{}.')))
|
||||
else:
|
||||
enc_hidden_dim = state_dict['{}txt_norm.weight'.format(key_prefix)].shape[0]
|
||||
selected_layer_index = (0,)
|
||||
|
||||
return {
|
||||
"image_model": "lens",
|
||||
"in_channels": img_in_w.shape[1],
|
||||
"out_channels": proj_out_w.shape[0] // 4, # patch_size ** 2 (=2² default)
|
||||
"num_layers": count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.'),
|
||||
"num_attention_heads": img_in_w.shape[0] // 64, # // attention_head_dim default
|
||||
"enc_hidden_dim": enc_hidden_dim,
|
||||
"multi_layer_encoder_feature": multi_layer,
|
||||
"selected_layer_index": selected_layer_index,
|
||||
}
|
||||
|
||||
if '{}txt_norm.weight'.format(key_prefix) in state_dict_keys: # Qwen Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "qwen_image"
|
||||
|
||||
476
comfy/ops.py
476
comfy/ops.py
@ -18,6 +18,7 @@
|
||||
|
||||
import torch
|
||||
import logging
|
||||
import contextlib
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
@ -1047,6 +1048,144 @@ class QuantLinearFunc(torch.autograd.Function):
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
# Quantized-weight module helpers
|
||||
|
||||
def _quantized_apply(module, fn, recurse=True):
|
||||
"""Re-wrap Parameters after fn so .to()/.cuda() propagate through QuantizedTensor weights."""
|
||||
if recurse:
|
||||
for child in module.children():
|
||||
child._apply(fn)
|
||||
for key, param in module._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
p = fn(param)
|
||||
if (not torch.is_inference_mode_enabled()) and p.is_inference():
|
||||
p = p.clone()
|
||||
module.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||
for key, buf in module._buffers.items():
|
||||
if buf is not None:
|
||||
module._buffers[key] = fn(buf)
|
||||
return module
|
||||
|
||||
|
||||
def _load_quantized_module(module, super_load, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs, load_extra_params=False):
|
||||
"""Shared _load_from_state_dict body for quantized-weight modules.
|
||||
|
||||
Pops weight (+ scales, +/- extras), populates module.weight as a Parameter
|
||||
or Parameter-wrapped QuantizedTensor, then calls super_load and strips
|
||||
consumed keys from missing_keys. Reads compute_dtype from factory_kwargs
|
||||
and disabled formats from module._disabled_formats.
|
||||
"""
|
||||
device = module.factory_kwargs["device"]
|
||||
compute_dtype = module.factory_kwargs["dtype"]
|
||||
disabled_formats = module._disabled_formats
|
||||
layer_name = prefix.rstrip('.')
|
||||
|
||||
weight = state_dict.pop(f"{prefix}weight", None)
|
||||
if weight is None:
|
||||
logging.warning(f"Missing weight for layer {layer_name}")
|
||||
module.weight = None
|
||||
return
|
||||
manually_loaded_keys = [f"{prefix}weight"]
|
||||
|
||||
def pop_scale(name, dtype=None):
|
||||
key = f"{prefix}{name}"
|
||||
v = state_dict.pop(key, None)
|
||||
if v is not None:
|
||||
v = v.to(device=device)
|
||||
if dtype is not None:
|
||||
v = v.view(dtype=dtype)
|
||||
manually_loaded_keys.append(key)
|
||||
return v
|
||||
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
|
||||
if layer_conf is None:
|
||||
module.weight = torch.nn.Parameter(weight.to(device=device, dtype=compute_dtype), requires_grad=False)
|
||||
else:
|
||||
module.quant_format = layer_conf.get("format", None)
|
||||
module._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||
if not module._full_precision_mm:
|
||||
module._full_precision_mm = module._full_precision_mm_config
|
||||
if module.quant_format in disabled_formats:
|
||||
module._full_precision_mm = True
|
||||
if module.quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
qconfig = QUANT_ALGOS[module.quant_format]
|
||||
module.layout_type = qconfig["comfy_tensor_layout"]
|
||||
layout_cls = get_layout_class(module.layout_type)
|
||||
|
||||
# Per-format scales; fp8 dtype views handle both legacy uint8-on-disk and native fp8.
|
||||
if module.quant_format in ("float8_e4m3fn", "float8_e5m2"):
|
||||
scales = {"scale": pop_scale("weight_scale")}
|
||||
elif module.quant_format == "mxfp8":
|
||||
bs = pop_scale("weight_scale", torch.float8_e8m0fnu)
|
||||
if bs is None:
|
||||
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
||||
scales = {"scale": bs}
|
||||
elif module.quant_format == "nvfp4":
|
||||
ts = pop_scale("weight_scale_2")
|
||||
bs = pop_scale("weight_scale", torch.float8_e4m3fn)
|
||||
if ts is None or bs is None:
|
||||
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
||||
scales = {"scale": ts, "block_scale": bs}
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {module.quant_format}")
|
||||
|
||||
params = layout_cls.Params(**scales, orig_dtype=compute_dtype, orig_shape=module._orig_shape)
|
||||
module.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), module.layout_type, params),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
if load_extra_params:
|
||||
for param_name in qconfig["parameters"]:
|
||||
if param_name in {"weight_scale", "weight_scale_2"}:
|
||||
continue
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
module.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
|
||||
super_load(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
for key in manually_loaded_keys:
|
||||
if key in missing_keys:
|
||||
missing_keys.remove(key)
|
||||
|
||||
|
||||
def _quantized_weight_state_dict(module, sd, prefix, extra_quant_conf=None, extra_quant_params=()):
|
||||
"""Shared state_dict body. extra_quant_conf merges into the comfy_quant JSON;
|
||||
extra_quant_params names attributes written as additional top-level keys."""
|
||||
if not hasattr(module, 'weight'):
|
||||
logging.warning(f"Warning: state dict on uninitialized op {prefix}")
|
||||
return sd
|
||||
bias = getattr(module, 'bias', None)
|
||||
if bias is not None:
|
||||
sd[f"{prefix}bias"] = bias
|
||||
if module.weight is None:
|
||||
return sd
|
||||
if isinstance(module.weight, QuantizedTensor):
|
||||
sd.update(module.weight.state_dict(f"{prefix}weight"))
|
||||
quant_conf = {"format": module.quant_format}
|
||||
if getattr(module, '_full_precision_mm_config', False):
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
if extra_quant_conf:
|
||||
quant_conf.update(extra_quant_conf)
|
||||
sd[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(quant_conf).encode("utf-8")), dtype=torch.uint8)
|
||||
for name in extra_quant_params:
|
||||
value = getattr(module, name, None)
|
||||
if value is not None:
|
||||
sd[f"{prefix}{name}"] = value
|
||||
else:
|
||||
sd[f"{prefix}weight"] = module.weight
|
||||
return sd
|
||||
|
||||
|
||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||
class MixedPrecisionOps(manual_cast):
|
||||
@ -1056,21 +1195,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
_disabled = disabled
|
||||
|
||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
_disabled_formats = disabled
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
||||
super().__init__()
|
||||
|
||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self._orig_shape = (out_features, in_features)
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||
else:
|
||||
@ -1083,151 +1217,12 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _load_scale_param(self, state_dict, prefix, param_name, device, manually_loaded_keys, dtype=None):
|
||||
key = f"{prefix}{param_name}"
|
||||
value = state_dict.pop(key, None)
|
||||
if value is not None:
|
||||
value = value.to(device=device)
|
||||
if dtype is not None:
|
||||
value = value.view(dtype=dtype)
|
||||
manually_loaded_keys.append(key)
|
||||
return value
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
device = self.factory_kwargs["device"]
|
||||
layer_name = prefix.rstrip('.')
|
||||
weight_key = f"{prefix}weight"
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
logging.warning(f"Missing weight for layer {layer_name}")
|
||||
self.weight = None
|
||||
return
|
||||
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
|
||||
if layer_conf is None:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
self.quant_format = layer_conf.get("format", None)
|
||||
self._full_precision_mm_config = layer_conf.get("full_precision_matrix_mult", False)
|
||||
if not self._full_precision_mm:
|
||||
self._full_precision_mm = self._full_precision_mm_config
|
||||
|
||||
if self.quant_format in MixedPrecisionOps._disabled:
|
||||
self._full_precision_mm = True
|
||||
|
||||
if self.quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
qconfig = QUANT_ALGOS[self.quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
layout_cls = get_layout_class(self.layout_type)
|
||||
|
||||
# Load format-specific parameters
|
||||
if self.quant_format in ["float8_e4m3fn", "float8_e5m2"]:
|
||||
# FP8: single tensor scale
|
||||
scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys)
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
|
||||
elif self.quant_format == "mxfp8":
|
||||
# MXFP8: E8M0 block scales stored as uint8 in safetensors
|
||||
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||
dtype=torch.uint8)
|
||||
|
||||
if block_scale is None:
|
||||
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
||||
|
||||
block_scale = block_scale.view(torch.float8_e8m0fnu)
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=block_scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
|
||||
elif self.quant_format == "nvfp4":
|
||||
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
||||
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||
dtype=torch.float8_e4m3fn)
|
||||
|
||||
if tensor_scale is None or block_scale is None:
|
||||
raise ValueError(f"Missing NVFP4 scales for layer {layer_name}")
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=tensor_scale,
|
||||
block_scale=block_scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization format: {self.quant_format}")
|
||||
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig["storage_t"]), self.layout_type, params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
for param_name in qconfig["parameters"]:
|
||||
if param_name in {"weight_scale", "weight_scale_2"}:
|
||||
continue # Already handled above
|
||||
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
for key in manually_loaded_keys:
|
||||
if key in missing_keys:
|
||||
missing_keys.remove(key)
|
||||
def _load_from_state_dict(self, *args):
|
||||
_load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=True)
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
if destination is not None:
|
||||
sd = destination
|
||||
else:
|
||||
sd = {}
|
||||
|
||||
if not hasattr(self, 'weight'):
|
||||
logging.warning("Warning: state dict on uninitialized op {}".format(prefix))
|
||||
return sd
|
||||
|
||||
if self.bias is not None:
|
||||
sd["{}bias".format(prefix)] = self.bias
|
||||
|
||||
if self.weight is None:
|
||||
return sd
|
||||
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||
for k in sd_out:
|
||||
sd[k] = sd_out[k]
|
||||
|
||||
quant_conf = {"format": self.quant_format}
|
||||
if self._full_precision_mm_config:
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
input_scale = getattr(self, 'input_scale', None)
|
||||
if input_scale is not None:
|
||||
sd["{}input_scale".format(prefix)] = input_scale
|
||||
else:
|
||||
sd["{}weight".format(prefix)] = self.weight
|
||||
return sd
|
||||
sd = destination if destination is not None else {}
|
||||
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_params=("input_scale",))
|
||||
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
@ -1317,25 +1312,126 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
|
||||
if recurse:
|
||||
for module in self.children():
|
||||
module._apply(fn)
|
||||
return _quantized_apply(self, fn, recurse)
|
||||
|
||||
for key, param in self._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
p = fn(param)
|
||||
if (not torch.is_inference_mode_enabled()) and p.is_inference():
|
||||
p = p.clone()
|
||||
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||
for key, buf in self._buffers.items():
|
||||
if buf is not None:
|
||||
self._buffers[key] = fn(buf)
|
||||
return self
|
||||
class MoEExperts(torch.nn.Module, CastWeightBiasOp):
|
||||
"""Container for E quantized expert weights, indexed via expert_weight(i).
|
||||
|
||||
The bank lives on self.weight as a single 3D tensor — either a
|
||||
compute_dtype Parameter or a Parameter wrapping a QuantizedTensor
|
||||
with leading expert dim.
|
||||
|
||||
State-dict layout matches mixed_precision_ops.Linear with a leading
|
||||
expert dim:
|
||||
{prefix}.weight quant data (storage_t), leading dim = E
|
||||
{prefix}.weight_scale block / per-tensor scale
|
||||
{prefix}.weight_scale_2 [E] or scalar NVFP4 only
|
||||
{prefix}.bias [E, out_features] optional, compute_dtype
|
||||
{prefix}.comfy_quant json -> {{"format": "...", "num_experts": E}}
|
||||
|
||||
Without comfy_quant the weight loads as a plain compute_dtype 3D Parameter [E, out, in].
|
||||
"""
|
||||
|
||||
_disabled_formats = disabled
|
||||
|
||||
def __init__(self, num_experts: int, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self._orig_shape = (num_experts, out_features, in_features)
|
||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(num_experts, out_features, **self.factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
# Populated by _load_from_state_dict:
|
||||
self.weight = None
|
||||
self.quant_format = None
|
||||
self.layout_type = None
|
||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||
self._full_precision_mm_config = False
|
||||
self._resident_bank = None
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
def _apply(self, fn, recurse=True):
|
||||
return _quantized_apply(self, fn, recurse)
|
||||
|
||||
def _load_from_state_dict(self, *args):
|
||||
_load_quantized_module(self, super()._load_from_state_dict, *args, load_extra_params=False)
|
||||
|
||||
def expert_weight(self, i: int):
|
||||
"""Expert i's weight (Tensor or per-expert QuantizedTensor view)."""
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
return self._expert_qt_from(self.weight, i)
|
||||
return self.weight[i]
|
||||
|
||||
@contextlib.contextmanager
|
||||
def bank_resident(self, input):
|
||||
"""Cast the whole bank once; expert_linear inside reuses the cast.
|
||||
Not re-entrant — do not nest calls on the same instance.
|
||||
"""
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
self._resident_bank = (weight, bias)
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
self._resident_bank = None
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
|
||||
def expert_linear(self, input: torch.Tensor, i: int) -> torch.Tensor:
|
||||
"""Linear against expert i's weight (with optional bias)."""
|
||||
resident = getattr(self, "_resident_bank", None)
|
||||
if resident is not None:
|
||||
weight, bias = resident
|
||||
return self._expert_linear_impl(input, weight, bias, i)
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
try:
|
||||
return self._expert_linear_impl(input, weight, bias, i)
|
||||
finally:
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
|
||||
def _expert_linear_impl(self, input, weight, bias, i):
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
qw = self._expert_qt_from(weight, i)
|
||||
else:
|
||||
qw = weight[i]
|
||||
b = cast_to_input(bias[i], input, copy=False) if bias is not None else None
|
||||
|
||||
if isinstance(qw, QuantizedTensor):
|
||||
use_fast = (
|
||||
not self._full_precision_mm
|
||||
and qw.layout_cls.supports_fast_matmul()
|
||||
and input.dim() == 2
|
||||
)
|
||||
if use_fast:
|
||||
qin = QuantizedTensor.from_float(input, self.layout_type)
|
||||
return torch.nn.functional.linear(qin, qw, b)
|
||||
out = input @ qw.dequantize().t()
|
||||
return out + b if b is not None else out
|
||||
return torch.nn.functional.linear(input, qw, b)
|
||||
|
||||
def _expert_qt_from(self, weight: QuantizedTensor, i: int) -> QuantizedTensor:
|
||||
"""Build a per-expert QuantizedTensor by indexing into a resident bank."""
|
||||
params = weight._params
|
||||
kwargs = {
|
||||
"scale": params.scale[i] if params.scale.dim() else params.scale,
|
||||
"orig_dtype": params.orig_dtype,
|
||||
"orig_shape": (self.out_features, self.in_features),
|
||||
}
|
||||
if hasattr(params, "block_scale"): # NVFP4
|
||||
kwargs["block_scale"] = params.block_scale[i]
|
||||
return QuantizedTensor(weight._qdata[i], weight._layout_cls, type(params)(**kwargs))
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
sd = destination if destination is not None else {}
|
||||
return _quantized_weight_state_dict(self, sd, prefix, extra_quant_conf={"num_experts": self.num_experts})
|
||||
|
||||
class Embedding(manual_cast.Embedding):
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
||||
weight_key = f"{prefix}weight"
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
@ -1343,14 +1439,16 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
|
||||
# Only fp8 makes sense for embeddings (per-row dequant via index select).
|
||||
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
|
||||
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
|
||||
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
|
||||
quant_format = layer_conf.get("format") if layer_conf is not None else None
|
||||
manually_loaded_keys = []
|
||||
|
||||
if quant_format in ("float8_e4m3fn", "float8_e5m2") and weight_key in state_dict:
|
||||
self.quant_format = quant_format
|
||||
qconfig = QUANT_ALGOS[quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
layout_cls = get_layout_class(self.layout_type)
|
||||
weight = state_dict.pop(weight_key)
|
||||
manually_loaded_keys = [weight_key]
|
||||
manually_loaded_keys.append(weight_key)
|
||||
|
||||
scale_key = f"{prefix}weight_scale"
|
||||
scale = state_dict.pop(scale_key, None)
|
||||
@ -1366,35 +1464,19 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
|
||||
requires_grad=False)
|
||||
elif layer_conf is not None:
|
||||
# Unsupported format — restore the marker so it round-trips; fall through to default load.
|
||||
state_dict[f"{prefix}comfy_quant"] = torch.tensor(
|
||||
list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
for k in manually_loaded_keys:
|
||||
if k in missing_keys:
|
||||
missing_keys.remove(k)
|
||||
else:
|
||||
if layer_conf is not None:
|
||||
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
for k in manually_loaded_keys:
|
||||
if k in missing_keys:
|
||||
missing_keys.remove(k)
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
if destination is not None:
|
||||
sd = destination
|
||||
else:
|
||||
sd = {}
|
||||
|
||||
if not hasattr(self, 'weight') or self.weight is None:
|
||||
return sd
|
||||
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||
for k in sd_out:
|
||||
sd[k] = sd_out[k]
|
||||
|
||||
quant_conf = {"format": self.quant_format}
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
else:
|
||||
sd["{}weight".format(prefix)] = self.weight
|
||||
return sd
|
||||
sd = destination if destination is not None else {}
|
||||
return _quantized_weight_state_dict(self, sd, prefix)
|
||||
|
||||
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
||||
weight = self.weight
|
||||
|
||||
20
comfy/sd.py
20
comfy/sd.py
@ -51,6 +51,7 @@ import comfy.text_encoders.lt
|
||||
import comfy.text_encoders.hunyuan_video
|
||||
import comfy.text_encoders.cosmos
|
||||
import comfy.text_encoders.lumina2
|
||||
import comfy.text_encoders.pixeldit
|
||||
import comfy.text_encoders.wan
|
||||
import comfy.text_encoders.hidream
|
||||
import comfy.text_encoders.ace
|
||||
@ -70,6 +71,7 @@ import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.gemma4
|
||||
import comfy.text_encoders.cogvideo
|
||||
import comfy.text_encoders.sa3
|
||||
import comfy.text_encoders.gpt_oss
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -1463,6 +1465,8 @@ class CLIPType(Enum):
|
||||
FLUX2 = 25
|
||||
LONGCAT_IMAGE = 26
|
||||
COGVIDEOX = 27
|
||||
LENS = 28
|
||||
PIXELDIT = 29
|
||||
|
||||
|
||||
|
||||
@ -1515,6 +1519,7 @@ class TEModel(Enum):
|
||||
GEMMA_4_E2B = 30
|
||||
GEMMA_4_31B = 31
|
||||
T5_GEMMA = 32
|
||||
GPT_OSS_20B = 33
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -1556,6 +1561,9 @@ def detect_te_model(sd):
|
||||
else:
|
||||
return TEModel.GEMMA_3_4B
|
||||
return TEModel.GEMMA_2_2B
|
||||
# Must precede the Qwen2.5-7B k_proj.bias=512 check (GPT-OSS also has 8*64=512).
|
||||
if "layers.0.self_attn.sinks" in sd and "layers.0.mlp.experts.gate_up_proj.weight" in sd:
|
||||
return TEModel.GPT_OSS_20B
|
||||
if 'model.layers.0.self_attn.k_proj.bias' in sd:
|
||||
weight = sd['model.layers.0.self_attn.k_proj.bias']
|
||||
if weight.shape[0] == 256:
|
||||
@ -1702,8 +1710,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.tokenizer = variant.tokenizer
|
||||
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||
elif te_model == TEModel.GEMMA_2_2B:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
if clip_type == CLIPType.PIXELDIT:
|
||||
clip_target.clip = comfy.text_encoders.pixeldit.pixeldit_te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.GEMMA_3_4B:
|
||||
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
|
||||
@ -1738,6 +1750,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
|
||||
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
|
||||
elif te_model == TEModel.GPT_OSS_20B:
|
||||
clip_target.clip = comfy.text_encoders.gpt_oss.lens_te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.gpt_oss.LensTokenizer
|
||||
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
|
||||
elif te_model == TEModel.QWEN3_4B:
|
||||
if clip_type == CLIPType.FLUX or clip_type == CLIPType.FLUX2:
|
||||
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_4b")
|
||||
|
||||
@ -30,6 +30,7 @@ import comfy.text_encoders.longcat_image
|
||||
import comfy.text_encoders.ernie
|
||||
import comfy.text_encoders.cogvideo
|
||||
import comfy.text_encoders.hidream_o1
|
||||
import comfy.text_encoders.pixeldit
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@ -829,6 +830,50 @@ class Flux2(Flux):
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class Lens(supported_models_base.BASE):
|
||||
"""Microsoft Lens (3.8B dual-stream MMDiT, GPT-OSS-20B text features, Flux2 VAE)."""
|
||||
|
||||
unet_config = {
|
||||
"image_model": "lens",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 1.829, # Default mu for 1440x1440 (and any seq_len > 4300
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux2
|
||||
|
||||
memory_usage_factor = 4.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32] # fp16 causes NaNs
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.Lens(self, model_type=model_base.ModelType.FLUX, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
for hint in ("gpt_oss.transformer.", ""):
|
||||
full_prefix = "{}{}".format(pref, hint)
|
||||
if "{}layers.0.self_attn.sinks".format(full_prefix) in state_dict:
|
||||
detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, full_prefix)
|
||||
return supported_models_base.ClipTarget(
|
||||
comfy.text_encoders.gpt_oss.LensTokenizer,
|
||||
comfy.text_encoders.gpt_oss.lens_te(**detect),
|
||||
)
|
||||
return supported_models_base.ClipTarget(
|
||||
comfy.text_encoders.gpt_oss.LensTokenizer,
|
||||
comfy.text_encoders.gpt_oss.lens_te(),
|
||||
)
|
||||
|
||||
|
||||
class GenmoMochi(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "mochi_preview",
|
||||
@ -1159,6 +1204,72 @@ class ZImagePixelSpace(ZImage):
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.ZImagePixelSpace(self, device=device)
|
||||
|
||||
class PixelDiTT2I(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "pixeldit_t2i",
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 4.0, # 1024px stage 3 default; 2.0 for 512px
|
||||
}
|
||||
|
||||
latent_format = latent_formats.PixelDiTPixel
|
||||
memory_usage_factor = 0.04
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.PixelDiTT2I(self, device=device)
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
# pixel_dim from pixel_embedder.proj.weight = (pixel_dim, in_channels); p2 derived per-weight from total // (6 * pixel_dim).
|
||||
pixel_dim = next(v for k, v in state_dict.items() if k.endswith("pixel_embedder.proj.weight")).shape[0]
|
||||
|
||||
out = {}
|
||||
marker = ".adaLN_modulation.0."
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith("_repa_projector") or k.startswith("net_ema."):
|
||||
continue
|
||||
if k.startswith("core."):
|
||||
k = k[len("core."):]
|
||||
elif k.startswith("net."):
|
||||
k = k[len("net."):]
|
||||
if "pixel_blocks." in k and marker in k:
|
||||
# Split into msa (chunks 0-2) and mlp (chunks 3-5) for the two-Linear PiTBlock to reduce peak VRAM
|
||||
p2 = v.shape[0] // (6 * pixel_dim)
|
||||
trail = v.shape[1:] # () for bias, (in_dim,) for weight
|
||||
vv = v.view(p2, 6, pixel_dim, *trail)
|
||||
base, suffix = k.split(marker)
|
||||
out[f"{base}.adaLN_modulation_msa.{suffix}"] = vv[:, 0:3].reshape(3 * p2 * pixel_dim, *trail).contiguous()
|
||||
out[f"{base}.adaLN_modulation_mlp.{suffix}"] = vv[:, 3:6].reshape(3 * p2 * pixel_dim, *trail).contiguous()
|
||||
else:
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(
|
||||
comfy.text_encoders.pixeldit.PixelDiTGemma2Tokenizer,
|
||||
comfy.text_encoders.pixeldit.PixelDiTGemma2TE,
|
||||
)
|
||||
|
||||
class PiD(PixelDiTT2I):
|
||||
unet_config = {
|
||||
"image_model": "pid",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 1.5, # close approximation of the original distill 4 steps [0.999, 0.866, 0.634, 0.342, 0]
|
||||
}
|
||||
|
||||
memory_usage_factor = 0.04
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.PiD(self, device=device)
|
||||
|
||||
class WAN21_T2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -2097,6 +2208,8 @@ models = [
|
||||
CosmosI2VPredict2,
|
||||
ZImagePixelSpace,
|
||||
ZImage,
|
||||
PiD,
|
||||
PixelDiTT2I,
|
||||
Lumina2,
|
||||
WAN22_T2V,
|
||||
WAN21_CausalAR_T2V,
|
||||
@ -2125,6 +2238,7 @@ models = [
|
||||
Omnigen2,
|
||||
QwenImage,
|
||||
Flux2,
|
||||
Lens,
|
||||
Kandinsky5Image,
|
||||
Kandinsky5,
|
||||
Anima,
|
||||
|
||||
600
comfy/text_encoders/gpt_oss.py
Normal file
600
comfy/text_encoders/gpt_oss.py
Normal file
@ -0,0 +1,600 @@
|
||||
"""GPT-OSS text encoder for Lens."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
from comfy import sd1_clip
|
||||
from comfy.ldm.modules.attention import TORCH_HAS_GQA, optimized_attention_for_device
|
||||
from comfy.text_encoders.llama import RMSNorm, apply_rope
|
||||
|
||||
|
||||
@dataclass
|
||||
class GptOss20BConfig:
|
||||
vocab_size: int = 201088
|
||||
hidden_size: int = 2880
|
||||
intermediate_size: int = 2880
|
||||
num_hidden_layers: int = 24
|
||||
num_attention_heads: int = 64
|
||||
num_key_value_heads: int = 8
|
||||
head_dim: int = 64
|
||||
num_local_experts: int = 32
|
||||
num_experts_per_tok: int = 4
|
||||
sliding_window: int = 128
|
||||
original_max_position_embeddings: int = 4096
|
||||
rope_theta: float = 150000.0
|
||||
rope_factor: float = 32.0
|
||||
rope_beta_fast: float = 32.0
|
||||
rope_beta_slow: float = 1.0
|
||||
rope_truncate: bool = False
|
||||
rms_norm_eps: float = 1e-5
|
||||
attention_bias: bool = True
|
||||
layer_types: Optional[List[str]] = None
|
||||
moe_alpha: float = 1.702
|
||||
moe_limit: float = 7.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.layer_types is None:
|
||||
self.layer_types = [
|
||||
"sliding_attention" if (i + 1) % 2 else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
|
||||
|
||||
def _yarn_inv_freq(head_dim: int, base: float, factor: float, beta_fast: float, beta_slow: float,
|
||||
original_max_position_embeddings: int, truncate: bool, device=None) -> tuple[torch.Tensor, float]:
|
||||
"""YARN inv_freq + attention scaling (matches transformers)."""
|
||||
dim = head_dim
|
||||
|
||||
def find_correction_dim(num_rotations: float) -> float:
|
||||
return (dim * math.log(original_max_position_embeddings / (num_rotations * 2 * math.pi))) / (
|
||||
2 * math.log(base)
|
||||
)
|
||||
|
||||
def find_correction_range() -> tuple[float, float]:
|
||||
low = find_correction_dim(beta_fast)
|
||||
high = find_correction_dim(beta_slow)
|
||||
if truncate:
|
||||
low = math.floor(low)
|
||||
high = math.ceil(high)
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
def linear_ramp_factor(min_: float, max_: float, n: int) -> torch.Tensor:
|
||||
if min_ == max_:
|
||||
max_ += 0.001
|
||||
linear = (torch.arange(n, dtype=torch.float32, device=device) - min_) / (max_ - min_)
|
||||
return torch.clamp(linear, 0, 1)
|
||||
|
||||
def get_mscale(scale: float) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * math.log(scale) + 1.0
|
||||
|
||||
attention_scaling = get_mscale(factor)
|
||||
|
||||
pos_freqs = base ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (factor * pos_freqs)
|
||||
|
||||
low, high = find_correction_range()
|
||||
extrap_factor = 1 - linear_ramp_factor(low, high, dim // 2)
|
||||
inv_freq = inv_freq_interpolation * (1 - extrap_factor) + inv_freq_extrapolation * extrap_factor
|
||||
return inv_freq, attention_scaling
|
||||
|
||||
|
||||
def _build_freqs_cis(inv_freq: torch.Tensor, attention_scaling: float, position_ids: torch.Tensor, dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
inv_freq_e = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
pos_e = position_ids[:, None, :].float()
|
||||
freqs = (inv_freq_e @ pos_e).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = (emb.cos() * attention_scaling).to(dtype).unsqueeze(1)
|
||||
sin = (emb.sin() * attention_scaling).to(dtype).unsqueeze(1)
|
||||
sin_split = sin.shape[-1] // 2
|
||||
return cos, sin[..., :sin_split], -sin[..., sin_split:]
|
||||
|
||||
|
||||
def _attention_with_sinks(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, sinks: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor], num_heads: int, num_kv_groups: int) -> torch.Tensor:
|
||||
"""Attention with per-head sinks.
|
||||
|
||||
Sinks add a learned term to each row's softmax denominator but contribute
|
||||
nothing to the output. We fake this by appending one zero k/v position and
|
||||
putting the sink logit in the mask at that column.
|
||||
"""
|
||||
|
||||
if num_kv_groups > 1 and not TORCH_HAS_GQA:
|
||||
k = k.repeat_interleave(num_kv_groups, dim=1)
|
||||
v = v.repeat_interleave(num_kv_groups, dim=1)
|
||||
|
||||
B, _, S_q, D = q.shape
|
||||
H_kv = k.shape[1]
|
||||
S_kv = k.shape[-2]
|
||||
|
||||
k = torch.cat([k, k.new_zeros(B, H_kv, 1, D)], dim=-2)
|
||||
v = torch.cat([v, v.new_zeros(B, H_kv, 1, D)], dim=-2)
|
||||
|
||||
sinks_col = sinks.to(q.dtype).view(1, num_heads, 1, 1).expand(B, num_heads, S_q, 1)
|
||||
if attention_mask is not None:
|
||||
mask_left = attention_mask[..., :S_kv].expand(B, num_heads, S_q, S_kv)
|
||||
else:
|
||||
mask_left = q.new_zeros(B, num_heads, S_q, S_kv)
|
||||
mask = torch.cat([mask_left, sinks_col], dim=-1)
|
||||
|
||||
op = optimized_attention_for_device(q.device, mask=True, small_input=True)
|
||||
return op(q, k, v, num_heads, mask=mask, skip_reshape=True, enable_gqa=True)
|
||||
|
||||
|
||||
class GptOssAttention(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
self.layer_type = config.layer_types[layer_idx]
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.num_kv_groups = self.num_heads // self.num_kv_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.hidden_size = config.hidden_size
|
||||
self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
|
||||
|
||||
bias = config.attention_bias
|
||||
self.q_proj = ops.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=bias, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=bias, device=device, dtype=dtype)
|
||||
self.sinks = nn.Parameter(torch.empty(self.num_heads, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], freqs_cis) -> torch.Tensor:
|
||||
B, S, _ = hidden_states.shape
|
||||
|
||||
q = self.q_proj(hidden_states).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = self.k_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
v = self.v_proj(hidden_states).view(B, S, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
q, k = apply_rope(q, k, freqs_cis)
|
||||
|
||||
out = _attention_with_sinks(q, k, v, self.sinks, attention_mask, self.num_heads, self.num_kv_groups)
|
||||
return self.o_proj(out)
|
||||
|
||||
|
||||
# Mixture of Experts
|
||||
|
||||
class GptOssTopKRouter(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.num_experts = config.num_local_experts
|
||||
self.weight = nn.Parameter(torch.empty(config.num_local_experts, config.hidden_size, device=device, dtype=dtype))
|
||||
self.bias = nn.Parameter(torch.empty(config.num_local_experts, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
weight = comfy.ops.cast_to_input(self.weight, hidden_states, copy=False)
|
||||
bias = comfy.ops.cast_to_input(self.bias, hidden_states, copy=False)
|
||||
logits = F.linear(hidden_states, weight, bias)
|
||||
top_vals, top_idx = torch.topk(logits, self.top_k, dim=-1)
|
||||
# Softmax over top-k slice only
|
||||
scores = F.softmax(top_vals, dim=-1, dtype=top_vals.dtype)
|
||||
return scores, top_idx
|
||||
|
||||
|
||||
class GptOssExperts(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.num_experts = config.num_local_experts
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.alpha = config.moe_alpha
|
||||
self.limit = config.moe_limit
|
||||
|
||||
E = self.num_experts
|
||||
H = self.hidden_size
|
||||
I = self.intermediate_size
|
||||
|
||||
self.gate_up_proj = ops.MoEExperts(num_experts=E, in_features=H, out_features=2 * I, bias=True, device=device, dtype=dtype)
|
||||
self.down_proj = ops.MoEExperts(num_experts=E, in_features=I, out_features=H, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def _apply_gate(self, gate_up: torch.Tensor) -> torch.Tensor:
|
||||
gate = gate_up[..., ::2]
|
||||
up = gate_up[..., 1::2]
|
||||
gate = gate.clamp(max=self.limit)
|
||||
up = up.clamp(min=-self.limit, max=self.limit)
|
||||
glu = gate * torch.sigmoid(gate * self.alpha)
|
||||
return torch.addcmul(glu, up, glu)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_indices: torch.Tensor, routing_weights: torch.Tensor) -> torch.Tensor:
|
||||
N = hidden_states.shape[0]
|
||||
top_k = router_indices.shape[-1]
|
||||
H = hidden_states.shape[-1]
|
||||
|
||||
per_pair = torch.zeros((N * top_k, H), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
|
||||
expert_mask = F.one_hot(router_indices, num_classes=self.num_experts).permute(2, 1, 0)
|
||||
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
|
||||
|
||||
with self.gate_up_proj.bank_resident(hidden_states) as gate_up_bank, \
|
||||
self.down_proj.bank_resident(hidden_states) as down_bank:
|
||||
for ei in expert_hit:
|
||||
expert_idx = int(ei.item())
|
||||
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
|
||||
current = hidden_states[token_idx]
|
||||
|
||||
gate_up = gate_up_bank.expert_linear(current, expert_idx)
|
||||
gated = self._apply_gate(gate_up)
|
||||
expert_out = down_bank.expert_linear(gated, expert_idx)
|
||||
|
||||
weighted = expert_out * routing_weights[token_idx, top_k_pos, None]
|
||||
|
||||
flat_idx = token_idx * top_k + top_k_pos
|
||||
per_pair[flat_idx] = weighted.to(per_pair.dtype)
|
||||
|
||||
return per_pair.view(N, top_k, H).sum(dim=1)
|
||||
|
||||
|
||||
class GptOssMLP(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.router = GptOssTopKRouter(config, device=device, dtype=dtype)
|
||||
self.experts = GptOssExperts(config, device=device, dtype=dtype, ops=ops)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
B, S, H = hidden_states.shape
|
||||
flat = hidden_states.reshape(-1, H)
|
||||
scores, idx = self.router(flat)
|
||||
out = self.experts(flat, idx, scores)
|
||||
return out.reshape(B, S, H)
|
||||
|
||||
|
||||
# Decoder layer + model
|
||||
|
||||
class GptOssDecoderLayer(nn.Module):
|
||||
def __init__(self, config: GptOss20BConfig, layer_idx: int, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.self_attn = GptOssAttention(config, layer_idx, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = GptOssMLP(config, device=device, dtype=dtype, ops=ops)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.layer_type = config.layer_types[layer_idx]
|
||||
|
||||
def forward(self, x: torch.Tensor, attention_masks: dict[str, Optional[torch.Tensor]], freqs_cis) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x = self.self_attn(x, attention_masks[self.layer_type], freqs_cis)
|
||||
x = residual + x
|
||||
|
||||
residual = x
|
||||
x = self.post_attention_layernorm(x)
|
||||
x = self.mlp(x)
|
||||
x = residual + x
|
||||
return x
|
||||
|
||||
|
||||
def _make_full_causal_mask(B: int, S: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
|
||||
neg = torch.finfo(dtype).min
|
||||
mask = torch.full((S, S), neg, dtype=dtype, device=device).triu_(1)
|
||||
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
|
||||
if key_padding_mask is not None:
|
||||
kp = key_padding_mask.to(dtype=dtype)
|
||||
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
|
||||
mask = mask + kp
|
||||
return mask
|
||||
|
||||
|
||||
def _make_sliding_causal_mask(B: int, S: int, window: int, key_padding_mask: Optional[torch.Tensor], dtype, device):
|
||||
neg = torch.finfo(dtype).min
|
||||
i = torch.arange(S, device=device).view(-1, 1)
|
||||
j = torch.arange(S, device=device).view(1, -1)
|
||||
keep = (j <= i) & (j > i - window)
|
||||
mask = torch.where(keep, torch.zeros((), dtype=dtype, device=device), torch.full((), neg, dtype=dtype, device=device))
|
||||
mask = mask.unsqueeze(0).unsqueeze(0).expand(B, 1, S, S).contiguous()
|
||||
if key_padding_mask is not None:
|
||||
kp = key_padding_mask.to(dtype=dtype)
|
||||
kp = (1.0 - kp).reshape(B, 1, 1, S) * neg
|
||||
mask = mask + kp
|
||||
return mask
|
||||
|
||||
|
||||
class GptOssModel(nn.Module):
|
||||
"""GPT-OSS decoder with multi-layer hidden-state capture + early exit."""
|
||||
|
||||
def __init__(self, config: GptOss20BConfig, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
GptOssDecoderLayer(config, i, device=device, dtype=dtype, ops=ops)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
# Always build on CPU so the buffer survives meta-device construction.
|
||||
inv_freq, attn_scaling = _yarn_inv_freq(
|
||||
head_dim=config.head_dim,
|
||||
base=config.rope_theta,
|
||||
factor=config.rope_factor,
|
||||
beta_fast=config.rope_beta_fast,
|
||||
beta_slow=config.rope_beta_slow,
|
||||
original_max_position_embeddings=config.original_max_position_embeddings,
|
||||
truncate=config.rope_truncate,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
self.register_buffer("rope_inv_freq", inv_freq, persistent=False)
|
||||
self.rope_attention_scaling = float(attn_scaling)
|
||||
|
||||
@property
|
||||
def num_layers(self) -> int:
|
||||
return self.config.num_hidden_layers
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def _build_attention_masks(self, B: int, S: int, attention_mask: Optional[torch.Tensor], dtype: torch.dtype, device,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
full = _make_full_causal_mask(B, S, attention_mask, dtype, device)
|
||||
masks = {"full_attention": full}
|
||||
if any(t == "sliding_attention" for t in self.config.layer_types):
|
||||
masks["sliding_attention"] = _make_sliding_causal_mask(
|
||||
B, S, self.config.sliding_window, attention_mask, dtype, device
|
||||
)
|
||||
return masks
|
||||
|
||||
def forward(self, input_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None,
|
||||
capture_layers: Optional[Sequence[int]] = None) -> dict[str, Any]:
|
||||
B, S = input_ids.shape
|
||||
device = input_ids.device
|
||||
dtype = self.dtype
|
||||
|
||||
hidden_states = self.embed_tokens(input_ids, out_dtype=dtype)
|
||||
|
||||
position_ids = torch.arange(S, device=device).unsqueeze(0).expand(B, -1)
|
||||
freqs_cis = _build_freqs_cis(self.rope_inv_freq.to(device=device), self.rope_attention_scaling, position_ids, dtype)
|
||||
|
||||
attn_masks = self._build_attention_masks(B, S, attention_mask, dtype, device)
|
||||
|
||||
capture_layers = list(capture_layers) if capture_layers else None
|
||||
if capture_layers:
|
||||
max_layer = max(capture_layers)
|
||||
wanted = {idx: pos for pos, idx in enumerate(capture_layers)}
|
||||
captured: List[Optional[torch.Tensor]] = [None] * len(capture_layers)
|
||||
else:
|
||||
max_layer = self.config.num_hidden_layers - 1
|
||||
wanted = None
|
||||
captured = None
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states = layer(hidden_states, attn_masks, freqs_cis)
|
||||
if wanted is not None and i in wanted:
|
||||
captured[wanted[i]] = hidden_states
|
||||
if i >= max_layer:
|
||||
break
|
||||
|
||||
if captured is not None:
|
||||
return {"hidden_states": captured}
|
||||
return {"last_hidden_state": self.norm(hidden_states)}
|
||||
|
||||
|
||||
# Lens chat-template constants (verbatim from the reference pipeline).
|
||||
_LENS_CHAT_SYSTEM = (
|
||||
"Describe the image by detailing the color, shape, size, texture, "
|
||||
"quantity, text, spatial relationships of the objects and background."
|
||||
)
|
||||
_LENS_CHAT_ASSISTANT_THINKING = "Need to generate one image according to the description."
|
||||
LENS_TXT_OFFSET = 97
|
||||
LENS_SELECTED_LAYERS = (5, 11, 17, 23)
|
||||
LENS_MAX_TOKENS = 512
|
||||
|
||||
|
||||
# The reference GPT-OSS Harmony template injects today's date here
|
||||
_LENS_CHAT_DATE = "2026-05-23"
|
||||
|
||||
|
||||
def _lens_render_chat(prompt: str) -> str:
|
||||
"""Render the Lens prompt in GPT-OSS Harmony format."""
|
||||
return (
|
||||
f"<|start|>system<|message|>"
|
||||
f"You are ChatGPT, a large language model trained by OpenAI.\n"
|
||||
f"Knowledge cutoff: 2024-06\n"
|
||||
f"Current date: {_LENS_CHAT_DATE}\n\n"
|
||||
f"Reasoning: medium\n\n"
|
||||
f"# Valid channels: analysis, commentary, final. "
|
||||
f"Channel must be included for every message.<|end|>"
|
||||
f"<|start|>developer<|message|># Instructions\n\n"
|
||||
f"{_LENS_CHAT_SYSTEM}\n\n<|end|>"
|
||||
f"<|start|>user<|message|>{prompt}<|end|>"
|
||||
f"<|start|>assistant<|channel|>analysis<|message|>"
|
||||
f"{_LENS_CHAT_ASSISTANT_THINKING}<|end|>"
|
||||
f"<|start|>assistant<|channel|>final<|message|>"
|
||||
)
|
||||
|
||||
|
||||
# GPT-OSS-20B fixed token IDs (from the tokenizer's added-tokens table).
|
||||
_LENS_PAD_TOKEN_ID = 199999 # <|endoftext|>
|
||||
|
||||
|
||||
class _GptOssRawTokenizer:
|
||||
"""Raw ``tokenizers.Tokenizer`` wrapper.
|
||||
|
||||
The tokenizer JSON ships as a byte tensor inside the encoder checkpoint
|
||||
(``tokenizer_json`` key) rather than as a committed file. Extracted
|
||||
it in ``sd.py`` and passes it here via ``tokenizer_data``.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer_json_bytes=None, **kwargs):
|
||||
from tokenizers import Tokenizer
|
||||
if isinstance(tokenizer_json_bytes, torch.Tensor):
|
||||
tokenizer_json_bytes = bytes(tokenizer_json_bytes.tolist())
|
||||
if tokenizer_json_bytes is None:
|
||||
raise ValueError(
|
||||
"Lens tokenizer requires the ``tokenizer_json`` byte tensor in the "
|
||||
"encoder state dict. Re-bundle the encoder via bundle_te.py so it "
|
||||
"embeds the tokenizer."
|
||||
)
|
||||
self.tokenizer = Tokenizer.from_str(tokenizer_json_bytes.decode("utf-8"))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, tokenizer_data, **kwargs):
|
||||
return cls(tokenizer_json_bytes=tokenizer_data, **kwargs)
|
||||
|
||||
def __call__(self, text):
|
||||
return {"input_ids": self.tokenizer.encode(text, add_special_tokens=False).ids}
|
||||
|
||||
def get_vocab(self):
|
||||
return self.tokenizer.get_vocab()
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
return [self.tokenizer.token_to_id(t) for t in tokens]
|
||||
|
||||
def decode(self, ids, **kwargs):
|
||||
return self.tokenizer.decode(ids, skip_special_tokens=kwargs.get("skip_special_tokens", False))
|
||||
|
||||
|
||||
class LensGptOssTokenizer(sd1_clip.SDTokenizer):
|
||||
tokenizer_json_data = None
|
||||
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_json = tokenizer_data.get("tokenizer_json", None)
|
||||
self.tokenizer_json_data = tokenizer_json
|
||||
super().__init__(
|
||||
tokenizer_json,
|
||||
embedding_directory=embedding_directory,
|
||||
pad_with_end=False,
|
||||
embedding_size=2880,
|
||||
embedding_key="gpt_oss",
|
||||
tokenizer_class=_GptOssRawTokenizer,
|
||||
has_start_token=False,
|
||||
has_end_token=False,
|
||||
pad_to_max_length=False,
|
||||
max_length=99999999,
|
||||
min_length=1,
|
||||
pad_left=False,
|
||||
disable_weights=True,
|
||||
tokenizer_data=tokenizer_data,
|
||||
)
|
||||
self.pad_token_id = _LENS_PAD_TOKEN_ID
|
||||
|
||||
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
|
||||
# Empty prompt -> empty list; encode_token_weights returns zeros (uncond).
|
||||
if not text or not text.strip():
|
||||
return [[]]
|
||||
rendered = _lens_render_chat(text)
|
||||
ids = self.tokenizer(rendered)["input_ids"]
|
||||
if len(ids) > LENS_MAX_TOKENS:
|
||||
ids = ids[:LENS_MAX_TOKENS]
|
||||
return [[(int(t), 1.0) for t in ids]]
|
||||
|
||||
def state_dict(self):
|
||||
if self.tokenizer_json_data is not None:
|
||||
return {"tokenizer_json": self.tokenizer_json_data}
|
||||
return {}
|
||||
|
||||
|
||||
class LensTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(
|
||||
embedding_directory=embedding_directory,
|
||||
tokenizer_data=tokenizer_data,
|
||||
name="gpt_oss",
|
||||
tokenizer=LensGptOssTokenizer,
|
||||
)
|
||||
|
||||
|
||||
class LensGptOssClipModel(nn.Module):
|
||||
"""SDClipModel-shaped Lens GPT-OSS encoder (multi-layer feature extractor)."""
|
||||
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs):
|
||||
super().__init__()
|
||||
model_options = dict(model_options or {})
|
||||
|
||||
operations = model_options.get("custom_operations")
|
||||
if operations is None:
|
||||
quant_config = model_options.get("quantization_metadata") or {}
|
||||
operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
|
||||
self.operations = operations
|
||||
|
||||
cfg_overrides = model_options.get("gpt_oss_config", {})
|
||||
self.config = GptOss20BConfig(**cfg_overrides)
|
||||
self.selected_layers = tuple(model_options.get("selected_layers", LENS_SELECTED_LAYERS))
|
||||
self.txt_offset = int(model_options.get("txt_offset", LENS_TXT_OFFSET))
|
||||
|
||||
self.transformer = GptOssModel(self.config, device=device, dtype=dtype, ops=operations)
|
||||
self.num_layers = self.config.num_hidden_layers
|
||||
self.dtype = dtype
|
||||
self.execution_device = None
|
||||
self._pad_token_id = _LENS_PAD_TOKEN_ID
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.execution_device = options.get("execution_device", self.execution_device)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.execution_device = None
|
||||
|
||||
def _gather_tokens(self, token_weight_pairs):
|
||||
ids_list = [[int(t[0]) for t in batch] for batch in token_weight_pairs]
|
||||
pad_id = self._pad_token_id
|
||||
max_len = max(len(x) for x in ids_list)
|
||||
device = self.execution_device
|
||||
ids = torch.full((len(ids_list), max_len), pad_id, dtype=torch.long, device=device)
|
||||
mask = torch.zeros((len(ids_list), max_len), dtype=torch.long, device=device)
|
||||
for i, x in enumerate(ids_list):
|
||||
ids[i, : len(x)] = torch.tensor(x, dtype=torch.long, device=device)
|
||||
mask[i, : len(x)] = 1
|
||||
return ids, mask
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
# Empty negative: emit zero-length features + zero mask
|
||||
if all(len(batch) == 0 for batch in token_weight_pairs):
|
||||
device = self.execution_device
|
||||
B = len(token_weight_pairs)
|
||||
L = len(self.selected_layers)
|
||||
H = self.config.hidden_size
|
||||
flat = torch.zeros(B, 0, L * H, dtype=self.dtype, device=device)
|
||||
mask = torch.zeros(B, 0, dtype=torch.long, device=device)
|
||||
return flat, None, {"attention_mask": mask, "num_layers_stacked": L}
|
||||
|
||||
input_ids, attn_mask = self._gather_tokens(token_weight_pairs)
|
||||
out = self.transformer(input_ids, attention_mask=attn_mask, capture_layers=self.selected_layers)
|
||||
layers = out["hidden_states"] # list of L × [B, S, H]
|
||||
stacked = torch.stack(layers, dim=2) # [B, S, L, H]
|
||||
|
||||
offset = self.txt_offset
|
||||
if stacked.shape[1] > offset:
|
||||
stacked = stacked[:, offset:].contiguous()
|
||||
mask_trim = attn_mask[:, offset:]
|
||||
else:
|
||||
stacked = stacked[:, :0]
|
||||
mask_trim = attn_mask[:, :0]
|
||||
|
||||
B, S, L, H = stacked.shape
|
||||
flat = stacked.reshape(B, S, L * H)
|
||||
extra = {"attention_mask": mask_trim, "num_layers_stacked": L}
|
||||
return flat, None, extra
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.transformer.load_state_dict(sd, strict=False, assign=True)
|
||||
|
||||
|
||||
class LensTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||
super().__init__(device=device, dtype=dtype, name="gpt_oss", clip_model=LensGptOssClipModel, model_options=model_options or {})
|
||||
|
||||
|
||||
def lens_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class LensTEModel_(LensTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||
mo = dict(model_options or {})
|
||||
if llama_quantization_metadata is not None:
|
||||
mo["quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype is None and dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=mo)
|
||||
|
||||
return LensTEModel_
|
||||
104
comfy/text_encoders/pixeldit.py
Normal file
104
comfy/text_encoders/pixeldit.py
Normal file
@ -0,0 +1,104 @@
|
||||
import torch
|
||||
|
||||
from comfy import sd1_clip
|
||||
from .lumina2 import Gemma2BTokenizer, LuminaModel
|
||||
import comfy.text_encoders.llama
|
||||
|
||||
|
||||
class PixelDiTGemma2_2BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
super().__init__(
|
||||
device=device, layer=layer, layer_idx=layer_idx,
|
||||
textmodel_json_config={}, dtype=dtype,
|
||||
special_tokens={"start": 2, "pad": 0},
|
||||
layer_norm_hidden_state=False,
|
||||
model_class=comfy.text_encoders.llama.Gemma2_2B,
|
||||
enable_attention_masks=attention_mask,
|
||||
return_attention_masks=attention_mask,
|
||||
model_options=model_options,
|
||||
)
|
||||
|
||||
|
||||
_PIXELDIT_CHI_PROMPT = (
|
||||
'Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions '
|
||||
"suitable for image generation. Evaluate the level of detail in the user prompt:\n"
|
||||
"- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, "
|
||||
"and spatial relationships to create vivid and concrete scenes.\n"
|
||||
"- If the prompt is already detailed, refine and enhance the existing details slightly without "
|
||||
"overcomplicating.\n"
|
||||
"Here are examples of how to transform or refine prompts:\n"
|
||||
"- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, "
|
||||
"sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.\n"
|
||||
"- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring "
|
||||
"glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus "
|
||||
"passing by towering glass skyscrapers.\n"
|
||||
"Please generate only the enhanced description for the prompt below and avoid including any "
|
||||
"additional commentary or evaluations:\n"
|
||||
"User Prompt: "
|
||||
)
|
||||
|
||||
_PIXELDIT_MAX_LENGTH = 300
|
||||
_PIXELDIT_CHI_PROMPT_DETECT_PREFIX = 'Given a user prompt, generate an "Enhanced prompt"'
|
||||
|
||||
|
||||
class PixelDiTGemma2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||
if tokenizer_data is None:
|
||||
tokenizer_data = {}
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
|
||||
name="gemma2_2b", tokenizer=Gemma2BTokenizer)
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
|
||||
if not text.strip():
|
||||
return super().tokenize_with_weights("", return_word_ids=return_word_ids, disable_weights=True, min_length=_PIXELDIT_MAX_LENGTH)
|
||||
|
||||
chi_token_count = len(self.gemma2_2b.tokenizer(_PIXELDIT_CHI_PROMPT)["input_ids"])
|
||||
combined = text if text.startswith(_PIXELDIT_CHI_PROMPT_DETECT_PREFIX) else _PIXELDIT_CHI_PROMPT + text
|
||||
max_length_all = chi_token_count + _PIXELDIT_MAX_LENGTH - 2
|
||||
out = super().tokenize_with_weights(combined, return_word_ids=return_word_ids,
|
||||
disable_weights=True, min_length=max_length_all)
|
||||
out["gemma2_2b"] = [out["gemma2_2b"][0][:max_length_all]]
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.gemma2_2b.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return self.gemma2_2b.state_dict()
|
||||
|
||||
|
||||
class PixelDiTGemma2TE(LuminaModel):
|
||||
# PixelDiT's select_index: keep BOS + last 299 embeddings of the padded sequence.
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="gemma2_2b",
|
||||
clip_model=PixelDiTGemma2_2BModel, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
result = super().encode_token_weights(token_weight_pairs)
|
||||
cond, pooled = result[0], result[1]
|
||||
extra = result[2] if len(result) > 2 else None
|
||||
if cond.shape[1] > _PIXELDIT_MAX_LENGTH:
|
||||
cond = torch.cat([cond[:, :1], cond[:, -(_PIXELDIT_MAX_LENGTH - 1):]], dim=1)
|
||||
if extra is not None and "attention_mask" in extra:
|
||||
am = extra["attention_mask"]
|
||||
extra["attention_mask"] = torch.cat([am[..., :1], am[..., -(_PIXELDIT_MAX_LENGTH - 1):]], dim=-1)
|
||||
if extra is not None:
|
||||
return cond, pooled, extra
|
||||
return cond, pooled
|
||||
|
||||
|
||||
def pixeldit_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class PixelDiTTE_(PixelDiTGemma2TE):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return PixelDiTTE_
|
||||
@ -762,10 +762,17 @@ class Accumulation(ComfyTypeIO):
|
||||
@comfytype(io_type="LOAD3D_CAMERA")
|
||||
class Load3DCamera(ComfyTypeIO):
|
||||
class CameraInfo(TypedDict):
|
||||
position: dict[str, float | int]
|
||||
target: dict[str, float | int]
|
||||
zoom: int
|
||||
cameraType: str
|
||||
# Coordinate system: right-handed, Y-up, camera looks down -Z
|
||||
position: dict[str, float | int] # scene units
|
||||
target: dict[str, float | int] # scene units; OrbitControls focus point
|
||||
zoom: float | int # dimensionless, 1 = 100%
|
||||
cameraType: str # 'perspective' | 'orthographic'
|
||||
quaternion: NotRequired[dict[str, float | int]] # normalized, dimensionless; camera world rotation
|
||||
fov: NotRequired[float | int] # degrees, vertical FOV (perspective only)
|
||||
aspect: NotRequired[float | int] # width / height (perspective only)
|
||||
near: NotRequired[float | int] # scene units
|
||||
far: NotRequired[float | int] # scene units
|
||||
frustum: NotRequired[dict[str, float | int]] # orthographic only: {left, right, top, bottom} in scene units
|
||||
|
||||
Type = CameraInfo
|
||||
|
||||
|
||||
32
comfy_api_nodes/apis/beeble.py
Normal file
32
comfy_api_nodes/apis/beeble.py
Normal file
@ -0,0 +1,32 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CreateSwitchXRequest(BaseModel):
|
||||
generation_type: str = Field(...)
|
||||
source_uri: str = Field(...)
|
||||
alpha_mode: str = Field(...)
|
||||
prompt: str | None = Field(None, max_length=2000)
|
||||
reference_image_uri: str | None = Field(None)
|
||||
alpha_uri: str | None = Field(None)
|
||||
max_resolution: int = Field(1080)
|
||||
callback_url: str | None = Field(None)
|
||||
idempotency_key: str | None = Field(None, max_length=256, min_length=1)
|
||||
|
||||
|
||||
class SwitchXOutputUrls(BaseModel):
|
||||
render: str | None = Field(None)
|
||||
source: str | None = Field(None)
|
||||
alpha: str | None = Field(None)
|
||||
|
||||
|
||||
class SwitchXStatusResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
status: str = Field(...)
|
||||
progress: int | None = Field(None)
|
||||
generation_type: str | None = Field(None)
|
||||
alpha_mode: str | None = Field(None)
|
||||
output: SwitchXOutputUrls | None = Field(None)
|
||||
error: str | None = Field(None)
|
||||
created_at: str | None = Field(None)
|
||||
modified_at: str | None = Field(None)
|
||||
completed_at: str | None = Field(None)
|
||||
@ -158,8 +158,9 @@ class SeedanceCreateAssetResponse(BaseModel):
|
||||
|
||||
|
||||
class SeedanceVirtualLibraryCreateAssetRequest(BaseModel):
|
||||
url: str = Field(..., description="Publicly accessible URL of the image asset to upload.")
|
||||
url: str = Field(..., description="Publicly accessible URL of the asset to upload.")
|
||||
hash: str = Field(..., description="Dedup key. Re-submitting the same hash returns the existing asset id.")
|
||||
asset_type: str | None = Field(None, description="BytePlus asset type. Defaults to Image server-side when omitted.")
|
||||
|
||||
|
||||
# Dollars per 1K tokens, keyed by (model_id, has_video_input).
|
||||
|
||||
46
comfy_api_nodes/apis/krea.py
Normal file
46
comfy_api_nodes/apis/krea.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""Pydantic models for the Krea image-generation API."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class KreaMoodboard(BaseModel):
|
||||
id: str = Field(...)
|
||||
strength: float = Field(default=0.35, ge=-0.5, le=1.5)
|
||||
|
||||
|
||||
class KreaImageStyleReference(BaseModel):
|
||||
strength: float = Field(..., ge=-2.0, le=2.0)
|
||||
url: str | None = Field(default=None)
|
||||
|
||||
|
||||
class KreaGenerateImageRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
aspect_ratio: str = Field(...)
|
||||
resolution: str = Field(...)
|
||||
seed: int | None = Field(default=None)
|
||||
creativity: str = Field(default="medium")
|
||||
moodboards: list[KreaMoodboard] | None = Field(default=None)
|
||||
image_style_references: list[KreaImageStyleReference] | None = Field(default=None)
|
||||
|
||||
|
||||
class KreaJobResult(BaseModel):
|
||||
urls: list[str] | None = Field(default=None)
|
||||
style_id: str | None = Field(default=None)
|
||||
|
||||
|
||||
class KreaJob(BaseModel):
|
||||
job_id: str = Field(...)
|
||||
status: str = Field(...)
|
||||
created_at: str = Field(...)
|
||||
completed_at: str | None = Field(default=None)
|
||||
result: KreaJobResult | None = Field(default=None)
|
||||
|
||||
|
||||
class KreaAssetResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
image_url: str = Field(...)
|
||||
uploaded_at: str = Field(...)
|
||||
width: float | None = Field(default=None)
|
||||
height: float | None = Field(default=None)
|
||||
size_bytes: float | None = Field(default=None)
|
||||
mime_type: str | None = Field(default=None)
|
||||
@ -155,7 +155,7 @@ class ClaudeNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ClaudeNode",
|
||||
display_name="Anthropic Claude",
|
||||
category="api node/text/Anthropic",
|
||||
category="text/partner/Anthropic",
|
||||
essentials_category="Text Generation",
|
||||
description="Generate text responses with Anthropic's Claude models. "
|
||||
"Provide a text prompt and optionally one or more images for multimodal context.",
|
||||
|
||||
404
comfy_api_nodes/nodes_beeble.py
Normal file
404
comfy_api_nodes/nodes_beeble.py
Normal file
@ -0,0 +1,404 @@
|
||||
from fractions import Fraction
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl, Types
|
||||
from comfy_api_nodes.apis.beeble import (
|
||||
CreateSwitchXRequest,
|
||||
SwitchXStatusResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
bytesio_to_image_tensor,
|
||||
convert_mask_to_image,
|
||||
download_url_as_bytesio,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
downscale_image_tensor,
|
||||
downscale_video_to_max_pixels,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_image_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
validate_video_frame_count,
|
||||
)
|
||||
|
||||
_MAX_PIXELS = 2_770_000
|
||||
_MAX_FRAMES = 240
|
||||
_MAX_PROMPT_LEN = 2000
|
||||
|
||||
|
||||
def _validate_inputs(prompt: str | None, reference_image: Input.Image | None) -> str | None:
|
||||
"""Beeble requires at least one of prompt or reference_image. Returns the cleaned prompt."""
|
||||
cleaned = prompt.strip() if prompt else ""
|
||||
if not cleaned and reference_image is None:
|
||||
raise ValueError("At least one of 'prompt' or 'reference_image' must be provided.")
|
||||
if cleaned:
|
||||
validate_string(cleaned, strip_whitespace=False, max_length=_MAX_PROMPT_LEN)
|
||||
return cleaned or None
|
||||
|
||||
|
||||
async def _upload_mask_as_image(
|
||||
cls: type[IO.ComfyNode],
|
||||
mask: Input.Image,
|
||||
*,
|
||||
wait_label: str,
|
||||
) -> str:
|
||||
"""Encode a single-frame MASK (H, W) or (1, H, W) as a PNG and upload."""
|
||||
if mask.dim() == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
image = convert_mask_to_image(mask[:1])
|
||||
return await upload_image_to_comfyapi(
|
||||
cls,
|
||||
image,
|
||||
mime_type="image/png",
|
||||
wait_label=wait_label,
|
||||
total_pixels=_MAX_PIXELS,
|
||||
)
|
||||
|
||||
|
||||
async def _upload_mask_batch_as_video(
|
||||
cls: type[IO.ComfyNode],
|
||||
mask: Input.Image,
|
||||
*,
|
||||
frame_rate: Fraction,
|
||||
source_frame_count: int,
|
||||
wait_label: str,
|
||||
) -> str:
|
||||
"""Encode a MASK batch (N, H, W) as a grayscale H.264 MP4 at frame_rate and upload.
|
||||
|
||||
The matte is always downscaled to the pixel budget so it stays within Beeble's limit and
|
||||
keeps the same dimensions as the (similarly downscaled) source — both use the same algorithm
|
||||
from the same starting dimensions, and downscaling is a no-op when already within budget.
|
||||
"""
|
||||
if mask.dim() == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[0] != source_frame_count:
|
||||
raise ValueError(
|
||||
f"Custom alpha video frame count ({mask.shape[0]}) does not match the "
|
||||
f"source video frame count ({source_frame_count}). The Beeble API requires "
|
||||
"one mask per source frame."
|
||||
)
|
||||
images = downscale_image_tensor(convert_mask_to_image(mask), _MAX_PIXELS)
|
||||
alpha_video = InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=None, frame_rate=frame_rate))
|
||||
return await upload_video_to_comfyapi(cls, alpha_video, wait_label=wait_label)
|
||||
|
||||
|
||||
def _alpha_mode_input(*, video: bool) -> IO.DynamicCombo.Input:
|
||||
"""Build the alpha_mode DynamicCombo with mode-specific extra inputs."""
|
||||
select_keyframe_tooltip = (
|
||||
"First-frame keyframe mask. Beeble propagates this across the video." if video else "Grayscale keyframe mask."
|
||||
)
|
||||
custom_tooltip = (
|
||||
"Per-frame grayscale mask covering the entire video. "
|
||||
"Must have the same frame count as the source. "
|
||||
"Connect a MASK output from SAM3_TrackToMask or similar."
|
||||
if video
|
||||
else "Grayscale mask to apply."
|
||||
)
|
||||
return IO.DynamicCombo.Input(
|
||||
"alpha_mode",
|
||||
tooltip=(
|
||||
"Controls how SwitchX decides what to keep vs. regenerate. "
|
||||
"'auto' isolates the main subject automatically. "
|
||||
"'fill' regenerates the entire frame while preserving geometry. "
|
||||
"'select' propagates a first-frame keyframe across the clip. "
|
||||
"'custom' uses a per-frame alpha matte you provide."
|
||||
),
|
||||
options=[
|
||||
IO.DynamicCombo.Option("auto", []),
|
||||
IO.DynamicCombo.Option("fill", []),
|
||||
IO.DynamicCombo.Option(
|
||||
"select",
|
||||
[IO.Mask.Input("alpha_keyframe", tooltip=select_keyframe_tooltip)],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"custom",
|
||||
[IO.Mask.Input("alpha_mask", tooltip=custom_tooltip)],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _common_inputs(*, source: IO.Input, video: bool) -> list[IO.Input]:
|
||||
return [
|
||||
source,
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip=(
|
||||
"Text description of the desired output (max 2000 chars). "
|
||||
"At least one of 'prompt' or 'reference_image' is required."
|
||||
),
|
||||
),
|
||||
IO.Image.Input(
|
||||
"reference_image",
|
||||
optional=True,
|
||||
tooltip=(
|
||||
"Reference image whose look (background, lighting, costume) the result "
|
||||
"should adopt. At least one of 'reference_image' or 'prompt' is required."
|
||||
),
|
||||
),
|
||||
_alpha_mode_input(video=video),
|
||||
IO.Combo.Input(
|
||||
"max_resolution",
|
||||
options=["1080p", "720p"],
|
||||
default="1080p",
|
||||
tooltip="Maximum output resolution.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip=(
|
||||
"Seed controls whether the node should re-run; " "results are non-deterministic regardless of seed."
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def _submit_and_poll(
|
||||
cls: type[IO.ComfyNode],
|
||||
request: CreateSwitchXRequest,
|
||||
) -> SwitchXStatusResponse:
|
||||
initial = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/beeble/v1/switchx/generations", method="POST"),
|
||||
response_model=SwitchXStatusResponse,
|
||||
data=request,
|
||||
)
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/beeble/v1/switchx/generations/{initial.id}"),
|
||||
response_model=SwitchXStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
|
||||
|
||||
def _require_output_url(response: SwitchXStatusResponse, name: str) -> str:
|
||||
if response.output is None or getattr(response.output, name) is None:
|
||||
raise RuntimeError(f"Beeble job {response.id} completed without a {name!r} output URL.")
|
||||
return getattr(response.output, name)
|
||||
|
||||
|
||||
def _alpha_url(response: SwitchXStatusResponse, mode: str) -> str | None:
|
||||
"""URL of the alpha matte, or None when the mode produces no separate matte.
|
||||
|
||||
'fill' selects the whole frame, so Beeble writes no alpha asset even though the status
|
||||
response still returns a (dangling) signed URL for it — fetching it 403s with S3
|
||||
AccessDenied. The other three modes ('auto', 'custom', 'select') all produce a real,
|
||||
downloadable matte.
|
||||
"""
|
||||
if mode == "fill" or response.output is None:
|
||||
return None
|
||||
return response.output.alpha
|
||||
|
||||
|
||||
class BeebleSwitchXVideoEdit(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="BeebleSwitchXVideoEdit",
|
||||
display_name="Beeble SwitchX Video Edit",
|
||||
category="video/partner/Beeble",
|
||||
description=(
|
||||
"Edit a video with Beeble SwitchX. Switches anything in the scene (background, "
|
||||
"lighting, costume) while preserving the original subject's pixels and motion. "
|
||||
"Provide a reference image and/or text prompt to describe the new look. "
|
||||
"Max 240 frames, max ~2.77MP per frame."
|
||||
),
|
||||
inputs=_common_inputs(source=IO.Video.Input("video"), video=True),
|
||||
outputs=[
|
||||
IO.Video.Output(display_name="video"),
|
||||
IO.Video.Output(
|
||||
display_name="alpha",
|
||||
tooltip="The alpha matte Beeble used. Empty for 'fill' mode, which has no separate matte.",
|
||||
),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["max_resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$rate := widgets.max_resolution = "1080p" ? 0.429 : 0.143;
|
||||
{"type":"usd","usd": $rate, "format":{"suffix":"/30 frames"}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
prompt: str,
|
||||
alpha_mode: dict,
|
||||
max_resolution: str,
|
||||
seed: int,
|
||||
reference_image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
cleaned_prompt = _validate_inputs(prompt, reference_image)
|
||||
|
||||
validate_video_frame_count(video, max_frame_count=_MAX_FRAMES)
|
||||
video = downscale_video_to_max_pixels(video, _MAX_PIXELS)
|
||||
|
||||
mode = alpha_mode["alpha_mode"]
|
||||
alpha_uri: str | None = None
|
||||
if mode == "select":
|
||||
alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_keyframe"], wait_label="Uploading keyframe")
|
||||
elif mode == "custom":
|
||||
alpha_uri = await _upload_mask_batch_as_video(
|
||||
cls,
|
||||
alpha_mode["alpha_mask"],
|
||||
frame_rate=video.get_frame_rate(),
|
||||
source_frame_count=video.get_frame_count(),
|
||||
wait_label="Uploading alpha video",
|
||||
)
|
||||
|
||||
source_uri = await upload_video_to_comfyapi(cls, video, wait_label="Uploading source")
|
||||
reference_uri: str | None = None
|
||||
if reference_image is not None:
|
||||
reference_uri = await upload_image_to_comfyapi(
|
||||
cls,
|
||||
reference_image,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading reference",
|
||||
total_pixels=_MAX_PIXELS,
|
||||
)
|
||||
|
||||
request = CreateSwitchXRequest(
|
||||
generation_type="video",
|
||||
source_uri=source_uri,
|
||||
alpha_mode=mode,
|
||||
prompt=cleaned_prompt,
|
||||
reference_image_uri=reference_uri,
|
||||
alpha_uri=alpha_uri,
|
||||
max_resolution=1080 if max_resolution == "1080p" else 720,
|
||||
)
|
||||
response = await _submit_and_poll(cls, request)
|
||||
|
||||
render = await download_url_to_video_output(_require_output_url(response, "render"))
|
||||
alpha = None
|
||||
if (alpha_url := _alpha_url(response, mode)) is not None:
|
||||
alpha = await download_url_to_video_output(alpha_url)
|
||||
return IO.NodeOutput(render, alpha)
|
||||
|
||||
|
||||
class BeebleSwitchXImageEdit(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="BeebleSwitchXImageEdit",
|
||||
display_name="Beeble SwitchX Image Edit",
|
||||
category="image/partner/Beeble",
|
||||
description=(
|
||||
"Edit a single image with Beeble SwitchX. Switches anything in the scene "
|
||||
"(background, lighting, costume) while preserving the original subject's pixels. "
|
||||
"Provide a reference image and/or text prompt to describe the new look. "
|
||||
"Max ~2.77MP."
|
||||
),
|
||||
inputs=_common_inputs(source=IO.Image.Input("image"), video=False),
|
||||
outputs=[
|
||||
IO.Image.Output(display_name="image"),
|
||||
IO.Mask.Output(
|
||||
display_name="alpha",
|
||||
tooltip="The alpha matte Beeble used. Empty for 'fill' mode, which has no separate matte.",
|
||||
),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["max_resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$rate := widgets.max_resolution = "1080p" ? 0.429 : 0.143;
|
||||
{"type":"usd","usd": $rate}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
alpha_mode: dict,
|
||||
max_resolution: str,
|
||||
seed: int,
|
||||
reference_image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
cleaned_prompt = _validate_inputs(prompt, reference_image)
|
||||
|
||||
image = downscale_image_tensor(image, _MAX_PIXELS)
|
||||
|
||||
mode = alpha_mode["alpha_mode"]
|
||||
alpha_uri: str | None = None
|
||||
if mode == "select":
|
||||
alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_keyframe"], wait_label="Uploading keyframe")
|
||||
elif mode == "custom":
|
||||
alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_mask"], wait_label="Uploading alpha")
|
||||
|
||||
source_uri = await upload_image_to_comfyapi(
|
||||
cls,
|
||||
image,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading source",
|
||||
total_pixels=None,
|
||||
)
|
||||
reference_uri: str | None = None
|
||||
if reference_image is not None:
|
||||
reference_uri = await upload_image_to_comfyapi(
|
||||
cls,
|
||||
reference_image,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading reference",
|
||||
total_pixels=_MAX_PIXELS,
|
||||
)
|
||||
|
||||
request = CreateSwitchXRequest(
|
||||
generation_type="image",
|
||||
source_uri=source_uri,
|
||||
alpha_mode=mode,
|
||||
prompt=cleaned_prompt,
|
||||
reference_image_uri=reference_uri,
|
||||
alpha_uri=alpha_uri,
|
||||
max_resolution=1080 if max_resolution == "1080p" else 720,
|
||||
)
|
||||
response = await _submit_and_poll(cls, request)
|
||||
|
||||
render = await download_url_to_image_tensor(_require_output_url(response, "render"))
|
||||
alpha_mask = None
|
||||
if (alpha_url := _alpha_url(response, mode)) is not None:
|
||||
alpha_image = bytesio_to_image_tensor(await download_url_as_bytesio(alpha_url), mode="L")
|
||||
alpha_mask = alpha_image.squeeze(-1) if alpha_image.dim() == 4 else alpha_image
|
||||
return IO.NodeOutput(render, alpha_mask)
|
||||
|
||||
|
||||
class BeebleExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
BeebleSwitchXVideoEdit,
|
||||
BeebleSwitchXImageEdit,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> BeebleExtension:
|
||||
return BeebleExtension()
|
||||
@ -42,7 +42,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="FluxProUltraImageNode",
|
||||
display_name="Flux 1.1 [pro] Ultra Image",
|
||||
category="api node/image/BFL",
|
||||
category="image/partner/BFL",
|
||||
description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -160,7 +160,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id=cls.NODE_ID,
|
||||
display_name=cls.DISPLAY_NAME,
|
||||
category="api node/image/BFL",
|
||||
category="image/partner/BFL",
|
||||
description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -282,7 +282,7 @@ class FluxProExpandNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="FluxProExpandNode",
|
||||
display_name="Flux.1 Expand Image",
|
||||
category="api node/image/BFL",
|
||||
category="image/partner/BFL",
|
||||
description="Outpaints image based on prompt.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -419,7 +419,7 @@ class FluxProFillNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="FluxProFillNode",
|
||||
display_name="Flux.1 Fill Image",
|
||||
category="api node/image/BFL",
|
||||
category="image/partner/BFL",
|
||||
description="Inpaints image based on mask and prompt.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -545,7 +545,7 @@ class Flux2ProImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id=cls.NODE_ID,
|
||||
display_name=cls.DISPLAY_NAME,
|
||||
category="api node/image/BFL",
|
||||
category="image/partner/BFL",
|
||||
description="Generates images synchronously based on prompt and resolution.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -716,7 +716,7 @@ class Flux2ImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Flux2ImageNode",
|
||||
display_name="Flux.2 Image",
|
||||
category="api node/image/BFL",
|
||||
category="image/partner/BFL",
|
||||
description="Generate images via Flux.2 [pro] or Flux.2 [max] from a prompt and optional reference images.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
|
||||
@ -31,7 +31,7 @@ class BriaImageEditNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="BriaImageEditNode",
|
||||
display_name="Bria FIBO Image Edit",
|
||||
category="api node/image/Bria",
|
||||
category="image/partner/Bria",
|
||||
description="Edit images using Bria latest model",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["FIBO"]),
|
||||
@ -169,7 +169,7 @@ class BriaRemoveImageBackground(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="BriaRemoveImageBackground",
|
||||
display_name="Bria Remove Image Background",
|
||||
category="api node/image/Bria",
|
||||
category="image/partner/Bria",
|
||||
description="Remove the background from an image using Bria RMBG 2.0.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -245,7 +245,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="BriaRemoveVideoBackground",
|
||||
display_name="Bria Remove Video Background",
|
||||
category="api node/video/Bria",
|
||||
category="video/partner/Bria",
|
||||
description="Remove the background from a video using Bria. ",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
|
||||
@ -2,11 +2,12 @@ import hashlib
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api_nodes.apis.bytedance import (
|
||||
RECOMMENDED_PRESETS,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4,
|
||||
@ -43,6 +44,7 @@ from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
downscale_image_tensor_by_max_side,
|
||||
downscale_video_to_max_pixels,
|
||||
get_number_of_images,
|
||||
image_tensor_pair_to_batch,
|
||||
@ -121,6 +123,14 @@ def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: st
|
||||
)
|
||||
|
||||
|
||||
def _prepare_seedance_image(image: Input.Image) -> Input.Image:
|
||||
"""Auto-downscale a Seedance image input to the per-side limits, then validate it."""
|
||||
validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||
image = downscale_image_tensor_by_max_side(image, max_side=6000)
|
||||
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
||||
return image
|
||||
|
||||
|
||||
async def _resolve_reference_assets(
|
||||
cls: type[IO.ComfyNode],
|
||||
asset_ids: list[str],
|
||||
@ -308,6 +318,26 @@ async def _seedance_virtual_library_upload_image_asset(
|
||||
return f"asset://{create_resp.asset_id}"
|
||||
|
||||
|
||||
async def _seedance_virtual_library_upload_video_asset(
|
||||
cls: type[IO.ComfyNode],
|
||||
video: Input.Video,
|
||||
*,
|
||||
wait_label: str = "Uploading video",
|
||||
) -> str:
|
||||
buf = BytesIO()
|
||||
video.save_to(buf, format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264)
|
||||
video_hash = hashlib.sha256(buf.getbuffer()).hexdigest()
|
||||
public_url = await upload_video_to_comfyapi(cls, video, wait_label=wait_label)
|
||||
create_resp = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/seedance/virtual-library/assets", method="POST"),
|
||||
response_model=SeedanceCreateAssetResponse,
|
||||
data=SeedanceVirtualLibraryCreateAssetRequest(url=public_url, hash=video_hash, asset_type="Video"),
|
||||
)
|
||||
await _wait_for_asset_active(cls, create_resp.asset_id, group_id="virtual-library")
|
||||
return f"asset://{create_resp.asset_id}"
|
||||
|
||||
|
||||
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
|
||||
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
|
||||
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
|
||||
@ -338,7 +368,7 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceImageNode",
|
||||
display_name="ByteDance Image",
|
||||
category="api node/image/ByteDance",
|
||||
category="image/partner/ByteDance",
|
||||
description="Generate images using ByteDance models via api based on prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
|
||||
@ -462,7 +492,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedreamNode",
|
||||
display_name="ByteDance Seedream 4.5 & 5.0",
|
||||
category="api node/image/ByteDance",
|
||||
category="image/partner/ByteDance",
|
||||
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -724,7 +754,7 @@ class ByteDanceSeedreamNodeV2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedreamNodeV2",
|
||||
display_name="ByteDance Seedream 4.5 & 5.0",
|
||||
category="api node/image/ByteDance",
|
||||
category="image/partner/ByteDance",
|
||||
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -890,7 +920,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceTextToVideoNode",
|
||||
display_name="ByteDance Text to Video",
|
||||
category="api node/video/ByteDance",
|
||||
category="video/partner/ByteDance",
|
||||
description="Generate video using ByteDance models via api based on prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -1018,7 +1048,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceImageToVideoNode",
|
||||
display_name="ByteDance Image to Video",
|
||||
category="api node/video/ByteDance",
|
||||
category="video/partner/ByteDance",
|
||||
description="Generate video using ByteDance models via api based on image and prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -1155,7 +1185,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceFirstLastFrameNode",
|
||||
display_name="ByteDance First-Last-Frame to Video",
|
||||
category="api node/video/ByteDance",
|
||||
category="video/partner/ByteDance",
|
||||
description="Generate video using prompt and first and last frames.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -1303,7 +1333,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceImageReferenceNode",
|
||||
display_name="ByteDance Reference Images to Video",
|
||||
category="api node/video/ByteDance",
|
||||
category="video/partner/ByteDance",
|
||||
description="Generate video using prompt and reference images.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -1546,7 +1576,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDance2TextToVideoNode",
|
||||
display_name="ByteDance Seedance 2.0 Text to Video",
|
||||
category="api node/video/ByteDance",
|
||||
category="video/partner/ByteDance",
|
||||
description="Generate video using Seedance 2.0 models based on a text prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1647,7 +1677,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDance2FirstLastFrameNode",
|
||||
display_name="ByteDance Seedance 2.0 First-Last-Frame to Video",
|
||||
category="api node/video/ByteDance",
|
||||
category="video/partner/ByteDance",
|
||||
description="Generate video using Seedance 2.0 from a first frame image and optional last frame image.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1760,6 +1790,11 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
if last_frame is not None and last_frame_asset_id:
|
||||
raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.")
|
||||
|
||||
if first_frame is not None:
|
||||
first_frame = _prepare_seedance_image(first_frame)
|
||||
if last_frame is not None:
|
||||
last_frame = _prepare_seedance_image(last_frame)
|
||||
|
||||
asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a]
|
||||
image_assets: dict[str, str] = {}
|
||||
if asset_ids_to_resolve:
|
||||
@ -1866,7 +1901,7 @@ def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"auto_downscale",
|
||||
default=False,
|
||||
default=True,
|
||||
optional=True,
|
||||
tooltip="Automatically downscale reference videos that exceed the model's pixel budget "
|
||||
"for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.",
|
||||
@ -1909,7 +1944,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDance2ReferenceNode",
|
||||
display_name="ByteDance Seedance 2.0 Reference to Video",
|
||||
category="api node/video/ByteDance",
|
||||
category="video/partner/ByteDance",
|
||||
description="Generate, edit, or extend video using Seedance 2.0 with reference images, "
|
||||
"videos, and audio. Supports multimodal reference, video editing, and video extension.",
|
||||
inputs=[
|
||||
@ -2034,6 +2069,9 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
f"(audios={len(reference_audios)}, audio assets={len(reference_audio_assets)}). Maximum is 3."
|
||||
)
|
||||
|
||||
for key in reference_images:
|
||||
reference_images[key] = _prepare_seedance_image(reference_images[key])
|
||||
|
||||
model_id = SEEDANCE_MODELS[model["model"]]
|
||||
has_video_input = total_videos > 0
|
||||
|
||||
@ -2106,7 +2144,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
content.append(
|
||||
TaskVideoContent(
|
||||
video_url=TaskVideoContentUrl(
|
||||
url=await upload_video_to_comfyapi(
|
||||
url=await _seedance_virtual_library_upload_video_asset(
|
||||
cls,
|
||||
reference_videos[key],
|
||||
wait_label=f"Uploading video {i}",
|
||||
@ -2203,7 +2241,7 @@ class ByteDanceCreateImageAsset(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceCreateImageAsset",
|
||||
display_name="ByteDance Create Image Asset",
|
||||
category="api node/image/ByteDance",
|
||||
category="image/partner/ByteDance",
|
||||
description=(
|
||||
"Create a Seedance 2.0 personal image asset. Uploads the input image and "
|
||||
"registers it in the given asset group. If group_id is empty, runs a real-person "
|
||||
@ -2270,7 +2308,7 @@ class ByteDanceCreateVideoAsset(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceCreateVideoAsset",
|
||||
display_name="ByteDance Create Video Asset",
|
||||
category="api node/video/ByteDance",
|
||||
category="video/partner/ByteDance",
|
||||
description=(
|
||||
"Create a Seedance 2.0 personal video asset. Uploads the input video and "
|
||||
"registers it in the given asset group. If group_id is empty, runs a real-person "
|
||||
|
||||
@ -144,7 +144,7 @@ class ByteDanceSeedNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedNode",
|
||||
display_name="ByteDance Seed",
|
||||
category="api node/text/ByteDance",
|
||||
category="text/partner/ByteDance",
|
||||
essentials_category="Text Generation",
|
||||
description="Generate text responses with ByteDance's Seed 2.0 models. "
|
||||
"Provide a text prompt and optionally one or more images or videos for multimodal context.",
|
||||
|
||||
@ -69,7 +69,7 @@ class ElevenLabsSpeechToText(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsSpeechToText",
|
||||
display_name="ElevenLabs Speech to Text",
|
||||
category="api node/audio/ElevenLabs",
|
||||
category="audio/partner/ElevenLabs",
|
||||
description="Transcribe audio to text. "
|
||||
"Supports automatic language detection, speaker diarization, and audio event tagging.",
|
||||
inputs=[
|
||||
@ -210,7 +210,7 @@ class ElevenLabsVoiceSelector(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsVoiceSelector",
|
||||
display_name="ElevenLabs Voice Selector",
|
||||
category="api node/audio/ElevenLabs",
|
||||
category="audio/partner/ElevenLabs",
|
||||
description="Select a predefined ElevenLabs voice for text-to-speech generation.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -239,7 +239,7 @@ class ElevenLabsTextToSpeech(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsTextToSpeech",
|
||||
display_name="ElevenLabs Text to Speech",
|
||||
category="api node/audio/ElevenLabs",
|
||||
category="audio/partner/ElevenLabs",
|
||||
description="Convert text to speech.",
|
||||
inputs=[
|
||||
IO.Custom(ELEVENLABS_VOICE).Input(
|
||||
@ -414,7 +414,7 @@ class ElevenLabsAudioIsolation(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsAudioIsolation",
|
||||
display_name="ElevenLabs Voice Isolation",
|
||||
category="api node/audio/ElevenLabs",
|
||||
category="audio/partner/ElevenLabs",
|
||||
description="Remove background noise from audio, isolating vocals or speech.",
|
||||
inputs=[
|
||||
IO.Audio.Input(
|
||||
@ -459,7 +459,7 @@ class ElevenLabsTextToSoundEffects(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsTextToSoundEffects",
|
||||
display_name="ElevenLabs Text to Sound Effects",
|
||||
category="api node/audio/ElevenLabs",
|
||||
category="audio/partner/ElevenLabs",
|
||||
description="Generate sound effects from text descriptions.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -555,7 +555,7 @@ class ElevenLabsInstantVoiceClone(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsInstantVoiceClone",
|
||||
display_name="ElevenLabs Instant Voice Clone",
|
||||
category="api node/audio/ElevenLabs",
|
||||
category="audio/partner/ElevenLabs",
|
||||
description="Create a cloned voice from audio samples. "
|
||||
"Provide 1-8 audio recordings of the voice to clone.",
|
||||
inputs=[
|
||||
@ -658,7 +658,7 @@ class ElevenLabsSpeechToSpeech(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsSpeechToSpeech",
|
||||
display_name="ElevenLabs Speech to Speech",
|
||||
category="api node/audio/ElevenLabs",
|
||||
category="audio/partner/ElevenLabs",
|
||||
description="Transform speech from one voice to another while preserving the original content and emotion.",
|
||||
inputs=[
|
||||
IO.Custom(ELEVENLABS_VOICE).Input(
|
||||
@ -793,7 +793,7 @@ class ElevenLabsTextToDialogue(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsTextToDialogue",
|
||||
display_name="ElevenLabs Text to Dialogue",
|
||||
category="api node/audio/ElevenLabs",
|
||||
category="audio/partner/ElevenLabs",
|
||||
description="Generate multi-speaker dialogue from text. Each dialogue entry has its own text and voice.",
|
||||
inputs=[
|
||||
IO.Float.Input(
|
||||
|
||||
@ -300,7 +300,7 @@ class GeminiNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GeminiNode",
|
||||
display_name="Google Gemini",
|
||||
category="api node/text/Gemini",
|
||||
category="text/partner/Gemini",
|
||||
description="Generate text responses with Google's Gemini AI model. "
|
||||
"You can provide multiple types of inputs (text, images, audio, video) "
|
||||
"as context for generating more relevant and meaningful responses.",
|
||||
@ -541,7 +541,7 @@ class GeminiInputFiles(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GeminiInputFiles",
|
||||
display_name="Gemini Input Files",
|
||||
category="api node/text/Gemini",
|
||||
category="text/partner/Gemini",
|
||||
description="Loads and prepares input files to include as inputs for Gemini LLM nodes. "
|
||||
"The files will be read by the Gemini model when generating a response. "
|
||||
"The contents of the text file count toward the token limit. "
|
||||
@ -598,7 +598,7 @@ class GeminiImage(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GeminiImageNode",
|
||||
display_name="Nano Banana (Google Gemini Image)",
|
||||
category="api node/image/Gemini",
|
||||
category="image/partner/Gemini",
|
||||
description="Edit images synchronously via Google API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -731,7 +731,7 @@ class GeminiImage2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GeminiImage2Node",
|
||||
display_name="Nano Banana Pro (Google Gemini Image)",
|
||||
category="api node/image/Gemini",
|
||||
category="image/partner/Gemini",
|
||||
description="Generate or edit images synchronously via Google Vertex API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -869,7 +869,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GeminiNanoBanana2",
|
||||
display_name="Nano Banana 2",
|
||||
category="api node/image/Gemini",
|
||||
category="image/partner/Gemini",
|
||||
description="Generate or edit images synchronously via Google Vertex API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -1085,7 +1085,7 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GeminiNanoBanana2V2",
|
||||
display_name="Nano Banana 2",
|
||||
category="api node/image/Gemini",
|
||||
category="image/partner/Gemini",
|
||||
description="Generate or edit images synchronously via Google Vertex API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
|
||||
@ -49,7 +49,7 @@ class GrokImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokImageNode",
|
||||
display_name="Grok Image",
|
||||
category="api node/image/Grok",
|
||||
category="image/partner/Grok",
|
||||
description="Generate images using Grok based on a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -224,7 +224,7 @@ class GrokImageEditNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokImageEditNode",
|
||||
display_name="Grok Image Edit",
|
||||
category="api node/image/Grok",
|
||||
category="image/partner/Grok",
|
||||
description="Modify an existing image based on a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -366,7 +366,7 @@ class GrokImageEditNodeV2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokImageEditNodeV2",
|
||||
display_name="Grok Image Edit",
|
||||
category="api node/image/Grok",
|
||||
category="image/partner/Grok",
|
||||
description="Modify an existing image based on a text prompt",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -503,7 +503,7 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoNode",
|
||||
display_name="Grok Video",
|
||||
category="api node/video/Grok",
|
||||
category="video/partner/Grok",
|
||||
description="Generate video from a prompt or an image",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
|
||||
@ -615,7 +615,7 @@ class GrokVideoEditNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoEditNode",
|
||||
display_name="Grok Video Edit",
|
||||
category="api node/video/Grok",
|
||||
category="video/partner/Grok",
|
||||
description="Edit an existing video based on a text prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
|
||||
@ -693,7 +693,7 @@ class GrokVideoReferenceNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoReferenceNode",
|
||||
display_name="Grok Reference-to-Video",
|
||||
category="api node/video/Grok",
|
||||
category="video/partner/Grok",
|
||||
description="Generate video guided by reference images as style and content references.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -826,7 +826,7 @@ class GrokVideoExtendNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoExtendNode",
|
||||
display_name="Grok Video Extend",
|
||||
category="api node/video/Grok",
|
||||
category="video/partner/Grok",
|
||||
description="Extend an existing video with a seamless continuation based on a text prompt.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
|
||||
@ -71,7 +71,7 @@ class HitPawGeneralImageEnhance(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HitPawGeneralImageEnhance",
|
||||
display_name="HitPaw General Image Enhance",
|
||||
category="api node/image/HitPaw",
|
||||
category="image/partner/HitPaw",
|
||||
description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. "
|
||||
f"Maximum output: {MAX_MP_GENERATIVE} megapixels.",
|
||||
inputs=[
|
||||
@ -201,7 +201,7 @@ class HitPawVideoEnhance(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HitPawVideoEnhance",
|
||||
display_name="HitPaw Video Enhance",
|
||||
category="api node/video/HitPaw",
|
||||
category="video/partner/HitPaw",
|
||||
description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. "
|
||||
"Prices shown are per second of video.",
|
||||
inputs=[
|
||||
|
||||
@ -123,7 +123,7 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TencentTextToModelNode",
|
||||
display_name="Hunyuan3D: Text to Model",
|
||||
category="api node/3d/Tencent",
|
||||
category="3d/partner/Tencent",
|
||||
essentials_category="3D",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -242,7 +242,7 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TencentImageToModelNode",
|
||||
display_name="Hunyuan3D: Image(s) to Model",
|
||||
category="api node/3d/Tencent",
|
||||
category="3d/partner/Tencent",
|
||||
essentials_category="3D",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -415,7 +415,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TencentModelTo3DUVNode",
|
||||
display_name="Hunyuan3D: Model to UV",
|
||||
category="api node/3d/Tencent",
|
||||
category="3d/partner/Tencent",
|
||||
description="Perform UV unfolding on a 3D model to generate UV texture. "
|
||||
"Input model must have less than 30000 faces.",
|
||||
inputs=[
|
||||
@ -505,7 +505,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Tencent3DTextureEditNode",
|
||||
display_name="Hunyuan3D: 3D Texture Edit",
|
||||
category="api node/3d/Tencent",
|
||||
category="3d/partner/Tencent",
|
||||
description="After inputting the 3D model, perform 3D model texture redrawing.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
@ -594,7 +594,7 @@ class Tencent3DPartNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Tencent3DPartNode",
|
||||
display_name="Hunyuan3D: 3D Part",
|
||||
category="api node/3d/Tencent",
|
||||
category="3d/partner/Tencent",
|
||||
description="Automatically perform component identification and generation based on the model structure.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
@ -666,7 +666,7 @@ class TencentSmartTopologyNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TencentSmartTopologyNode",
|
||||
display_name="Hunyuan3D: Smart Topology",
|
||||
category="api node/3d/Tencent",
|
||||
category="3d/partner/Tencent",
|
||||
description="Perform smart retopology on a 3D model. "
|
||||
"Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.",
|
||||
inputs=[
|
||||
|
||||
@ -234,7 +234,7 @@ class IdeogramV1(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV1",
|
||||
display_name="Ideogram V1",
|
||||
category="api node/image/Ideogram",
|
||||
category="image/partner/Ideogram",
|
||||
description="Generates images using the Ideogram V1 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -360,7 +360,7 @@ class IdeogramV2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV2",
|
||||
display_name="Ideogram V2",
|
||||
category="api node/image/Ideogram",
|
||||
category="image/partner/Ideogram",
|
||||
description="Generates images using the Ideogram V2 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -526,7 +526,7 @@ class IdeogramV3(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV3",
|
||||
display_name="Ideogram V3",
|
||||
category="api node/image/Ideogram",
|
||||
category="image/partner/Ideogram",
|
||||
description="Generates images using the Ideogram V3 model. "
|
||||
"Supports both regular image generation from text prompts and image editing with mask.",
|
||||
inputs=[
|
||||
|
||||
@ -642,7 +642,7 @@ class KlingCameraControls(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingCameraControls",
|
||||
display_name="Kling Camera Controls",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
description="Allows specifying configuration options for Kling Camera Controls and motion control effects.",
|
||||
inputs=[
|
||||
IO.Combo.Input("camera_control_type", options=KlingCameraControlType),
|
||||
@ -762,7 +762,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingTextToVideoNode",
|
||||
display_name="Kling Text to Video",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
description="Kling Text to Video Node",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||
@ -849,7 +849,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProTextToVideoNode",
|
||||
display_name="Kling 3.0 Omni Text to Video",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
description="Use text prompts to generate videos with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
||||
@ -998,7 +998,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProFirstLastFrameNode",
|
||||
display_name="Kling 3.0 Omni First-Last-Frame to Video",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
||||
@ -1205,7 +1205,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProImageToVideoNode",
|
||||
display_name="Kling 3.0 Omni Image to Video",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
description="Use up to 7 reference images to generate a video with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
||||
@ -1374,7 +1374,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProVideoToVideoNode",
|
||||
display_name="Kling 3.0 Omni Video to Video",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
||||
@ -1485,7 +1485,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProEditVideoNode",
|
||||
display_name="Kling 3.0 Omni Edit Video",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
essentials_category="Video Generation",
|
||||
description="Edit an existing video with the latest model from Kling.",
|
||||
inputs=[
|
||||
@ -1593,7 +1593,7 @@ class OmniProImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProImageNode",
|
||||
display_name="Kling 3.0 Omni Image",
|
||||
category="api node/image/Kling",
|
||||
category="image/partner/Kling",
|
||||
description="Create or edit images with the latest model from Kling.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-image-o1"]),
|
||||
@ -1721,7 +1721,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingCameraControlT2VNode",
|
||||
display_name="Kling Text to Video (Camera Control)",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
description="Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||
@ -1783,7 +1783,7 @@ class KlingImage2VideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingImage2VideoNode",
|
||||
display_name="Kling Image(First Frame) to Video",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
inputs=[
|
||||
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||
@ -1882,7 +1882,7 @@ class KlingCameraControlI2VNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingCameraControlI2VNode",
|
||||
display_name="Kling Image to Video (Camera Control)",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
description="Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -1953,7 +1953,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingStartEndFrameNode",
|
||||
display_name="Kling Start-End Frame to Video",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
description="Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -2047,7 +2047,7 @@ class KlingVideoExtendNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingVideoExtendNode",
|
||||
display_name="Kling Video Extend",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
description="Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -2128,7 +2128,7 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingDualCharacterVideoEffectNode",
|
||||
display_name="Kling Dual Character Video Effects",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
description="Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite.",
|
||||
inputs=[
|
||||
IO.Image.Input("image_left", tooltip="Left side image"),
|
||||
@ -2218,7 +2218,7 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingSingleImageVideoEffectNode",
|
||||
display_name="Kling Video Effects",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
description="Achieve different special effects when generating a video based on the effect_scene.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -2291,7 +2291,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingLipSyncAudioToVideoNode",
|
||||
display_name="Kling Lip Sync Video with Audio",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
essentials_category="Video Generation",
|
||||
description="Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.",
|
||||
inputs=[
|
||||
@ -2343,7 +2343,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingLipSyncTextToVideoNode",
|
||||
display_name="Kling Lip Sync Video with Text",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
description="Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
@ -2411,7 +2411,7 @@ class KlingVirtualTryOnNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingVirtualTryOnNode",
|
||||
display_name="Kling Virtual Try On",
|
||||
category="api node/image/Kling",
|
||||
category="image/partner/Kling",
|
||||
description="Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background.",
|
||||
inputs=[
|
||||
IO.Image.Input("human_image"),
|
||||
@ -2478,7 +2478,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingImageGenerationNode",
|
||||
display_name="Kling 3.0 Image",
|
||||
category="api node/image/Kling",
|
||||
category="image/partner/Kling",
|
||||
description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||
@ -2615,7 +2615,7 @@ class TextToVideoWithAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingTextToVideoWithAudio",
|
||||
display_name="Kling 2.6 Text to Video with Audio",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v2-6"]),
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
|
||||
@ -2683,7 +2683,7 @@ class ImageToVideoWithAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingImageToVideoWithAudio",
|
||||
display_name="Kling 2.6 Image(First Frame) to Video with Audio",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v2-6"]),
|
||||
IO.Image.Input("start_frame"),
|
||||
@ -2753,7 +2753,7 @@ class MotionControl(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingMotionControl",
|
||||
display_name="Kling Motion Control",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True),
|
||||
IO.Image.Input("reference_image"),
|
||||
@ -2854,7 +2854,7 @@ class KlingVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingVideoNode",
|
||||
display_name="Kling 3.0 Video",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
description="Generate videos with Kling V3. "
|
||||
"Supports text-to-video and image-to-video with optional storyboard multi-prompt and audio generation.",
|
||||
inputs=[
|
||||
@ -3077,7 +3077,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingFirstLastFrameNode",
|
||||
display_name="Kling 3.0 First-Last-Frame to Video",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
description="Generate videos with Kling V3 using first and last frames.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
@ -3202,7 +3202,7 @@ class KlingAvatarNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingAvatarNode",
|
||||
display_name="Kling Avatar 2.0",
|
||||
category="api node/video/Kling",
|
||||
category="video/partner/Kling",
|
||||
description="Generate broadcast-style digital human videos from a single photo and an audio file.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
|
||||
290
comfy_api_nodes/nodes_krea.py
Normal file
290
comfy_api_nodes/nodes_krea.py
Normal file
@ -0,0 +1,290 @@
|
||||
"""Krea image-generation nodes."""
|
||||
|
||||
import re
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.krea import (
|
||||
KreaAssetResponse,
|
||||
KreaGenerateImageRequest,
|
||||
KreaImageStyleReference,
|
||||
KreaJob,
|
||||
KreaMoodboard,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_bytesio,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
|
||||
class KreaIO:
|
||||
STYLE_REF = "KREA_STYLE_REF"
|
||||
|
||||
|
||||
async def _upload_image_to_krea_assets(cls: type[IO.ComfyNode], image: Input.Image) -> str:
|
||||
"""Upload an image to Krea's /assets endpoint and return the Krea-hosted image URL."""
|
||||
img_io = tensor_to_bytesio(image, total_pixels=2048 * 2048, mime_type="image/png")
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/krea/assets", method="POST"),
|
||||
response_model=KreaAssetResponse,
|
||||
files=[("file", (img_io.name, img_io, "image/png"))],
|
||||
content_type="multipart/form-data",
|
||||
max_retries=1,
|
||||
wait_label="Uploading reference",
|
||||
)
|
||||
return response.image_url
|
||||
|
||||
|
||||
_MODEL_MEDIUM = "Krea 2 Medium"
|
||||
_MODEL_LARGE = "Krea 2 Large"
|
||||
_MODEL_ENDPOINTS: dict[str, str] = {
|
||||
_MODEL_MEDIUM: "/proxy/krea/generate/image/krea/krea-2/medium",
|
||||
_MODEL_LARGE: "/proxy/krea/generate/image/krea/krea-2/large",
|
||||
}
|
||||
|
||||
_ASPECT_RATIOS = ["1:1", "4:3", "3:2", "16:9", "2.35:1", "4:5", "2:3", "9:16"]
|
||||
_RESOLUTIONS = ["1K"]
|
||||
_CREATIVITY_LEVELS = ["raw", "low", "medium", "high"]
|
||||
_KREA_QUEUED_STATUSES = ["backlogged", "queued", "scheduled"]
|
||||
|
||||
_UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$")
|
||||
|
||||
|
||||
def _krea_model_inputs() -> list:
|
||||
"""Nested inputs shared by both Krea 2 Medium and Large under the DynamicCombo."""
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=_ASPECT_RATIOS,
|
||||
tooltip="Output aspect ratio.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=_RESOLUTIONS,
|
||||
tooltip="Resolution scale.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"creativity",
|
||||
options=_CREATIVITY_LEVELS,
|
||||
default="medium",
|
||||
tooltip="Prompt interpretation strength: raw stays closest to the prompt; high is most creative.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"moodboard_id",
|
||||
default="",
|
||||
tooltip="Optional Krea moodboard UUID (e.g. from the Krea website). "
|
||||
"Leave empty to disable. Only one moodboard is supported per request.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"moodboard_strength",
|
||||
default=0.35,
|
||||
min=-0.5,
|
||||
max=1.5,
|
||||
step=0.05,
|
||||
tooltip="Moodboard influence; ignored when moodboard_id is empty.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Custom(KreaIO.STYLE_REF).Input(
|
||||
"style_reference",
|
||||
optional=True,
|
||||
tooltip="Optional chain of style references (max 10) from Krea 2 Style Reference nodes.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class Krea2ImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Krea2ImageNode",
|
||||
display_name="Krea 2 Image",
|
||||
category="image/partner/Krea",
|
||||
description=(
|
||||
"Generate images via Krea 2 — pick Medium (expressive illustrations) or "
|
||||
"Large (expressive photorealism). Supports an optional moodboard and up "
|
||||
"to 10 chained image style references."
|
||||
),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the image.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(_MODEL_MEDIUM, _krea_model_inputs()),
|
||||
IO.DynamicCombo.Option(_MODEL_LARGE, _krea_model_inputs()),
|
||||
],
|
||||
tooltip="Krea 2 Medium is best for expressive illustrations; "
|
||||
"Krea 2 Large is best for expressive photorealism.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Random seed for reproducibility.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model", "model.moodboard_id"],
|
||||
inputs=["model.style_reference"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$isLarge := widgets.model = "krea 2 large";
|
||||
$hasMoodboard := $length($lookup(widgets, "model.moodboard_id")) > 0;
|
||||
$hasStyle := $lookup(inputs, "model.style_reference").connected;
|
||||
$usd := $hasMoodboard
|
||||
? ($isLarge ? 0.07 : 0.04)
|
||||
: ($hasStyle
|
||||
? ($isLarge ? 0.065 : 0.035)
|
||||
: ($isLarge ? 0.06 : 0.03));
|
||||
{"type":"usd","usd": $usd}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1)
|
||||
|
||||
model_choice = model["model"]
|
||||
endpoint_path = _MODEL_ENDPOINTS.get(model_choice)
|
||||
if endpoint_path is None:
|
||||
raise ValueError(f"Unknown Krea 2 model: {model_choice!r}")
|
||||
|
||||
moodboards: list[KreaMoodboard] | None = None
|
||||
mb_id = (model.get("moodboard_id") or "").strip()
|
||||
if mb_id:
|
||||
if not _UUID_RE.match(mb_id):
|
||||
raise ValueError(f"moodboard_id must be a UUID (received {mb_id!r}); copy it from the Krea website.")
|
||||
mb_strength = model.get("moodboard_strength")
|
||||
moodboards = [KreaMoodboard(id=mb_id, strength=0.35 if mb_strength is None else float(mb_strength))]
|
||||
|
||||
style_reference = model.get("style_reference")
|
||||
image_style_references: list[KreaImageStyleReference] | None = None
|
||||
if style_reference:
|
||||
if len(style_reference) > 10:
|
||||
raise ValueError(f"Krea 2 accepts at most 10 image_style_references; received {len(style_reference)}.")
|
||||
image_style_references = [
|
||||
KreaImageStyleReference(url=ref["url"], strength=float(ref["strength"])) for ref in style_reference
|
||||
]
|
||||
initial = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=endpoint_path, method="POST"),
|
||||
response_model=KreaJob,
|
||||
data=KreaGenerateImageRequest(
|
||||
prompt=prompt,
|
||||
aspect_ratio=model["aspect_ratio"],
|
||||
resolution=model["resolution"],
|
||||
seed=seed,
|
||||
creativity=model["creativity"],
|
||||
moodboards=moodboards,
|
||||
image_style_references=image_style_references,
|
||||
),
|
||||
)
|
||||
job = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/krea/jobs/{initial.job_id}", method="GET"),
|
||||
response_model=KreaJob,
|
||||
status_extractor=lambda r: r.status,
|
||||
queued_statuses=_KREA_QUEUED_STATUSES,
|
||||
)
|
||||
if not job.result or not job.result.urls:
|
||||
raise RuntimeError(f"Krea 2 job {job.job_id} completed without any image URLs.")
|
||||
image = await download_url_to_image_tensor(job.result.urls[0])
|
||||
return IO.NodeOutput(image)
|
||||
|
||||
|
||||
class Krea2StyleReferenceNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Krea2StyleReferenceNode",
|
||||
display_name="Krea 2 Style Reference",
|
||||
category="image/partner/Krea",
|
||||
description=(
|
||||
"Add an image style reference to a Krea 2 generation. Chain multiple Krea 2 "
|
||||
"Style Reference nodes (max 10) and feed the final `style_reference` output "
|
||||
"into Krea 2 Image. Each image is uploaded to ComfyAPI storage and passed as URL."
|
||||
),
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Reference image whose style influences the generation.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"strength",
|
||||
default=1.0,
|
||||
min=-2.0,
|
||||
max=2.0,
|
||||
step=0.05,
|
||||
tooltip="Reference strength; negative values invert the style influence.",
|
||||
),
|
||||
IO.Custom(KreaIO.STYLE_REF).Input(
|
||||
"style_reference",
|
||||
optional=True,
|
||||
tooltip="Optional incoming chain of style references; this node appends one more.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Custom(KreaIO.STYLE_REF).Output(display_name="style_reference")],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
strength: float,
|
||||
style_reference: list[dict] | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
chain: list[dict] = list(style_reference) if style_reference else []
|
||||
if len(chain) >= 10:
|
||||
raise ValueError("Krea 2 accepts at most 10 image_style_references in one generation.")
|
||||
url = await _upload_image_to_krea_assets(cls, image)
|
||||
chain.append({"url": url, "strength": float(strength)})
|
||||
return IO.NodeOutput(chain)
|
||||
|
||||
|
||||
class KreaExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
Krea2ImageNode,
|
||||
Krea2StyleReferenceNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> KreaExtension:
|
||||
return KreaExtension()
|
||||
@ -50,7 +50,7 @@ class TextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LtxvApiTextToVideo",
|
||||
display_name="LTXV Text To Video",
|
||||
category="api node/video/LTXV",
|
||||
category="video/partner/LTXV",
|
||||
description="Professional-quality videos with customizable duration and resolution.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
|
||||
@ -127,7 +127,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LtxvApiImageToVideo",
|
||||
display_name="LTXV Image To Video",
|
||||
category="api node/video/LTXV",
|
||||
category="video/partner/LTXV",
|
||||
description="Professional-quality videos with customizable duration and resolution based on start image.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="First frame to be used for the video."),
|
||||
|
||||
@ -46,7 +46,7 @@ class LumaReferenceNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaReferenceNode",
|
||||
display_name="Luma Reference",
|
||||
category="api node/image/Luma",
|
||||
category="image/partner/Luma",
|
||||
description="Holds an image and weight for use with Luma Generate Image node.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -85,7 +85,7 @@ class LumaConceptsNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaConceptsNode",
|
||||
display_name="Luma Concepts",
|
||||
category="api node/video/Luma",
|
||||
category="video/partner/Luma",
|
||||
description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -134,7 +134,7 @@ class LumaImageGenerationNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaImageNode",
|
||||
display_name="Luma Text to Image",
|
||||
category="api node/image/Luma",
|
||||
category="image/partner/Luma",
|
||||
description="Generates images synchronously based on prompt and aspect ratio.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -278,7 +278,7 @@ class LumaImageModifyNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaImageModifyNode",
|
||||
display_name="Luma Image to Image",
|
||||
category="api node/image/Luma",
|
||||
category="image/partner/Luma",
|
||||
description="Modifies images synchronously based on prompt and aspect ratio.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -371,7 +371,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaVideoNode",
|
||||
display_name="Luma Text to Video",
|
||||
category="api node/video/Luma",
|
||||
category="video/partner/Luma",
|
||||
description="Generates videos synchronously based on prompt and output_size.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -472,7 +472,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaImageToVideoNode",
|
||||
display_name="Luma Image to Video",
|
||||
category="api node/video/Luma",
|
||||
category="video/partner/Luma",
|
||||
description="Generates videos synchronously based on prompt, input images, and output_size.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -724,7 +724,7 @@ class LumaImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaImageNode2",
|
||||
display_name="Luma UNI-1 Image",
|
||||
category="api node/image/Luma",
|
||||
category="image/partner/Luma",
|
||||
description="Generate images from text using the Luma UNI-1 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -853,7 +853,7 @@ class LumaImageEditNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaImageEditNode2",
|
||||
display_name="Luma UNI-1 Image Edit",
|
||||
category="api node/image/Luma",
|
||||
category="image/partner/Luma",
|
||||
description="Edit an existing image with a text prompt using the Luma UNI-1 model.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
|
||||
@ -61,7 +61,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageUpscalerCreativeNode",
|
||||
display_name="Magnific Image Upscale (Creative)",
|
||||
category="api node/image/Magnific",
|
||||
category="image/partner/Magnific",
|
||||
description="Prompt‑guided enhancement, stylization, and 2x/4x/8x/16x upscaling. "
|
||||
"Maximum output: 25.3 megapixels.",
|
||||
inputs=[
|
||||
@ -240,7 +240,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageUpscalerPreciseV2Node",
|
||||
display_name="Magnific Image Upscale (Precise V2)",
|
||||
category="api node/image/Magnific",
|
||||
category="image/partner/Magnific",
|
||||
description="High-fidelity upscaling with fine control over sharpness, grain, and detail. "
|
||||
"Maximum output: 10060×10060 pixels.",
|
||||
inputs=[
|
||||
@ -400,7 +400,7 @@ class MagnificImageStyleTransferNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageStyleTransferNode",
|
||||
display_name="Magnific Image Style Transfer",
|
||||
category="api node/image/Magnific",
|
||||
category="image/partner/Magnific",
|
||||
description="Transfer the style from a reference image to your input image.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The image to apply style transfer to."),
|
||||
@ -549,7 +549,7 @@ class MagnificImageRelightNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageRelightNode",
|
||||
display_name="Magnific Image Relight",
|
||||
category="api node/image/Magnific",
|
||||
category="image/partner/Magnific",
|
||||
description="Relight an image with lighting adjustments and optional reference-based light transfer.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The image to relight."),
|
||||
@ -789,7 +789,7 @@ class MagnificImageSkinEnhancerNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageSkinEnhancerNode",
|
||||
display_name="Magnific Image Skin Enhancer",
|
||||
category="api node/image/Magnific",
|
||||
category="image/partner/Magnific",
|
||||
description="Skin enhancement for portraits with multiple processing modes.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The portrait image to enhance."),
|
||||
|
||||
@ -33,7 +33,7 @@ class MeshyTextToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyTextToModelNode",
|
||||
display_name="Meshy: Text to Model",
|
||||
category="api node/3d/Meshy",
|
||||
category="3d/partner/Meshy",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
@ -145,7 +145,7 @@ class MeshyRefineNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyRefineNode",
|
||||
display_name="Meshy: Refine Draft Model",
|
||||
category="api node/3d/Meshy",
|
||||
category="3d/partner/Meshy",
|
||||
description="Refine a previously created draft model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
@ -240,7 +240,7 @@ class MeshyImageToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyImageToModelNode",
|
||||
display_name="Meshy: Image to Model",
|
||||
category="api node/3d/Meshy",
|
||||
category="3d/partner/Meshy",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
IO.Image.Input("image"),
|
||||
@ -405,7 +405,7 @@ class MeshyMultiImageToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyMultiImageToModelNode",
|
||||
display_name="Meshy: Multi-Image to Model",
|
||||
category="api node/3d/Meshy",
|
||||
category="3d/partner/Meshy",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
IO.Autogrow.Input(
|
||||
@ -575,7 +575,7 @@ class MeshyRigModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyRigModelNode",
|
||||
display_name="Meshy: Rig Model",
|
||||
category="api node/3d/Meshy",
|
||||
category="3d/partner/Meshy",
|
||||
description="Provides a rigged character in standard formats. "
|
||||
"Auto-rigging is currently not suitable for untextured meshes, non-humanoid assets, "
|
||||
"or humanoid assets with unclear limb and body structure.",
|
||||
@ -656,7 +656,7 @@ class MeshyAnimateModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyAnimateModelNode",
|
||||
display_name="Meshy: Animate Model",
|
||||
category="api node/3d/Meshy",
|
||||
category="3d/partner/Meshy",
|
||||
description="Apply a specific animation action to a previously rigged character.",
|
||||
inputs=[
|
||||
IO.Custom("MESHY_RIGGED_TASK_ID").Input("rig_task_id"),
|
||||
@ -722,7 +722,7 @@ class MeshyTextureNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyTextureNode",
|
||||
display_name="Meshy: Texture Model",
|
||||
category="api node/3d/Meshy",
|
||||
category="3d/partner/Meshy",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"),
|
||||
|
||||
@ -101,7 +101,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MinimaxTextToVideoNode",
|
||||
display_name="MiniMax Text to Video",
|
||||
category="api node/video/MiniMax",
|
||||
category="video/partner/MiniMax",
|
||||
description="Generates videos synchronously based on a prompt, and optional parameters.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -163,7 +163,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MinimaxImageToVideoNode",
|
||||
display_name="MiniMax Image to Video",
|
||||
category="api node/video/MiniMax",
|
||||
category="video/partner/MiniMax",
|
||||
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -230,7 +230,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MinimaxSubjectToVideoNode",
|
||||
display_name="MiniMax Subject to Video",
|
||||
category="api node/video/MiniMax",
|
||||
category="video/partner/MiniMax",
|
||||
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -294,7 +294,7 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MinimaxHailuoVideoNode",
|
||||
display_name="MiniMax Hailuo Video",
|
||||
category="api node/video/MiniMax",
|
||||
category="video/partner/MiniMax",
|
||||
description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
|
||||
@ -99,7 +99,7 @@ class OpenAIDalle2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIDalle2",
|
||||
display_name="OpenAI DALL·E 2",
|
||||
category="api node/image/OpenAI",
|
||||
category="image/partner/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -249,7 +249,7 @@ class OpenAIDalle3(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIDalle3",
|
||||
display_name="OpenAI DALL·E 3",
|
||||
category="api node/image/OpenAI",
|
||||
category="image/partner/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -371,7 +371,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIGPTImage1",
|
||||
display_name="OpenAI GPT Image 2",
|
||||
category="api node/image/OpenAI",
|
||||
category="image/partner/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
|
||||
is_deprecated=True,
|
||||
inputs=[
|
||||
@ -695,7 +695,7 @@ class OpenAIGPTImageNodeV2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIGPTImageNodeV2",
|
||||
display_name="OpenAI GPT Image 2",
|
||||
category="api node/image/OpenAI",
|
||||
category="image/partner/OpenAI",
|
||||
description="Generates images via OpenAI's GPT Image endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -962,7 +962,7 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIChatNode",
|
||||
display_name="OpenAI ChatGPT",
|
||||
category="api node/text/OpenAI",
|
||||
category="text/partner/OpenAI",
|
||||
essentials_category="Text Generation",
|
||||
description="Generate text responses from an OpenAI model.",
|
||||
inputs=[
|
||||
@ -1201,7 +1201,7 @@ class OpenAIInputFiles(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIInputFiles",
|
||||
display_name="OpenAI ChatGPT Input Files",
|
||||
category="api node/text/OpenAI",
|
||||
category="text/partner/OpenAI",
|
||||
description="Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -1248,7 +1248,7 @@ class OpenAIChatConfig(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIChatConfig",
|
||||
display_name="OpenAI ChatGPT Advanced Options",
|
||||
category="api node/text/OpenAI",
|
||||
category="text/partner/OpenAI",
|
||||
description="Allows specifying advanced configuration options for the OpenAI Chat Nodes.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
|
||||
@ -265,7 +265,7 @@ class OpenRouterLLMNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenRouterLLMNode",
|
||||
display_name="OpenRouter LLM",
|
||||
category="api node/text/OpenRouter",
|
||||
category="text/partner/OpenRouter",
|
||||
essentials_category="Text Generation",
|
||||
description=(
|
||||
"Generate text responses through OpenRouter. Routes to a curated set of popular "
|
||||
|
||||
@ -53,7 +53,7 @@ class PixverseTemplateNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="PixverseTemplateNode",
|
||||
display_name="PixVerse Template",
|
||||
category="api node/video/PixVerse",
|
||||
category="video/partner/PixVerse",
|
||||
inputs=[
|
||||
IO.Combo.Input("template", options=list(pixverse_templates.keys())),
|
||||
],
|
||||
@ -74,7 +74,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="PixverseTextToVideoNode",
|
||||
display_name="PixVerse Text to Video",
|
||||
category="api node/video/PixVerse",
|
||||
category="video/partner/PixVerse",
|
||||
description="Generates videos based on prompt and output_size.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -192,7 +192,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="PixverseImageToVideoNode",
|
||||
display_name="PixVerse Image to Video",
|
||||
category="api node/video/PixVerse",
|
||||
category="video/partner/PixVerse",
|
||||
description="Generates videos based on prompt and output_size.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -310,7 +310,7 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="PixverseTransitionVideoNode",
|
||||
display_name="PixVerse Transition Video",
|
||||
category="api node/video/PixVerse",
|
||||
category="video/partner/PixVerse",
|
||||
description="Generates videos based on prompt and output_size.",
|
||||
inputs=[
|
||||
IO.Image.Input("first_frame"),
|
||||
|
||||
@ -62,7 +62,7 @@ class QuiverTextToSVGNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="QuiverTextToSVGNode",
|
||||
display_name="Quiver Text to SVG",
|
||||
category="api node/image/Quiver",
|
||||
category="image/partner/Quiver",
|
||||
description="Generate an SVG from a text prompt using Quiver AI.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -177,7 +177,7 @@ class QuiverImageToSVGNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="QuiverImageToSVGNode",
|
||||
display_name="Quiver Image to SVG",
|
||||
category="api node/image/Quiver",
|
||||
category="image/partner/Quiver",
|
||||
description="Vectorize a raster image into SVG using Quiver AI.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
|
||||
@ -178,7 +178,7 @@ class RecraftColorRGBNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftColorRGB",
|
||||
display_name="Recraft Color RGB",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
description="Create Recraft Color by choosing specific RGB values.",
|
||||
inputs=[
|
||||
IO.Int.Input("r", default=0, min=0, max=255, tooltip="Red value of color."),
|
||||
@ -204,7 +204,7 @@ class RecraftControlsNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftControls",
|
||||
display_name="Recraft Controls",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
description="Create Recraft Controls for customizing Recraft generation.",
|
||||
inputs=[
|
||||
IO.Custom(RecraftIO.COLOR).Input("colors", optional=True),
|
||||
@ -228,7 +228,7 @@ class RecraftStyleV3RealisticImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftStyleV3RealisticImage",
|
||||
display_name="Recraft Style - Realistic Image",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
description="Select realistic_image style and optional substyle.",
|
||||
inputs=[
|
||||
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
|
||||
@ -253,7 +253,7 @@ class RecraftStyleV3DigitalIllustrationNode(RecraftStyleV3RealisticImageNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftStyleV3DigitalIllustration",
|
||||
display_name="Recraft Style - Digital Illustration",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
description="Select realistic_image style and optional substyle.",
|
||||
inputs=[
|
||||
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
|
||||
@ -272,7 +272,7 @@ class RecraftStyleV3VectorIllustrationNode(RecraftStyleV3RealisticImageNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftStyleV3VectorIllustrationNode",
|
||||
display_name="Recraft Style - Realistic Image",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
description="Select realistic_image style and optional substyle.",
|
||||
inputs=[
|
||||
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
|
||||
@ -291,7 +291,7 @@ class RecraftStyleV3LogoRasterNode(RecraftStyleV3RealisticImageNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftStyleV3LogoRaster",
|
||||
display_name="Recraft Style - Logo Raster",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
description="Select realistic_image style and optional substyle.",
|
||||
inputs=[
|
||||
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE, include_none=False)),
|
||||
@ -308,7 +308,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftStyleV3InfiniteStyleLibrary",
|
||||
display_name="Recraft Style - Infinite Style Library",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.",
|
||||
inputs=[
|
||||
IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."),
|
||||
@ -331,7 +331,7 @@ class RecraftCreateStyleNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftCreateStyleNode",
|
||||
display_name="Recraft Create Style",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
description="Create a custom style from reference images. "
|
||||
"Upload 1-5 images to use as style references. "
|
||||
"Total size of all images is limited to 5 MB.",
|
||||
@ -400,7 +400,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftTextToImageNode",
|
||||
display_name="Recraft Text to Image",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
description="Generates images synchronously based on prompt and resolution.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."),
|
||||
@ -512,7 +512,7 @@ class RecraftImageToImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftImageToImageNode",
|
||||
display_name="Recraft Image to Image",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
description="Modify image based on prompt and strength.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -630,7 +630,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftImageInpaintingNode",
|
||||
display_name="Recraft Image Inpainting",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
description="Modify image based on prompt and mask.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -732,7 +732,7 @@ class RecraftTextToVectorNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftTextToVectorNode",
|
||||
display_name="Recraft Text to Vector",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
description="Generates SVG synchronously based on prompt and resolution.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", default="", tooltip="Prompt for the image generation.", multiline=True),
|
||||
@ -832,7 +832,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftVectorizeImageNode",
|
||||
display_name="Recraft Vectorize Image",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
essentials_category="Image Tools",
|
||||
description="Generates SVG synchronously from an input image.",
|
||||
inputs=[
|
||||
@ -876,7 +876,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftReplaceBackgroundNode",
|
||||
display_name="Recraft Replace Background",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
description="Replace background on image, based on provided prompt.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -963,7 +963,7 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftRemoveBackgroundNode",
|
||||
display_name="Recraft Remove Background",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
essentials_category="Image Tools",
|
||||
description="Remove background from image, and return processed image and mask.",
|
||||
inputs=[
|
||||
@ -1012,7 +1012,7 @@ class RecraftCrispUpscaleNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftCrispUpscaleNode",
|
||||
display_name="Recraft Crisp Upscale Image",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
description="Upscale image synchronously.\n"
|
||||
"Enhances a given raster image using ‘crisp upscale’ tool, "
|
||||
"increasing image resolution, making the image sharper and cleaner.",
|
||||
@ -1058,7 +1058,7 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftCreativeUpscaleNode",
|
||||
display_name="Recraft Creative Upscale Image",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
description="Upscale image synchronously.\n"
|
||||
"Enhances a given raster image using ‘creative upscale’ tool, "
|
||||
"boosting resolution with a focus on refining small details and faces.",
|
||||
@ -1086,7 +1086,7 @@ class RecraftV4TextToImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftV4TextToImageNode",
|
||||
display_name="Recraft V4 Text to Image",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
description="Generates images using Recraft V4 or V4 Pro models.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -1210,7 +1210,7 @@ class RecraftV4TextToVectorNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftV4TextToVectorNode",
|
||||
display_name="Recraft V4 Text to Vector",
|
||||
category="api node/image/Recraft",
|
||||
category="image/partner/Recraft",
|
||||
description="Generates SVG using Recraft V4 or V4 Pro models.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
|
||||
@ -109,7 +109,7 @@ class ReveImageCreateNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ReveImageCreateNode",
|
||||
display_name="Reve Image Create",
|
||||
category="api node/image/Reve",
|
||||
category="image/partner/Reve",
|
||||
description="Generate images from text descriptions using Reve.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -200,7 +200,7 @@ class ReveImageEditNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ReveImageEditNode",
|
||||
display_name="Reve Image Edit",
|
||||
category="api node/image/Reve",
|
||||
category="image/partner/Reve",
|
||||
description="Edit images using natural language instructions with Reve.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The image to edit."),
|
||||
@ -300,7 +300,7 @@ class ReveImageRemixNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ReveImageRemixNode",
|
||||
display_name="Reve Image Remix",
|
||||
category="api node/image/Reve",
|
||||
category="image/partner/Reve",
|
||||
description="Combine reference images with text prompts to create new images using Reve.",
|
||||
inputs=[
|
||||
IO.Autogrow.Input(
|
||||
|
||||
@ -230,7 +230,7 @@ class Rodin3D_Regular(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Regular",
|
||||
display_name="Rodin 3D Generate - Regular Generate",
|
||||
category="api node/3d/Rodin",
|
||||
category="3d/partner/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("Images"),
|
||||
@ -289,7 +289,7 @@ class Rodin3D_Detail(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Detail",
|
||||
display_name="Rodin 3D Generate - Detail Generate",
|
||||
category="api node/3d/Rodin",
|
||||
category="3d/partner/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("Images"),
|
||||
@ -348,7 +348,7 @@ class Rodin3D_Smooth(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Smooth",
|
||||
display_name="Rodin 3D Generate - Smooth Generate",
|
||||
category="api node/3d/Rodin",
|
||||
category="3d/partner/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("Images"),
|
||||
@ -406,7 +406,7 @@ class Rodin3D_Sketch(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Sketch",
|
||||
display_name="Rodin 3D Generate - Sketch Generate",
|
||||
category="api node/3d/Rodin",
|
||||
category="3d/partner/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("Images"),
|
||||
@ -468,7 +468,7 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Gen2",
|
||||
display_name="Rodin 3D Generate - Gen-2 Generate",
|
||||
category="api node/3d/Rodin",
|
||||
category="3d/partner/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("Images"),
|
||||
@ -941,7 +941,7 @@ class Rodin3D_Gen25_Image(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Gen25_Image",
|
||||
display_name="Rodin 3D Gen-2.5 - Image to 3D",
|
||||
category="api node/3d/Rodin",
|
||||
category="3d/partner/Rodin",
|
||||
description=(
|
||||
"Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. "
|
||||
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
|
||||
@ -1035,7 +1035,7 @@ class Rodin3D_Gen25_Text(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Gen25_Text",
|
||||
display_name="Rodin 3D Gen-2.5 - Text to 3D",
|
||||
category="api node/3d/Rodin",
|
||||
category="3d/partner/Rodin",
|
||||
description=(
|
||||
"Generate a 3D model from a text prompt via Rodin Gen-2.5. "
|
||||
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
|
||||
|
||||
@ -140,7 +140,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RunwayImageToVideoNodeGen3a",
|
||||
display_name="Runway Image to Video (Gen3a Turbo)",
|
||||
category="api node/video/Runway",
|
||||
category="video/partner/Runway",
|
||||
description="Generate a video from a single starting frame using Gen3a Turbo model. "
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
@ -234,7 +234,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RunwayImageToVideoNodeGen4",
|
||||
display_name="Runway Image to Video (Gen4 Turbo)",
|
||||
category="api node/video/Runway",
|
||||
category="video/partner/Runway",
|
||||
description="Generate a video from a single starting frame using Gen4 Turbo model. "
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
@ -329,7 +329,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RunwayFirstLastFrameNode",
|
||||
display_name="Runway First-Last-Frame to Video",
|
||||
category="api node/video/Runway",
|
||||
category="video/partner/Runway",
|
||||
description="Upload first and last keyframes, draft a prompt, and generate a video. "
|
||||
"More complex transitions, such as cases where the Last frame is completely different "
|
||||
"from the First frame, may benefit from the longer 10s duration. "
|
||||
@ -440,7 +440,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RunwayTextToImageNode",
|
||||
display_name="Runway Text to Image",
|
||||
category="api node/image/Runway",
|
||||
category="image/partner/Runway",
|
||||
description="Generate an image from a text prompt using Runway's Gen 4 model. "
|
||||
"You can also include reference image to guide the generation.",
|
||||
inputs=[
|
||||
|
||||
@ -34,7 +34,7 @@ class SoniloVideoToMusic(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="SoniloVideoToMusic",
|
||||
display_name="Sonilo Video to Music",
|
||||
category="api node/audio/Sonilo",
|
||||
category="audio/partner/Sonilo",
|
||||
description="Generate music from video content using Sonilo's AI model. "
|
||||
"Analyzes the video and creates matching music.",
|
||||
inputs=[
|
||||
@ -99,7 +99,7 @@ class SoniloTextToMusic(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="SoniloTextToMusic",
|
||||
display_name="Sonilo Text to Music",
|
||||
category="api node/audio/Sonilo",
|
||||
category="audio/partner/Sonilo",
|
||||
description="Generate music from a text prompt using Sonilo's AI model. "
|
||||
"Leave duration at 0 to let the model infer it from the prompt.",
|
||||
inputs=[
|
||||
|
||||
@ -34,7 +34,7 @@ class OpenAIVideoSora2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIVideoSora2",
|
||||
display_name="OpenAI Sora - Video (DEPRECATED)",
|
||||
category="api node/video/Sora",
|
||||
category="video/partner/Sora",
|
||||
description=(
|
||||
"OpenAI video and audio generation.\n\n"
|
||||
"DEPRECATION NOTICE: OpenAI will stop serving the Sora v2 API in September 2026. "
|
||||
|
||||
@ -62,7 +62,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityStableImageUltraNode",
|
||||
display_name="Stability AI Stable Image Ultra",
|
||||
category="api node/image/Stability AI",
|
||||
category="image/partner/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -197,7 +197,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityStableImageSD_3_5Node",
|
||||
display_name="Stability AI Stable Diffusion 3.5 Image",
|
||||
category="api node/image/Stability AI",
|
||||
category="image/partner/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -354,7 +354,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleConservativeNode",
|
||||
display_name="Stability AI Upscale Conservative",
|
||||
category="api node/image/Stability AI",
|
||||
category="image/partner/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -457,7 +457,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleCreativeNode",
|
||||
display_name="Stability AI Upscale Creative",
|
||||
category="api node/image/Stability AI",
|
||||
category="image/partner/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -578,7 +578,7 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleFastNode",
|
||||
display_name="Stability AI Upscale Fast",
|
||||
category="api node/image/Stability AI",
|
||||
category="image/partner/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -630,7 +630,7 @@ class StabilityTextToAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityTextToAudio",
|
||||
display_name="Stability AI Text To Audio",
|
||||
category="api node/audio/Stability AI",
|
||||
category="audio/partner/Stability AI",
|
||||
essentials_category="Audio",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
@ -708,7 +708,7 @@ class StabilityAudioToAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityAudioToAudio",
|
||||
display_name="Stability AI Audio To Audio",
|
||||
category="api node/audio/Stability AI",
|
||||
category="audio/partner/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -802,7 +802,7 @@ class StabilityAudioInpaint(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityAudioInpaint",
|
||||
display_name="Stability AI Audio Inpaint",
|
||||
category="api node/audio/Stability AI",
|
||||
category="audio/partner/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
|
||||
@ -52,7 +52,7 @@ class TopazImageEnhance(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TopazImageEnhance",
|
||||
display_name="Topaz Image Enhance",
|
||||
category="api node/image/Topaz",
|
||||
category="image/partner/Topaz",
|
||||
description="Industry-standard upscaling and image enhancement.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["Reimagine"]),
|
||||
@ -235,7 +235,7 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TopazVideoEnhance",
|
||||
display_name="Topaz Video Enhance (Legacy)",
|
||||
category="api node/video/Topaz",
|
||||
category="video/partner/Topaz",
|
||||
description="Breathe new life into video with powerful upscaling and recovery technology.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
@ -475,7 +475,7 @@ class TopazVideoEnhanceV2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TopazVideoEnhanceV2",
|
||||
display_name="Topaz Video Enhance",
|
||||
category="api node/video/Topaz",
|
||||
category="video/partner/Topaz",
|
||||
description="Breathe new life into video with powerful upscaling and recovery technology.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
|
||||
@ -80,7 +80,7 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoTextToModelNode",
|
||||
display_name="Tripo: Text to Model",
|
||||
category="api node/3d/Tripo",
|
||||
category="3d/partner/Tripo",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True),
|
||||
IO.String.Input("negative_prompt", multiline=True, optional=True),
|
||||
@ -195,7 +195,7 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoImageToModelNode",
|
||||
display_name="Tripo: Image to Model",
|
||||
category="api node/3d/Tripo",
|
||||
category="3d/partner/Tripo",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input(
|
||||
@ -323,7 +323,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoMultiviewToModelNode",
|
||||
display_name="Tripo: Multiview to Model",
|
||||
category="api node/3d/Tripo",
|
||||
category="3d/partner/Tripo",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Image.Input("image_left", optional=True),
|
||||
@ -461,7 +461,7 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoTextureNode",
|
||||
display_name="Tripo: Texture model",
|
||||
category="api node/3d/Tripo",
|
||||
category="3d/partner/Tripo",
|
||||
inputs=[
|
||||
IO.Custom("MODEL_TASK_ID").Input("model_task_id"),
|
||||
IO.Boolean.Input("texture", default=True, optional=True),
|
||||
@ -528,7 +528,7 @@ class TripoRefineNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoRefineNode",
|
||||
display_name="Tripo: Refine Draft model",
|
||||
category="api node/3d/Tripo",
|
||||
category="3d/partner/Tripo",
|
||||
description="Refine a draft model created by v1.4 Tripo models only.",
|
||||
inputs=[
|
||||
IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"),
|
||||
@ -568,7 +568,7 @@ class TripoRigNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoRigNode",
|
||||
display_name="Tripo: Rig model",
|
||||
category="api node/3d/Tripo",
|
||||
category="3d/partner/Tripo",
|
||||
inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
@ -605,7 +605,7 @@ class TripoRetargetNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoRetargetNode",
|
||||
display_name="Tripo: Retarget rigged model",
|
||||
category="api node/3d/Tripo",
|
||||
category="3d/partner/Tripo",
|
||||
inputs=[
|
||||
IO.Custom("RIG_TASK_ID").Input("original_model_task_id"),
|
||||
IO.Combo.Input(
|
||||
@ -670,7 +670,7 @@ class TripoConversionNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoConversionNode",
|
||||
display_name="Tripo: Convert model",
|
||||
category="api node/3d/Tripo",
|
||||
category="3d/partner/Tripo",
|
||||
inputs=[
|
||||
IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"),
|
||||
IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]),
|
||||
|
||||
@ -45,7 +45,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="VeoVideoGenerationNode",
|
||||
display_name="Google Veo 2 Video Generation",
|
||||
category="api node/video/Veo",
|
||||
category="video/partner/Veo",
|
||||
description="Generates videos from text prompts using Google's Veo 2 API",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -256,7 +256,7 @@ class Veo3VideoGenerationNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Veo3VideoGenerationNode",
|
||||
display_name="Google Veo 3 Video Generation",
|
||||
category="api node/video/Veo",
|
||||
category="video/partner/Veo",
|
||||
description="Generates videos from text prompts using Google's Veo 3 API",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -468,7 +468,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Veo3FirstLastFrameNode",
|
||||
display_name="Google Veo 3 First-Last-Frame to Video",
|
||||
category="api node/video/Veo",
|
||||
category="video/partner/Veo",
|
||||
description="Generate video using prompt and first and last frames.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
|
||||
@ -71,7 +71,7 @@ class ViduTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ViduTextToVideoNode",
|
||||
display_name="Vidu Text To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="video/partner/Vidu",
|
||||
description="Generate video from a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
@ -169,7 +169,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ViduImageToVideoNode",
|
||||
display_name="Vidu Image To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="video/partner/Vidu",
|
||||
description="Generate video from image and optional prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
@ -273,7 +273,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ViduReferenceVideoNode",
|
||||
display_name="Vidu Reference To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="video/partner/Vidu",
|
||||
description="Generate video from multiple images and a prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
@ -388,7 +388,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ViduStartEndToVideoNode",
|
||||
display_name="Vidu Start End To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="video/partner/Vidu",
|
||||
description="Generate a video from start and end frames and a prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
@ -492,7 +492,7 @@ class Vidu2TextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2TextToVideoNode",
|
||||
display_name="Vidu2 Text-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="video/partner/Vidu",
|
||||
description="Generate video from a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2"]),
|
||||
@ -584,7 +584,7 @@ class Vidu2ImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2ImageToVideoNode",
|
||||
display_name="Vidu2 Image-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="video/partner/Vidu",
|
||||
description="Generate a video from an image and an optional prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
|
||||
@ -714,7 +714,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2ReferenceVideoNode",
|
||||
display_name="Vidu2 Reference-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="video/partner/Vidu",
|
||||
description="Generate a video from multiple reference images and a prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2"]),
|
||||
@ -849,7 +849,7 @@ class Vidu2StartEndToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2StartEndToVideoNode",
|
||||
display_name="Vidu2 Start/End Frame-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="video/partner/Vidu",
|
||||
description="Generate a video from a start frame, an end frame, and a prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
|
||||
@ -969,7 +969,7 @@ class ViduExtendVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ViduExtendVideoNode",
|
||||
display_name="Vidu Video Extension",
|
||||
category="api node/video/Vidu",
|
||||
category="video/partner/Vidu",
|
||||
description="Extend an existing video by generating additional frames.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1138,7 +1138,7 @@ class ViduMultiFrameVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ViduMultiFrameVideoNode",
|
||||
display_name="Vidu Multi-Frame Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="video/partner/Vidu",
|
||||
description="Generate a video with multiple keyframe transitions.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2-pro", "viduq2-turbo"]),
|
||||
@ -1284,7 +1284,7 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu3TextToVideoNode",
|
||||
display_name="Vidu Q3 Text-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="video/partner/Vidu",
|
||||
description="Generate video from a text prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1429,7 +1429,7 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu3ImageToVideoNode",
|
||||
display_name="Vidu Q3 Image-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="video/partner/Vidu",
|
||||
description="Generate a video from an image and an optional prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1571,7 +1571,7 @@ class Vidu3StartEndToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu3StartEndToVideoNode",
|
||||
display_name="Vidu Q3 Start/End Frame-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="video/partner/Vidu",
|
||||
description="Generate a video from a start frame, an end frame, and a prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
|
||||
@ -61,7 +61,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WanTextToImageApi",
|
||||
display_name="Wan Text to Image",
|
||||
category="api node/image/Wan",
|
||||
category="image/partner/Wan",
|
||||
description="Generates an image based on a text prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -184,7 +184,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WanImageToImageApi",
|
||||
display_name="Wan Image to Image",
|
||||
category="api node/image/Wan",
|
||||
category="image/partner/Wan",
|
||||
description="Generates an image from one or two input images and a text prompt. "
|
||||
"The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).",
|
||||
inputs=[
|
||||
@ -312,7 +312,7 @@ class WanTextToVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WanTextToVideoApi",
|
||||
display_name="Wan Text to Video",
|
||||
category="api node/video/Wan",
|
||||
category="video/partner/Wan",
|
||||
description="Generates a video based on a text prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -495,7 +495,7 @@ class WanImageToVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WanImageToVideoApi",
|
||||
display_name="Wan Image to Video",
|
||||
category="api node/video/Wan",
|
||||
category="video/partner/Wan",
|
||||
description="Generates a video from the first frame and a text prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -674,7 +674,7 @@ class WanReferenceVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WanReferenceVideoApi",
|
||||
display_name="Wan Reference to Video",
|
||||
category="api node/video/Wan",
|
||||
category="video/partner/Wan",
|
||||
description="Use the character and voice from input videos, combined with a prompt, "
|
||||
"to generate a new video that maintains character consistency.",
|
||||
inputs=[
|
||||
@ -828,7 +828,7 @@ class Wan2TextToVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Wan2TextToVideoApi",
|
||||
display_name="Wan 2.7 Text to Video",
|
||||
category="api node/video/Wan",
|
||||
category="video/partner/Wan",
|
||||
description="Generates a video based on a text prompt using the Wan 2.7 model.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -981,7 +981,7 @@ class Wan2ImageToVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Wan2ImageToVideoApi",
|
||||
display_name="Wan 2.7 Image to Video",
|
||||
category="api node/video/Wan",
|
||||
category="video/partner/Wan",
|
||||
description="Generate a video from a first-frame image, with optional last-frame image and audio.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1152,7 +1152,7 @@ class Wan2VideoContinuationApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Wan2VideoContinuationApi",
|
||||
display_name="Wan 2.7 Video Continuation",
|
||||
category="api node/video/Wan",
|
||||
category="video/partner/Wan",
|
||||
description="Continue a video from where it left off, with optional last-frame control.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1319,7 +1319,7 @@ class Wan2VideoEditApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Wan2VideoEditApi",
|
||||
display_name="Wan 2.7 Video Edit",
|
||||
category="api node/video/Wan",
|
||||
category="video/partner/Wan",
|
||||
description="Edit a video using text instructions, reference images, or style transfer.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1477,7 +1477,7 @@ class Wan2ReferenceVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Wan2ReferenceVideoApi",
|
||||
display_name="Wan 2.7 Reference to Video",
|
||||
category="api node/video/Wan",
|
||||
category="video/partner/Wan",
|
||||
description="Generate a video featuring a person or object from reference materials. "
|
||||
"Supports single-character performances and multi-character interactions.",
|
||||
inputs=[
|
||||
@ -1651,7 +1651,7 @@ class HappyHorseTextToVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HappyHorseTextToVideoApi",
|
||||
display_name="HappyHorse Text to Video",
|
||||
category="api node/video/Wan",
|
||||
category="video/partner/Wan",
|
||||
description="Generates a video based on a text prompt using the HappyHorse model.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1775,7 +1775,7 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HappyHorseImageToVideoApi",
|
||||
display_name="HappyHorse Image to Video",
|
||||
category="api node/video/Wan",
|
||||
category="video/partner/Wan",
|
||||
description="Generate a video from a first-frame image using the HappyHorse model.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1905,7 +1905,7 @@ class HappyHorseVideoEditApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HappyHorseVideoEditApi",
|
||||
display_name="HappyHorse Video Edit",
|
||||
category="api node/video/Wan",
|
||||
category="video/partner/Wan",
|
||||
description="Edit a video using text instructions or reference images with the HappyHorse model. "
|
||||
"Output duration is 3-15s and matches the input video; inputs longer than 15s are truncated.",
|
||||
inputs=[
|
||||
@ -2046,7 +2046,7 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HappyHorseReferenceVideoApi",
|
||||
display_name="HappyHorse Reference to Video",
|
||||
category="api node/video/Wan",
|
||||
category="video/partner/Wan",
|
||||
description="Generate a video featuring a person or object from reference materials with the HappyHorse "
|
||||
"model. Supports single-character performances and multi-character interactions.",
|
||||
inputs=[
|
||||
|
||||
@ -27,7 +27,7 @@ class WavespeedFlashVSRNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WavespeedFlashVSRNode",
|
||||
display_name="FlashVSR Video Upscale",
|
||||
category="api node/video/WaveSpeed",
|
||||
category="video/partner/WaveSpeed",
|
||||
description="Fast, high-quality video upscaler that "
|
||||
"boosts resolution and restores clarity for low-resolution or blurry footage.",
|
||||
inputs=[
|
||||
@ -98,7 +98,7 @@ class WavespeedImageUpscaleNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WavespeedImageUpscaleNode",
|
||||
display_name="WaveSpeed Image Upscale",
|
||||
category="api node/image/WaveSpeed",
|
||||
category="image/partner/WaveSpeed",
|
||||
description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]),
|
||||
|
||||
@ -86,7 +86,7 @@ class _PollUIState:
|
||||
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait", "in_queue"]
|
||||
|
||||
|
||||
async def sync_op(
|
||||
|
||||
@ -11,7 +11,7 @@ class TextEncodeAceStepAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TextEncodeAceStepAudio",
|
||||
category="conditioning",
|
||||
category="model/conditioning",
|
||||
inputs=[
|
||||
IO.Clip.Input("clip"),
|
||||
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
@ -33,7 +33,7 @@ class TextEncodeAceStepAudio15(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TextEncodeAceStepAudio1.5",
|
||||
category="conditioning",
|
||||
category="model/conditioning",
|
||||
inputs=[
|
||||
IO.Clip.Input("clip"),
|
||||
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
@ -67,7 +67,7 @@ class EmptyAceStepLatentAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="EmptyAceStepLatentAudio",
|
||||
display_name="Empty Ace Step 1.0 Latent Audio",
|
||||
category="latent/audio",
|
||||
category="model/latent/audio",
|
||||
inputs=[
|
||||
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
||||
IO.Int.Input(
|
||||
@ -90,7 +90,7 @@ class EmptyAceStep15LatentAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="EmptyAceStep1.5LatentAudio",
|
||||
display_name="Empty Ace Step 1.5 Latent Audio",
|
||||
category="latent/audio",
|
||||
category="model/latent/audio",
|
||||
inputs=[
|
||||
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
|
||||
IO.Int.Input(
|
||||
|
||||
@ -45,7 +45,7 @@ class SamplerLCMUpscale(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="SamplerLCMUpscale",
|
||||
category="sampling/samplers",
|
||||
category="model/sampling/samplers",
|
||||
inputs=[
|
||||
io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01, advanced=True),
|
||||
io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1, advanced=True),
|
||||
@ -91,7 +91,7 @@ class SamplerLCM(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="SamplerLCM",
|
||||
category="sampling/samplers",
|
||||
category="model/sampling/samplers",
|
||||
description=("LCM sampler with tunable per-step noise. s_noise is a multiplier on the model's training noise scale"),
|
||||
inputs=[
|
||||
io.Float.Input("s_noise", default=1.0, min=0.0, max=64.0, step=0.01,
|
||||
|
||||
@ -29,7 +29,7 @@ class AlignYourStepsScheduler(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="AlignYourStepsScheduler",
|
||||
search_aliases=["AYS scheduler"],
|
||||
category="sampling/schedulers",
|
||||
category="model/sampling/schedulers",
|
||||
inputs=[
|
||||
io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]),
|
||||
io.Int.Input("steps", default=10, min=1, max=10000),
|
||||
|
||||
@ -16,7 +16,7 @@ class APG(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="APG",
|
||||
display_name="Adaptive Projected Guidance",
|
||||
category="sampling/custom_sampling",
|
||||
category="model/sampling/custom_sampling",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input(
|
||||
|
||||
@ -19,7 +19,7 @@ class EmptyARVideoLatent(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyARVideoLatent",
|
||||
category="latent/video",
|
||||
category="model/latent/video",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
||||
@ -53,7 +53,7 @@ class SamplerARVideo(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="SamplerARVideo",
|
||||
display_name="Sampler AR Video",
|
||||
category="sampling/samplers",
|
||||
category="model/sampling/samplers",
|
||||
inputs=[
|
||||
io.Int.Input(
|
||||
"num_frame_per_block",
|
||||
@ -85,7 +85,7 @@ class ARVideoI2V(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ARVideoI2V",
|
||||
category="conditioning/video_models",
|
||||
category="model/conditioning/video_models",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Vae.Input("vae"),
|
||||
|
||||
@ -16,7 +16,7 @@ class EmptyLatentAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="EmptyLatentAudio",
|
||||
display_name="Empty Latent Audio",
|
||||
category="latent/audio",
|
||||
category="model/latent/audio",
|
||||
essentials_category="Audio",
|
||||
inputs=[
|
||||
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
|
||||
@ -41,7 +41,7 @@ class ConditioningStableAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ConditioningStableAudio",
|
||||
category="conditioning",
|
||||
category="model/conditioning",
|
||||
inputs=[
|
||||
IO.Conditioning.Input("positive"),
|
||||
IO.Conditioning.Input("negative"),
|
||||
@ -70,7 +70,7 @@ class VAEEncodeAudio(IO.ComfyNode):
|
||||
node_id="VAEEncodeAudio",
|
||||
search_aliases=["audio to latent"],
|
||||
display_name="VAE Encode Audio",
|
||||
category="latent/audio",
|
||||
category="model/latent/audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.Vae.Input("vae"),
|
||||
@ -115,7 +115,7 @@ class VAEDecodeAudio(IO.ComfyNode):
|
||||
node_id="VAEDecodeAudio",
|
||||
search_aliases=["latent to audio"],
|
||||
display_name="VAE Decode Audio",
|
||||
category="latent/audio",
|
||||
category="model/latent/audio",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
@ -137,7 +137,7 @@ class VAEDecodeAudioTiled(IO.ComfyNode):
|
||||
node_id="VAEDecodeAudioTiled",
|
||||
search_aliases=["latent to audio"],
|
||||
display_name="VAE Decode Audio (Tiled)",
|
||||
category="latent/audio",
|
||||
category="model/latent/audio",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
|
||||
@ -11,7 +11,7 @@ class AudioEncoderLoader(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="AudioEncoderLoader",
|
||||
display_name="Load Audio Encoder",
|
||||
category="loaders",
|
||||
category="model/loaders",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"audio_encoder_name",
|
||||
@ -36,7 +36,7 @@ class AudioEncoderEncode(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="AudioEncoderEncode",
|
||||
category="conditioning",
|
||||
category="model/conditioning",
|
||||
inputs=[
|
||||
io.AudioEncoder.Input("audio_encoder"),
|
||||
io.Audio.Input("audio"),
|
||||
|
||||
@ -11,7 +11,7 @@ class LoadBackgroundRemovalModel(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LoadBackgroundRemovalModel",
|
||||
display_name="Load Background Removal Model",
|
||||
category="loaders",
|
||||
category="model/loaders",
|
||||
inputs=[
|
||||
IO.Combo.Input("bg_removal_name", options=sorted(files), tooltip="The model used to remove backgrounds from images"),
|
||||
],
|
||||
|
||||
@ -153,7 +153,7 @@ class WanCameraEmbedding(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanCameraEmbedding",
|
||||
category="conditioning/video_models",
|
||||
category="model/conditioning/video_models",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"camera_pose",
|
||||
|
||||
@ -57,24 +57,55 @@ class CFGNorm(io.ComfyNode):
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("strength", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||
io.Boolean.Input(
|
||||
"pre_cfg",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip=(
|
||||
"If true, rescale the combined noise BEFORE the sampler's CFG combine, "
|
||||
"without clamping (can amplify). Matches the norm-scaled CFG used by "
|
||||
"models like Lens. Default false keeps the original post-CFG x0-space "
|
||||
"attenuate-only behavior."
|
||||
),
|
||||
),
|
||||
],
|
||||
outputs=[io.Model.Output(display_name="patched_model")],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model, strength) -> io.NodeOutput:
|
||||
def execute(cls, model, strength, pre_cfg=False) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
def cfg_norm(args):
|
||||
cond_p = args['cond_denoised']
|
||||
pred_text_ = args["denoised"]
|
||||
if pre_cfg:
|
||||
def cfg_norm_pre(args):
|
||||
cond = args["cond"]
|
||||
uncond = args["uncond"]
|
||||
cond_scale = args["cond_scale"]
|
||||
comb = uncond + cond_scale * (cond - uncond)
|
||||
cond_norm = torch.linalg.vector_norm(cond, dim=1, keepdim=True)
|
||||
comb_norm = torch.linalg.vector_norm(comb, dim=1, keepdim=True)
|
||||
rescale = torch.where(
|
||||
comb_norm > 0,
|
||||
cond_norm / comb_norm.clamp_min(1e-12),
|
||||
torch.ones_like(comb_norm),
|
||||
)
|
||||
rescaled = comb * rescale
|
||||
# strength blends back toward standard linear CFG (1.0 = full rescale).
|
||||
if strength != 1.0:
|
||||
rescaled = strength * rescaled + (1.0 - strength) * comb
|
||||
return rescaled
|
||||
m.set_model_sampler_cfg_function(cfg_norm_pre)
|
||||
else:
|
||||
def cfg_norm(args):
|
||||
cond_p = args['cond_denoised']
|
||||
pred_text_ = args["denoised"]
|
||||
|
||||
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
|
||||
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
|
||||
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
|
||||
return pred_text_ * scale * strength
|
||||
norm_full_cond = torch.norm(cond_p, dim=1, keepdim=True)
|
||||
norm_pred_text = torch.norm(pred_text_, dim=1, keepdim=True)
|
||||
scale = (norm_full_cond / (norm_pred_text + 1e-8)).clamp(min=0.0, max=1.0)
|
||||
return pred_text_ * scale * strength
|
||||
|
||||
m.set_model_sampler_post_cfg_function(cfg_norm)
|
||||
m.set_model_sampler_post_cfg_function(cfg_norm)
|
||||
return io.NodeOutput(m)
|
||||
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ class EmptyChromaRadianceLatentImage(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="EmptyChromaRadianceLatentImage",
|
||||
category="latent/chroma_radiance",
|
||||
category="model/latent/chroma_radiance",
|
||||
inputs=[
|
||||
io.Int.Input(id="width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input(id="height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
@ -33,7 +33,7 @@ class ChromaRadianceOptions(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="ChromaRadianceOptions",
|
||||
category="model_patches/chroma_radiance",
|
||||
category="model/patch/chroma_radiance",
|
||||
description="Allows setting advanced options for the Chroma Radiance model.",
|
||||
inputs=[
|
||||
io.Model.Input(id="model"),
|
||||
|
||||
@ -8,7 +8,7 @@ class ColorToRGBInt(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ColorToRGBInt",
|
||||
display_name="Color to RGB Int",
|
||||
category="utils",
|
||||
category="utilities",
|
||||
description="Convert a color to a RGB integer value.",
|
||||
inputs=[
|
||||
io.Color.Input("color"),
|
||||
|
||||
@ -9,7 +9,7 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ContextWindowsManual",
|
||||
display_name="Context Windows (Manual)",
|
||||
category="model_patches",
|
||||
category="model/patch",
|
||||
description="Manually set context windows.",
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to apply context windows to during sampling."),
|
||||
|
||||
@ -9,7 +9,7 @@ class SetUnionControlNetType(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SetUnionControlNetType",
|
||||
category="conditioning/controlnet",
|
||||
category="model/conditioning/controlnet",
|
||||
inputs=[
|
||||
io.ControlNet.Input("control_net"),
|
||||
io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())),
|
||||
@ -39,7 +39,7 @@ class ControlNetInpaintingAliMamaApply(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ControlNetInpaintingAliMamaApply",
|
||||
search_aliases=["masked controlnet"],
|
||||
category="conditioning/controlnet",
|
||||
category="model/conditioning/controlnet",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
|
||||
@ -13,7 +13,7 @@ class EmptyCosmosLatentVideo(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="EmptyCosmosLatentVideo",
|
||||
category="latent/video",
|
||||
category="model/latent/video",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
@ -45,7 +45,7 @@ class CosmosImageToVideoLatent(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="CosmosImageToVideoLatent",
|
||||
category="conditioning/inpaint",
|
||||
category="model/conditioning/inpaint",
|
||||
inputs=[
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
@ -88,7 +88,7 @@ class CosmosPredict2ImageToVideoLatent(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="CosmosPredict2ImageToVideoLatent",
|
||||
category="conditioning/inpaint",
|
||||
category="model/conditioning/inpaint",
|
||||
inputs=[
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
|
||||
@ -11,7 +11,7 @@ class CurveEditor(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="CurveEditor",
|
||||
display_name="Curve Editor",
|
||||
category="utils",
|
||||
category="utilities",
|
||||
inputs=[
|
||||
io.Curve.Input("curve"),
|
||||
io.Histogram.Input("histogram", optional=True),
|
||||
@ -38,7 +38,7 @@ class ImageHistogram(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ImageHistogram",
|
||||
display_name="Image Histogram",
|
||||
category="utils",
|
||||
category="utilities",
|
||||
inputs=[
|
||||
io.Image.Input("image"),
|
||||
],
|
||||
|
||||
@ -17,7 +17,7 @@ class BasicScheduler(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="BasicScheduler",
|
||||
category="sampling/schedulers",
|
||||
category="model/sampling/schedulers",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES),
|
||||
@ -47,7 +47,7 @@ class KarrasScheduler(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="KarrasScheduler",
|
||||
category="sampling/schedulers",
|
||||
category="model/sampling/schedulers",
|
||||
inputs=[
|
||||
io.Int.Input("steps", default=20, min=1, max=10000),
|
||||
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
|
||||
@ -69,7 +69,7 @@ class ExponentialScheduler(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ExponentialScheduler",
|
||||
category="sampling/schedulers",
|
||||
category="model/sampling/schedulers",
|
||||
inputs=[
|
||||
io.Int.Input("steps", default=20, min=1, max=10000),
|
||||
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
|
||||
@ -90,7 +90,7 @@ class PolyexponentialScheduler(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="PolyexponentialScheduler",
|
||||
category="sampling/schedulers",
|
||||
category="model/sampling/schedulers",
|
||||
inputs=[
|
||||
io.Int.Input("steps", default=20, min=1, max=10000),
|
||||
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
|
||||
@ -112,7 +112,7 @@ class LaplaceScheduler(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LaplaceScheduler",
|
||||
category="sampling/schedulers",
|
||||
category="model/sampling/schedulers",
|
||||
inputs=[
|
||||
io.Int.Input("steps", default=20, min=1, max=10000),
|
||||
io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True),
|
||||
@ -136,7 +136,7 @@ class SDTurboScheduler(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SDTurboScheduler",
|
||||
category="sampling/schedulers",
|
||||
category="model/sampling/schedulers",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Int.Input("steps", default=1, min=1, max=10),
|
||||
@ -160,7 +160,7 @@ class BetaSamplingScheduler(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="BetaSamplingScheduler",
|
||||
category="sampling/schedulers",
|
||||
category="model/sampling/schedulers",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Int.Input("steps", default=20, min=1, max=10000),
|
||||
@ -182,7 +182,7 @@ class VPScheduler(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="VPScheduler",
|
||||
category="sampling/schedulers",
|
||||
category="model/sampling/schedulers",
|
||||
inputs=[
|
||||
io.Int.Input("steps", default=20, min=1, max=10000),
|
||||
io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), #TODO: fix default values
|
||||
@ -204,7 +204,7 @@ class SplitSigmas(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SplitSigmas",
|
||||
category="sampling/sigmas",
|
||||
category="model/sampling/sigmas",
|
||||
inputs=[
|
||||
io.Sigmas.Input("sigmas"),
|
||||
io.Int.Input("step", default=0, min=0, max=10000),
|
||||
@ -228,7 +228,7 @@ class SplitSigmasDenoise(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SplitSigmasDenoise",
|
||||
category="sampling/sigmas",
|
||||
category="model/sampling/sigmas",
|
||||
inputs=[
|
||||
io.Sigmas.Input("sigmas"),
|
||||
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01),
|
||||
@ -254,7 +254,7 @@ class FlipSigmas(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="FlipSigmas",
|
||||
category="sampling/sigmas",
|
||||
category="model/sampling/sigmas",
|
||||
inputs=[io.Sigmas.Input("sigmas")],
|
||||
outputs=[io.Sigmas.Output()]
|
||||
)
|
||||
@ -276,7 +276,7 @@ class SetFirstSigma(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SetFirstSigma",
|
||||
category="sampling/sigmas",
|
||||
category="model/sampling/sigmas",
|
||||
inputs=[
|
||||
io.Sigmas.Input("sigmas"),
|
||||
io.Float.Input("sigma", default=136.0, min=0.0, max=20000.0, step=0.001, round=False),
|
||||
@ -298,7 +298,7 @@ class ExtendIntermediateSigmas(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ExtendIntermediateSigmas",
|
||||
search_aliases=["interpolate sigmas"],
|
||||
category="sampling/sigmas",
|
||||
category="model/sampling/sigmas",
|
||||
inputs=[
|
||||
io.Sigmas.Input("sigmas"),
|
||||
io.Int.Input("steps", default=2, min=1, max=100),
|
||||
@ -351,7 +351,7 @@ class SamplingPercentToSigma(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplingPercentToSigma",
|
||||
category="sampling/sigmas",
|
||||
category="model/sampling/sigmas",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001),
|
||||
@ -379,7 +379,7 @@ class KSamplerSelect(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="KSamplerSelect",
|
||||
category="sampling/samplers",
|
||||
category="model/sampling/samplers",
|
||||
inputs=[io.Combo.Input("sampler_name", options=comfy.samplers.SAMPLER_NAMES)],
|
||||
outputs=[io.Sampler.Output()]
|
||||
)
|
||||
@ -396,7 +396,7 @@ class SamplerDPMPP_3M_SDE(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerDPMPP_3M_SDE",
|
||||
category="sampling/samplers",
|
||||
category="model/sampling/samplers",
|
||||
inputs=[
|
||||
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
||||
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
||||
@ -421,7 +421,7 @@ class SamplerDPMPP_2M_SDE(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerDPMPP_2M_SDE",
|
||||
category="sampling/samplers",
|
||||
category="model/sampling/samplers",
|
||||
inputs=[
|
||||
io.Combo.Input("solver_type", options=['midpoint', 'heun']),
|
||||
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
||||
@ -448,7 +448,7 @@ class SamplerDPMPP_SDE(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerDPMPP_SDE",
|
||||
category="sampling/samplers",
|
||||
category="model/sampling/samplers",
|
||||
inputs=[
|
||||
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
||||
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
||||
@ -474,7 +474,7 @@ class SamplerDPMPP_2S_Ancestral(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerDPMPP_2S_Ancestral",
|
||||
category="sampling/samplers",
|
||||
category="model/sampling/samplers",
|
||||
inputs=[
|
||||
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
|
||||
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False),
|
||||
@ -494,7 +494,7 @@ class SamplerEulerAncestral(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerEulerAncestral",
|
||||
category="sampling/samplers",
|
||||
category="model/sampling/samplers",
|
||||
inputs=[
|
||||
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
||||
io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
||||
@ -515,7 +515,7 @@ class SamplerEulerAncestralCFGPP(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="SamplerEulerAncestralCFGPP",
|
||||
display_name="SamplerEulerAncestralCFG++",
|
||||
category="sampling/samplers",
|
||||
category="model/sampling/samplers",
|
||||
inputs=[
|
||||
io.Float.Input("eta", default=1.0, min=0.0, max=1.0, step=0.01, round=False),
|
||||
io.Float.Input("s_noise", default=1.0, min=0.0, max=10.0, step=0.01, round=False),
|
||||
@ -537,7 +537,7 @@ class SamplerLMS(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerLMS",
|
||||
category="sampling/samplers",
|
||||
category="model/sampling/samplers",
|
||||
inputs=[io.Int.Input("order", default=4, min=1, max=100, advanced=True)],
|
||||
outputs=[io.Sampler.Output()]
|
||||
)
|
||||
@ -554,7 +554,7 @@ class SamplerDPMAdaptative(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerDPMAdaptative",
|
||||
category="sampling/samplers",
|
||||
category="model/sampling/samplers",
|
||||
inputs=[
|
||||
io.Int.Input("order", default=3, min=2, max=3, advanced=True),
|
||||
io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False, advanced=True),
|
||||
@ -585,7 +585,7 @@ class SamplerER_SDE(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerER_SDE",
|
||||
category="sampling/samplers",
|
||||
category="model/sampling/samplers",
|
||||
inputs=[
|
||||
io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]),
|
||||
io.Int.Input("max_stage", default=3, min=1, max=3, advanced=True),
|
||||
@ -623,7 +623,7 @@ class SamplerSASolver(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="SamplerSASolver",
|
||||
search_aliases=["sde"],
|
||||
category="sampling/samplers",
|
||||
category="model/sampling/samplers",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False, advanced=True),
|
||||
@ -668,7 +668,7 @@ class SamplerSEEDS2(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="SamplerSEEDS2",
|
||||
search_aliases=["sde", "exp heun"],
|
||||
category="sampling/samplers",
|
||||
category="model/sampling/samplers",
|
||||
inputs=[
|
||||
io.Combo.Input("solver_type", options=["phi_1", "phi_2"]),
|
||||
io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength", advanced=True),
|
||||
@ -727,7 +727,7 @@ class SamplerCustom(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerCustom",
|
||||
category="sampling/custom_sampling",
|
||||
category="model/sampling/custom_sampling",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Boolean.Input("add_noise", default=True, advanced=True),
|
||||
@ -795,7 +795,7 @@ class BasicGuider(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="BasicGuider",
|
||||
display_name="Basic Guider",
|
||||
category="sampling/guiders",
|
||||
category="model/sampling/guiders",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Conditioning.Input("conditioning"),
|
||||
@ -817,7 +817,7 @@ class CFGGuider(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="CFGGuider",
|
||||
display_name="CFG Guider",
|
||||
category="sampling/guiders",
|
||||
category="model/sampling/guiders",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Conditioning.Input("positive"),
|
||||
@ -872,7 +872,7 @@ class DualCFGGuider(io.ComfyNode):
|
||||
node_id="DualCFGGuider",
|
||||
search_aliases=["dual prompt guidance"],
|
||||
display_name="Dual CFG Guider",
|
||||
category="sampling/guiders",
|
||||
category="model/sampling/guiders",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Conditioning.Input("cond1"),
|
||||
@ -900,7 +900,7 @@ class DisableNoise(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="DisableNoise",
|
||||
search_aliases=["zero noise"],
|
||||
category="sampling/noise",
|
||||
category="model/sampling/noise",
|
||||
inputs=[],
|
||||
outputs=[io.Noise.Output()]
|
||||
)
|
||||
@ -917,7 +917,7 @@ class RandomNoise(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="RandomNoise",
|
||||
category="sampling/noise",
|
||||
category="model/sampling/noise",
|
||||
inputs=[io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True)],
|
||||
outputs=[io.Noise.Output()]
|
||||
)
|
||||
@ -934,7 +934,7 @@ class SamplerCustomAdvanced(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SamplerCustomAdvanced",
|
||||
category="sampling/custom_sampling",
|
||||
category="model/sampling/custom_sampling",
|
||||
inputs=[
|
||||
io.Noise.Input("noise"),
|
||||
io.Guider.Input("guider"),
|
||||
|
||||
@ -157,7 +157,7 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode):
|
||||
return io.NodeOutput(output_tensor, captions)
|
||||
|
||||
|
||||
def save_images_to_folder(image_list, output_dir, prefix="image"):
|
||||
def save_images_to_folder(image_list, output_dir, prefix="image", overwrite=True):
|
||||
"""Utility function to save a list of image tensors to disk.
|
||||
|
||||
Args:
|
||||
@ -197,7 +197,11 @@ def save_images_to_folder(image_list, output_dir, prefix="image"):
|
||||
raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}")
|
||||
|
||||
# Save image
|
||||
filename = f"{prefix}_{idx:05d}.png"
|
||||
if overwrite:
|
||||
filename = f"{prefix}_{idx:05d}.png"
|
||||
else:
|
||||
_, _, counter, _, resolved_prefix = folder_paths.get_save_image_path(prefix, output_dir)
|
||||
filename = f"{resolved_prefix}_{counter:05}_{idx:05d}.png"
|
||||
filepath = os.path.join(output_dir, filename)
|
||||
img.save(filepath)
|
||||
saved_files.append(filename)
|
||||
@ -230,19 +234,26 @@ class SaveImageDataSetToFolderNode(io.ComfyNode):
|
||||
tooltip="Prefix for saved image filenames.",
|
||||
advanced=True,
|
||||
),
|
||||
io.Combo.Input(
|
||||
"mode",
|
||||
default="overwrite",
|
||||
options=["overwrite", "increment"],
|
||||
tooltip="Whether to overwrite existing files or increment filenames to avoid overwriting."
|
||||
),
|
||||
],
|
||||
outputs=[],
|
||||
is_deprecated=True, # This node is redundant and superseded by existing Save Image nodes where the target folder can be specified in the filename_prefix
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, folder_name, filename_prefix):
|
||||
def execute(cls, images, folder_name, filename_prefix, mode):
|
||||
# Extract scalar values
|
||||
folder_name = folder_name[0]
|
||||
filename_prefix = filename_prefix[0]
|
||||
mode = mode[0]
|
||||
|
||||
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
||||
saved_files = save_images_to_folder(images, output_dir, filename_prefix)
|
||||
saved_files = save_images_to_folder(images, output_dir, filename_prefix, mode=='overwrite')
|
||||
|
||||
logging.info(f"Saved {len(saved_files)} images to {output_dir}.")
|
||||
return io.NodeOutput()
|
||||
@ -278,18 +289,25 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode):
|
||||
tooltip="Prefix for saved image filenames.",
|
||||
advanced=True,
|
||||
),
|
||||
io.Combo.Input(
|
||||
"mode",
|
||||
default="overwrite",
|
||||
options=["overwrite", "increment"],
|
||||
tooltip="Whether to overwrite existing files or increment filenames to avoid overwriting."
|
||||
),
|
||||
],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, folder_name, filename_prefix, texts=None):
|
||||
def execute(cls, images, folder_name, filename_prefix, mode, texts=None):
|
||||
# Extract scalar values
|
||||
folder_name = folder_name[0]
|
||||
filename_prefix = filename_prefix[0]
|
||||
mode = mode[0]
|
||||
|
||||
output_dir = os.path.join(folder_paths.get_output_directory(), folder_name)
|
||||
saved_files = save_images_to_folder(images, output_dir, filename_prefix)
|
||||
saved_files = save_images_to_folder(images, output_dir, filename_prefix, mode=='overwrite')
|
||||
|
||||
# Save captions
|
||||
if texts:
|
||||
@ -574,7 +592,7 @@ class TextProcessingNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id=cls.node_id,
|
||||
display_name=cls.display_name or cls.node_id,
|
||||
category="dataset/text",
|
||||
category="text",
|
||||
is_experimental=True,
|
||||
is_input_list=is_group, # True for group, False for individual
|
||||
inputs=inputs,
|
||||
@ -1208,7 +1226,7 @@ class ResolutionBucket(io.ComfyNode):
|
||||
node_id="ResolutionBucket",
|
||||
search_aliases=["bucket by resolution", "group by resolution", "batch by resolution"],
|
||||
display_name="Resolution Bucket",
|
||||
category="training",
|
||||
category="model/training",
|
||||
description="Group latents and conditionings into buckets",
|
||||
is_experimental=True,
|
||||
is_input_list=True,
|
||||
@ -1302,7 +1320,7 @@ class MakeTrainingDataset(io.ComfyNode):
|
||||
node_id="MakeTrainingDataset",
|
||||
search_aliases=["encode dataset"],
|
||||
display_name="Make Training Dataset",
|
||||
category="training",
|
||||
category="model/training",
|
||||
description="Encode images with VAE and texts with CLIP to create a training dataset of latents and conditionings.",
|
||||
is_experimental=True,
|
||||
is_input_list=True, # images and texts as lists
|
||||
@ -1390,7 +1408,7 @@ class SaveTrainingDataset(io.ComfyNode):
|
||||
node_id="SaveTrainingDataset",
|
||||
search_aliases=["export dataset", "save dataset"],
|
||||
display_name="Save Training Dataset",
|
||||
category="training",
|
||||
category="model/training",
|
||||
description="Save encoded training dataset (latents + conditioning) to disk for efficient loading during training.",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
@ -1493,7 +1511,7 @@ class LoadTrainingDataset(io.ComfyNode):
|
||||
node_id="LoadTrainingDataset",
|
||||
search_aliases=["import dataset", "training data"],
|
||||
display_name="Load Training Dataset",
|
||||
category="training",
|
||||
category="model/training",
|
||||
description="Load encoded training dataset (latents + conditioning) from disk for use in training.",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
|
||||
@ -18,7 +18,7 @@ class EpsilonScaling(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Epsilon Scaling",
|
||||
category="model_patches/unet",
|
||||
category="model/patch/unet",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input(
|
||||
@ -84,7 +84,7 @@ class TemporalScoreRescaling(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="TemporalScoreRescaling",
|
||||
display_name="TSR - Temporal Score Rescaling",
|
||||
category="model_patches/unet",
|
||||
category="model/patch/unet",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Float.Input(
|
||||
|
||||
@ -40,7 +40,7 @@ class EmptyFlux2LatentImage(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="EmptyFlux2LatentImage",
|
||||
display_name="Empty Flux 2 Latent",
|
||||
category="latent",
|
||||
category="model/latent",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
@ -215,7 +215,7 @@ class Flux2Scheduler(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Flux2Scheduler",
|
||||
category="sampling/schedulers",
|
||||
category="model/sampling/schedulers",
|
||||
inputs=[
|
||||
io.Int.Input("steps", default=20, min=1, max=4096),
|
||||
io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1),
|
||||
|
||||
@ -19,7 +19,7 @@ class FrameInterpolationModelLoader(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="FrameInterpolationModelLoader",
|
||||
display_name="Load Frame Interpolation Model",
|
||||
category="loaders",
|
||||
category="model/loaders",
|
||||
inputs=[
|
||||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("frame_interpolation"),
|
||||
tooltip="Select a frame interpolation model to load. Models must be placed in the 'frame_interpolation' folder."),
|
||||
|
||||
@ -29,7 +29,7 @@ class FreeU(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="FreeU",
|
||||
category="model_patches/unet",
|
||||
category="model/patch/unet",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||
@ -76,7 +76,7 @@ class FreeU_V2(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="FreeU_V2",
|
||||
category="model_patches/unet",
|
||||
category="model/patch/unet",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01, advanced=True),
|
||||
|
||||
@ -340,7 +340,7 @@ class GITSScheduler(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="GITSScheduler",
|
||||
category="sampling/schedulers",
|
||||
category="model/sampling/schedulers",
|
||||
inputs=[
|
||||
io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05, advanced=True),
|
||||
io.Int.Input("steps", default=10, min=2, max=1000),
|
||||
|
||||
@ -14,7 +14,7 @@ class EmptyHiDreamO1LatentImage(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="EmptyHiDreamO1LatentImage",
|
||||
display_name="Empty HiDream-O1 Latent Image",
|
||||
category="latent/image",
|
||||
category="model/latent/image",
|
||||
description=(
|
||||
"Empty pixel-space latent for HiDream-O1-Image. The model was "
|
||||
"trained at ~4 megapixels; lower resolutions go off-distribution "
|
||||
@ -47,7 +47,7 @@ class HiDreamO1ReferenceImages(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="HiDreamO1ReferenceImages",
|
||||
display_name="HiDream-O1 Reference Images",
|
||||
category="conditioning/image",
|
||||
category="model/conditioning/image",
|
||||
description=(
|
||||
"Attach 1-10 reference images to conditioning, one for edit instruction"
|
||||
"or multiple for subject-driven personalization."
|
||||
|
||||
@ -41,7 +41,7 @@ class EmptyHunyuanLatentVideo(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="EmptyHunyuanLatentVideo",
|
||||
display_name="Empty HunyuanVideo 1.0 Latent",
|
||||
category="latent/video",
|
||||
category="model/latent/video",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16),
|
||||
@ -81,7 +81,7 @@ class HunyuanVideo15ImageToVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HunyuanVideo15ImageToVideo",
|
||||
category="conditioning/video_models",
|
||||
category="model/conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
@ -132,7 +132,7 @@ class HunyuanVideo15SuperResolution(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="HunyuanVideo15SuperResolution",
|
||||
display_name="Hunyuan Video 1.5 Super Resolution",
|
||||
category="conditioning/video_models",
|
||||
category="model/conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
@ -178,7 +178,7 @@ class LatentUpscaleModelLoader(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LatentUpscaleModelLoader",
|
||||
display_name="Load Latent Upscale Model",
|
||||
category="loaders",
|
||||
category="model/loaders",
|
||||
inputs=[
|
||||
io.Combo.Input("model_name", options=folder_paths.get_filename_list("latent_upscale_models")),
|
||||
],
|
||||
@ -227,7 +227,7 @@ class HunyuanVideo15LatentUpscaleWithModel(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="HunyuanVideo15LatentUpscaleWithModel",
|
||||
display_name="Hunyuan Video 15 Latent Upscale With Model",
|
||||
category="latent",
|
||||
category="model/latent",
|
||||
inputs=[
|
||||
io.LatentUpscaleModel.Input("model"),
|
||||
io.Latent.Input("samples"),
|
||||
@ -308,7 +308,7 @@ class HunyuanImageToVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HunyuanImageToVideo",
|
||||
category="conditioning/video_models",
|
||||
category="model/conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Vae.Input("vae"),
|
||||
@ -359,7 +359,7 @@ class EmptyHunyuanImageLatent(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyHunyuanImageLatent",
|
||||
category="latent",
|
||||
category="model/latent",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("height", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||
@ -384,7 +384,7 @@ class HunyuanRefinerLatent(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="HunyuanRefinerLatent",
|
||||
display_name="Hunyuan Latent Refiner",
|
||||
category="conditioning/video_models",
|
||||
category="model/conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
|
||||
@ -12,7 +12,7 @@ class EmptyLatentHunyuan3Dv2(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyLatentHunyuan3Dv2",
|
||||
category="latent/3d",
|
||||
category="model/latent/3d",
|
||||
inputs=[
|
||||
IO.Int.Input("resolution", default=3072, min=1, max=8192),
|
||||
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
|
||||
@ -35,7 +35,7 @@ class Hunyuan3Dv2Conditioning(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Hunyuan3Dv2Conditioning",
|
||||
category="conditioning/3d_models",
|
||||
category="model/conditioning/3d_models",
|
||||
inputs=[
|
||||
IO.ClipVisionOutput.Input("clip_vision_output"),
|
||||
],
|
||||
@ -60,7 +60,7 @@ class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Hunyuan3Dv2ConditioningMultiView",
|
||||
category="conditioning/3d_models",
|
||||
category="model/conditioning/3d_models",
|
||||
inputs=[
|
||||
IO.ClipVisionOutput.Input("front", optional=True),
|
||||
IO.ClipVisionOutput.Input("left", optional=True),
|
||||
@ -97,7 +97,7 @@ class VAEDecodeHunyuan3D(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeHunyuan3D",
|
||||
category="latent/3d",
|
||||
category="model/latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
|
||||
@ -103,7 +103,7 @@ class HypernetworkLoader(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HypernetworkLoader",
|
||||
display_name="Load Hypernetwork",
|
||||
category="loaders",
|
||||
category="model/loaders",
|
||||
inputs=[
|
||||
IO.Model.Input("model"),
|
||||
IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")),
|
||||
|
||||
@ -27,7 +27,7 @@ class HyperTile(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="HyperTile",
|
||||
category="model_patches/unet",
|
||||
category="model/patch/unet",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Int.Input("tile_size", default=256, min=1, max=2048, advanced=True),
|
||||
|
||||
@ -95,7 +95,7 @@ class BoundingBox(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="PrimitiveBoundingBox",
|
||||
display_name="Bounding Box",
|
||||
category="utils/primitive",
|
||||
category="utilities/primitive",
|
||||
inputs=[
|
||||
IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION),
|
||||
IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION),
|
||||
|
||||
@ -9,7 +9,7 @@ class InstructPixToPixConditioning(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="InstructPixToPixConditioning",
|
||||
category="conditioning/instructpix2pix",
|
||||
category="model/conditioning/instructpix2pix",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
|
||||
@ -13,7 +13,7 @@ class Kandinsky5ImageToVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Kandinsky5ImageToVideo",
|
||||
category="conditioning/video_models",
|
||||
category="model/conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
@ -71,7 +71,7 @@ class NormalizeVideoLatentStart(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="NormalizeVideoLatentStart",
|
||||
category="conditioning/video_models",
|
||||
category="model/conditioning/video_models",
|
||||
description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.",
|
||||
inputs=[
|
||||
io.Latent.Input("latent"),
|
||||
|
||||
@ -22,7 +22,7 @@ class LatentAdd(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LatentAdd",
|
||||
search_aliases=["combine latents", "sum latents"],
|
||||
category="latent/advanced",
|
||||
category="model/latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
io.Latent.Input("samples2"),
|
||||
@ -49,7 +49,7 @@ class LatentSubtract(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LatentSubtract",
|
||||
search_aliases=["difference latent", "remove features"],
|
||||
category="latent/advanced",
|
||||
category="model/latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
io.Latent.Input("samples2"),
|
||||
@ -76,7 +76,7 @@ class LatentMultiply(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LatentMultiply",
|
||||
search_aliases=["scale latent", "amplify latent", "latent gain"],
|
||||
category="latent/advanced",
|
||||
category="model/latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01),
|
||||
@ -100,7 +100,7 @@ class LatentInterpolate(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LatentInterpolate",
|
||||
search_aliases=["blend latent", "mix latent", "lerp latent", "transition"],
|
||||
category="latent/advanced",
|
||||
category="model/latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
io.Latent.Input("samples2"),
|
||||
@ -139,7 +139,7 @@ class LatentConcat(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LatentConcat",
|
||||
search_aliases=["join latents", "stitch latents"],
|
||||
category="latent/advanced",
|
||||
category="model/latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
io.Latent.Input("samples2"),
|
||||
@ -179,7 +179,7 @@ class LatentCut(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LatentCut",
|
||||
search_aliases=["crop latent", "slice latent", "extract region"],
|
||||
category="latent/advanced",
|
||||
category="model/latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
io.Combo.Input("dim", options=["x", "y", "t"]),
|
||||
@ -220,7 +220,7 @@ class LatentCutToBatch(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LatentCutToBatch",
|
||||
search_aliases=["slice to batch", "split latent", "tile latent"],
|
||||
category="latent/advanced",
|
||||
category="model/latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
io.Combo.Input("dim", options=["t", "x", "y"]),
|
||||
@ -262,7 +262,7 @@ class LatentBatch(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LatentBatch",
|
||||
search_aliases=["combine latents", "merge latents", "join latents"],
|
||||
category="latent/batch",
|
||||
category="model/latent/batch",
|
||||
is_deprecated=True,
|
||||
inputs=[
|
||||
io.Latent.Input("samples1"),
|
||||
@ -290,7 +290,7 @@ class LatentBatchSeedBehavior(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentBatchSeedBehavior",
|
||||
category="latent/advanced",
|
||||
category="model/latent/advanced",
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"),
|
||||
@ -319,7 +319,7 @@ class LatentApplyOperation(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LatentApplyOperation",
|
||||
search_aliases=["transform latent"],
|
||||
category="latent/advanced/operations",
|
||||
category="model/latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Latent.Input("samples"),
|
||||
@ -343,7 +343,7 @@ class LatentApplyOperationCFG(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentApplyOperationCFG",
|
||||
category="latent/advanced/operations",
|
||||
category="model/latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
@ -375,7 +375,7 @@ class LatentOperationTonemapReinhard(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LatentOperationTonemapReinhard",
|
||||
search_aliases=["hdr latent"],
|
||||
category="latent/advanced/operations",
|
||||
category="model/latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01),
|
||||
@ -410,7 +410,7 @@ class LatentOperationSharpen(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LatentOperationSharpen",
|
||||
category="latent/advanced/operations",
|
||||
category="model/latent/advanced/operations",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1, advanced=True),
|
||||
@ -447,7 +447,7 @@ class ReplaceVideoLatentFrames(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="ReplaceVideoLatentFrames",
|
||||
category="latent/batch",
|
||||
category="model/latent/batch",
|
||||
inputs=[
|
||||
io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."),
|
||||
io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."),
|
||||
|
||||
@ -34,7 +34,7 @@ class Load3D(IO.ComfyNode):
|
||||
essentials_category="Basics",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model),
|
||||
IO.Combo.Input("model_file", options=["none"] + sorted(files), upload=IO.UploadType.model),
|
||||
IO.Load3D.Input("image"),
|
||||
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
|
||||
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
|
||||
@ -68,8 +68,12 @@ class Load3D(IO.ComfyNode):
|
||||
|
||||
video = InputImpl.VideoFromFile(recording_video_path)
|
||||
|
||||
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
|
||||
return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video, file_3d)
|
||||
file_3d = None
|
||||
mesh_path = ""
|
||||
if model_file and model_file != "none":
|
||||
file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file))
|
||||
mesh_path = model_file
|
||||
return IO.NodeOutput(output_image, output_mask, mesh_path, normal_image, image['camera_info'], video, file_3d)
|
||||
|
||||
process = execute # TODO: remove
|
||||
|
||||
|
||||
@ -13,7 +13,7 @@ class NotNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComfyNotNode",
|
||||
display_name="Not",
|
||||
category="utils/logic",
|
||||
category="utilities/logic",
|
||||
description="Logical NOT operation. Returns true if the value is falsy. Uses Python's rules for truthiness.",
|
||||
search_aliases=["invert", "toggle", "negate", "flip boolean"],
|
||||
inputs=[
|
||||
@ -40,7 +40,7 @@ class AndNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComfyAndNode",
|
||||
display_name="And",
|
||||
category="utils/logic",
|
||||
category="utilities/logic",
|
||||
description="Logical AND operation. Returns true if all of the values are truthy. Uses Python's rules for truthiness.",
|
||||
search_aliases=["all", "every"],
|
||||
inputs=[
|
||||
@ -67,7 +67,7 @@ class OrNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComfyOrNode",
|
||||
display_name="Or",
|
||||
category="utils/logic",
|
||||
category="utilities/logic",
|
||||
description="Logical OR operation. Returns true if any of the values are truthy. Uses Python's rules for truthiness.",
|
||||
search_aliases=["any", "some"],
|
||||
inputs=[
|
||||
@ -90,7 +90,7 @@ class SwitchNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComfySwitchNode",
|
||||
display_name="Switch",
|
||||
category="utils/logic",
|
||||
category="utilities/logic",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Boolean.Input("switch"),
|
||||
@ -121,7 +121,7 @@ class SoftSwitchNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComfySoftSwitchNode",
|
||||
display_name="Soft Switch",
|
||||
category="utils/logic",
|
||||
category="utilities/logic",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Boolean.Input("switch"),
|
||||
@ -176,7 +176,7 @@ class CustomComboNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="CustomCombo",
|
||||
display_name="Custom Combo",
|
||||
category="utils",
|
||||
category="utilities",
|
||||
is_experimental=True,
|
||||
inputs=[io.Combo.Input("choice", options=[])],
|
||||
outputs=[
|
||||
@ -211,7 +211,7 @@ class DCTestNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="DCTestNode",
|
||||
display_name="DCTest",
|
||||
category="utils/logic",
|
||||
category="utilities/logic",
|
||||
is_output_node=True,
|
||||
inputs=[io.DynamicCombo.Input("combo", options=[
|
||||
io.DynamicCombo.Option("option1", [io.String.Input("string")]),
|
||||
@ -249,7 +249,7 @@ class AutogrowNamesTestNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="AutogrowNamesTestNode",
|
||||
display_name="AutogrowNamesTest",
|
||||
category="utils/logic",
|
||||
category="utilities/logic",
|
||||
inputs=[
|
||||
_io.Autogrow.Input("autogrow", template=template)
|
||||
],
|
||||
@ -269,7 +269,7 @@ class AutogrowPrefixTestNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="AutogrowPrefixTestNode",
|
||||
display_name="AutogrowPrefixTest",
|
||||
category="utils/logic",
|
||||
category="utilities/logic",
|
||||
inputs=[
|
||||
_io.Autogrow.Input("autogrow", template=template)
|
||||
],
|
||||
@ -288,7 +288,7 @@ class ComboOutputTestNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComboOptionTestNode",
|
||||
display_name="ComboOptionTest",
|
||||
category="utils/logic",
|
||||
category="utilities/logic",
|
||||
inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]),
|
||||
io.Combo.Input("combo2", options=["option4", "option5", "option6"])],
|
||||
outputs=[io.Combo.Output(), io.Combo.Output()],
|
||||
@ -305,7 +305,7 @@ class ConvertStringToComboNode(io.ComfyNode):
|
||||
node_id="ConvertStringToComboNode",
|
||||
search_aliases=["string to dropdown", "text to combo"],
|
||||
display_name="Convert String to Combo",
|
||||
category="utils/logic",
|
||||
category="utilities/logic",
|
||||
inputs=[io.String.Input("string")],
|
||||
outputs=[io.Combo.Output()],
|
||||
)
|
||||
@ -321,7 +321,7 @@ class InvertBooleanNode(io.ComfyNode):
|
||||
node_id="InvertBooleanNode",
|
||||
search_aliases=["not", "toggle", "negate", "flip boolean"],
|
||||
display_name="Invert Boolean",
|
||||
category="utils/logic",
|
||||
category="utilities/logic",
|
||||
inputs=[io.Boolean.Input("boolean")],
|
||||
outputs=[io.Boolean.Output()],
|
||||
)
|
||||
|
||||
@ -30,7 +30,7 @@ class LoraLoaderBypass:
|
||||
OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.")
|
||||
FUNCTION = "load_lora"
|
||||
|
||||
CATEGORY = "loaders"
|
||||
CATEGORY = "model/loaders"
|
||||
DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios."
|
||||
EXPERIMENTAL = True
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ class LotusConditioning(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LotusConditioning",
|
||||
category="conditioning/lotus",
|
||||
category="model/conditioning/lotus",
|
||||
inputs=[],
|
||||
outputs=[io.Conditioning.Output(display_name="conditioning")],
|
||||
)
|
||||
|
||||
@ -25,7 +25,7 @@ class GetICLoRAParameters(io.ComfyNode):
|
||||
display_name="Get IC-LoRA Parameters",
|
||||
description="Extracts IC-LoRA parameters from the safetensors metadata of a LoRA-loaded "
|
||||
"model and outputs them for LTXVAddGuide (eg. reference_downscale_factor).",
|
||||
category="conditioning/video_models",
|
||||
category="model/conditioning/video_models",
|
||||
search_aliases=["ic-lora", "ic lora", "iclora", "downscale factor", "reference downscale"],
|
||||
inputs=[
|
||||
io.Model.Input(
|
||||
@ -62,7 +62,7 @@ class EmptyLTXVLatentVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyLTXVLatentVideo",
|
||||
category="latent/video/ltxv",
|
||||
category="model/latent/video/ltxv",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32),
|
||||
@ -86,7 +86,7 @@ class LTXVImgToVideo(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVImgToVideo",
|
||||
category="conditioning/video_models",
|
||||
category="model/conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
@ -131,7 +131,7 @@ class LTXVImgToVideoInplace(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVImgToVideoInplace",
|
||||
category="conditioning/video_models",
|
||||
category="model/conditioning/video_models",
|
||||
inputs=[
|
||||
io.Vae.Input("vae"),
|
||||
io.Image.Input("image"),
|
||||
@ -226,10 +226,20 @@ def get_noise_mask(latent):
|
||||
noise_mask = noise_mask.clone()
|
||||
return noise_mask
|
||||
|
||||
def get_keyframe_idxs(cond):
|
||||
def get_keyframe_idxs(cond, latent_shape=None):
|
||||
keyframe_idxs = conditioning_get_any_value(cond, "keyframe_idxs", None)
|
||||
if keyframe_idxs is None:
|
||||
return None, 0
|
||||
# Get number of keyframes from latent_shape or guide_attention_entries if available
|
||||
if latent_shape is not None and len(latent_shape) == 5:
|
||||
tokens_per_frame = latent_shape[-2] * latent_shape[-1]
|
||||
num_keyframes = keyframe_idxs.shape[2] // tokens_per_frame
|
||||
return keyframe_idxs, num_keyframes
|
||||
entries = conditioning_get_any_value(cond, "guide_attention_entries", None)
|
||||
if entries:
|
||||
num_keyframes = sum(e["latent_shape"][0] for e in entries)
|
||||
return keyframe_idxs, num_keyframes
|
||||
# fallback, may under-count if keyframes share t-start
|
||||
# keyframe_idxs contains start/end positions (last dimension), checking for unqiue values only for start
|
||||
num_keyframes = torch.unique(keyframe_idxs[:, 0, :, 0]).shape[0]
|
||||
return keyframe_idxs, num_keyframes
|
||||
@ -241,7 +251,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVAddGuide",
|
||||
category="conditioning/video_models",
|
||||
category="model/conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
@ -322,9 +332,9 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
return factor
|
||||
|
||||
@classmethod
|
||||
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
|
||||
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors, latent_shape=None):
|
||||
time_scale_factor, _, _ = scale_factors
|
||||
_, num_keyframes = get_keyframe_idxs(cond)
|
||||
_, num_keyframes = get_keyframe_idxs(cond, latent_shape)
|
||||
latent_count = latent_length - num_keyframes
|
||||
frame_idx = frame_idx if frame_idx >= 0 else max((latent_count - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||
if guide_length > 1 and frame_idx != 0:
|
||||
@ -436,7 +446,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
|
||||
resolved_frame_idx = frame_idx
|
||||
if frame_idx < 0:
|
||||
_, num_keyframes = get_keyframe_idxs(positive)
|
||||
_, num_keyframes = get_keyframe_idxs(positive, latent_image.shape)
|
||||
resolved_frame_idx = max((latent_length - num_keyframes - 1) * time_scale_factor + 1 + frame_idx, 0)
|
||||
causal_fix = resolved_frame_idx == 0 or num_frames_to_keep == 1
|
||||
|
||||
@ -454,7 +464,7 @@ class LTXVAddGuide(io.ComfyNode):
|
||||
if latent_downscale_factor > 1:
|
||||
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
|
||||
|
||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
|
||||
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors, latent_shape=latent_image.shape)
|
||||
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
|
||||
|
||||
positive, negative, latent_image, noise_mask = cls.append_keyframe(
|
||||
@ -488,7 +498,7 @@ class LTXVCropGuides(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVCropGuides",
|
||||
category="conditioning/video_models",
|
||||
category="model/conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
@ -506,7 +516,7 @@ class LTXVCropGuides(io.ComfyNode):
|
||||
latent_image = latent["samples"].clone()
|
||||
noise_mask = get_noise_mask(latent)
|
||||
|
||||
_, num_keyframes = get_keyframe_idxs(positive)
|
||||
_, num_keyframes = get_keyframe_idxs(positive, latent_image.shape)
|
||||
if num_keyframes == 0:
|
||||
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask},)
|
||||
|
||||
@ -532,7 +542,7 @@ class LTXVConditioning(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVConditioning",
|
||||
category="conditioning/video_models",
|
||||
category="model/conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
@ -601,7 +611,7 @@ class LTXVScheduler(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVScheduler",
|
||||
category="sampling/schedulers",
|
||||
category="model/sampling/schedulers",
|
||||
inputs=[
|
||||
io.Int.Input("steps", default=20, min=1, max=10000),
|
||||
io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01),
|
||||
@ -736,7 +746,7 @@ class LTXVConcatAVLatent(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVConcatAVLatent",
|
||||
category="latent/video/ltxv",
|
||||
category="model/latent/video/ltxv",
|
||||
inputs=[
|
||||
io.Latent.Input("video_latent"),
|
||||
io.Latent.Input("audio_latent"),
|
||||
@ -771,7 +781,7 @@ class LTXVSeparateAVLatent(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LTXVSeparateAVLatent",
|
||||
category="latent/video/ltxv",
|
||||
category="model/latent/video/ltxv",
|
||||
description="LTXV Separate AV Latent",
|
||||
inputs=[
|
||||
io.Latent.Input("av_latent"),
|
||||
@ -804,7 +814,7 @@ class LTXVReferenceAudio(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LTXVReferenceAudio",
|
||||
display_name="LTXV Reference Audio (ID-LoRA)",
|
||||
category="conditioning/audio",
|
||||
category="model/conditioning/audio",
|
||||
description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
|
||||
@ -12,7 +12,7 @@ class LTXVAudioVAELoader(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAELoader",
|
||||
display_name="Load LTXV Audio VAE",
|
||||
category="loaders",
|
||||
category="model/loaders",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"ckpt_name",
|
||||
@ -40,7 +40,7 @@ class LTXVAudioVAEEncode(VAEEncodeAudio):
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAEEncode",
|
||||
display_name="LTXV Audio VAE Encode",
|
||||
category="latent/audio",
|
||||
category="model/latent/audio",
|
||||
inputs=[
|
||||
io.Audio.Input("audio", tooltip="The audio to be encoded."),
|
||||
io.Vae.Input(
|
||||
@ -63,7 +63,7 @@ class LTXVAudioVAEDecode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LTXVAudioVAEDecode",
|
||||
display_name="LTXV Audio VAE Decode",
|
||||
category="latent/audio",
|
||||
category="model/latent/audio",
|
||||
inputs=[
|
||||
io.Latent.Input("samples", tooltip="The latent to be decoded."),
|
||||
io.Vae.Input(
|
||||
@ -96,7 +96,7 @@ class LTXVEmptyLatentAudio(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="LTXVEmptyLatentAudio",
|
||||
display_name="LTXV Empty Latent Audio",
|
||||
category="latent/audio",
|
||||
category="model/latent/audio",
|
||||
inputs=[
|
||||
io.Int.Input(
|
||||
"frames_number",
|
||||
|
||||
@ -1,32 +1,32 @@
|
||||
from comfy import model_management
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
from typing_extensions import override
|
||||
import math
|
||||
|
||||
class LTXVLatentUpsampler:
|
||||
|
||||
class LTXVLatentUpsampler(IO.ComfyNode):
|
||||
"""
|
||||
Upsamples a video latent by a factor of 2.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"samples": ("LATENT",),
|
||||
"upscale_model": ("LATENT_UPSCALE_MODEL",),
|
||||
"vae": ("VAE",),
|
||||
}
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="LTXVLatentUpsampler",
|
||||
category="model/latent/video",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.LatentUpscaleModel.Input("upscale_model"),
|
||||
IO.Vae.Input("vae"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Latent.Output(),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "upsample_latent"
|
||||
CATEGORY = "latent/video"
|
||||
EXPERIMENTAL = True
|
||||
|
||||
def upsample_latent(
|
||||
self,
|
||||
samples: dict,
|
||||
upscale_model,
|
||||
vae,
|
||||
) -> tuple:
|
||||
@classmethod
|
||||
def execute(cls, samples, upscale_model, vae) -> IO.NodeOutput:
|
||||
"""
|
||||
Upsample the input latent using the provided model.
|
||||
|
||||
@ -34,7 +34,6 @@ class LTXVLatentUpsampler:
|
||||
samples (dict): Input latent samples
|
||||
upscale_model (LatentUpsampler): Loaded upscale model
|
||||
vae: VAE model for normalization
|
||||
auto_tiling (bool): Whether to automatically tile the input for processing
|
||||
|
||||
Returns:
|
||||
tuple: Tuple containing the upsampled latent
|
||||
@ -67,9 +66,16 @@ class LTXVLatentUpsampler:
|
||||
return_dict = samples.copy()
|
||||
return_dict["samples"] = upsampled_latents
|
||||
return_dict.pop("noise_mask", None)
|
||||
return (return_dict,)
|
||||
return IO.NodeOutput(return_dict)
|
||||
|
||||
upsample_latent = execute # TODO: remove
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"LTXVLatentUpsampler": LTXVLatentUpsampler,
|
||||
}
|
||||
class LTXVLatentUpsamplerExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [LTXVLatentUpsampler]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> LTXVLatentUpsamplerExtension:
|
||||
return LTXVLatentUpsamplerExtension()
|
||||
|
||||
@ -81,7 +81,7 @@ class CLIPTextEncodeLumina2(io.ComfyNode):
|
||||
node_id="CLIPTextEncodeLumina2",
|
||||
search_aliases=["lumina prompt"],
|
||||
display_name="CLIP Text Encode for Lumina2",
|
||||
category="conditioning",
|
||||
category="model/conditioning",
|
||||
description="Encodes a system prompt and a user prompt using a CLIP model into an embedding "
|
||||
"that can be used to guide the diffusion model towards generating specific images.",
|
||||
inputs=[
|
||||
|
||||
@ -53,7 +53,7 @@ class LatentCompositeMasked(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LatentCompositeMasked",
|
||||
search_aliases=["overlay latent", "layer latent", "paste latent", "inpaint latent"],
|
||||
category="latent",
|
||||
category="model/latent",
|
||||
inputs=[
|
||||
IO.Latent.Input("destination"),
|
||||
IO.Latent.Input("source"),
|
||||
|
||||
@ -69,7 +69,7 @@ class MathExpressionNode(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="ComfyMathExpression",
|
||||
display_name="Math Expression",
|
||||
category="utils",
|
||||
category="utilities",
|
||||
search_aliases=[
|
||||
"expression", "formula", "calculate", "calculator",
|
||||
"eval", "math",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user