Merge branch 'comfyanonymous:master' into master

This commit is contained in:
mengqin 2025-12-05 07:55:05 -08:00 committed by GitHub
commit 483ba1e98b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
55 changed files with 2496 additions and 928 deletions

View File

@ -66,8 +66,10 @@ if branch is None:
try: try:
ref = repo.lookup_reference('refs/remotes/origin/master') ref = repo.lookup_reference('refs/remotes/origin/master')
except: except:
print("pulling.") # noqa: T201 print("fetching.") # noqa: T201
pull(repo) for remote in repo.remotes:
if remote.name == "origin":
remote.fetch()
ref = repo.lookup_reference('refs/remotes/origin/master') ref = repo.lookup_reference('refs/remotes/origin/master')
repo.checkout(ref) repo.checkout(ref)
branch = repo.lookup_branch('master') branch = repo.lookup_branch('master')
@ -149,3 +151,4 @@ try:
shutil.copy(stable_update_script, stable_update_script_to) shutil.copy(stable_update_script, stable_update_script_to)
except: except:
pass pass

View File

@ -1,3 +1,2 @@
# Admins # Admins
* @comfyanonymous * @comfyanonymous @kosinkadink @guill
* @kosinkadink

View File

@ -81,6 +81,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) - [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/) - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/) - [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
- [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5)
- Audio Models - Audio Models
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) - [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/) - [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)

View File

@ -122,6 +122,12 @@ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.") upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
manager_group = parser.add_mutually_exclusive_group()
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
vram_group = parser.add_mutually_exclusive_group() vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).") vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.") vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
@ -169,6 +175,7 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level') parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).") parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
# The default built-in provider hosted under web/ # The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest" DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"

View File

@ -40,7 +40,8 @@ class ChromaParams:
out_dim: int out_dim: int
hidden_dim: int hidden_dim: int
n_layers: int n_layers: int
txt_ids_dims: list
vec_in_dim: int

View File

@ -57,6 +57,35 @@ class MLPEmbedder(nn.Module):
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x))) return self.out_layer(self.silu(self.in_layer(x)))
class YakMLP(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device)
self.act_fn = nn.SiLU()
def forward(self, x: Tensor) -> Tensor:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None):
if yak_mlp:
return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
if mlp_silu_act:
return nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
SiLUActivation(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
)
else:
return nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
class RMSNorm(torch.nn.Module): class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None): def __init__(self, dim: int, dtype=None, device=None, operations=None):
@ -140,7 +169,7 @@ class SiLUActivation(nn.Module):
class DoubleStreamBlock(nn.Module): class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None): def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio) mlp_hidden_dim = int(hidden_size * mlp_ratio)
@ -156,18 +185,7 @@ class DoubleStreamBlock(nn.Module):
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
if mlp_silu_act: self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
SiLUActivation(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
)
else:
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
if self.modulation: if self.modulation:
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
@ -177,18 +195,7 @@ class DoubleStreamBlock(nn.Module):
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
if mlp_silu_act: self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
SiLUActivation(),
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
)
else:
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt self.flipped_img_txt = flipped_img_txt
@ -275,6 +282,7 @@ class SingleStreamBlock(nn.Module):
modulation=True, modulation=True,
mlp_silu_act=False, mlp_silu_act=False,
bias=True, bias=True,
yak_mlp=False,
dtype=None, dtype=None,
device=None, device=None,
operations=None operations=None
@ -288,12 +296,17 @@ class SingleStreamBlock(nn.Module):
self.mlp_hidden_dim = int(hidden_size * mlp_ratio) self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_hidden_dim_first = self.mlp_hidden_dim self.mlp_hidden_dim_first = self.mlp_hidden_dim
self.yak_mlp = yak_mlp
if mlp_silu_act: if mlp_silu_act:
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2) self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
self.mlp_act = SiLUActivation() self.mlp_act = SiLUActivation()
else: else:
self.mlp_act = nn.GELU(approximate="tanh") self.mlp_act = nn.GELU(approximate="tanh")
if self.yak_mlp:
self.mlp_hidden_dim_first *= 2
self.mlp_act = nn.SiLU()
# qkv and mlp_in # qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device) self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
# proj and mlp_out # proj and mlp_out
@ -325,6 +338,9 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v del q, k, v
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
if self.yak_mlp:
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
else:
mlp = self.mlp_act(mlp) mlp = self.mlp_act(mlp)
output = self.linear2(torch.cat((attn, mlp), 2)) output = self.linear2(torch.cat((attn, mlp), 2))
x += apply_mod(output, mod.gate, None, modulation_dims) x += apply_mod(output, mod.gate, None, modulation_dims)

View File

@ -15,7 +15,8 @@ from .layers import (
MLPEmbedder, MLPEmbedder,
SingleStreamBlock, SingleStreamBlock,
timestep_embedding, timestep_embedding,
Modulation Modulation,
RMSNorm
) )
@dataclass @dataclass
@ -34,11 +35,14 @@ class FluxParams:
patch_size: int patch_size: int
qkv_bias: bool qkv_bias: bool
guidance_embed: bool guidance_embed: bool
txt_ids_dims: list
global_modulation: bool = False global_modulation: bool = False
mlp_silu_act: bool = False mlp_silu_act: bool = False
ops_bias: bool = True ops_bias: bool = True
default_ref_method: str = "offset" default_ref_method: str = "offset"
ref_index_scale: float = 1.0 ref_index_scale: float = 1.0
yak_mlp: bool = False
txt_norm: bool = False
class Flux(nn.Module): class Flux(nn.Module):
@ -76,6 +80,11 @@ class Flux(nn.Module):
) )
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
if params.txt_norm:
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
else:
self.txt_norm = None
self.double_blocks = nn.ModuleList( self.double_blocks = nn.ModuleList(
[ [
DoubleStreamBlock( DoubleStreamBlock(
@ -86,6 +95,7 @@ class Flux(nn.Module):
modulation=params.global_modulation is False, modulation=params.global_modulation is False,
mlp_silu_act=params.mlp_silu_act, mlp_silu_act=params.mlp_silu_act,
proj_bias=params.ops_bias, proj_bias=params.ops_bias,
yak_mlp=params.yak_mlp,
dtype=dtype, device=device, operations=operations dtype=dtype, device=device, operations=operations
) )
for _ in range(params.depth) for _ in range(params.depth)
@ -94,7 +104,7 @@ class Flux(nn.Module):
self.single_blocks = nn.ModuleList( self.single_blocks = nn.ModuleList(
[ [
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks) for _ in range(params.depth_single_blocks)
] ]
) )
@ -150,6 +160,8 @@ class Flux(nn.Module):
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
if self.txt_norm is not None:
txt = self.txt_norm(txt)
txt = self.txt_in(txt) txt = self.txt_in(txt)
vec_orig = vec vec_orig = vec
@ -332,8 +344,9 @@ class Flux(nn.Module):
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32) txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
if len(self.params.axes_dim) == 4: # Flux 2 if len(self.params.txt_ids_dims) > 0:
txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) for i in self.params.txt_ids_dims:
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens] out = out[:, :img_tokens]

View File

@ -1,7 +1,8 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
import model_management, model_patcher import model_management, model_patcher
class SRResidualCausalBlock3D(nn.Module): class SRResidualCausalBlock3D(nn.Module):

View File

@ -1,42 +1,12 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
import comfy.ops import comfy.ops
import comfy.ldm.models.autoencoder import comfy.ldm.models.autoencoder
import comfy.model_management import comfy.model_management
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
class NoPadConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
super().__init__()
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
return self.conv(x)
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
x = xl[0]
xl.clear()
if conv_carry_out is not None:
to_push = x[:, :, -2:, :, :].clone()
conv_carry_out.append(to_push)
if isinstance(op, NoPadConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
else:
carry_len = conv_carry_in[0].shape[2]
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
out = op(x)
return out
class RMS_norm(nn.Module): class RMS_norm(nn.Module):
def __init__(self, dim): def __init__(self, dim):
@ -49,7 +19,7 @@ class RMS_norm(nn.Module):
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
class DnSmpl(nn.Module): class DnSmpl(nn.Module):
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d): def __init__(self, ic, oc, tds, refiner_vae, op):
super().__init__() super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2 fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0 assert oc % fct == 0
@ -109,7 +79,7 @@ class DnSmpl(nn.Module):
class UpSmpl(nn.Module): class UpSmpl(nn.Module):
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d): def __init__(self, ic, oc, tus, refiner_vae, op):
super().__init__() super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2 fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1) self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
@ -163,23 +133,6 @@ class UpSmpl(nn.Module):
return h + x return h + x
class HunyuanRefinerResnetBlock(ResnetBlock):
def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
h = x
h = [ self.swish(self.norm1(x)) ]
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
h = [ self.dropout(self.swish(self.norm2(h))) ]
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x+h
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_): ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
@ -191,7 +144,7 @@ class Encoder(nn.Module):
self.refiner_vae = refiner_vae self.refiner_vae = refiner_vae
if self.refiner_vae: if self.refiner_vae:
conv_op = NoPadConv3d conv_op = CarriedConv3d
norm_op = RMS_norm norm_op = RMS_norm
else: else:
conv_op = ops.Conv3d conv_op = ops.Conv3d
@ -206,8 +159,9 @@ class Encoder(nn.Module):
for i, tgt in enumerate(block_out_channels): for i, tgt in enumerate(block_out_channels):
stage = nn.Module() stage = nn.Module()
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt, out_channels=tgt,
temb_channels=0,
conv_op=conv_op, norm_op=norm_op) conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)]) for j in range(num_res_blocks)])
ch = tgt ch = tgt
@ -218,9 +172,9 @@ class Encoder(nn.Module):
self.down.append(stage) self.down.append(stage)
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.norm_out = norm_op(ch) self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1) self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
@ -246,22 +200,20 @@ class Encoder(nn.Module):
conv_carry_out = [] conv_carry_out = []
if i == len(x) - 1: if i == len(x) - 1:
conv_carry_out = None conv_carry_out = None
x1 = [ x1 ] x1 = [ x1 ]
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for stage in self.down: for stage in self.down:
for blk in stage.block: for blk in stage.block:
x1 = blk(x1, conv_carry_in, conv_carry_out) x1 = blk(x1, None, conv_carry_in, conv_carry_out)
if hasattr(stage, 'downsample'): if hasattr(stage, 'downsample'):
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out) x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
out.append(x1) out.append(x1)
conv_carry_in = conv_carry_out conv_carry_in = conv_carry_out
if len(out) > 1: out = torch_cat_if_needed(out, dim=2)
out = torch.cat(out, dim=2)
else:
out = out[0]
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out))) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
del out del out
@ -288,7 +240,7 @@ class Decoder(nn.Module):
self.refiner_vae = refiner_vae self.refiner_vae = refiner_vae
if self.refiner_vae: if self.refiner_vae:
conv_op = NoPadConv3d conv_op = CarriedConv3d
norm_op = RMS_norm norm_op = RMS_norm
else: else:
conv_op = ops.Conv3d conv_op = ops.Conv3d
@ -298,9 +250,9 @@ class Decoder(nn.Module):
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList() self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length() depth = (ffactor_spatial >> 1).bit_length()
@ -308,8 +260,9 @@ class Decoder(nn.Module):
for i, tgt in enumerate(block_out_channels): for i, tgt in enumerate(block_out_channels):
stage = nn.Module() stage = nn.Module()
stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
out_channels=tgt, out_channels=tgt,
temb_channels=0,
conv_op=conv_op, norm_op=norm_op) conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)]) for j in range(num_res_blocks + 1)])
ch = tgt ch = tgt
@ -340,7 +293,7 @@ class Decoder(nn.Module):
conv_carry_out = None conv_carry_out = None
for stage in self.up: for stage in self.up:
for blk in stage.block: for blk in stage.block:
x1 = blk(x1, conv_carry_in, conv_carry_out) x1 = blk(x1, None, conv_carry_in, conv_carry_out)
if hasattr(stage, 'upsample'): if hasattr(stage, 'upsample'):
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out) x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
@ -350,10 +303,7 @@ class Decoder(nn.Module):
conv_carry_in = conv_carry_out conv_carry_in = conv_carry_out
del x del x
if len(out) > 1: out = torch_cat_if_needed(out, dim=2)
out = torch.cat(out, dim=2)
else:
out = out[0]
if not self.refiner_vae: if not self.refiner_vae:
if z.shape[-3] == 1: if z.shape[-3] == 1:

View File

@ -0,0 +1,113 @@
import torch
from torch import nn
from .model import JointTransformerBlock
class ZImageControlTransformerBlock(JointTransformerBlock):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
multiple_of: int,
ffn_dim_multiplier: float,
norm_eps: float,
qk_norm: bool,
modulation=True,
block_id=0,
operation_settings=None,
):
super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
self.block_id = block_id
if block_id == 0:
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, c, x, **kwargs):
if self.block_id == 0:
c = self.before_proj(c) + x
c = super().forward(c, **kwargs)
c_skip = self.after_proj(c)
return c_skip, c
class ZImage_Control(torch.nn.Module):
def __init__(
self,
dim: int = 3840,
n_heads: int = 30,
n_kv_heads: int = 30,
multiple_of: int = 256,
ffn_dim_multiplier: float = (8.0 / 3.0),
norm_eps: float = 1e-5,
qk_norm: bool = True,
dtype=None,
device=None,
operations=None,
**kwargs
):
super().__init__()
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.additional_in_dim = 0
self.control_in_dim = 16
n_refiner_layers = 2
self.n_control_layers = 6
self.control_layers = nn.ModuleList(
[
ZImageControlTransformerBlock(
i,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
block_id=i,
operation_settings=operation_settings,
)
for i in range(self.n_control_layers)
]
)
all_x_embedder = {}
patch_size = 2
f_patch_size = 1
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
self.control_noise_refiner = nn.ModuleList(
[
JointTransformerBlock(
layer_id,
dim,
n_heads,
n_kv_heads,
multiple_of,
ffn_dim_multiplier,
norm_eps,
qk_norm,
modulation=True,
z_image_modulation=True,
operation_settings=operation_settings,
)
for layer_id in range(n_refiner_layers)
]
)
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
patch_size = 2
f_patch_size = 1
pH = pW = patch_size
B, C, H, W = control_context.shape
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
x_attn_mask = None
for layer in self.control_noise_refiner:
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
return control_context
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)

View File

@ -22,6 +22,10 @@ def modulate(x, scale):
# Core NextDiT Model # # Core NextDiT Model #
############################################################################# #############################################################################
def clamp_fp16(x):
if x.dtype == torch.float16:
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
class JointAttention(nn.Module): class JointAttention(nn.Module):
"""Multi-head attention module.""" """Multi-head attention module."""
@ -169,7 +173,7 @@ class FeedForward(nn.Module):
# @torch.compile # @torch.compile
def _forward_silu_gating(self, x1, x3): def _forward_silu_gating(self, x1, x3):
return F.silu(x1) * x3 return clamp_fp16(F.silu(x1) * x3)
def forward(self, x): def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
@ -273,27 +277,27 @@ class JointTransformerBlock(nn.Module):
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
self.attention( clamp_fp16(self.attention(
modulate(self.attention_norm1(x), scale_msa), modulate(self.attention_norm1(x), scale_msa),
x_mask, x_mask,
freqs_cis, freqs_cis,
transformer_options=transformer_options, transformer_options=transformer_options,
) ))
) )
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
self.feed_forward( clamp_fp16(self.feed_forward(
modulate(self.ffn_norm1(x), scale_mlp), modulate(self.ffn_norm1(x), scale_mlp),
) ))
) )
else: else:
assert adaln_input is None assert adaln_input is None
x = x + self.attention_norm2( x = x + self.attention_norm2(
self.attention( clamp_fp16(self.attention(
self.attention_norm1(x), self.attention_norm1(x),
x_mask, x_mask,
freqs_cis, freqs_cis,
transformer_options=transformer_options, transformer_options=transformer_options,
) ))
) )
x = x + self.ffn_norm2( x = x + self.ffn_norm2(
self.feed_forward( self.feed_forward(
@ -564,7 +568,7 @@ class NextDiT(nn.Module):
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
# def forward(self, x, t, cap_feats, cap_mask): # def forward(self, x, t, cap_feats, cap_mask):
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
t = 1.0 - timesteps t = 1.0 - timesteps
cap_feats = context cap_feats = context
cap_mask = attention_mask cap_mask = attention_mask
@ -581,16 +585,24 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
patches = transformer_options.get("patches", {})
transformer_options = kwargs.get("transformer_options", {}) transformer_options = kwargs.get("transformer_options", {})
x_is_tensor = isinstance(x, torch.Tensor) x_is_tensor = isinstance(x, torch.Tensor)
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
freqs_cis = freqs_cis.to(x.device) freqs_cis = freqs_cis.to(img.device)
for layer in self.layers: for i, layer in enumerate(self.layers):
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options) img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
if "img" in out:
img[:, cap_size[0]:] = out["img"]
if "txt" in out:
img[:, :cap_size[0]] = out["txt"]
x = self.final_layer(x, adaln_input) img = self.final_layer(img, adaln_input)
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
return -x return -img

View File

@ -529,6 +529,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@wrap_attn @wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
exception_fallback = False
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
tensor_layout = "HND" tensor_layout = "HND"
@ -553,6 +554,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e: except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e)) logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
exception_fallback = True
if exception_fallback:
if tensor_layout == "NHD": if tensor_layout == "NHD":
q, k, v = map( q, k, v = map(
lambda t: t.transpose(1, 2), lambda t: t.transpose(1, 2),

View File

@ -13,6 +13,12 @@ if model_management.xformers_enabled_vae():
import xformers import xformers
import xformers.ops import xformers.ops
def torch_cat_if_needed(xl, dim):
if len(xl) > 1:
return torch.cat(xl, dim)
else:
return xl[0]
def get_timestep_embedding(timesteps, embedding_dim): def get_timestep_embedding(timesteps, embedding_dim):
""" """
This matches the implementation in Denoising Diffusion Probabilistic Models: This matches the implementation in Denoising Diffusion Probabilistic Models:
@ -43,6 +49,37 @@ def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class CarriedConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
super().__init__()
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
return self.conv(x)
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
x = xl[0]
xl.clear()
if isinstance(op, CarriedConv3d):
if conv_carry_in is None:
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
else:
carry_len = conv_carry_in[0].shape[2]
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
if conv_carry_out is not None:
to_push = x[:, :, -2:, :, :].clone()
conv_carry_out.append(to_push)
out = op(x)
return out
class VideoConv3d(nn.Module): class VideoConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs): def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
super().__init__() super().__init__()
@ -89,29 +126,24 @@ class Upsample(nn.Module):
stride=1, stride=1,
padding=1) padding=1)
def forward(self, x): def forward(self, x, conv_carry_in=None, conv_carry_out=None):
scale_factor = self.scale_factor scale_factor = self.scale_factor
if isinstance(scale_factor, (int, float)): if isinstance(scale_factor, (int, float)):
scale_factor = (scale_factor,) * (x.ndim - 2) scale_factor = (scale_factor,) * (x.ndim - 2)
if x.ndim == 5 and scale_factor[0] > 1.0: if x.ndim == 5 and scale_factor[0] > 1.0:
t = x.shape[2] results = []
if t > 1: if conv_carry_in is None:
a, b = x.split((1, t - 1), dim=2) first = x[:, :, :1, :, :]
del x results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2))
b = interpolate_up(b, scale_factor) x = x[:, :, 1:, :, :]
else: if x.shape[2] > 0:
a = x results.append(interpolate_up(x, scale_factor))
x = torch_cat_if_needed(results, dim=2)
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
if t > 1:
x = torch.cat((a, b), dim=2)
else:
x = a
else: else:
x = interpolate_up(x, scale_factor) x = interpolate_up(x, scale_factor)
if self.with_conv: if self.with_conv:
x = self.conv(x) x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
return x return x
@ -127,12 +159,15 @@ class Downsample(nn.Module):
stride=stride, stride=stride,
padding=0) padding=0)
def forward(self, x): def forward(self, x, conv_carry_in=None, conv_carry_out=None):
if self.with_conv: if self.with_conv:
if x.ndim == 4: if isinstance(self.conv, CarriedConv3d):
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
elif x.ndim == 4:
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
mode = "constant" mode = "constant"
x = torch.nn.functional.pad(x, pad, mode=mode, value=0) x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
x = self.conv(x)
elif x.ndim == 5: elif x.ndim == 5:
pad = (1, 1, 1, 1, 2, 0) pad = (1, 1, 1, 1, 2, 0)
mode = "replicate" mode = "replicate"
@ -183,23 +218,23 @@ class ResnetBlock(nn.Module):
stride=1, stride=1,
padding=0) padding=0)
def forward(self, x, temb=None): def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None):
h = x h = x
h = self.norm1(h) h = self.norm1(h)
h = self.swish(h) h = [ self.swish(h) ]
h = self.conv1(h) h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if temb is not None: if temb is not None:
h = h + self.temb_proj(self.swish(temb))[:,:,None,None] h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
h = self.norm2(h) h = self.norm2(h)
h = self.swish(h) h = self.swish(h)
h = self.dropout(h) h = [ self.dropout(h) ]
h = self.conv2(h) h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
x = self.conv_shortcut(x) x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
else: else:
x = self.nin_shortcut(x) x = self.nin_shortcut(x)
@ -279,6 +314,7 @@ def pytorch_attention(q, k, v):
orig_shape = q.shape orig_shape = q.shape
B = orig_shape[0] B = orig_shape[0]
C = orig_shape[1] C = orig_shape[1]
oom_fallback = False
q, k, v = map( q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v), (q, k, v),
@ -289,6 +325,8 @@ def pytorch_attention(q, k, v):
out = out.transpose(2, 3).reshape(orig_shape) out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION: except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
oom_fallback = True
if oom_fallback:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape) out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out return out
@ -517,8 +555,13 @@ class Encoder(nn.Module):
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
self.resolution = resolution self.resolution = resolution
self.in_channels = in_channels self.in_channels = in_channels
self.carried = False
if conv3d: if conv3d:
if not attn_resolutions:
conv_op = CarriedConv3d
self.carried = True
else:
conv_op = VideoConv3d conv_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d mid_attn_conv_op = ops.Conv3d
else: else:
@ -532,6 +575,7 @@ class Encoder(nn.Module):
stride=1, stride=1,
padding=1) padding=1)
self.time_compress = 1
curr_res = resolution curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult) in_ch_mult = (1,)+tuple(ch_mult)
self.in_ch_mult = in_ch_mult self.in_ch_mult = in_ch_mult
@ -558,10 +602,15 @@ class Encoder(nn.Module):
if time_compress is not None: if time_compress is not None:
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress): if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
stride = (1, 2, 2) stride = (1, 2, 2)
else:
self.time_compress *= 2
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op) down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
curr_res = curr_res // 2 curr_res = curr_res // 2
self.down.append(down) self.down.append(down)
if time_compress is not None:
self.time_compress = time_compress
# middle # middle
self.mid = nn.Module() self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, self.mid.block_1 = ResnetBlock(in_channels=block_in,
@ -587,15 +636,42 @@ class Encoder(nn.Module):
def forward(self, x): def forward(self, x):
# timestep embedding # timestep embedding
temb = None temb = None
if self.carried:
xl = [x[:, :, :1, :, :]]
if x.shape[2] > self.time_compress:
tc = self.time_compress
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2)
x = xl
else:
x = [x]
out = []
conv_carry_in = None
for i, x1 in enumerate(x):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
# downsampling # downsampling
h = self.conv_in(x) x1 = [ x1 ]
h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for i_level in range(self.num_resolutions): for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks): for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](h, temb) h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out)
if len(self.down[i_level].attn) > 0: if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h) assert i == 0 #carried should not happen if attn exists
h1 = self.down[i_level].attn[i_block](h1)
if i_level != self.num_resolutions-1: if i_level != self.num_resolutions-1:
h = self.down[i_level].downsample(h) h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out)
out.append(h1)
conv_carry_in = conv_carry_out
h = torch_cat_if_needed(out, dim=2)
del out
# middle # middle
h = self.mid.block_1(h, temb) h = self.mid.block_1(h, temb)
@ -604,15 +680,15 @@ class Encoder(nn.Module):
# end # end
h = self.norm_out(h) h = self.norm_out(h)
h = nonlinearity(h) h = [ nonlinearity(h) ]
h = self.conv_out(h) h = conv_carry_causal_3d(h, self.conv_out)
return h return h
class Decoder(nn.Module): class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, resolution, z_channels, tanh_out=False, use_linear_attn=False,
conv_out_op=ops.Conv2d, conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock, resnet_op=ResnetBlock,
attn_op=AttnBlock, attn_op=AttnBlock,
@ -626,12 +702,18 @@ class Decoder(nn.Module):
self.num_res_blocks = num_res_blocks self.num_res_blocks = num_res_blocks
self.resolution = resolution self.resolution = resolution
self.in_channels = in_channels self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out self.tanh_out = tanh_out
self.carried = False
if conv3d: if conv3d:
if not attn_resolutions and resnet_op == ResnetBlock:
conv_op = CarriedConv3d
conv_out_op = CarriedConv3d
self.carried = True
else:
conv_op = VideoConv3d conv_op = VideoConv3d
conv_out_op = VideoConv3d conv_out_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d mid_attn_conv_op = ops.Conv3d
else: else:
conv_op = ops.Conv2d conv_op = ops.Conv2d
@ -706,29 +788,43 @@ class Decoder(nn.Module):
temb = None temb = None
# z to block_in # z to block_in
h = self.conv_in(z) h = conv_carry_causal_3d([z], self.conv_in)
# middle # middle
h = self.mid.block_1(h, temb, **kwargs) h = self.mid.block_1(h, temb, **kwargs)
h = self.mid.attn_1(h, **kwargs) h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb, **kwargs) h = self.mid.block_2(h, temb, **kwargs)
if self.carried:
h = torch.split(h, 2, dim=2)
else:
h = [ h ]
out = []
conv_carry_in = None
# upsampling # upsampling
for i, h1 in enumerate(h):
conv_carry_out = []
if i == len(h) - 1:
conv_carry_out = None
for i_level in reversed(range(self.num_resolutions)): for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1): for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb, **kwargs) h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs)
if len(self.up[i_level].attn) > 0: if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, **kwargs) assert i == 0 #carried should not happen if attn exists
h1 = self.up[i_level].attn[i_block](h1, **kwargs)
if i_level != 0: if i_level != 0:
h = self.up[i_level].upsample(h) h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out)
# end h1 = self.norm_out(h1)
if self.give_pre_end: h1 = [ nonlinearity(h1) ]
return h h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out)
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h, **kwargs)
if self.tanh_out: if self.tanh_out:
h = torch.tanh(h) h1 = torch.tanh(h1)
return h out.append(h1)
conv_carry_in = conv_carry_out
out = torch_cat_if_needed(out, dim=2)
return out

View File

@ -208,12 +208,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["theta"] = 2000 dit_config["theta"] = 2000
dit_config["out_channels"] = 128 dit_config["out_channels"] = 128
dit_config["global_modulation"] = True dit_config["global_modulation"] = True
dit_config["vec_in_dim"] = None
dit_config["mlp_silu_act"] = True dit_config["mlp_silu_act"] = True
dit_config["qkv_bias"] = False dit_config["qkv_bias"] = False
dit_config["ops_bias"] = False dit_config["ops_bias"] = False
dit_config["default_ref_method"] = "index" dit_config["default_ref_method"] = "index"
dit_config["ref_index_scale"] = 10.0 dit_config["ref_index_scale"] = 10.0
dit_config["txt_ids_dims"] = [3]
patch_size = 1 patch_size = 1
else: else:
dit_config["image_model"] = "flux" dit_config["image_model"] = "flux"
@ -223,6 +223,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["theta"] = 10000 dit_config["theta"] = 10000
dit_config["out_channels"] = 16 dit_config["out_channels"] = 16
dit_config["qkv_bias"] = True dit_config["qkv_bias"] = True
dit_config["txt_ids_dims"] = []
patch_size = 2 patch_size = 2
dit_config["in_channels"] = 16 dit_config["in_channels"] = 16
@ -245,6 +246,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix) vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
if vec_in_key in state_dict_keys: if vec_in_key in state_dict_keys:
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1] dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
else:
dit_config["vec_in_dim"] = None
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
@ -270,6 +273,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_embedder_dtype"] = torch.float32 dit_config["nerf_embedder_dtype"] = torch.float32
else: else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
dit_config["txt_ids_dims"] = [1, 2]
return dit_config return dit_config
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview

View File

@ -699,12 +699,12 @@ class ModelPatcher:
offloaded = [] offloaded = []
offload_buffer = 0 offload_buffer = 0
loading.sort(reverse=True) loading.sort(reverse=True)
for x in loading: for i, x in enumerate(loading):
module_offload_mem, module_mem, n, m, params = x module_offload_mem, module_mem, n, m, params = x
lowvram_weight = False lowvram_weight = False
potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1)) potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]]))
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
weight_key = "{}.weight".format(n) weight_key = "{}.weight".format(n)
@ -876,14 +876,18 @@ class ModelPatcher:
patch_counter = 0 patch_counter = 0
unload_list = self._load_list() unload_list = self._load_list()
unload_list.sort() unload_list.sort()
offload_buffer = self.model.model_offload_buffer_memory offload_buffer = self.model.model_offload_buffer_memory
if len(unload_list) > 0:
NS = comfy.model_management.NUM_STREAMS
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
for unload in unload_list: for unload in unload_list:
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed: if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
break break
module_offload_mem, module_mem, n, m, params = unload module_offload_mem, module_mem, n, m, params = unload
potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem potential_offload = module_offload_mem + sum(offload_weight_factor)
lowvram_possible = hasattr(m, "comfy_cast_weights") lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
@ -935,6 +939,8 @@ class ModelPatcher:
m.comfy_patched_weights = False m.comfy_patched_weights = False
memory_freed += module_mem memory_freed += module_mem
offload_buffer = max(offload_buffer, potential_offload) offload_buffer = max(offload_buffer, potential_offload)
offload_weight_factor.append(module_mem)
offload_weight_factor.pop(0)
logging.debug("freed {}".format(n)) logging.debug("freed {}".format(n))
for param in params: for param in params:

View File

@ -111,22 +111,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if s.bias is not None: if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
if bias_has_function: comfy.model_management.sync_stream(device, offload_stream)
with wf_context:
bias_a = bias
weight_a = weight
if s.bias is not None:
for f in s.bias_function: for f in s.bias_function:
bias = f(bias) bias = f(bias)
if weight_has_function or weight.dtype != dtype: if weight_has_function or weight.dtype != dtype:
with wf_context:
weight = weight.to(dtype=dtype) weight = weight.to(dtype=dtype)
if isinstance(weight, QuantizedTensor): if isinstance(weight, QuantizedTensor):
weight = weight.dequantize() weight = weight.dequantize()
for f in s.weight_function: for f in s.weight_function:
weight = f(weight) weight = f(weight)
comfy.model_management.sync_stream(device, offload_stream)
if offloadable: if offloadable:
return weight, bias, offload_stream return weight, bias, (offload_stream, weight_a, bias_a)
else: else:
#Legacy function signature #Legacy function signature
return weight, bias return weight, bias
@ -135,13 +137,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
def uncast_bias_weight(s, weight, bias, offload_stream): def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None: if offload_stream is None:
return return
if weight is not None: os, weight_a, bias_a = offload_stream
device = weight.device if os is None:
else:
if bias is None:
return return
device = bias.device if weight_a is not None:
offload_stream.wait_stream(comfy.model_management.current_stream(device)) device = weight_a.device
else:
if bias_a is None:
return
device = bias_a.device
os.wait_stream(comfy.model_management.current_stream(device))
class CastWeightBiasOp: class CastWeightBiasOp:

View File

@ -53,6 +53,7 @@ import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image import comfy.text_encoders.z_image
import comfy.text_encoders.ovis
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
@ -192,6 +193,7 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model() self.load_model()
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
all_hooks.reset() all_hooks.reset()
self.patcher.patch_hooks(None) self.patcher.patch_hooks(None)
if show_pbar: if show_pbar:
@ -239,6 +241,7 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model() self.load_model()
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
o = self.cond_stage_model.encode_token_weights(tokens) o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2] cond, pooled = o[:2]
if return_dict: if return_dict:
@ -468,7 +471,7 @@ class VAE:
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig}) decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
elif "decoder.conv_in.conv.weight" in sd: elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True ddconfig["conv3d"] = True
@ -480,8 +483,10 @@ class VAE:
self.latent_dim = 3 self.latent_dim = 3
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype) #This is likely to significantly over-estimate with single image or low frame counts as the
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype) #implementation is able to completely skip caching. Rework if used as an image only VAE
self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
elif "decoder.unpatcher3d.wavelets" in sd: elif "decoder.unpatcher3d.wavelets" in sd:
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8) self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
@ -956,6 +961,7 @@ class CLIPType(Enum):
QWEN_IMAGE = 18 QWEN_IMAGE = 18
HUNYUAN_IMAGE = 19 HUNYUAN_IMAGE = 19
HUNYUAN_VIDEO_15 = 20 HUNYUAN_VIDEO_15 = 20
OVIS = 21
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@ -987,6 +993,7 @@ class TEModel(Enum):
MISTRAL3_24B = 14 MISTRAL3_24B = 14
MISTRAL3_24B_PRUNED_FLUX2 = 15 MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16 QWEN3_4B = 16
QWEN3_2B = 17
def detect_te_model(sd): def detect_te_model(sd):
@ -1020,9 +1027,12 @@ def detect_te_model(sd):
if weight.shape[0] == 512: if weight.shape[0] == 512:
return TEModel.QWEN25_7B return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd: if "model.layers.0.post_attention_layernorm.weight" in sd:
if 'model.layers.0.self_attn.q_norm.weight' in sd:
return TEModel.QWEN3_4B
weight = sd['model.layers.0.post_attention_layernorm.weight'] weight = sd['model.layers.0.post_attention_layernorm.weight']
if 'model.layers.0.self_attn.q_norm.weight' in sd:
if weight.shape[0] == 2560:
return TEModel.QWEN3_4B
elif weight.shape[0] == 2048:
return TEModel.QWEN3_2B
if weight.shape[0] == 5120: if weight.shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd: if "model.layers.39.post_attention_layernorm.weight" in sd:
return TEModel.MISTRAL3_24B return TEModel.MISTRAL3_24B
@ -1150,6 +1160,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.QWEN3_4B: elif te_model == TEModel.QWEN3_4B:
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data)) clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
elif te_model == TEModel.QWEN3_2B:
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
else: else:
# clip_l # clip_l
if clip_type == CLIPType.SD3: if clip_type == CLIPType.SD3:

View File

@ -147,6 +147,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer_norm_hidden_state = layer_norm_hidden_state self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled self.return_projected_pooled = return_projected_pooled
self.return_attention_masks = return_attention_masks self.return_attention_masks = return_attention_masks
self.execution_device = None
if layer == "hidden": if layer == "hidden":
assert layer_idx is not None assert layer_idx is not None
@ -163,6 +164,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_clip_options(self, options): def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx) layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
self.execution_device = options.get("execution_device", self.execution_device)
if isinstance(self.layer, list) or self.layer == "all": if isinstance(self.layer, list) or self.layer == "all":
pass pass
elif layer_idx is None or abs(layer_idx) > self.num_layers: elif layer_idx is None or abs(layer_idx) > self.num_layers:
@ -175,6 +177,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer = self.options_default[0] self.layer = self.options_default[0]
self.layer_idx = self.options_default[1] self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2] self.return_projected_pooled = self.options_default[2]
self.execution_device = None
def process_tokens(self, tokens, device): def process_tokens(self, tokens, device):
end_token = self.special_tokens.get("end", None) end_token = self.special_tokens.get("end", None)
@ -258,7 +261,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
def forward(self, tokens): def forward(self, tokens):
if self.execution_device is None:
device = self.transformer.get_input_embeddings().weight.device device = self.transformer.get_input_embeddings().weight.device
else:
device = self.execution_device
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device) embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
attention_mask_model = None attention_mask_model = None

View File

@ -1027,6 +1027,8 @@ class ZImage(Lumina2):
memory_usage_factor = 1.7 memory_usage_factor = 1.7
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))

View File

@ -100,6 +100,28 @@ class Qwen3_4BConfig:
rope_scale = None rope_scale = None
final_norm: bool = True final_norm: bool = True
@dataclass
class Ovis25_2BConfig:
vocab_size: int = 151936
hidden_size: int = 2048
intermediate_size: int = 6144
num_hidden_layers: int = 28
num_attention_heads: int = 16
num_key_value_heads: int = 8
max_position_embeddings: int = 40960
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
transformer_type: str = "llama"
head_dim = 128
rms_norm_add = False
mlp_activation = "silu"
qkv_bias = False
rope_dims = None
q_norm = "gemma3"
k_norm = "gemma3"
rope_scale = None
final_norm: bool = True
@dataclass @dataclass
class Qwen25_7BVLI_Config: class Qwen25_7BVLI_Config:
vocab_size: int = 152064 vocab_size: int = 152064
@ -542,6 +564,15 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Ovis25_2B(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Ovis25_2BConfig(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, torch.nn.Module): class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()

View File

@ -0,0 +1,69 @@
from transformers import Qwen2Tokenizer
import comfy.text_encoders.llama
from comfy import sd1_clip
import os
import torch
import numbers
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen3_2b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=284, pad_token=151643, tokenizer_data=tokenizer_data)
class OvisTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_2b", tokenizer=Qwen3Tokenizer)
self.llama_template = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: {}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
return tokens
class Ovis25_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options)
class OvisTEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3_2b", clip_model=Ovis25_2BModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs, template_end=-1):
out, pooled = super().encode_token_weights(token_weight_pairs)
tok_pairs = token_weight_pairs["qwen3_2b"][0]
count_im_start = 0
if template_end == -1:
for i, v in enumerate(tok_pairs):
elem = v[0]
if not torch.is_tensor(elem):
if isinstance(elem, numbers.Integral):
if elem == 4004 and count_im_start < 1:
template_end = i
count_im_start += 1
if out.shape[1] > (template_end + 1):
if tok_pairs[template_end + 1][0] == 25:
template_end += 1
out = out[:, template_end:]
return out, pooled, {}
def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
class OvisTEModel_(OvisTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return OvisTEModel_

View File

@ -179,36 +179,36 @@
"special": false "special": false
}, },
"151665": { "151665": {
"content": "<|img|>", "content": "<tool_response>",
"lstrip": false, "lstrip": false,
"normalized": false, "normalized": false,
"rstrip": false, "rstrip": false,
"single_word": false, "single_word": false,
"special": true "special": false
}, },
"151666": { "151666": {
"content": "<|endofimg|>", "content": "</tool_response>",
"lstrip": false, "lstrip": false,
"normalized": false, "normalized": false,
"rstrip": false, "rstrip": false,
"single_word": false, "single_word": false,
"special": true "special": false
}, },
"151667": { "151667": {
"content": "<|meta|>", "content": "<think>",
"lstrip": false, "lstrip": false,
"normalized": false, "normalized": false,
"rstrip": false, "rstrip": false,
"single_word": false, "single_word": false,
"special": true "special": false
}, },
"151668": { "151668": {
"content": "<|endofmeta|>", "content": "</think>",
"lstrip": false, "lstrip": false,
"normalized": false, "normalized": false,
"rstrip": false, "rstrip": false,
"single_word": false, "single_word": false,
"special": true "special": false
} }
}, },
"additional_special_tokens": [ "additional_special_tokens": [

View File

@ -13,6 +13,7 @@ from comfy.cli_args import args
SERVER_FEATURE_FLAGS: Dict[str, Any] = { SERVER_FEATURE_FLAGS: Dict[str, Any] = {
"supports_preview_metadata": True, "supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
} }

View File

@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
from . import _io as io from . import _io_public as io
from . import _ui as ui from . import _ui_public as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 # from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple from comfy_execution.progress import get_progress_state, PreviewImageTuple

View File

@ -336,7 +336,10 @@ class VideoFromComponents(VideoInput):
raise ValueError("Only MP4 format is supported for now") raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264: if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now") raise ValueError("Only H264 codec is supported for now")
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output: extra_kwargs = {}
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
extra_kwargs["format"] = format.value
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
# Add metadata before writing any streams # Add metadata before writing any streams
if metadata is not None: if metadata is not None:
for key, value in metadata.items(): for key, value in metadata.items():

View File

@ -4,7 +4,8 @@ import copy
import inspect import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import Counter from collections import Counter
from dataclasses import asdict, dataclass from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from enum import Enum from enum import Enum
from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING
from typing_extensions import NotRequired, final from typing_extensions import NotRequired, final
@ -150,6 +151,9 @@ class _IO_V3:
def __init__(self): def __init__(self):
pass pass
def validate(self):
pass
@property @property
def io_type(self): def io_type(self):
return self.Parent.io_type return self.Parent.io_type
@ -182,6 +186,9 @@ class Input(_IO_V3):
def get_io_type(self): def get_io_type(self):
return _StringIOType(self.io_type) return _StringIOType(self.io_type)
def get_all(self) -> list[Input]:
return [self]
class WidgetInput(Input): class WidgetInput(Input):
''' '''
Base class for a V3 Input with widget. Base class for a V3 Input with widget.
@ -814,13 +821,61 @@ class MultiType:
else: else:
return super().as_dict() return super().as_dict()
@comfytype(io_type="COMFY_MATCHTYPE_V3")
class MatchType(ComfyTypeIO):
class Template:
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType):
self.template_id = template_id
# account for syntactic sugar
if not isinstance(allowed_types, Iterable):
allowed_types = [allowed_types]
for t in allowed_types:
if not isinstance(t, type):
if not isinstance(t, _ComfyType):
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__class__.__name__}")
else:
if not issubclass(t, _ComfyType):
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__name__}")
self.allowed_types = allowed_types
def as_dict(self):
return {
"template_id": self.template_id,
"allowed_types": ",".join([t.io_type for t in self.allowed_types]),
}
class Input(Input):
def __init__(self, id: str, template: MatchType.Template,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.template = template
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
})
class Output(Output):
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list)
self.template = template
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
})
class DynamicInput(Input, ABC): class DynamicInput(Input, ABC):
''' '''
Abstract class for dynamic input registration. Abstract class for dynamic input registration.
''' '''
@abstractmethod
def get_dynamic(self) -> list[Input]: def get_dynamic(self) -> list[Input]:
... return []
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
pass
class DynamicOutput(Output, ABC): class DynamicOutput(Output, ABC):
''' '''
@ -830,99 +885,223 @@ class DynamicOutput(Output, ABC):
is_output_list=False): is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list) super().__init__(id, display_name, tooltip, is_output_list)
@abstractmethod
def get_dynamic(self) -> list[Output]: def get_dynamic(self) -> list[Output]:
... return []
@comfytype(io_type="COMFY_AUTOGROW_V3") @comfytype(io_type="COMFY_AUTOGROW_V3")
class AutogrowDynamic(ComfyTypeI): class Autogrow(ComfyTypeI):
Type = list[Any] Type = dict[str, Any]
class Input(DynamicInput): _MaxNames = 100 # NOTE: max 100 names for sanity
def __init__(self, id: str, template_input: Input, min: int=1, max: int=None,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): class _AutogrowTemplate:
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) def __init__(self, input: Input):
self.template_input = template_input # dynamic inputs are not allowed as the template input
if min is not None: assert(not isinstance(input, DynamicInput))
assert(min >= 1) self.input = copy.copy(input)
if max is not None: if isinstance(self.input, WidgetInput):
self.input.force_input = True
self.names: list[str] = []
self.cached_inputs = {}
def _create_input(self, input: Input, name: str):
new_input = copy.copy(self.input)
new_input.id = name
return new_input
def _create_cached_inputs(self):
for name in self.names:
self.cached_inputs[name] = self._create_input(self.input, name)
def get_all(self) -> list[Input]:
return list(self.cached_inputs.values())
def as_dict(self):
return prune_dict({
"input": create_input_dict_v1([self.input]),
})
def validate(self):
self.input.validate()
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
real_inputs = []
for name, input in self.cached_inputs.items():
if name in live_inputs:
real_inputs.append(input)
add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, real_inputs, curr_prefix)
class TemplatePrefix(_AutogrowTemplate):
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
super().__init__(input)
self.prefix = prefix
assert(min >= 0)
assert(max >= 1) assert(max >= 1)
assert(max <= Autogrow._MaxNames)
self.min = min self.min = min
self.max = max self.max = max
self.names = [f"{self.prefix}{i}" for i in range(self.max)]
self._create_cached_inputs()
def as_dict(self):
return super().as_dict() | prune_dict({
"prefix": self.prefix,
"min": self.min,
"max": self.max,
})
class TemplateNames(_AutogrowTemplate):
def __init__(self, input: Input, names: list[str], min: int=1):
super().__init__(input)
self.names = names[:Autogrow._MaxNames]
assert(min >= 0)
self.min = min
self._create_cached_inputs()
def as_dict(self):
return super().as_dict() | prune_dict({
"names": self.names,
"min": self.min,
})
class Input(DynamicInput):
def __init__(self, id: str, template: Autogrow.TemplatePrefix | Autogrow.TemplateNames,
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.template = template
def as_dict(self):
return super().as_dict() | prune_dict({
"template": self.template.as_dict(),
})
def get_dynamic(self) -> list[Input]: def get_dynamic(self) -> list[Input]:
curr_count = 1 return self.template.get_all()
new_inputs = []
for i in range(self.min):
new_input = copy.copy(self.template_input)
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
if new_input.display_name is not None:
new_input.display_name = f"{new_input.display_name}{curr_count}"
new_input.optional = self.optional or new_input.optional
if isinstance(self.template_input, WidgetInput):
new_input.force_input = True
new_inputs.append(new_input)
curr_count += 1
# pretend to expand up to max
for i in range(curr_count-1, self.max):
new_input = copy.copy(self.template_input)
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
if new_input.display_name is not None:
new_input.display_name = f"{new_input.display_name}{curr_count}"
new_input.optional = True
if isinstance(self.template_input, WidgetInput):
new_input.force_input = True
new_inputs.append(new_input)
curr_count += 1
return new_inputs
@comfytype(io_type="COMFY_COMBODYNAMIC_V3") def get_all(self) -> list[Input]:
class ComboDynamic(ComfyTypeI): return [self] + self.template.get_all()
class Input(DynamicInput):
def __init__(self, id: str):
pass
@comfytype(io_type="COMFY_MATCHTYPE_V3") def validate(self):
class MatchType(ComfyTypeIO): self.template.validate()
class Template:
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]): def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
self.template_id = template_id curr_prefix = f"{curr_prefix}{self.id}."
self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types # need to remove self from expected inputs dictionary; replaced by template inputs in frontend
for inner_dict in d.values():
if self.id in inner_dict:
del inner_dict[self.id]
self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
class DynamicCombo(ComfyTypeI):
Type = dict[str, Any]
class Option:
def __init__(self, key: str, inputs: list[Input]):
self.key = key
self.inputs = inputs
def as_dict(self): def as_dict(self):
return { return {
"template_id": self.template_id, "key": self.key,
"allowed_types": "".join(t.io_type for t in self.allowed_types), "inputs": create_input_dict_v1(self.inputs),
} }
class Input(DynamicInput): class Input(DynamicInput):
def __init__(self, id: str, template: MatchType.Template, def __init__(self, id: str, options: list[DynamicCombo.Option],
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
self.template = template self.options = options
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
# check if dynamic input's id is in live_inputs
if self.id in live_inputs:
curr_prefix = f"{curr_prefix}{self.id}."
key = live_inputs[self.id]
selected_option = None
for option in self.options:
if option.key == key:
selected_option = option
break
if selected_option is not None:
add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self)
def get_dynamic(self) -> list[Input]: def get_dynamic(self) -> list[Input]:
return [self] return [input for option in self.options for input in option.inputs]
def get_all(self) -> list[Input]:
return [self] + [input for option in self.options for input in option.inputs]
def as_dict(self): def as_dict(self):
return super().as_dict() | prune_dict({ return super().as_dict() | prune_dict({
"template": self.template.as_dict(), "options": [o.as_dict() for o in self.options],
}) })
class Output(DynamicOutput): def validate(self):
def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None, # make sure all nested inputs are validated
is_output_list=False): for option in self.options:
super().__init__(id, display_name, tooltip, is_output_list) for input in option.inputs:
self.template = template input.validate()
def get_dynamic(self) -> list[Output]: @comfytype(io_type="COMFY_DYNAMICSLOT_V3")
return [self] class DynamicSlot(ComfyTypeI):
Type = dict[str, Any]
class Input(DynamicInput):
def __init__(self, slot: Input, inputs: list[Input],
display_name: str=None, tooltip: str=None, lazy: bool=None, extra_dict=None):
assert(not isinstance(slot, DynamicInput))
self.slot = copy.copy(slot)
self.slot.display_name = slot.display_name if slot.display_name is not None else display_name
optional = True
self.slot.tooltip = slot.tooltip if slot.tooltip is not None else tooltip
self.slot.lazy = slot.lazy if slot.lazy is not None else lazy
self.slot.extra_dict = slot.extra_dict if slot.extra_dict is not None else extra_dict
super().__init__(slot.id, self.slot.display_name, optional, self.slot.tooltip, self.slot.lazy, self.slot.extra_dict)
self.inputs = inputs
self.force_input = None
# force widget inputs to have no widgets, otherwise this would be awkward
if isinstance(self.slot, WidgetInput):
self.force_input = True
self.slot.force_input = True
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
if self.id in live_inputs:
curr_prefix = f"{curr_prefix}{self.id}."
add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix)
add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix)
def get_dynamic(self) -> list[Input]:
return [self.slot] + self.inputs
def get_all(self) -> list[Input]:
return [self] + [self.slot] + self.inputs
def as_dict(self): def as_dict(self):
return super().as_dict() | prune_dict({ return super().as_dict() | prune_dict({
"template": self.template.as_dict(), "slotType": str(self.slot.get_io_type()),
"inputs": create_input_dict_v1(self.inputs),
"forceInput": self.force_input,
}) })
def validate(self):
self.slot.validate()
for input in self.inputs:
input.validate()
def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None):
dynamic = d.setdefault("dynamic_paths", {})
if self is not None:
dynamic[self.id] = f"{curr_prefix}{self.id}"
for i in inputs:
if not isinstance(i, DynamicInput):
dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}"
class V3Data(TypedDict):
hidden_inputs: dict[str, Any]
dynamic_paths: dict[str, Any]
class HiddenHolder: class HiddenHolder:
def __init__(self, unique_id: str, prompt: Any, def __init__(self, unique_id: str, prompt: Any,
@ -984,6 +1163,7 @@ class NodeInfoV1:
output_is_list: list[bool]=None output_is_list: list[bool]=None
output_name: list[str]=None output_name: list[str]=None
output_tooltips: list[str]=None output_tooltips: list[str]=None
output_matchtypes: list[str]=None
name: str=None name: str=None
display_name: str=None display_name: str=None
description: str=None description: str=None
@ -1019,9 +1199,9 @@ class Schema:
"""Display name of node.""" """Display name of node."""
category: str = "sd" category: str = "sd"
"""The category of the node, as per the "Add Node" menu.""" """The category of the node, as per the "Add Node" menu."""
inputs: list[Input]=None inputs: list[Input] = field(default_factory=list)
outputs: list[Output]=None outputs: list[Output] = field(default_factory=list)
hidden: list[Hidden]=None hidden: list[Hidden] = field(default_factory=list)
description: str="" description: str=""
"""Node description, shown as a tooltip when hovering over the node.""" """Node description, shown as a tooltip when hovering over the node."""
is_input_list: bool = False is_input_list: bool = False
@ -1061,7 +1241,11 @@ class Schema:
'''Validate the schema: '''Validate the schema:
- verify ids on inputs and outputs are unique - both internally and in relation to each other - verify ids on inputs and outputs are unique - both internally and in relation to each other
''' '''
input_ids = [i.id for i in self.inputs] if self.inputs is not None else [] nested_inputs: list[Input] = []
if self.inputs is not None:
for input in self.inputs:
nested_inputs.extend(input.get_all())
input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else []
output_ids = [o.id for o in self.outputs] if self.outputs is not None else [] output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
input_set = set(input_ids) input_set = set(input_ids)
output_set = set(output_ids) output_set = set(output_ids)
@ -1077,6 +1261,13 @@ class Schema:
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.") issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
if len(issues) > 0: if len(issues) > 0:
raise ValueError("\n".join(issues)) raise ValueError("\n".join(issues))
# validate inputs and outputs
if self.inputs is not None:
for input in self.inputs:
input.validate()
if self.outputs is not None:
for output in self.outputs:
output.validate()
def finalize(self): def finalize(self):
"""Add hidden based on selected schema options, and give outputs without ids default ids.""" """Add hidden based on selected schema options, and give outputs without ids default ids."""
@ -1102,19 +1293,10 @@ class Schema:
if output.id is None: if output.id is None:
output.id = f"_{i}_{output.io_type}_" output.id = f"_{i}_{output.io_type}_"
def get_v1_info(self, cls) -> NodeInfoV1: def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1:
# NOTE: live_inputs will not be used anymore very soon and this will be done another way
# get V1 inputs # get V1 inputs
input = { input = create_input_dict_v1(self.inputs, live_inputs)
"required": {}
}
if self.inputs:
for i in self.inputs:
if isinstance(i, DynamicInput):
dynamic_inputs = i.get_dynamic()
for d in dynamic_inputs:
add_to_dict_v1(d, input)
else:
add_to_dict_v1(i, input)
if self.hidden: if self.hidden:
for hidden in self.hidden: for hidden in self.hidden:
input.setdefault("hidden", {})[hidden.name] = (hidden.value,) input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
@ -1123,12 +1305,24 @@ class Schema:
output_is_list = [] output_is_list = []
output_name = [] output_name = []
output_tooltips = [] output_tooltips = []
output_matchtypes = []
any_matchtypes = False
if self.outputs: if self.outputs:
for o in self.outputs: for o in self.outputs:
output.append(o.io_type) output.append(o.io_type)
output_is_list.append(o.is_output_list) output_is_list.append(o.is_output_list)
output_name.append(o.display_name if o.display_name else o.io_type) output_name.append(o.display_name if o.display_name else o.io_type)
output_tooltips.append(o.tooltip if o.tooltip else None) output_tooltips.append(o.tooltip if o.tooltip else None)
# special handling for MatchType
if isinstance(o, MatchType.Output):
output_matchtypes.append(o.template.template_id)
any_matchtypes = True
else:
output_matchtypes.append(None)
# clear out lists that are all None
if not any_matchtypes:
output_matchtypes = None
info = NodeInfoV1( info = NodeInfoV1(
input=input, input=input,
@ -1137,6 +1331,7 @@ class Schema:
output_is_list=output_is_list, output_is_list=output_is_list,
output_name=output_name, output_name=output_name,
output_tooltips=output_tooltips, output_tooltips=output_tooltips,
output_matchtypes=output_matchtypes,
name=self.node_id, name=self.node_id,
display_name=self.display_name, display_name=self.display_name,
category=self.category, category=self.category,
@ -1182,16 +1377,57 @@ class Schema:
return info return info
def add_to_dict_v1(i: Input, input: dict): def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict:
input = {
"required": {}
}
add_to_input_dict_v1(input, inputs, live_inputs)
return input
def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''):
for i in inputs:
if isinstance(i, DynamicInput):
add_to_dict_v1(i, d)
if live_inputs is not None:
i.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
else:
add_to_dict_v1(i, d)
def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
key = "optional" if i.optional else "required" key = "optional" if i.optional else "required"
as_dict = i.as_dict() as_dict = i.as_dict()
# for v1, we don't want to include the optional key # for v1, we don't want to include the optional key
as_dict.pop("optional", None) as_dict.pop("optional", None)
input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict) if dynamic_dict is None:
value = (i.get_io_type(), as_dict)
else:
value = (i.get_io_type(), as_dict, dynamic_dict)
d.setdefault(key, {})[i.id] = value
def add_to_dict_v3(io: Input | Output, d: dict): def add_to_dict_v3(io: Input | Output, d: dict):
d[io.id] = (io.get_io_type(), io.as_dict()) d[io.id] = (io.get_io_type(), io.as_dict())
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
paths = v3_data.get("dynamic_paths", None)
if paths is None:
return values
values = values.copy()
result = {}
for key, path in paths.items():
parts = path.split(".")
current = result
for i, p in enumerate(parts):
is_last = (i == len(parts) - 1)
if is_last:
current[p] = values.pop(key, None)
else:
current = current.setdefault(p, {})
values.update(result)
return values
class _ComfyNodeBaseInternal(_ComfyNodeInternal): class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@ -1311,12 +1547,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final @final
@classmethod @classmethod
def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]: def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
"""Creates clone of real node class to prevent monkey-patching.""" """Creates clone of real node class to prevent monkey-patching."""
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
type_clone: type[ComfyNode] = shallow_clone_class(c_type) type_clone: type[ComfyNode] = shallow_clone_class(c_type)
# set hidden # set hidden
type_clone.hidden = HiddenHolder.from_dict(hidden_inputs) type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
return type_clone return type_clone
@final @final
@ -1433,14 +1669,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final @final
@classmethod @classmethod
def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]: def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]:
schema = cls.FINALIZE_SCHEMA() schema = cls.FINALIZE_SCHEMA()
info = schema.get_v1_info(cls) info = schema.get_v1_info(cls, live_inputs)
input = info.input input = info.input
if not include_hidden: if not include_hidden:
input.pop("hidden", None) input.pop("hidden", None)
if return_schema: if return_schema:
return input, schema v3_data: V3Data = {}
dynamic = input.pop("dynamic_paths", None)
if dynamic is not None:
v3_data["dynamic_paths"] = dynamic
return input, schema, v3_data
return input return input
@final @final
@ -1513,7 +1753,7 @@ class ComfyNode(_ComfyNodeBaseInternal):
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def validate_inputs(cls, **kwargs) -> bool: def validate_inputs(cls, **kwargs) -> bool | str:
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.""" """Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
raise NotImplementedError raise NotImplementedError
@ -1628,6 +1868,7 @@ __all__ = [
"StyleModel", "StyleModel",
"Gligen", "Gligen",
"UpscaleModel", "UpscaleModel",
"LatentUpscaleModel",
"Audio", "Audio",
"Video", "Video",
"SVG", "SVG",
@ -1651,6 +1892,10 @@ __all__ = [
"SEGS", "SEGS",
"AnyType", "AnyType",
"MultiType", "MultiType",
# Dynamic Types
"MatchType",
# "DynamicCombo",
# "Autogrow",
# Other classes # Other classes
"HiddenHolder", "HiddenHolder",
"Hidden", "Hidden",
@ -1661,4 +1906,5 @@ __all__ = [
"NodeOutput", "NodeOutput",
"add_to_dict_v1", "add_to_dict_v1",
"add_to_dict_v3", "add_to_dict_v3",
"V3Data",
] ]

View File

@ -0,0 +1 @@
from ._io import * # noqa: F403

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import json import json
import os import os
import random import random
import uuid
from io import BytesIO from io import BytesIO
from typing import Type from typing import Type
@ -318,9 +319,10 @@ class AudioSaveHelper:
for key, value in metadata.items(): for key, value in metadata.items():
output_container.metadata[key] = value output_container.metadata[key] = value
layout = "mono" if waveform.shape[0] == 1 else "stereo"
# Set up the output stream with appropriate properties # Set up the output stream with appropriate properties
if format == "opus": if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate) out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
if quality == "64k": if quality == "64k":
out_stream.bit_rate = 64000 out_stream.bit_rate = 64000
elif quality == "96k": elif quality == "96k":
@ -332,7 +334,7 @@ class AudioSaveHelper:
elif quality == "320k": elif quality == "320k":
out_stream.bit_rate = 320000 out_stream.bit_rate = 320000
elif format == "mp3": elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate) out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
if quality == "V0": if quality == "V0":
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool # TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1 out_stream.codec_context.qscale = 1
@ -341,12 +343,12 @@ class AudioSaveHelper:
elif quality == "320k": elif quality == "320k":
out_stream.bit_rate = 320000 out_stream.bit_rate = 320000
else: # format == "flac": else: # format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate) out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
frame = av.AudioFrame.from_ndarray( frame = av.AudioFrame.from_ndarray(
waveform.movedim(0, 1).reshape(1, -1).float().numpy(), waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
format="flt", format="flt",
layout="mono" if waveform.shape[0] == 1 else "stereo", layout=layout,
) )
frame.sample_rate = sample_rate frame.sample_rate = sample_rate
frame.pts = 0 frame.pts = 0
@ -436,9 +438,19 @@ class PreviewUI3D(_UIOutput):
def __init__(self, model_file, camera_info, **kwargs): def __init__(self, model_file, camera_info, **kwargs):
self.model_file = model_file self.model_file = model_file
self.camera_info = camera_info self.camera_info = camera_info
self.bg_image_path = None
bg_image = kwargs.get("bg_image", None)
if bg_image is not None:
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
img = PILImage.fromarray(img_array)
temp_dir = folder_paths.get_temp_directory()
filename = f"bg_{uuid.uuid4().hex}.png"
bg_image_path = os.path.join(temp_dir, filename)
img.save(bg_image_path, compress_level=1)
self.bg_image_path = f"temp/{filename}"
def as_dict(self): def as_dict(self):
return {"result": [self.model_file, self.camera_info]} return {"result": [self.model_file, self.camera_info, self.bg_image_path]}
class PreviewText(_UIOutput): class PreviewText(_UIOutput):

View File

@ -0,0 +1 @@
from ._ui import * # noqa: F403

View File

@ -6,7 +6,7 @@ from comfy_api.latest import (
) )
from typing import Type, TYPE_CHECKING from typing import Type, TYPE_CHECKING
from comfy_api.internal.async_to_sync import create_sync_class from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest import io, ui, ComfyExtension #noqa: F401 from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest): class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
@ -42,4 +42,8 @@ __all__ = [
"InputImpl", "InputImpl",
"Types", "Types",
"ComfyExtension", "ComfyExtension",
"io",
"IO",
"ui",
"UI",
] ]

View File

@ -0,0 +1,86 @@
from pydantic import BaseModel, Field
class OmniProText2VideoRequest(BaseModel):
model_name: str = Field(..., description="kling-video-o1")
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
class OmniParamImage(BaseModel):
image_url: str = Field(...)
type: str | None = Field(None, description="Can be 'first_frame' or 'end_frame'")
class OmniParamVideo(BaseModel):
video_url: str = Field(...)
refer_type: str | None = Field(..., description="Can be 'base' or 'feature'")
keep_original_sound: str = Field(..., description="'yes' or 'no'")
class OmniProFirstLastFrameRequest(BaseModel):
model_name: str = Field(..., description="kling-video-o1")
image_list: list[OmniParamImage] = Field(..., min_length=1, max_length=7)
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
class OmniProReferences2VideoRequest(BaseModel):
model_name: str = Field(..., description="kling-video-o1")
aspect_ratio: str | None = Field(..., description="'16:9', '9:16' or '1:1'")
image_list: list[OmniParamImage] | None = Field(
None, max_length=7, description="Max length 4 when video is present."
)
video_list: list[OmniParamVideo] | None = Field(None, max_length=1)
duration: str | None = Field(..., description="From 3 to 10.")
prompt: str = Field(...)
mode: str = Field("pro")
class TaskStatusVideoResult(BaseModel):
duration: str | None = Field(None, description="Total video duration")
id: str | None = Field(None, description="Generated video ID")
url: str | None = Field(None, description="URL for generated video")
class TaskStatusImageResult(BaseModel):
index: int = Field(..., description="Image Number0-9")
url: str = Field(..., description="URL for generated image")
class OmniTaskStatusResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None)
images: list[TaskStatusImageResult] | None = Field(None)
class OmniTaskStatusResponseData(BaseModel):
created_at: int | None = Field(None, description="Task creation time")
updated_at: int | None = Field(None, description="Task update time")
task_status: str | None = None
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
task_id: str | None = Field(None, description="Task ID")
task_result: OmniTaskStatusResults | None = Field(None)
class OmniTaskStatusResponse(BaseModel):
code: int | None = Field(None, description="Error code")
message: str | None = Field(None, description="Error message")
request_id: str | None = Field(None, description="Request ID")
data: OmniTaskStatusResponseData | None = Field(None)
class OmniImageParamImage(BaseModel):
image: str = Field(...)
class OmniProImageRequest(BaseModel):
model_name: str = Field(..., description="kling-image-o1")
resolution: str = Field(..., description="'1k' or '2k'")
aspect_ratio: str | None = Field(...)
prompt: str = Field(...)
mode: str = Field("pro")
n: int | None = Field(1, le=9)
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)

View File

@ -4,13 +4,14 @@ For source of truth on the allowed permutations of request fields, please refere
- [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) - [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap)
""" """
import math
import logging import logging
import math
from typing_extensions import override import re
import torch import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis import ( from comfy_api_nodes.apis import (
KlingCameraControl, KlingCameraControl,
KlingCameraConfig, KlingCameraConfig,
@ -48,23 +49,33 @@ from comfy_api_nodes.apis import (
KlingCharacterEffectModelName, KlingCharacterEffectModelName,
KlingSingleImageEffectModelName, KlingSingleImageEffectModelName,
) )
from comfy_api_nodes.apis.kling_api import (
OmniImageParamImage,
OmniParamImage,
OmniParamVideo,
OmniProFirstLastFrameRequest,
OmniProImageRequest,
OmniProReferences2VideoRequest,
OmniProText2VideoRequest,
OmniTaskStatusResponse,
)
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
validate_image_dimensions, ApiEndpoint,
download_url_to_image_tensor,
download_url_to_video_output,
get_number_of_images,
poll_op,
sync_op,
tensor_to_base64_string,
upload_audio_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_image_aspect_ratio, validate_image_aspect_ratio,
validate_image_dimensions,
validate_string,
validate_video_dimensions, validate_video_dimensions,
validate_video_duration, validate_video_duration,
tensor_to_base64_string,
validate_string,
upload_audio_to_comfyapi,
download_url_to_image_tensor,
upload_video_to_comfyapi,
download_url_to_video_output,
sync_op,
ApiEndpoint,
poll_op,
) )
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, IO, Input
KLING_API_VERSION = "v1" KLING_API_VERSION = "v1"
PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video" PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video"
@ -202,6 +213,50 @@ VOICES_CONFIG = {
} }
def normalize_omni_prompt_references(prompt: str) -> str:
"""
Rewrites Kling Omni-style placeholders used in the app, like:
@image, @image1, @image2, ... @imageN
@video, @video1, @video2, ... @videoN
into the API-compatible form:
<<<image_1>>>, <<<image_2>>>, ...
<<<video_1>>>, <<<video_2>>>, ...
This is a UX shim for ComfyUI so users can type the same syntax as in the Kling app.
"""
if not prompt:
return prompt
def _image_repl(match):
return f"<<<image_{match.group('idx') or '1'}>>>"
def _video_repl(match):
return f"<<<video_{match.group('idx') or '1'}>>>"
# (?<!\w) avoids matching e.g. "test@image.com"
# (?!\w) makes sure we only match @image / @image<digits> and not @imageFoo
prompt = re.sub(r"(?<!\w)@image(?P<idx>\d*)(?!\w)", _image_repl, prompt)
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput:
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
response_model=OmniTaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
max_poll_attempts=160,
)
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
def is_valid_camera_control_configs(configs: list[float]) -> bool: def is_valid_camera_control_configs(configs: list[float]) -> bool:
"""Verifies that at least one camera control configuration is non-zero.""" """Verifies that at least one camera control configuration is non-zero."""
return any(not math.isclose(value, 0.0) for value in configs) return any(not math.isclose(value, 0.0) for value in configs)
@ -449,7 +504,7 @@ async def execute_video_effect(
image_1: torch.Tensor, image_1: torch.Tensor,
image_2: torch.Tensor | None = None, image_2: torch.Tensor | None = None,
model_mode: KlingVideoGenMode | None = None, model_mode: KlingVideoGenMode | None = None,
) -> tuple[VideoFromFile, str, str]: ) -> tuple[InputImpl.VideoFromFile, str, str]:
if dual_character: if dual_character:
request_input_field = KlingDualCharacterEffectInput( request_input_field = KlingDualCharacterEffectInput(
model_name=model_name, model_name=model_name,
@ -736,6 +791,474 @@ class KlingTextToVideoNode(IO.ComfyNode):
) )
class OmniProTextToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProTextToVideoNode",
display_name="Kling Omni Text to Video (Pro)",
category="api node/video/Kling",
description="Use text prompts to generate videos with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
"This can include both positive and negative descriptions.",
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Combo.Input("duration", options=[5, 10]),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
aspect_ratio: str,
duration: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2500)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
data=OmniProText2VideoRequest(
model_name=model_name,
prompt=prompt,
aspect_ratio=aspect_ratio,
duration=str(duration),
),
)
return await finish_omni_video_task(cls, response)
class OmniProFirstLastFrameNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProFirstLastFrameNode",
display_name="Kling Omni First-Last-Frame to Video (Pro)",
category="api node/video/Kling",
description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
"This can include both positive and negative descriptions.",
),
IO.Combo.Input("duration", options=["5", "10"]),
IO.Image.Input("first_frame"),
IO.Image.Input(
"end_frame",
optional=True,
tooltip="An optional end frame for the video. "
"This cannot be used simultaneously with 'reference_images'.",
),
IO.Image.Input(
"reference_images",
optional=True,
tooltip="Up to 6 additional reference images.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
duration: int,
first_frame: Input.Image,
end_frame: Input.Image | None = None,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
if end_frame is not None and reference_images is not None:
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
validate_image_dimensions(first_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
image_list: list[OmniParamImage] = [
OmniParamImage(
image_url=(await upload_images_to_comfyapi(cls, first_frame, wait_label="Uploading first frame"))[0],
type="first_frame",
)
]
if end_frame is not None:
validate_image_dimensions(end_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1))
image_list.append(
OmniParamImage(
image_url=(await upload_images_to_comfyapi(cls, end_frame, wait_label="Uploading end frame"))[0],
type="end_frame",
)
)
if reference_images is not None:
if get_number_of_images(reference_images) > 6:
raise ValueError("The maximum number of reference images allowed is 6.")
for i in reference_images:
validate_image_dimensions(i, min_width=300, min_height=300)
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"):
image_list.append(OmniParamImage(image_url=i))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
data=OmniProFirstLastFrameRequest(
model_name=model_name,
prompt=prompt,
duration=str(duration),
image_list=image_list,
),
)
return await finish_omni_video_task(cls, response)
class OmniProImageToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProImageToVideoNode",
display_name="Kling Omni Image to Video (Pro)",
category="api node/video/Kling",
description="Use up to 7 reference images to generate a video with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
"This can include both positive and negative descriptions.",
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
IO.Image.Input(
"reference_images",
tooltip="Up to 7 reference images.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
aspect_ratio: str,
duration: int,
reference_images: Input.Image,
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
if get_number_of_images(reference_images) > 7:
raise ValueError("The maximum number of reference images is 7.")
for i in reference_images:
validate_image_dimensions(i, min_width=300, min_height=300)
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
image_list: list[OmniParamImage] = []
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
image_list.append(OmniParamImage(image_url=i))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
aspect_ratio=aspect_ratio,
duration=str(duration),
image_list=image_list,
),
)
return await finish_omni_video_task(cls, response)
class OmniProVideoToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProVideoToVideoNode",
display_name="Kling Omni Video to Video (Pro)",
category="api node/video/Kling",
description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
"This can include both positive and negative descriptions.",
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
IO.Video.Input("reference_video", tooltip="Video to use as a reference."),
IO.Boolean.Input("keep_original_sound", default=True),
IO.Image.Input(
"reference_images",
tooltip="Up to 4 additional reference images.",
optional=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
aspect_ratio: str,
duration: int,
reference_video: Input.Video,
keep_original_sound: bool,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05)
validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160)
image_list: list[OmniParamImage] = []
if reference_images is not None:
if get_number_of_images(reference_images) > 4:
raise ValueError("The maximum number of reference images allowed with a video input is 4.")
for i in reference_images:
validate_image_dimensions(i, min_width=300, min_height=300)
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
image_list.append(OmniParamImage(image_url=i))
video_list = [
OmniParamVideo(
video_url=await upload_video_to_comfyapi(cls, reference_video, wait_label="Uploading reference video"),
refer_type="feature",
keep_original_sound="yes" if keep_original_sound else "no",
)
]
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
aspect_ratio=aspect_ratio,
duration=str(duration),
image_list=image_list if image_list else None,
video_list=video_list,
),
)
return await finish_omni_video_task(cls, response)
class OmniProEditVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProEditVideoNode",
display_name="Kling Omni Edit Video (Pro)",
category="api node/video/Kling",
description="Edit an existing video with the latest model from Kling.",
inputs=[
IO.Combo.Input("model_name", options=["kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
"This can include both positive and negative descriptions.",
),
IO.Video.Input("video", tooltip="Video for editing. The output video length will be the same."),
IO.Boolean.Input("keep_original_sound", default=True),
IO.Image.Input(
"reference_images",
tooltip="Up to 4 additional reference images.",
optional=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
video: Input.Video,
keep_original_sound: bool,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(video, min_duration=3.0, max_duration=10.05)
validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160)
image_list: list[OmniParamImage] = []
if reference_images is not None:
if get_number_of_images(reference_images) > 4:
raise ValueError("The maximum number of reference images allowed with a video input is 4.")
for i in reference_images:
validate_image_dimensions(i, min_width=300, min_height=300)
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
image_list.append(OmniParamImage(image_url=i))
video_list = [
OmniParamVideo(
video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading base video"),
refer_type="base",
keep_original_sound="yes" if keep_original_sound else "no",
)
]
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
response_model=OmniTaskStatusResponse,
data=OmniProReferences2VideoRequest(
model_name=model_name,
prompt=prompt,
aspect_ratio=None,
duration=None,
image_list=image_list if image_list else None,
video_list=video_list,
),
)
return await finish_omni_video_task(cls, response)
class OmniProImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProImageNode",
display_name="Kling Omni Image (Pro)",
category="api node/image/Kling",
description="Create or edit images with the latest model from Kling.",
inputs=[
IO.Combo.Input("model_name", options=["kling-image-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the image content. "
"This can include both positive and negative descriptions.",
),
IO.Combo.Input("resolution", options=["1K", "2K"]),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"],
),
IO.Image.Input(
"reference_images",
tooltip="Up to 10 additional reference images.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)
@classmethod
async def execute(
cls,
model_name: str,
prompt: str,
resolution: str,
aspect_ratio: str,
reference_images: Input.Image | None = None,
) -> IO.NodeOutput:
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
image_list: list[OmniImageParamImage] = []
if reference_images is not None:
if get_number_of_images(reference_images) > 10:
raise ValueError("The maximum number of reference images is 10.")
for i in reference_images:
validate_image_dimensions(i, min_width=300, min_height=300)
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
image_list.append(OmniImageParamImage(image=i))
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
response_model=OmniTaskStatusResponse,
data=OmniProImageRequest(
model_name=model_name,
prompt=prompt,
resolution=resolution.lower(),
aspect_ratio=aspect_ratio,
image_list=image_list if image_list else None,
),
)
if response.code:
raise RuntimeError(
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
response_model=OmniTaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
class KlingCameraControlT2VNode(IO.ComfyNode): class KlingCameraControlT2VNode(IO.ComfyNode):
""" """
Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera. Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera.
@ -1162,7 +1685,10 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
category="api node/video/Kling", category="api node/video/Kling",
description="Achieve different special effects when generating a video based on the effect_scene.", description="Achieve different special effects when generating a video based on the effect_scene.",
inputs=[ inputs=[
IO.Image.Input("image", tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1"), IO.Image.Input(
"image",
tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1",
),
IO.Combo.Input( IO.Combo.Input(
"effect_scene", "effect_scene",
options=[i.value for i in KlingSingleImageEffectsScene], options=[i.value for i in KlingSingleImageEffectsScene],
@ -1525,6 +2051,12 @@ class KlingExtension(ComfyExtension):
KlingImageGenerationNode, KlingImageGenerationNode,
KlingSingleImageVideoEffectNode, KlingSingleImageVideoEffectNode,
KlingDualCharacterVideoEffectNode, KlingDualCharacterVideoEffectNode,
OmniProTextToVideoNode,
OmniProFirstLastFrameNode,
OmniProImageToVideoNode,
OmniProVideoToVideoNode,
OmniProEditVideoNode,
# OmniProImageNode, # need support from backend
] ]

View File

@ -47,6 +47,7 @@ from .validation_utils import (
validate_string, validate_string,
validate_video_dimensions, validate_video_dimensions,
validate_video_duration, validate_video_duration,
validate_video_frame_count,
) )
__all__ = [ __all__ = [
@ -94,6 +95,7 @@ __all__ = [
"validate_string", "validate_string",
"validate_video_dimensions", "validate_video_dimensions",
"validate_video_duration", "validate_video_duration",
"validate_video_frame_count",
# Misc functions # Misc functions
"get_fs_object_size", "get_fs_object_size",
] ]

View File

@ -2,8 +2,8 @@ import asyncio
import contextlib import contextlib
import os import os
import time import time
from collections.abc import Callable
from io import BytesIO from io import BytesIO
from typing import Callable, Optional, Union
from comfy.cli_args import args from comfy.cli_args import args
from comfy.model_management import processing_interrupted from comfy.model_management import processing_interrupted
@ -35,12 +35,12 @@ def default_base_url() -> str:
async def sleep_with_interrupt( async def sleep_with_interrupt(
seconds: float, seconds: float,
node_cls: Optional[type[IO.ComfyNode]], node_cls: type[IO.ComfyNode] | None,
label: Optional[str] = None, label: str | None = None,
start_ts: Optional[float] = None, start_ts: float | None = None,
estimated_total: Optional[int] = None, estimated_total: int | None = None,
*, *,
display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None, display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None,
): ):
""" """
Sleep in 1s slices while: Sleep in 1s slices while:
@ -65,7 +65,7 @@ def mimetype_to_extension(mime_type: str) -> str:
return mime_type.split("/")[-1].lower() return mime_type.split("/")[-1].lower()
def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int: def get_fs_object_size(path_or_object: str | BytesIO) -> int:
if isinstance(path_or_object, str): if isinstance(path_or_object, str):
return os.path.getsize(path_or_object) return os.path.getsize(path_or_object)
return len(path_or_object.getvalue()) return len(path_or_object.getvalue())

View File

@ -4,10 +4,11 @@ import json
import logging import logging
import time import time
import uuid import uuid
from collections.abc import Callable, Iterable
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from io import BytesIO from io import BytesIO
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union from typing import Any, Literal, TypeVar
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
import aiohttp import aiohttp
@ -37,8 +38,8 @@ class ApiEndpoint:
path: str, path: str,
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET", method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
*, *,
query_params: Optional[dict[str, Any]] = None, query_params: dict[str, Any] | None = None,
headers: Optional[dict[str, str]] = None, headers: dict[str, str] | None = None,
): ):
self.path = path self.path = path
self.method = method self.method = method
@ -52,18 +53,18 @@ class _RequestConfig:
endpoint: ApiEndpoint endpoint: ApiEndpoint
timeout: float timeout: float
content_type: str content_type: str
data: Optional[dict[str, Any]] data: dict[str, Any] | None
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] files: dict[str, Any] | list[tuple[str, Any]] | None
multipart_parser: Optional[Callable] multipart_parser: Callable | None
max_retries: int max_retries: int
retry_delay: float retry_delay: float
retry_backoff: float retry_backoff: float
wait_label: str = "Waiting" wait_label: str = "Waiting"
monitor_progress: bool = True monitor_progress: bool = True
estimated_total: Optional[int] = None estimated_total: int | None = None
final_label_on_success: Optional[str] = "Completed" final_label_on_success: str | None = "Completed"
progress_origin_ts: Optional[float] = None progress_origin_ts: float | None = None
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None price_extractor: Callable[[dict[str, Any]], float | None] | None = None
@dataclass @dataclass
@ -71,10 +72,10 @@ class _PollUIState:
started: float started: float
status_label: str = "Queued" status_label: str = "Queued"
is_queued: bool = True is_queued: bool = True
price: Optional[float] = None price: float | None = None
estimated_duration: Optional[int] = None estimated_duration: int | None = None
base_processing_elapsed: float = 0.0 # sum of completed active intervals base_processing_elapsed: float = 0.0 # sum of completed active intervals
active_since: Optional[float] = None # start time of current active interval (None if queued) active_since: float | None = None # start time of current active interval (None if queued)
_RETRY_STATUS = {408, 429, 500, 502, 503, 504} _RETRY_STATUS = {408, 429, 500, 502, 503, 504}
@ -87,20 +88,20 @@ async def sync_op(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
endpoint: ApiEndpoint, endpoint: ApiEndpoint,
*, *,
response_model: Type[M], response_model: type[M],
price_extractor: Optional[Callable[[M], Optional[float]]] = None, price_extractor: Callable[[M | Any], float | None] | None = None,
data: Optional[BaseModel] = None, data: BaseModel | None = None,
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, files: dict[str, Any] | list[tuple[str, Any]] | None = None,
content_type: str = "application/json", content_type: str = "application/json",
timeout: float = 3600.0, timeout: float = 3600.0,
multipart_parser: Optional[Callable] = None, multipart_parser: Callable | None = None,
max_retries: int = 3, max_retries: int = 3,
retry_delay: float = 1.0, retry_delay: float = 1.0,
retry_backoff: float = 2.0, retry_backoff: float = 2.0,
wait_label: str = "Waiting for server", wait_label: str = "Waiting for server",
estimated_duration: Optional[int] = None, estimated_duration: int | None = None,
final_label_on_success: Optional[str] = "Completed", final_label_on_success: str | None = "Completed",
progress_origin_ts: Optional[float] = None, progress_origin_ts: float | None = None,
monitor_progress: bool = True, monitor_progress: bool = True,
) -> M: ) -> M:
raw = await sync_op_raw( raw = await sync_op_raw(
@ -131,22 +132,22 @@ async def poll_op(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint, poll_endpoint: ApiEndpoint,
*, *,
response_model: Type[M], response_model: type[M],
status_extractor: Callable[[M], Optional[Union[str, int]]], status_extractor: Callable[[M | Any], str | int | None],
progress_extractor: Optional[Callable[[M], Optional[int]]] = None, progress_extractor: Callable[[M | Any], int | None] | None = None,
price_extractor: Optional[Callable[[M], Optional[float]]] = None, price_extractor: Callable[[M | Any], float | None] | None = None,
completed_statuses: Optional[list[Union[str, int]]] = None, completed_statuses: list[str | int] | None = None,
failed_statuses: Optional[list[Union[str, int]]] = None, failed_statuses: list[str | int] | None = None,
queued_statuses: Optional[list[Union[str, int]]] = None, queued_statuses: list[str | int] | None = None,
data: Optional[BaseModel] = None, data: BaseModel | None = None,
poll_interval: float = 5.0, poll_interval: float = 5.0,
max_poll_attempts: int = 120, max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0, timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3, max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0, retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0, retry_backoff_per_poll: float = 2.0,
estimated_duration: Optional[int] = None, estimated_duration: int | None = None,
cancel_endpoint: Optional[ApiEndpoint] = None, cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0, cancel_timeout: float = 10.0,
) -> M: ) -> M:
raw = await poll_op_raw( raw = await poll_op_raw(
@ -178,22 +179,22 @@ async def sync_op_raw(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
endpoint: ApiEndpoint, endpoint: ApiEndpoint,
*, *,
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
data: Optional[Union[dict[str, Any], BaseModel]] = None, data: dict[str, Any] | BaseModel | None = None,
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None, files: dict[str, Any] | list[tuple[str, Any]] | None = None,
content_type: str = "application/json", content_type: str = "application/json",
timeout: float = 3600.0, timeout: float = 3600.0,
multipart_parser: Optional[Callable] = None, multipart_parser: Callable | None = None,
max_retries: int = 3, max_retries: int = 3,
retry_delay: float = 1.0, retry_delay: float = 1.0,
retry_backoff: float = 2.0, retry_backoff: float = 2.0,
wait_label: str = "Waiting for server", wait_label: str = "Waiting for server",
estimated_duration: Optional[int] = None, estimated_duration: int | None = None,
as_binary: bool = False, as_binary: bool = False,
final_label_on_success: Optional[str] = "Completed", final_label_on_success: str | None = "Completed",
progress_origin_ts: Optional[float] = None, progress_origin_ts: float | None = None,
monitor_progress: bool = True, monitor_progress: bool = True,
) -> Union[dict[str, Any], bytes]: ) -> dict[str, Any] | bytes:
""" """
Make a single network request. Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON). - If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
@ -229,21 +230,21 @@ async def poll_op_raw(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint, poll_endpoint: ApiEndpoint,
*, *,
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]], status_extractor: Callable[[dict[str, Any]], str | int | None],
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None, progress_extractor: Callable[[dict[str, Any]], int | None] | None = None,
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None, price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
completed_statuses: Optional[list[Union[str, int]]] = None, completed_statuses: list[str | int] | None = None,
failed_statuses: Optional[list[Union[str, int]]] = None, failed_statuses: list[str | int] | None = None,
queued_statuses: Optional[list[Union[str, int]]] = None, queued_statuses: list[str | int] | None = None,
data: Optional[Union[dict[str, Any], BaseModel]] = None, data: dict[str, Any] | BaseModel | None = None,
poll_interval: float = 5.0, poll_interval: float = 5.0,
max_poll_attempts: int = 120, max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0, timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3, max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0, retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0, retry_backoff_per_poll: float = 2.0,
estimated_duration: Optional[int] = None, estimated_duration: int | None = None,
cancel_endpoint: Optional[ApiEndpoint] = None, cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0, cancel_timeout: float = 10.0,
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
@ -261,7 +262,7 @@ async def poll_op_raw(
consumed_attempts = 0 # counts only non-queued polls consumed_attempts = 0 # counts only non-queued polls
progress_bar = utils.ProgressBar(100) if progress_extractor else None progress_bar = utils.ProgressBar(100) if progress_extractor else None
last_progress: Optional[int] = None last_progress: int | None = None
state = _PollUIState(started=started, estimated_duration=estimated_duration) state = _PollUIState(started=started, estimated_duration=estimated_duration)
stop_ticker = asyncio.Event() stop_ticker = asyncio.Event()
@ -420,10 +421,10 @@ async def poll_op_raw(
def _display_text( def _display_text(
node_cls: type[IO.ComfyNode], node_cls: type[IO.ComfyNode],
text: Optional[str], text: str | None,
*, *,
status: Optional[Union[str, int]] = None, status: str | int | None = None,
price: Optional[float] = None, price: float | None = None,
) -> None: ) -> None:
display_lines: list[str] = [] display_lines: list[str] = []
if status: if status:
@ -440,13 +441,13 @@ def _display_text(
def _display_time_progress( def _display_time_progress(
node_cls: type[IO.ComfyNode], node_cls: type[IO.ComfyNode],
status: Optional[Union[str, int]], status: str | int | None,
elapsed_seconds: int, elapsed_seconds: int,
estimated_total: Optional[int] = None, estimated_total: int | None = None,
*, *,
price: Optional[float] = None, price: float | None = None,
is_queued: Optional[bool] = None, is_queued: bool | None = None,
processing_elapsed_seconds: Optional[int] = None, processing_elapsed_seconds: int | None = None,
) -> None: ) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False: if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
@ -488,7 +489,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
raise ValueError("files tuple must be (filename, file[, content_type])") raise ValueError("files tuple must be (filename, file[, content_type])")
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]: def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
params = dict(endpoint_params or {}) params = dict(endpoint_params or {})
if method.upper() == "GET" and data: if method.upper() == "GET" and data:
for k, v in data.items(): for k, v in data.items():
@ -534,9 +535,9 @@ def _generate_operation_id(method: str, path: str, attempt: int) -> str:
def _snapshot_request_body_for_logging( def _snapshot_request_body_for_logging(
content_type: str, content_type: str,
method: str, method: str,
data: Optional[dict[str, Any]], data: dict[str, Any] | None,
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]], files: dict[str, Any] | list[tuple[str, Any]] | None,
) -> Optional[Union[dict[str, Any], str]]: ) -> dict[str, Any] | str | None:
if method.upper() == "GET": if method.upper() == "GET":
return None return None
if content_type == "multipart/form-data": if content_type == "multipart/form-data":
@ -586,13 +587,13 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
attempt = 0 attempt = 0
delay = cfg.retry_delay delay = cfg.retry_delay
operation_succeeded: bool = False operation_succeeded: bool = False
final_elapsed_seconds: Optional[int] = None final_elapsed_seconds: int | None = None
extracted_price: Optional[float] = None extracted_price: float | None = None
while True: while True:
attempt += 1 attempt += 1
stop_event = asyncio.Event() stop_event = asyncio.Event()
monitor_task: Optional[asyncio.Task] = None monitor_task: asyncio.Task | None = None
sess: Optional[aiohttp.ClientSession] = None sess: aiohttp.ClientSession | None = None
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
@ -887,7 +888,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
) )
def _validate_or_raise(response_model: Type[M], payload: Any) -> M: def _validate_or_raise(response_model: type[M], payload: Any) -> M:
try: try:
return response_model.model_validate(payload) return response_model.model_validate(payload)
except Exception as e: except Exception as e:
@ -902,9 +903,9 @@ def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
def _wrap_model_extractor( def _wrap_model_extractor(
response_model: Type[M], response_model: type[M],
extractor: Optional[Callable[[M], Any]], extractor: Callable[[M], Any] | None,
) -> Optional[Callable[[dict[str, Any]], Any]]: ) -> Callable[[dict[str, Any]], Any] | None:
"""Wrap a typed extractor so it can be used by the dict-based poller. """Wrap a typed extractor so it can be used by the dict-based poller.
Validates the dict into `response_model` before invoking `extractor`. Validates the dict into `response_model` before invoking `extractor`.
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
@ -929,10 +930,10 @@ def _wrap_model_extractor(
return _wrapped return _wrapped
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]: def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]:
if not values: if not values:
return set() return set()
out: set[Union[str, int]] = set() out: set[str | int] = set()
for v in values: for v in values:
nv = _normalize_status_value(v) nv = _normalize_status_value(v)
if nv is not None: if nv is not None:
@ -940,7 +941,7 @@ def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Unio
return out return out
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]: def _normalize_status_value(val: str | int | None) -> str | int | None:
if isinstance(val, str): if isinstance(val, str):
return val.strip().lower() return val.strip().lower()
return val return val

View File

@ -4,7 +4,6 @@ import math
import mimetypes import mimetypes
import uuid import uuid
from io import BytesIO from io import BytesIO
from typing import Optional
import av import av
import numpy as np import numpy as np
@ -12,8 +11,7 @@ import torch
from PIL import Image from PIL import Image
from comfy.utils import common_upscale from comfy.utils import common_upscale
from comfy_api.latest import Input, InputImpl from comfy_api.latest import Input, InputImpl, Types
from comfy_api.util import VideoCodec, VideoContainer
from ._helpers import mimetype_to_extension from ._helpers import mimetype_to_extension
@ -57,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to
def tensor_to_bytesio( def tensor_to_bytesio(
image: torch.Tensor, image: torch.Tensor,
name: Optional[str] = None, name: str | None = None,
total_pixels: int = 2048 * 2048, total_pixels: int = 2048 * 2048,
mime_type: str = "image/png", mime_type: str = "image/png",
) -> BytesIO: ) -> BytesIO:
@ -177,8 +175,8 @@ def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", co
def video_to_base64_string( def video_to_base64_string(
video: Input.Video, video: Input.Video,
container_format: VideoContainer = None, container_format: Types.VideoContainer | None = None,
codec: VideoCodec = None codec: Types.VideoCodec | None = None,
) -> str: ) -> str:
""" """
Converts a video input to a base64 string. Converts a video input to a base64 string.
@ -189,12 +187,11 @@ def video_to_base64_string(
codec: Optional codec to use (defaults to video.codec if available) codec: Optional codec to use (defaults to video.codec if available)
""" """
video_bytes_io = BytesIO() video_bytes_io = BytesIO()
video.save_to(
# Use provided format/codec if specified, otherwise use video's own if available video_bytes_io,
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4) format=container_format or getattr(video, "container", Types.VideoContainer.MP4),
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264) codec=codec or getattr(video, "codec", Types.VideoCodec.H264),
)
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
video_bytes_io.seek(0) video_bytes_io.seek(0)
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")

View File

@ -3,15 +3,15 @@ import contextlib
import uuid import uuid
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import IO, Optional, Union from typing import IO
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
import aiohttp import aiohttp
import torch import torch
from aiohttp.client_exceptions import ClientError, ContentTypeError from aiohttp.client_exceptions import ClientError, ContentTypeError
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO as COMFY_IO from comfy_api.latest import IO as COMFY_IO
from comfy_api.latest import InputImpl
from . import request_logger from . import request_logger
from ._helpers import ( from ._helpers import (
@ -29,9 +29,9 @@ _RETRY_STATUS = {408, 429, 500, 502, 503, 504}
async def download_url_to_bytesio( async def download_url_to_bytesio(
url: str, url: str,
dest: Optional[Union[BytesIO, IO[bytes], str, Path]], dest: BytesIO | IO[bytes] | str | Path | None,
*, *,
timeout: Optional[float] = None, timeout: float | None = None,
max_retries: int = 5, max_retries: int = 5,
retry_delay: float = 1.0, retry_delay: float = 1.0,
retry_backoff: float = 2.0, retry_backoff: float = 2.0,
@ -71,10 +71,10 @@ async def download_url_to_bytesio(
is_path_sink = isinstance(dest, (str, Path)) is_path_sink = isinstance(dest, (str, Path))
fhandle = None fhandle = None
session: Optional[aiohttp.ClientSession] = None session: aiohttp.ClientSession | None = None
stop_evt: Optional[asyncio.Event] = None stop_evt: asyncio.Event | None = None
monitor_task: Optional[asyncio.Task] = None monitor_task: asyncio.Task | None = None
req_task: Optional[asyncio.Task] = None req_task: asyncio.Task | None = None
try: try:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
@ -234,11 +234,11 @@ async def download_url_to_video_output(
timeout: float = None, timeout: float = None,
max_retries: int = 5, max_retries: int = 5,
cls: type[COMFY_IO.ComfyNode] = None, cls: type[COMFY_IO.ComfyNode] = None,
) -> VideoFromFile: ) -> InputImpl.VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output.""" """Downloads a video from a URL and returns a `VIDEO` output."""
result = BytesIO() result = BytesIO()
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls) await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
return VideoFromFile(result) return InputImpl.VideoFromFile(result)
async def download_url_as_bytesio( async def download_url_as_bytesio(

View File

@ -1,5 +1,3 @@
from __future__ import annotations
import datetime import datetime
import hashlib import hashlib
import json import json

View File

@ -4,15 +4,13 @@ import logging
import time import time
import uuid import uuid
from io import BytesIO from io import BytesIO
from typing import Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import aiohttp import aiohttp
import torch import torch
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from comfy_api.latest import IO, Input from comfy_api.latest import IO, Input, Types
from comfy_api.util import VideoCodec, VideoContainer
from . import request_logger from . import request_logger
from ._helpers import is_processing_interrupted, sleep_with_interrupt from ._helpers import is_processing_interrupted, sleep_with_interrupt
@ -32,7 +30,7 @@ from .conversions import (
class UploadRequest(BaseModel): class UploadRequest(BaseModel):
file_name: str = Field(..., description="Filename to upload") file_name: str = Field(..., description="Filename to upload")
content_type: Optional[str] = Field( content_type: str | None = Field(
None, None,
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
) )
@ -56,7 +54,7 @@ async def upload_images_to_comfyapi(
Uploads images to ComfyUI API and returns download URLs. Uploads images to ComfyUI API and returns download URLs.
To upload multiple images, stack them in the batch dimension first. To upload multiple images, stack them in the batch dimension first.
""" """
# if batch, try to upload each file if max_images is greater than 0 # if batched, try to upload each file if max_images is greater than 0
download_urls: list[str] = [] download_urls: list[str] = []
is_batch = len(image.shape) > 3 is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1 batch_len = image.shape[0] if is_batch else 1
@ -100,9 +98,10 @@ async def upload_video_to_comfyapi(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
video: Input.Video, video: Input.Video,
*, *,
container: VideoContainer = VideoContainer.MP4, container: Types.VideoContainer = Types.VideoContainer.MP4,
codec: VideoCodec = VideoCodec.H264, codec: Types.VideoCodec = Types.VideoCodec.H264,
max_duration: Optional[int] = None, max_duration: int | None = None,
wait_label: str | None = "Uploading",
) -> str: ) -> str:
""" """
Uploads a single video to ComfyUI API and returns its download URL. Uploads a single video to ComfyUI API and returns its download URL.
@ -127,7 +126,7 @@ async def upload_video_to_comfyapi(
video.save_to(video_bytes_io, format=container, codec=codec) video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0) video_bytes_io.seek(0)
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type) return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
async def upload_file_to_comfyapi( async def upload_file_to_comfyapi(
@ -219,7 +218,7 @@ async def upload_file(
return return
monitor_task = asyncio.create_task(_monitor()) monitor_task = asyncio.create_task(_monitor())
sess: Optional[aiohttp.ClientSession] = None sess: aiohttp.ClientSession | None = None
try: try:
try: try:
request_logger.log_request_response( request_logger.log_request_response(

View File

@ -1,9 +1,7 @@
import logging import logging
from typing import Optional
import torch import torch
from comfy_api.input.video_types import VideoInput
from comfy_api.latest import Input from comfy_api.latest import Input
@ -18,10 +16,10 @@ def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
def validate_image_dimensions( def validate_image_dimensions(
image: torch.Tensor, image: torch.Tensor,
min_width: Optional[int] = None, min_width: int | None = None,
max_width: Optional[int] = None, max_width: int | None = None,
min_height: Optional[int] = None, min_height: int | None = None,
max_height: Optional[int] = None, max_height: int | None = None,
): ):
height, width = get_image_dimensions(image) height, width = get_image_dimensions(image)
@ -37,8 +35,8 @@ def validate_image_dimensions(
def validate_image_aspect_ratio( def validate_image_aspect_ratio(
image: torch.Tensor, image: torch.Tensor,
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4) min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1) max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
*, *,
strict: bool = True, # True -> (min, max); False -> [min, max] strict: bool = True, # True -> (min, max); False -> [min, max]
) -> float: ) -> float:
@ -84,8 +82,8 @@ def validate_images_aspect_ratio_closeness(
def validate_aspect_ratio_string( def validate_aspect_ratio_string(
aspect_ratio: str, aspect_ratio: str,
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4) min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1) max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
*, *,
strict: bool = False, # True -> (min, max); False -> [min, max] strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float: ) -> float:
@ -97,10 +95,10 @@ def validate_aspect_ratio_string(
def validate_video_dimensions( def validate_video_dimensions(
video: Input.Video, video: Input.Video,
min_width: Optional[int] = None, min_width: int | None = None,
max_width: Optional[int] = None, max_width: int | None = None,
min_height: Optional[int] = None, min_height: int | None = None,
max_height: Optional[int] = None, max_height: int | None = None,
): ):
try: try:
width, height = video.get_dimensions() width, height = video.get_dimensions()
@ -120,8 +118,8 @@ def validate_video_dimensions(
def validate_video_duration( def validate_video_duration(
video: Input.Video, video: Input.Video,
min_duration: Optional[float] = None, min_duration: float | None = None,
max_duration: Optional[float] = None, max_duration: float | None = None,
): ):
try: try:
duration = video.get_duration() duration = video.get_duration()
@ -136,6 +134,23 @@ def validate_video_duration(
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s") raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
def validate_video_frame_count(
video: Input.Video,
min_frame_count: int | None = None,
max_frame_count: int | None = None,
):
try:
frame_count = video.get_frame_count()
except Exception as e:
logging.error("Error getting frame count of video: %s", e)
return
if min_frame_count is not None and min_frame_count > frame_count:
raise ValueError(f"Video frame count must be at least {min_frame_count}, got {frame_count}")
if max_frame_count is not None and frame_count > max_frame_count:
raise ValueError(f"Video frame count must be at most {max_frame_count}, got {frame_count}")
def get_number_of_images(images): def get_number_of_images(images):
if isinstance(images, torch.Tensor): if isinstance(images, torch.Tensor):
return images.shape[0] if images.ndim >= 4 else 1 return images.shape[0] if images.ndim >= 4 else 1
@ -144,8 +159,8 @@ def get_number_of_images(images):
def validate_audio_duration( def validate_audio_duration(
audio: Input.Audio, audio: Input.Audio,
min_duration: Optional[float] = None, min_duration: float | None = None,
max_duration: Optional[float] = None, max_duration: float | None = None,
) -> None: ) -> None:
sr = int(audio["sample_rate"]) sr = int(audio["sample_rate"])
dur = int(audio["waveform"].shape[-1]) / sr dur = int(audio["waveform"].shape[-1]) / sr
@ -177,7 +192,7 @@ def validate_string(
) )
def validate_container_format_is_mp4(video: VideoInput) -> None: def validate_container_format_is_mp4(video: Input.Video) -> None:
"""Validates video container format is MP4.""" """Validates video container format is MP4."""
container_format = video.get_container_format() container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
@ -194,8 +209,8 @@ def _ratio_from_tuple(r: tuple[float, float]) -> float:
def _assert_ratio_bounds( def _assert_ratio_bounds(
ar: float, ar: float,
*, *,
min_ratio: Optional[tuple[float, float]] = None, min_ratio: tuple[float, float] | None = None,
max_ratio: Optional[tuple[float, float]] = None, max_ratio: tuple[float, float] | None = None,
strict: bool = True, strict: bool = True,
) -> None: ) -> None:
"""Validate a numeric aspect ratio against optional min/max ratio bounds.""" """Validate a numeric aspect ratio against optional min/max ratio bounds."""

View File

@ -1,4 +1,5 @@
from __future__ import annotations from __future__ import annotations
from comfy_api.latest import IO
def validate_node_input( def validate_node_input(
@ -23,6 +24,11 @@ def validate_node_input(
if not received_type != input_type: if not received_type != input_type:
return True return True
# If the received type or input_type is a MatchType, we can return True immediately;
# validation for this is handled by the frontend
if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
return True
# Not equal, and not strings # Not equal, and not strings
if not isinstance(received_type, str) or not isinstance(input_type, str): if not isinstance(received_type, str) or not isinstance(input_type, str):
return False return False

View File

@ -6,65 +6,80 @@ import torch
import comfy.model_management import comfy.model_management
import folder_paths import folder_paths
import os import os
import io
import json
import random
import hashlib import hashlib
import node_helpers import node_helpers
import logging import logging
from comfy.cli_args import args from typing_extensions import override
from comfy.comfy_types import FileLocator from comfy_api.latest import ComfyExtension, IO, UI
class EmptyLatentAudio: class EmptyLatentAudio(IO.ComfyNode):
def __init__(self): @classmethod
self.device = comfy.model_management.intermediate_device() def define_schema(cls):
return IO.Schema(
node_id="EmptyLatentAudio",
display_name="Empty Latent Audio",
category="latent/audio",
inputs=[
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
IO.Int.Input(
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
),
],
outputs=[IO.Latent.Output()],
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, seconds, batch_size) -> IO.NodeOutput:
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/audio"
def generate(self, seconds, batch_size):
length = round((seconds * 44100 / 2048) / 2) * 2 length = round((seconds * 44100 / 2048) / 2) * 2
latent = torch.zeros([batch_size, 64, length], device=self.device) latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
return ({"samples":latent, "type": "audio"}, ) return IO.NodeOutput({"samples":latent, "type": "audio"})
class ConditioningStableAudio: generate = execute # TODO: remove
class ConditioningStableAudio(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"positive": ("CONDITIONING", ), return IO.Schema(
"negative": ("CONDITIONING", ), node_id="ConditioningStableAudio",
"seconds_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}), category="conditioning",
"seconds_total": ("FLOAT", {"default": 47.0, "min": 0.0, "max": 1000.0, "step": 0.1}), inputs=[
}} IO.Conditioning.Input("positive"),
IO.Conditioning.Input("negative"),
IO.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1),
IO.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1),
],
outputs=[
IO.Conditioning.Output(display_name="positive"),
IO.Conditioning.Output(display_name="negative"),
],
)
RETURN_TYPES = ("CONDITIONING","CONDITIONING") @classmethod
RETURN_NAMES = ("positive", "negative") def execute(cls, positive, negative, seconds_start, seconds_total) -> IO.NodeOutput:
FUNCTION = "append"
CATEGORY = "conditioning"
def append(self, positive, negative, seconds_start, seconds_total):
positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total}) positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total})
negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total}) negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total})
return (positive, negative) return IO.NodeOutput(positive, negative)
class VAEEncodeAudio: append = execute # TODO: remove
class VAEEncodeAudio(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}} return IO.Schema(
RETURN_TYPES = ("LATENT",) node_id="VAEEncodeAudio",
FUNCTION = "encode" display_name="VAE Encode Audio",
category="latent/audio",
inputs=[
IO.Audio.Input("audio"),
IO.Vae.Input("vae"),
],
outputs=[IO.Latent.Output()],
)
CATEGORY = "latent/audio" @classmethod
def execute(cls, vae, audio) -> IO.NodeOutput:
def encode(self, vae, audio):
sample_rate = audio["sample_rate"] sample_rate = audio["sample_rate"]
if 44100 != sample_rate: if 44100 != sample_rate:
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100) waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
@ -72,213 +87,134 @@ class VAEEncodeAudio:
waveform = audio["waveform"] waveform = audio["waveform"]
t = vae.encode(waveform.movedim(1, -1)) t = vae.encode(waveform.movedim(1, -1))
return ({"samples":t}, ) return IO.NodeOutput({"samples":t})
class VAEDecodeAudio: encode = execute # TODO: remove
class VAEDecodeAudio(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}} return IO.Schema(
RETURN_TYPES = ("AUDIO",) node_id="VAEDecodeAudio",
FUNCTION = "decode" display_name="VAE Decode Audio",
category="latent/audio",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae"),
],
outputs=[IO.Audio.Output()],
)
CATEGORY = "latent/audio" @classmethod
def execute(cls, vae, samples) -> IO.NodeOutput:
def decode(self, vae, samples):
audio = vae.decode(samples["samples"]).movedim(-1, 1) audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0 std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0 std[std < 1.0] = 1.0
audio /= std audio /= std
return ({"waveform": audio, "sample_rate": 44100}, ) return IO.NodeOutput({"waveform": audio, "sample_rate": 44100})
decode = execute # TODO: remove
def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"): class SaveAudio(IO.ComfyNode):
@classmethod
filename_prefix += self.prefix_append def define_schema(cls):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) return IO.Schema(
results: list[FileLocator] = [] node_id="SaveAudio",
display_name="Save Audio (FLAC)",
# Prepare metadata dictionary category="audio",
metadata = {} inputs=[
if not args.disable_metadata: IO.Audio.Input("audio"),
if prompt is not None: IO.String.Input("filename_prefix", default="audio/ComfyUI"),
metadata["prompt"] = json.dumps(prompt) ],
if extra_pnginfo is not None: hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
for x in extra_pnginfo: is_output_node=True,
metadata[x] = json.dumps(extra_pnginfo[x]) )
# Opus supported sample rates
OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
output_path = os.path.join(full_output_folder, file)
# Use original sample rate initially
sample_rate = audio["sample_rate"]
# Handle Opus sample rate requirements
if format == "opus":
if sample_rate > 48000:
sample_rate = 48000
elif sample_rate not in OPUS_RATES:
# Find the next highest supported rate
for rate in sorted(OPUS_RATES):
if rate > sample_rate:
sample_rate = rate
break
if sample_rate not in OPUS_RATES: # Fallback if still not supported
sample_rate = 48000
# Resample if necessary
if sample_rate != audio["sample_rate"]:
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
# Create output with specified format
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format=format)
# Set metadata on the container
for key, value in metadata.items():
output_container.metadata[key] = value
layout = 'mono' if waveform.shape[0] == 1 else 'stereo'
# Set up the output stream with appropriate properties
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
out_stream.bit_rate = 96000
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "192k":
out_stream.bit_rate = 192000
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
if quality == "V0":
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "320k":
out_stream.bit_rate = 320000
else: #format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout)
frame.sample_rate = sample_rate
frame.pts = 0
output_container.mux(out_stream.encode(frame))
# Flush encoder
output_container.mux(out_stream.encode(None))
# Close containers
output_container.close()
# Write the output to file
output_buffer.seek(0)
with open(output_path, 'wb') as f:
f.write(output_buffer.getbuffer())
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "audio": results } }
class SaveAudio:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput:
return {"required": { "audio": ("AUDIO", ), return IO.NodeOutput(
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
}, )
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = () save_flac = execute # TODO: remove
FUNCTION = "save_flac"
OUTPUT_NODE = True
CATEGORY = "audio" class SaveAudioMP3(IO.ComfyNode):
@classmethod
def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None): def define_schema(cls):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo) return IO.Schema(
node_id="SaveAudioMP3",
class SaveAudioMP3: display_name="Save Audio (MP3)",
def __init__(self): category="audio",
self.output_dir = folder_paths.get_output_directory() inputs=[
self.type = "output" IO.Audio.Input("audio"),
self.prefix_append = "" IO.String.Input("filename_prefix", default="audio/ComfyUI"),
IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput:
return {"required": { "audio": ("AUDIO", ), return IO.NodeOutput(
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), ui=UI.AudioSaveHelper.get_save_audio_ui(
"quality": (["V0", "128k", "320k"], {"default": "V0"}), audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
}, )
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, )
}
RETURN_TYPES = () save_mp3 = execute # TODO: remove
FUNCTION = "save_mp3"
OUTPUT_NODE = True
CATEGORY = "audio" class SaveAudioOpus(IO.ComfyNode):
@classmethod
def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"): def define_schema(cls):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality) return IO.Schema(
node_id="SaveAudioOpus",
class SaveAudioOpus: display_name="Save Audio (Opus)",
def __init__(self): category="audio",
self.output_dir = folder_paths.get_output_directory() inputs=[
self.type = "output" IO.Audio.Input("audio"),
self.prefix_append = "" IO.String.Input("filename_prefix", default="audio/ComfyUI"),
IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput:
return {"required": { "audio": ("AUDIO", ), return IO.NodeOutput(
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), ui=UI.AudioSaveHelper.get_save_audio_ui(
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}), audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
}, )
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, )
}
RETURN_TYPES = () save_opus = execute # TODO: remove
FUNCTION = "save_opus"
OUTPUT_NODE = True
CATEGORY = "audio" class PreviewAudio(IO.ComfyNode):
@classmethod
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"): def define_schema(cls):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality) return IO.Schema(
node_id="PreviewAudio",
class PreviewAudio(SaveAudio): display_name="Preview Audio",
def __init__(self): category="audio",
self.output_dir = folder_paths.get_temp_directory() inputs=[
self.type = "temp" IO.Audio.Input("audio"),
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) ],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod @classmethod
def INPUT_TYPES(s): def execute(cls, audio) -> IO.NodeOutput:
return {"required": return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
{"audio": ("AUDIO", ), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, save_flac = execute # TODO: remove
}
def f32_pcm(wav: torch.Tensor) -> torch.Tensor: def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format.""" """Convert audio to float 32 bits PCM format."""
@ -316,26 +252,30 @@ def load(filepath: str) -> tuple[torch.Tensor, int]:
wav = f32_pcm(wav) wav = f32_pcm(wav)
return wav, sr return wav, sr
class LoadAudio: class LoadAudio(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
input_dir = folder_paths.get_input_directory() input_dir = folder_paths.get_input_directory()
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"]) files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
return {"required": {"audio": (sorted(files), {"audio_upload": True})}} return IO.Schema(
node_id="LoadAudio",
display_name="Load Audio",
category="audio",
inputs=[
IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)),
],
outputs=[IO.Audio.Output()],
)
CATEGORY = "audio" @classmethod
def execute(cls, audio) -> IO.NodeOutput:
RETURN_TYPES = ("AUDIO", )
FUNCTION = "load"
def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio) audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = load(audio_path) waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, ) return IO.NodeOutput(audio)
@classmethod @classmethod
def IS_CHANGED(s, audio): def fingerprint_inputs(cls, audio):
image_path = folder_paths.get_annotated_filepath(audio) image_path = folder_paths.get_annotated_filepath(audio)
m = hashlib.sha256() m = hashlib.sha256()
with open(image_path, 'rb') as f: with open(image_path, 'rb') as f:
@ -343,46 +283,69 @@ class LoadAudio:
return m.digest().hex() return m.digest().hex()
@classmethod @classmethod
def VALIDATE_INPUTS(s, audio): def validate_inputs(cls, audio):
if not folder_paths.exists_annotated_filepath(audio): if not folder_paths.exists_annotated_filepath(audio):
return "Invalid audio file: {}".format(audio) return "Invalid audio file: {}".format(audio)
return True return True
class RecordAudio: load = execute # TODO: remove
class RecordAudio(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": {"audio": ("AUDIO_RECORD", {})}} return IO.Schema(
node_id="RecordAudio",
display_name="Record Audio",
category="audio",
inputs=[
IO.Custom("AUDIO_RECORD").Input("audio"),
],
outputs=[IO.Audio.Output()],
)
CATEGORY = "audio" @classmethod
def execute(cls, audio) -> IO.NodeOutput:
RETURN_TYPES = ("AUDIO", )
FUNCTION = "load"
def load(self, audio):
audio_path = folder_paths.get_annotated_filepath(audio) audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = load(audio_path) waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, ) return IO.NodeOutput(audio)
load = execute # TODO: remove
class TrimAudioDuration: class TrimAudioDuration(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(cls): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="TrimAudioDuration",
"audio": ("AUDIO",), display_name="Trim Audio Duration",
"start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}), description="Trim audio tensor into chosen time range.",
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}), category="audio",
}, inputs=[
} IO.Audio.Input("audio"),
IO.Float.Input(
"start_index",
default=0.0,
min=-0xffffffffffffffff,
max=0xffffffffffffffff,
step=0.01,
tooltip="Start time in seconds, can be negative to count from the end (supports sub-seconds).",
),
IO.Float.Input(
"duration",
default=60.0,
min=0.0,
step=0.01,
tooltip="Duration in seconds",
),
],
outputs=[IO.Audio.Output()],
)
FUNCTION = "trim" @classmethod
RETURN_TYPES = ("AUDIO",) def execute(cls, audio, start_index, duration) -> IO.NodeOutput:
CATEGORY = "audio"
DESCRIPTION = "Trim audio tensor into chosen time range."
def trim(self, audio, start_index, duration):
waveform = audio["waveform"] waveform = audio["waveform"]
sample_rate = audio["sample_rate"] sample_rate = audio["sample_rate"]
audio_length = waveform.shape[-1] audio_length = waveform.shape[-1]
@ -399,23 +362,30 @@ class TrimAudioDuration:
if start_frame >= end_frame: if start_frame >= end_frame:
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.") raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},) return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate})
trim = execute # TODO: remove
class SplitAudioChannels: class SplitAudioChannels(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return IO.Schema(
"audio": ("AUDIO",), node_id="SplitAudioChannels",
}} display_name="Split Audio Channels",
description="Separates the audio into left and right channels.",
category="audio",
inputs=[
IO.Audio.Input("audio"),
],
outputs=[
IO.Audio.Output(display_name="left"),
IO.Audio.Output(display_name="right"),
],
)
RETURN_TYPES = ("AUDIO", "AUDIO") @classmethod
RETURN_NAMES = ("left", "right") def execute(cls, audio) -> IO.NodeOutput:
FUNCTION = "separate"
CATEGORY = "audio"
DESCRIPTION = "Separates the audio into left and right channels."
def separate(self, audio):
waveform = audio["waveform"] waveform = audio["waveform"]
sample_rate = audio["sample_rate"] sample_rate = audio["sample_rate"]
@ -425,7 +395,9 @@ class SplitAudioChannels:
left_channel = waveform[..., 0:1, :] left_channel = waveform[..., 0:1, :]
right_channel = waveform[..., 1:2, :] right_channel = waveform[..., 1:2, :]
return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate}) return IO.NodeOutput({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
separate = execute # TODO: remove
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2): def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
@ -443,21 +415,29 @@ def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_
return waveform_1, waveform_2, output_sample_rate return waveform_1, waveform_2, output_sample_rate
class AudioConcat: class AudioConcat(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return IO.Schema(
"audio1": ("AUDIO",), node_id="AudioConcat",
"audio2": ("AUDIO",), display_name="Audio Concat",
"direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}), description="Concatenates the audio1 to audio2 in the specified direction.",
}} category="audio",
inputs=[
IO.Audio.Input("audio1"),
IO.Audio.Input("audio2"),
IO.Combo.Input(
"direction",
options=['after', 'before'],
default="after",
tooltip="Whether to append audio2 after or before audio1.",
)
],
outputs=[IO.Audio.Output()],
)
RETURN_TYPES = ("AUDIO",) @classmethod
FUNCTION = "concat" def execute(cls, audio1, audio2, direction) -> IO.NodeOutput:
CATEGORY = "audio"
DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction."
def concat(self, audio1, audio2, direction):
waveform_1 = audio1["waveform"] waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"] waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"] sample_rate_1 = audio1["sample_rate"]
@ -477,26 +457,33 @@ class AudioConcat:
elif direction == 'before': elif direction == 'before':
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2) concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},) return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate})
concat = execute # TODO: remove
class AudioMerge: class AudioMerge(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(cls): def define_schema(cls):
return { return IO.Schema(
"required": { node_id="AudioMerge",
"audio1": ("AUDIO",), display_name="Audio Merge",
"audio2": ("AUDIO",), description="Combine two audio tracks by overlaying their waveforms.",
"merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}), category="audio",
}, inputs=[
} IO.Audio.Input("audio1"),
IO.Audio.Input("audio2"),
IO.Combo.Input(
"merge_method",
options=["add", "mean", "subtract", "multiply"],
tooltip="The method used to combine the audio waveforms.",
)
],
outputs=[IO.Audio.Output()],
)
FUNCTION = "merge" @classmethod
RETURN_TYPES = ("AUDIO",) def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput:
CATEGORY = "audio"
DESCRIPTION = "Combine two audio tracks by overlaying their waveforms."
def merge(self, audio1, audio2, merge_method):
waveform_1 = audio1["waveform"] waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"] waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"] sample_rate_1 = audio1["sample_rate"]
@ -530,85 +517,108 @@ class AudioMerge:
if max_val > 1.0: if max_val > 1.0:
waveform = waveform / max_val waveform = waveform / max_val
return ({"waveform": waveform, "sample_rate": output_sample_rate},) return IO.NodeOutput({"waveform": waveform, "sample_rate": output_sample_rate})
merge = execute # TODO: remove
class AudioAdjustVolume: class AudioAdjustVolume(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return IO.Schema(
"audio": ("AUDIO",), node_id="AudioAdjustVolume",
"volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}), display_name="Audio Adjust Volume",
}} category="audio",
inputs=[
IO.Audio.Input("audio"),
IO.Int.Input(
"volume",
default=1,
min=-100,
max=100,
tooltip="Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc",
)
],
outputs=[IO.Audio.Output()],
)
RETURN_TYPES = ("AUDIO",) @classmethod
FUNCTION = "adjust_volume" def execute(cls, audio, volume) -> IO.NodeOutput:
CATEGORY = "audio"
def adjust_volume(self, audio, volume):
if volume == 0: if volume == 0:
return (audio,) return IO.NodeOutput(audio)
waveform = audio["waveform"] waveform = audio["waveform"]
sample_rate = audio["sample_rate"] sample_rate = audio["sample_rate"]
gain = 10 ** (volume / 20) gain = 10 ** (volume / 20)
waveform = waveform * gain waveform = waveform * gain
return ({"waveform": waveform, "sample_rate": sample_rate},) return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
adjust_volume = execute # TODO: remove
class EmptyAudio: class EmptyAudio(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return IO.Schema(
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}), node_id="EmptyAudio",
"sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}), display_name="Empty Audio",
"channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}), category="audio",
}} inputs=[
IO.Float.Input(
"duration",
default=60.0,
min=0.0,
max=0xffffffffffffffff,
step=0.01,
tooltip="Duration of the empty audio clip in seconds",
),
IO.Float.Input(
"sample_rate",
default=44100,
tooltip="Sample rate of the empty audio clip.",
),
IO.Float.Input(
"channels",
default=2,
min=1,
max=2,
tooltip="Number of audio channels (1 for mono, 2 for stereo).",
),
],
outputs=[IO.Audio.Output()],
)
RETURN_TYPES = ("AUDIO",) @classmethod
FUNCTION = "create_empty_audio" def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput:
CATEGORY = "audio"
def create_empty_audio(self, duration, sample_rate, channels):
num_samples = int(round(duration * sample_rate)) num_samples = int(round(duration * sample_rate))
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32) waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
return ({"waveform": waveform, "sample_rate": sample_rate},) return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
create_empty_audio = execute # TODO: remove
NODE_CLASS_MAPPINGS = { class AudioExtension(ComfyExtension):
"EmptyLatentAudio": EmptyLatentAudio, @override
"VAEEncodeAudio": VAEEncodeAudio, async def get_node_list(self) -> list[type[IO.ComfyNode]]:
"VAEDecodeAudio": VAEDecodeAudio, return [
"SaveAudio": SaveAudio, EmptyLatentAudio,
"SaveAudioMP3": SaveAudioMP3, VAEEncodeAudio,
"SaveAudioOpus": SaveAudioOpus, VAEDecodeAudio,
"LoadAudio": LoadAudio, SaveAudio,
"PreviewAudio": PreviewAudio, SaveAudioMP3,
"ConditioningStableAudio": ConditioningStableAudio, SaveAudioOpus,
"RecordAudio": RecordAudio, LoadAudio,
"TrimAudioDuration": TrimAudioDuration, PreviewAudio,
"SplitAudioChannels": SplitAudioChannels, ConditioningStableAudio,
"AudioConcat": AudioConcat, RecordAudio,
"AudioMerge": AudioMerge, TrimAudioDuration,
"AudioAdjustVolume": AudioAdjustVolume, SplitAudioChannels,
"EmptyAudio": EmptyAudio, AudioConcat,
} AudioMerge,
AudioAdjustVolume,
EmptyAudio,
]
NODE_DISPLAY_NAME_MAPPINGS = { async def comfy_entrypoint() -> AudioExtension:
"EmptyLatentAudio": "Empty Latent Audio", return AudioExtension()
"VAEEncodeAudio": "VAE Encode Audio",
"VAEDecodeAudio": "VAE Decode Audio",
"PreviewAudio": "Preview Audio",
"LoadAudio": "Load Audio",
"SaveAudio": "Save Audio (FLAC)",
"SaveAudioMP3": "Save Audio (MP3)",
"SaveAudioOpus": "Save Audio (Opus)",
"RecordAudio": "Record Audio",
"TrimAudioDuration": "Trim Audio Duration",
"SplitAudioChannels": "Split Audio Channels",
"AudioConcat": "Audio Concat",
"AudioMerge": "Audio Merge",
"AudioAdjustVolume": "Audio Adjust Volume",
"EmptyAudio": "Empty Audio",
}

View File

@ -2,22 +2,18 @@ import nodes
import folder_paths import folder_paths
import os import os
from comfy.comfy_types import IO from typing_extensions import override
from comfy_api.input_impl import VideoFromFile from comfy_api.latest import IO, ComfyExtension, InputImpl, UI
from pathlib import Path from pathlib import Path
from PIL import Image
import numpy as np
import uuid
def normalize_path(path): def normalize_path(path):
return path.replace('\\', '/') return path.replace('\\', '/')
class Load3D(): class Load3D(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
input_dir = os.path.join(folder_paths.get_input_directory(), "3d") input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
os.makedirs(input_dir, exist_ok=True) os.makedirs(input_dir, exist_ok=True)
@ -30,23 +26,29 @@ class Load3D():
for file_path in input_path.rglob("*") for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'} if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
] ]
return IO.Schema(
node_id="Load3D",
display_name="Load 3D & Animation",
category="3d",
is_experimental=True,
inputs=[
IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model),
IO.Load3D.Input("image"),
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
],
outputs=[
IO.Image.Output(display_name="image"),
IO.Mask.Output(display_name="mask"),
IO.String.Output(display_name="mesh_path"),
IO.Image.Output(display_name="normal"),
IO.Load3DCamera.Output(display_name="camera_info"),
IO.Video.Output(display_name="recording_video"),
],
)
return {"required": { @classmethod
"model_file": (sorted(files), {"file_upload": True}), def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput:
"image": ("LOAD_3D", {}),
"width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
FUNCTION = "process"
EXPERIMENTAL = True
CATEGORY = "3d"
def process(self, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image['image']) image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask']) mask_path = folder_paths.get_annotated_filepath(image['mask'])
normal_path = folder_paths.get_annotated_filepath(image['normal']) normal_path = folder_paths.get_annotated_filepath(image['normal'])
@ -61,58 +63,47 @@ class Load3D():
if image['recording'] != "": if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image['recording']) recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
video = VideoFromFile(recording_video_path) video = InputImpl.VideoFromFile(recording_video_path)
return output_image, output_mask, model_file, normal_image, image['camera_info'], video return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video)
class Preview3D(): process = execute # TODO: remove
class Preview3D(IO.ComfyNode):
@classmethod @classmethod
def INPUT_TYPES(s): def define_schema(cls):
return {"required": { return IO.Schema(
"model_file": ("STRING", {"default": "", "multiline": False}), node_id="Preview3D",
}, display_name="Preview 3D & Animation",
"optional": { category="3d",
"camera_info": ("LOAD3D_CAMERA", {}), is_experimental=True,
"bg_image": ("IMAGE", {}) is_output_node=True,
}} inputs=[
IO.String.Input("model_file", default="", multiline=False),
IO.Load3DCamera.Input("camera_info", optional=True),
IO.Image.Input("bg_image", optional=True),
],
outputs=[],
)
OUTPUT_NODE = True @classmethod
RETURN_TYPES = () def execute(cls, model_file, **kwargs) -> IO.NodeOutput:
CATEGORY = "3d"
FUNCTION = "process"
EXPERIMENTAL = True
def process(self, model_file, **kwargs):
camera_info = kwargs.get("camera_info", None) camera_info = kwargs.get("camera_info", None)
bg_image = kwargs.get("bg_image", None) bg_image = kwargs.get("bg_image", None)
return IO.NodeOutput(ui=UI.PreviewUI3D(model_file, camera_info, bg_image=bg_image))
bg_image_path = None process = execute # TODO: remove
if bg_image is not None:
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
img = Image.fromarray(img_array)
temp_dir = folder_paths.get_temp_directory() class Load3DExtension(ComfyExtension):
filename = f"bg_{uuid.uuid4().hex}.png" @override
bg_image_path = os.path.join(temp_dir, filename) async def get_node_list(self) -> list[type[IO.ComfyNode]]:
img.save(bg_image_path, compress_level=1) return [
Load3D,
Preview3D,
]
bg_image_path = f"temp/{filename}"
return { async def comfy_entrypoint() -> Load3DExtension:
"ui": { return Load3DExtension()
"result": [model_file, camera_info, bg_image_path]
}
}
NODE_CLASS_MAPPINGS = {
"Load3D": Load3D,
"Preview3D": Preview3D,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"Load3D": "Load 3D & Animation",
"Preview3D": "Preview 3D & Animation",
}

155
comfy_extras/nodes_logic.py Normal file
View File

@ -0,0 +1,155 @@
from typing import TypedDict
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from comfy_api.latest import _io
class SwitchNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = io.MatchType.Template("switch")
return io.Schema(
node_id="ComfySwitchNode",
display_name="Switch",
category="logic",
is_experimental=True,
inputs=[
io.Boolean.Input("switch"),
io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
io.MatchType.Input("on_true", template=template, lazy=True, optional=True),
],
outputs=[
io.MatchType.Output(template=template, display_name="output"),
],
)
@classmethod
def check_lazy_status(cls, switch, on_false=..., on_true=...):
# We use ... instead of None, as None is passed for connected-but-unevaluated inputs.
# This trick allows us to ignore the value of the switch and still be able to run execute().
# One of the inputs may be missing, in which case we need to evaluate the other input
if on_false is ...:
return ["on_true"]
if on_true is ...:
return ["on_false"]
# Normal lazy switch operation
if switch and on_true is None:
return ["on_true"]
if not switch and on_false is None:
return ["on_false"]
@classmethod
def validate_inputs(cls, switch, on_false=..., on_true=...):
# This check happens before check_lazy_status(), so we can eliminate the case where
# both inputs are missing.
if on_false is ... and on_true is ...:
return "At least one of on_false or on_true must be connected to Switch node"
return True
@classmethod
def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput:
if on_true is ...:
return io.NodeOutput(on_false)
if on_false is ...:
return io.NodeOutput(on_true)
return io.NodeOutput(on_true if switch else on_false)
class DCTestNode(io.ComfyNode):
class DCValues(TypedDict):
combo: str
string: str
integer: int
image: io.Image.Type
subcombo: dict[str]
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DCTestNode",
display_name="DCTest",
category="logic",
is_output_node=True,
inputs=[_io.DynamicCombo.Input("combo", options=[
_io.DynamicCombo.Option("option1", [io.String.Input("string")]),
_io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
_io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
_io.DynamicCombo.Option("option4", [
_io.DynamicCombo.Input("subcombo", options=[
_io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
_io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
])
])]
)],
outputs=[io.AnyType.Output()],
)
@classmethod
def execute(cls, combo: DCValues) -> io.NodeOutput:
combo_val = combo["combo"]
if combo_val == "option1":
return io.NodeOutput(combo["string"])
elif combo_val == "option2":
return io.NodeOutput(combo["integer"])
elif combo_val == "option3":
return io.NodeOutput(combo["image"])
elif combo_val == "option4":
return io.NodeOutput(f"{combo['subcombo']}")
else:
raise ValueError(f"Invalid combo: {combo_val}")
class AutogrowNamesTestNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = _io.Autogrow.TemplateNames(input=io.Float.Input("float"), names=["a", "b", "c"])
return io.Schema(
node_id="AutogrowNamesTestNode",
display_name="AutogrowNamesTest",
category="logic",
inputs=[
_io.Autogrow.Input("autogrow", template=template)
],
outputs=[io.String.Output()],
)
@classmethod
def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
vals = list(autogrow.values())
combined = ",".join([str(x) for x in vals])
return io.NodeOutput(combined)
class AutogrowPrefixTestNode(io.ComfyNode):
@classmethod
def define_schema(cls):
template = _io.Autogrow.TemplatePrefix(input=io.Float.Input("float"), prefix="float", min=1, max=10)
return io.Schema(
node_id="AutogrowPrefixTestNode",
display_name="AutogrowPrefixTest",
category="logic",
inputs=[
_io.Autogrow.Input("autogrow", template=template)
],
outputs=[io.String.Output()],
)
@classmethod
def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
vals = list(autogrow.values())
combined = ",".join([str(x) for x in vals])
return io.NodeOutput(combined)
class LogicExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
# SwitchNode,
# DCTestNode,
# AutogrowNamesTestNode,
# AutogrowPrefixTestNode,
]
async def comfy_entrypoint() -> LogicExtension:
return LogicExtension()

View File

@ -6,6 +6,7 @@ import comfy.ops
import comfy.model_management import comfy.model_management
import comfy.ldm.common_dit import comfy.ldm.common_dit
import comfy.latent_formats import comfy.latent_formats
import comfy.ldm.lumina.controlnet
class BlockWiseControlBlock(torch.nn.Module): class BlockWiseControlBlock(torch.nn.Module):
@ -189,6 +190,35 @@ class SigLIPMultiFeatProjModel(torch.nn.Module):
return embedding return embedding
def z_image_convert(sd):
replace_keys = {".attention.to_out.0.bias": ".attention.out.bias",
".attention.norm_k.weight": ".attention.k_norm.weight",
".attention.norm_q.weight": ".attention.q_norm.weight",
".attention.to_out.0.weight": ".attention.out.weight"
}
out_sd = {}
for k in sorted(sd.keys()):
w = sd[k]
k_out = k
if k_out.endswith(".attention.to_k.weight"):
cc = [w]
continue
if k_out.endswith(".attention.to_q.weight"):
cc = [w] + cc
continue
if k_out.endswith(".attention.to_v.weight"):
cc = cc + [w]
w = torch.cat(cc, dim=0)
k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight")
for r, rr in replace_keys.items():
k_out = k_out.replace(r, rr)
out_sd[k_out] = w
return out_sd
class ModelPatchLoader: class ModelPatchLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -211,6 +241,9 @@ class ModelPatchLoader:
elif 'feature_embedder.mid_layer_norm.bias' in sd: elif 'feature_embedder.mid_layer_norm.bias' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True) sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
sd = z_image_convert(sd)
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
model.load_state_dict(sd) model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
@ -263,6 +296,69 @@ class DiffSynthCnetPatch:
def models(self): def models(self):
return [self.model_patch] return [self.model_patch]
class ZImageControlPatch:
def __init__(self, model_patch, vae, image, strength):
self.model_patch = model_patch
self.vae = vae
self.image = image
self.strength = strength
self.encoded_image = self.encode_latent_cond(image)
self.encoded_image_size = (image.shape[1], image.shape[2])
self.temp_data = None
def encode_latent_cond(self, image):
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
return latent_image
def __call__(self, kwargs):
x = kwargs.get("x")
img = kwargs.get("img")
txt = kwargs.get("txt")
pe = kwargs.get("pe")
vec = kwargs.get("vec")
block_index = kwargs.get("block_index")
spacial_compression = self.vae.spacial_compression_encode()
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
comfy.model_management.load_models_gpu(loaded_models)
cnet_index = (block_index // 5)
cnet_index_float = (block_index / 5)
kwargs.pop("img") # we do ops in place
kwargs.pop("txt")
cnet_blocks = self.model_patch.model.n_control_layers
if cnet_index_float > (cnet_blocks - 1):
self.temp_data = None
return kwargs
if self.temp_data is None or self.temp_data[0] > cnet_index:
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
next_layer = self.temp_data[0] + 1
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
if cnet_index_float == self.temp_data[0]:
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
if cnet_blocks == self.temp_data[0] + 1:
self.temp_data = None
return kwargs
def to(self, device_or_dtype):
if isinstance(device_or_dtype, torch.device):
self.encoded_image = self.encoded_image.to(device_or_dtype)
self.temp_data = None
return self
def models(self):
return [self.model_patch]
class QwenImageDiffsynthControlnet: class QwenImageDiffsynthControlnet:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -289,6 +385,9 @@ class QwenImageDiffsynthControlnet:
mask = mask.unsqueeze(2) mask = mask.unsqueeze(2)
mask = 1.0 - mask mask = 1.0 - mask
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
else:
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
return (model_patched,) return (model_patched,)

View File

@ -623,7 +623,7 @@ class TrainLoraNode(io.ComfyNode):
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed) noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
if multi_res: if multi_res:
# use first latent as dummy latent if multi_res # use first latent as dummy latent if multi_res
latents = latents[0].repeat(num_images, 1, 1, 1) latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1)))
guider.sample( guider.sample(
noise.generate_noise({"samples": latents}), noise.generate_noise({"samples": latents}),
latents, latents,

View File

@ -88,7 +88,7 @@ class SaveVideo(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput: def execute(cls, video: VideoInput, filename_prefix, format: str, codec) -> io.NodeOutput:
width, height = video.get_dimensions() width, height = video.get_dimensions()
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix, filename_prefix,
@ -108,7 +108,7 @@ class SaveVideo(io.ComfyNode):
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}" file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
video.save_to( video.save_to(
os.path.join(full_output_folder, file), os.path.join(full_output_folder, file),
format=format, format=VideoContainer(format),
codec=codec, codec=codec,
metadata=saved_metadata metadata=saved_metadata
) )

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.75" __version__ = "0.3.76"

View File

@ -34,7 +34,7 @@ from comfy_execution.validation import validate_node_input
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
from comfy_execution.utils import CurrentNodeContext from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io from comfy_api.latest import io, _io
class ExecutionResult(Enum): class ExecutionResult(Enum):
@ -76,7 +76,7 @@ class IsChangedCache:
return self.is_changed[node_id] return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None) input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
try: try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name) is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
is_changed = await resolve_map_node_over_list_results(is_changed) is_changed = await resolve_map_node_over_list_results(is_changed)
@ -146,8 +146,9 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
is_v3 = issubclass(class_def, _ComfyNodeInternal) is_v3 = issubclass(class_def, _ComfyNodeInternal)
v3_data: io.V3Data = {}
if is_v3: if is_v3:
valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True) valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
else: else:
valid_inputs = class_def.INPUT_TYPES() valid_inputs = class_def.INPUT_TYPES()
input_data_all = {} input_data_all = {}
@ -207,7 +208,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
if h[x] == "API_KEY_COMFY_ORG": if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
return input_data_all, missing_keys, hidden_inputs_v3 v3_data["hidden_inputs"] = hidden_inputs_v3
return input_data_all, missing_keys, v3_data
map_node_over_list = None #Don't hook this please map_node_over_list = None #Don't hook this please
@ -223,7 +225,7 @@ async def resolve_map_node_over_list_results(results):
raise exc raise exc
return [x.result() if isinstance(x, asyncio.Task) else x for x in results] return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
# check if node wants the lists # check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False) input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@ -259,13 +261,16 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
if is_class(obj): if is_class(obj):
type_obj = obj type_obj = obj
obj.VALIDATE_CLASS() obj.VALIDATE_CLASS()
class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs) class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
# otherwise, use class instance to populate/reuse some fields # otherwise, use class instance to populate/reuse some fields
else: else:
type_obj = type(obj) type_obj = type(obj)
type_obj.VALIDATE_CLASS() type_obj.VALIDATE_CLASS()
class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs) class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
f = make_locked_method_func(type_obj, func, class_clone) f = make_locked_method_func(type_obj, func, class_clone)
# in case of dynamic inputs, restructure inputs to expected nested dict
if v3_data is not None:
inputs = _io.build_nested_inputs(inputs, v3_data)
# V1 # V1
else: else:
f = getattr(obj, func) f = getattr(obj, func)
@ -320,8 +325,8 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results]) output.append([o[i] for o in results])
return output return output
async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task: if has_pending_task:
return return_values, {}, False, has_pending_task return return_values, {}, False, has_pending_task
@ -460,7 +465,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
has_subgraph = False has_subgraph = False
else: else:
get_progress_state().start_progress(unique_id) get_progress_state().start_progress(unique_id)
input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None: if server.client_id is not None:
server.last_node_id = display_node_id server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@ -475,7 +480,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
else: else:
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
if lazy_status_present: if lazy_status_present:
required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs) required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data)
required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and ( required_inputs = [x for x in required_inputs if isinstance(x,str) and (
@ -507,7 +512,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
def pre_execute_cb(call_index): def pre_execute_cb(call_index):
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0) GraphBuilder.set_default_prefix(unique_id, call_index, 0)
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
if has_pending_tasks: if has_pending_tasks:
pending_async_nodes[unique_id] = output_data pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id) unblock = execution_list.add_external_block(unique_id)
@ -745,18 +750,17 @@ async def validate_inputs(prompt_id, prompt, item, validated):
class_type = prompt[unique_id]['class_type'] class_type = prompt[unique_id]['class_type']
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
class_inputs = obj_class.INPUT_TYPES()
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
errors = [] errors = []
valid = True valid = True
validate_function_inputs = [] validate_function_inputs = []
validate_has_kwargs = False validate_has_kwargs = False
if issubclass(obj_class, _ComfyNodeInternal): if issubclass(obj_class, _ComfyNodeInternal):
class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
validate_function_name = "validate_inputs" validate_function_name = "validate_inputs"
validate_function = first_real_override(obj_class, validate_function_name) validate_function = first_real_override(obj_class, validate_function_name)
else: else:
class_inputs = obj_class.INPUT_TYPES()
validate_function_name = "VALIDATE_INPUTS" validate_function_name = "VALIDATE_INPUTS"
validate_function = getattr(obj_class, validate_function_name, None) validate_function = getattr(obj_class, validate_function_name, None)
if validate_function is not None: if validate_function is not None:
@ -765,6 +769,8 @@ async def validate_inputs(prompt_id, prompt, item, validated):
validate_has_kwargs = argspec.varkw is not None validate_has_kwargs = argspec.varkw is not None
received_types = {} received_types = {}
valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
for x in valid_inputs: for x in valid_inputs:
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs) input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
assert extra_info is not None assert extra_info is not None
@ -935,7 +941,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
continue continue
if len(validate_function_inputs) > 0 or validate_has_kwargs: if len(validate_function_inputs) > 0 or validate_has_kwargs:
input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id) input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id)
input_filtered = {} input_filtered = {}
for x in input_data_all: for x in input_data_all:
if x in validate_function_inputs or validate_has_kwargs: if x in validate_function_inputs or validate_has_kwargs:
@ -943,7 +949,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
if 'input_types' in validate_function_inputs: if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types] input_filtered['input_types'] = [received_types]
ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs) ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data)
ret = await resolve_map_node_over_list_results(ret) ret = await resolve_map_node_over_list_results(ret)
for x in input_filtered: for x in input_filtered:
for i, r in enumerate(ret): for i, r in enumerate(ret):

30
main.py
View File

@ -15,6 +15,7 @@ from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context from comfy_execution.utils import get_executing_context
from comfy_api import feature_flags from comfy_api import feature_flags
if __name__ == "__main__": if __name__ == "__main__":
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes. #NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
@ -22,6 +23,23 @@ if __name__ == "__main__":
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
def handle_comfyui_manager_unavailable():
if not args.windows_standalone_build:
logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n")
args.enable_manager = False
if args.enable_manager:
if importlib.util.find_spec("comfyui_manager"):
import comfyui_manager
if not comfyui_manager.__file__ or not comfyui_manager.__file__.endswith('__init__.py'):
handle_comfyui_manager_unavailable()
else:
handle_comfyui_manager_unavailable()
def apply_custom_paths(): def apply_custom_paths():
# extra model paths # extra model paths
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml") extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
@ -79,6 +97,11 @@ def execute_prestartup_script():
for possible_module in possible_modules: for possible_module in possible_modules:
module_path = os.path.join(custom_node_path, possible_module) module_path = os.path.join(custom_node_path, possible_module)
if args.enable_manager:
if comfyui_manager.should_be_disabled(module_path):
continue
if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__": if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
continue continue
@ -101,6 +124,10 @@ def execute_prestartup_script():
logging.info("") logging.info("")
apply_custom_paths() apply_custom_paths()
if args.enable_manager:
comfyui_manager.prestartup()
execute_prestartup_script() execute_prestartup_script()
@ -323,6 +350,9 @@ def start_comfyui(asyncio_loop=None):
asyncio.set_event_loop(asyncio_loop) asyncio.set_event_loop(asyncio_loop)
prompt_server = server.PromptServer(asyncio_loop) prompt_server = server.PromptServer(asyncio_loop)
if args.enable_manager and not args.disable_manager_ui:
comfyui_manager.start()
hook_breaker_ac10a0.save_functions() hook_breaker_ac10a0.save_functions()
asyncio_loop.run_until_complete(nodes.init_extra_nodes( asyncio_loop.run_until_complete(nodes.init_extra_nodes(
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0, init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,

1
manager_requirements.txt Normal file
View File

@ -0,0 +1 @@
comfyui_manager==4.0.3b3

View File

@ -43,6 +43,9 @@ import folder_paths
import latent_preview import latent_preview
import node_helpers import node_helpers
if args.enable_manager:
import comfyui_manager
def before_node_execution(): def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
@ -939,7 +942,7 @@ class CLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2"], ), "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"], ),
}, },
"optional": { "optional": {
"device": (["default", "cpu"], {"advanced": True}), "device": (["default", "cpu"], {"advanced": True}),
@ -2243,6 +2246,12 @@ async def init_external_custom_nodes():
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes: if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes") logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
continue continue
if args.enable_manager:
if comfyui_manager.should_be_disabled(module_path):
logging.info(f"Blocked by policy: {module_path}")
continue
time_before = time.perf_counter() time_before = time.perf_counter()
success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes") success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
node_import_times.append((time.perf_counter() - time_before, module_path, success)) node_import_times.append((time.perf_counter() - time_before, module_path, success))
@ -2346,6 +2355,7 @@ async def init_builtin_extra_nodes():
"nodes_easycache.py", "nodes_easycache.py",
"nodes_audio_encoder.py", "nodes_audio_encoder.py",
"nodes_rope.py", "nodes_rope.py",
"nodes_logic.py",
"nodes_nop.py", "nodes_nop.py",
] ]

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.75" version = "0.3.76"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"

View File

@ -1,4 +1,4 @@
comfyui-frontend-package==1.32.9 comfyui-frontend-package==1.33.10
comfyui-workflow-templates==0.7.25 comfyui-workflow-templates==0.7.25
comfyui-embedded-docs==0.3.1 comfyui-embedded-docs==0.3.1
torch torch

View File

@ -44,6 +44,9 @@ from protocol import BinaryEventTypes
# Import cache control middleware # Import cache control middleware
from middleware.cache_middleware import cache_control from middleware.cache_middleware import cache_control
if args.enable_manager:
import comfyui_manager
async def send_socket_catch_exception(function, message): async def send_socket_catch_exception(function, message):
try: try:
await function(message) await function(message)
@ -95,7 +98,7 @@ def create_cors_middleware(allowed_origin: str):
response = await handler(request) response = await handler(request)
response.headers['Access-Control-Allow-Origin'] = allowed_origin response.headers['Access-Control-Allow-Origin'] = allowed_origin
response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS, PATCH'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
response.headers['Access-Control-Allow-Credentials'] = 'true' response.headers['Access-Control-Allow-Credentials'] = 'true'
return response return response
@ -212,6 +215,9 @@ class PromptServer():
if args.disable_api_nodes: if args.disable_api_nodes:
middlewares.append(create_block_external_middleware()) middlewares.append(create_block_external_middleware())
if args.enable_manager:
middlewares.append(comfyui_manager.create_middleware())
max_upload_size = round(args.max_upload_size * 1024 * 1024) max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares) self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
self.sockets = dict() self.sockets = dict()
@ -599,7 +605,7 @@ class PromptServer():
system_stats = { system_stats = {
"system": { "system": {
"os": os.name, "os": sys.platform,
"ram_total": ram_total, "ram_total": ram_total,
"ram_free": ram_free, "ram_free": ram_free,
"comfyui_version": __version__, "comfyui_version": __version__,