diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py
index 51a263203..59ece5130 100755
--- a/.ci/update_windows/update.py
+++ b/.ci/update_windows/update.py
@@ -66,8 +66,10 @@ if branch is None:
try:
ref = repo.lookup_reference('refs/remotes/origin/master')
except:
- print("pulling.") # noqa: T201
- pull(repo)
+ print("fetching.") # noqa: T201
+ for remote in repo.remotes:
+ if remote.name == "origin":
+ remote.fetch()
ref = repo.lookup_reference('refs/remotes/origin/master')
repo.checkout(ref)
branch = repo.lookup_branch('master')
@@ -149,3 +151,4 @@ try:
shutil.copy(stable_update_script, stable_update_script_to)
except:
pass
+
diff --git a/CODEOWNERS b/CODEOWNERS
index b7aca9b26..4d5448636 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -1,3 +1,2 @@
# Admins
-* @comfyanonymous
-* @kosinkadink
+* @comfyanonymous @kosinkadink @guill
diff --git a/README.md b/README.md
index 91fb510e1..ed857df9f 100644
--- a/README.md
+++ b/README.md
@@ -81,6 +81,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
+ - [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5)
- Audio Models
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index a3c4a6bc6..6becebcb5 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -122,6 +122,12 @@ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
+parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
+manager_group = parser.add_mutually_exclusive_group()
+manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
+manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
+
+
vram_group = parser.add_mutually_exclusive_group()
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
@@ -169,6 +175,7 @@ parser.add_argument("--multi-user", action="store_true", help="Enables per-user
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
+
# The default built-in provider hosted under web/
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py
index a72f8cc47..2e8ef0687 100644
--- a/comfy/ldm/chroma/model.py
+++ b/comfy/ldm/chroma/model.py
@@ -40,7 +40,8 @@ class ChromaParams:
out_dim: int
hidden_dim: int
n_layers: int
-
+ txt_ids_dims: list
+ vec_in_dim: int
diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py
index 2472ab79c..60f2bdae2 100644
--- a/comfy/ldm/flux/layers.py
+++ b/comfy/ldm/flux/layers.py
@@ -57,6 +57,35 @@ class MLPEmbedder(nn.Module):
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
+class YakMLP(nn.Module):
+ def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
+ self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
+ self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device)
+ self.act_fn = nn.SiLU()
+
+ def forward(self, x: Tensor) -> Tensor:
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
+ return down_proj
+
+def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None):
+ if yak_mlp:
+ return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
+ if mlp_silu_act:
+ return nn.Sequential(
+ operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
+ SiLUActivation(),
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
+ )
+ else:
+ return nn.Sequential(
+ operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
+ nn.GELU(approximate="tanh"),
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
+ )
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
@@ -140,7 +169,7 @@ class SiLUActivation(nn.Module):
class DoubleStreamBlock(nn.Module):
- def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, dtype=None, device=None, operations=None):
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@@ -156,18 +185,7 @@ class DoubleStreamBlock(nn.Module):
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
- if mlp_silu_act:
- self.img_mlp = nn.Sequential(
- operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
- SiLUActivation(),
- operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
- )
- else:
- self.img_mlp = nn.Sequential(
- operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
- nn.GELU(approximate="tanh"),
- operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
- )
+ self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
if self.modulation:
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
@@ -177,18 +195,7 @@ class DoubleStreamBlock(nn.Module):
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
- if mlp_silu_act:
- self.txt_mlp = nn.Sequential(
- operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
- SiLUActivation(),
- operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
- )
- else:
- self.txt_mlp = nn.Sequential(
- operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
- nn.GELU(approximate="tanh"),
- operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
- )
+ self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
self.flipped_img_txt = flipped_img_txt
@@ -275,6 +282,7 @@ class SingleStreamBlock(nn.Module):
modulation=True,
mlp_silu_act=False,
bias=True,
+ yak_mlp=False,
dtype=None,
device=None,
operations=None
@@ -288,12 +296,17 @@ class SingleStreamBlock(nn.Module):
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp_hidden_dim_first = self.mlp_hidden_dim
+ self.yak_mlp = yak_mlp
if mlp_silu_act:
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
self.mlp_act = SiLUActivation()
else:
self.mlp_act = nn.GELU(approximate="tanh")
+ if self.yak_mlp:
+ self.mlp_hidden_dim_first *= 2
+ self.mlp_act = nn.SiLU()
+
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
# proj and mlp_out
@@ -325,7 +338,10 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
# compute activation in mlp stream, cat again and run second linear layer
- mlp = self.mlp_act(mlp)
+ if self.yak_mlp:
+ mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
+ else:
+ mlp = self.mlp_act(mlp)
output = self.linear2(torch.cat((attn, mlp), 2))
x += apply_mod(output, mod.gate, None, modulation_dims)
if x.dtype == torch.float16:
diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py
index d5674dea6..f40c2a7a9 100644
--- a/comfy/ldm/flux/model.py
+++ b/comfy/ldm/flux/model.py
@@ -15,7 +15,8 @@ from .layers import (
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
- Modulation
+ Modulation,
+ RMSNorm
)
@dataclass
@@ -34,11 +35,14 @@ class FluxParams:
patch_size: int
qkv_bias: bool
guidance_embed: bool
+ txt_ids_dims: list
global_modulation: bool = False
mlp_silu_act: bool = False
ops_bias: bool = True
default_ref_method: str = "offset"
ref_index_scale: float = 1.0
+ yak_mlp: bool = False
+ txt_norm: bool = False
class Flux(nn.Module):
@@ -76,6 +80,11 @@ class Flux(nn.Module):
)
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
+ if params.txt_norm:
+ self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
+ else:
+ self.txt_norm = None
+
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
@@ -86,6 +95,7 @@ class Flux(nn.Module):
modulation=params.global_modulation is False,
mlp_silu_act=params.mlp_silu_act,
proj_bias=params.ops_bias,
+ yak_mlp=params.yak_mlp,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -94,7 +104,7 @@ class Flux(nn.Module):
self.single_blocks = nn.ModuleList(
[
- SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
@@ -150,6 +160,8 @@ class Flux(nn.Module):
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
+ if self.txt_norm is not None:
+ txt = self.txt_norm(txt)
txt = self.txt_in(txt)
vec_orig = vec
@@ -332,8 +344,9 @@ class Flux(nn.Module):
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
- if len(self.params.axes_dim) == 4: # Flux 2
- txt_ids[:, :, 3] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
+ if len(self.params.txt_ids_dims) > 0:
+ for i in self.params.txt_ids_dims:
+ txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens]
diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py
index 9f5e91a59..85f515f67 100644
--- a/comfy/ldm/hunyuan_video/upsampler.py
+++ b/comfy/ldm/hunyuan_video/upsampler.py
@@ -1,7 +1,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d
+from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
+from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
import model_management, model_patcher
class SRResidualCausalBlock3D(nn.Module):
diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py
index 9f750dcc4..ddf77cd0e 100644
--- a/comfy/ldm/hunyuan_video/vae_refiner.py
+++ b/comfy/ldm/hunyuan_video/vae_refiner.py
@@ -1,42 +1,12 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
-from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
+from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
import comfy.ops
import comfy.ldm.models.autoencoder
import comfy.model_management
ops = comfy.ops.disable_weight_init
-class NoPadConv3d(nn.Module):
- def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
- super().__init__()
- self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
-
- def forward(self, x):
- return self.conv(x)
-
-
-def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
-
- x = xl[0]
- xl.clear()
-
- if conv_carry_out is not None:
- to_push = x[:, :, -2:, :, :].clone()
- conv_carry_out.append(to_push)
-
- if isinstance(op, NoPadConv3d):
- if conv_carry_in is None:
- x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
- else:
- carry_len = conv_carry_in[0].shape[2]
- x = torch.cat([conv_carry_in.pop(0), x], dim=2)
- x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
-
- out = op(x)
-
- return out
-
class RMS_norm(nn.Module):
def __init__(self, dim):
@@ -49,7 +19,7 @@ class RMS_norm(nn.Module):
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
class DnSmpl(nn.Module):
- def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
+ def __init__(self, ic, oc, tds, refiner_vae, op):
super().__init__()
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
assert oc % fct == 0
@@ -109,7 +79,7 @@ class DnSmpl(nn.Module):
class UpSmpl(nn.Module):
- def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
+ def __init__(self, ic, oc, tus, refiner_vae, op):
super().__init__()
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
@@ -163,23 +133,6 @@ class UpSmpl(nn.Module):
return h + x
-class HunyuanRefinerResnetBlock(ResnetBlock):
- def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm):
- super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
-
- def forward(self, x, conv_carry_in=None, conv_carry_out=None):
- h = x
- h = [ self.swish(self.norm1(x)) ]
- h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
-
- h = [ self.dropout(self.swish(self.norm2(h))) ]
- h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
-
- if self.in_channels != self.out_channels:
- x = self.nin_shortcut(x)
-
- return x+h
-
class Encoder(nn.Module):
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_):
@@ -191,7 +144,7 @@ class Encoder(nn.Module):
self.refiner_vae = refiner_vae
if self.refiner_vae:
- conv_op = NoPadConv3d
+ conv_op = CarriedConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
@@ -206,9 +159,10 @@ class Encoder(nn.Module):
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
- stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
- out_channels=tgt,
- conv_op=conv_op, norm_op=norm_op)
+ stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
+ out_channels=tgt,
+ temb_channels=0,
+ conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks)])
ch = tgt
if i < depth:
@@ -218,9 +172,9 @@ class Encoder(nn.Module):
self.down.append(stage)
self.mid = nn.Module()
- self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
+ self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
- self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
+ self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.norm_out = norm_op(ch)
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
@@ -246,22 +200,20 @@ class Encoder(nn.Module):
conv_carry_out = []
if i == len(x) - 1:
conv_carry_out = None
+
x1 = [ x1 ]
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
for stage in self.down:
for blk in stage.block:
- x1 = blk(x1, conv_carry_in, conv_carry_out)
+ x1 = blk(x1, None, conv_carry_in, conv_carry_out)
if hasattr(stage, 'downsample'):
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
out.append(x1)
conv_carry_in = conv_carry_out
- if len(out) > 1:
- out = torch.cat(out, dim=2)
- else:
- out = out[0]
+ out = torch_cat_if_needed(out, dim=2)
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
del out
@@ -288,7 +240,7 @@ class Decoder(nn.Module):
self.refiner_vae = refiner_vae
if self.refiner_vae:
- conv_op = NoPadConv3d
+ conv_op = CarriedConv3d
norm_op = RMS_norm
else:
conv_op = ops.Conv3d
@@ -298,9 +250,9 @@ class Decoder(nn.Module):
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
self.mid = nn.Module()
- self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
+ self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
- self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
+ self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
self.up = nn.ModuleList()
depth = (ffactor_spatial >> 1).bit_length()
@@ -308,9 +260,10 @@ class Decoder(nn.Module):
for i, tgt in enumerate(block_out_channels):
stage = nn.Module()
- stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt,
- out_channels=tgt,
- conv_op=conv_op, norm_op=norm_op)
+ stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
+ out_channels=tgt,
+ temb_channels=0,
+ conv_op=conv_op, norm_op=norm_op)
for j in range(num_res_blocks + 1)])
ch = tgt
if i < depth:
@@ -340,7 +293,7 @@ class Decoder(nn.Module):
conv_carry_out = None
for stage in self.up:
for blk in stage.block:
- x1 = blk(x1, conv_carry_in, conv_carry_out)
+ x1 = blk(x1, None, conv_carry_in, conv_carry_out)
if hasattr(stage, 'upsample'):
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
@@ -350,10 +303,7 @@ class Decoder(nn.Module):
conv_carry_in = conv_carry_out
del x
- if len(out) > 1:
- out = torch.cat(out, dim=2)
- else:
- out = out[0]
+ out = torch_cat_if_needed(out, dim=2)
if not self.refiner_vae:
if z.shape[-3] == 1:
diff --git a/comfy/ldm/lumina/controlnet.py b/comfy/ldm/lumina/controlnet.py
new file mode 100644
index 000000000..fd7ce3b5c
--- /dev/null
+++ b/comfy/ldm/lumina/controlnet.py
@@ -0,0 +1,113 @@
+import torch
+from torch import nn
+
+from .model import JointTransformerBlock
+
+class ZImageControlTransformerBlock(JointTransformerBlock):
+ def __init__(
+ self,
+ layer_id: int,
+ dim: int,
+ n_heads: int,
+ n_kv_heads: int,
+ multiple_of: int,
+ ffn_dim_multiplier: float,
+ norm_eps: float,
+ qk_norm: bool,
+ modulation=True,
+ block_id=0,
+ operation_settings=None,
+ ):
+ super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
+ self.block_id = block_id
+ if block_id == 0:
+ self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, c, x, **kwargs):
+ if self.block_id == 0:
+ c = self.before_proj(c) + x
+ c = super().forward(c, **kwargs)
+ c_skip = self.after_proj(c)
+ return c_skip, c
+
+class ZImage_Control(torch.nn.Module):
+ def __init__(
+ self,
+ dim: int = 3840,
+ n_heads: int = 30,
+ n_kv_heads: int = 30,
+ multiple_of: int = 256,
+ ffn_dim_multiplier: float = (8.0 / 3.0),
+ norm_eps: float = 1e-5,
+ qk_norm: bool = True,
+ dtype=None,
+ device=None,
+ operations=None,
+ **kwargs
+ ):
+ super().__init__()
+ operation_settings = {"operations": operations, "device": device, "dtype": dtype}
+
+ self.additional_in_dim = 0
+ self.control_in_dim = 16
+ n_refiner_layers = 2
+ self.n_control_layers = 6
+ self.control_layers = nn.ModuleList(
+ [
+ ZImageControlTransformerBlock(
+ i,
+ dim,
+ n_heads,
+ n_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ qk_norm,
+ block_id=i,
+ operation_settings=operation_settings,
+ )
+ for i in range(self.n_control_layers)
+ ]
+ )
+
+ all_x_embedder = {}
+ patch_size = 2
+ f_patch_size = 1
+ x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
+ all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
+
+ self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
+ self.control_noise_refiner = nn.ModuleList(
+ [
+ JointTransformerBlock(
+ layer_id,
+ dim,
+ n_heads,
+ n_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ qk_norm,
+ modulation=True,
+ z_image_modulation=True,
+ operation_settings=operation_settings,
+ )
+ for layer_id in range(n_refiner_layers)
+ ]
+ )
+
+ def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
+ patch_size = 2
+ f_patch_size = 1
+ pH = pW = patch_size
+ B, C, H, W = control_context.shape
+ control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
+
+ x_attn_mask = None
+ for layer in self.control_noise_refiner:
+ control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
+ return control_context
+
+ def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
+ return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py
index 7d7e9112c..f1c1a0ec3 100644
--- a/comfy/ldm/lumina/model.py
+++ b/comfy/ldm/lumina/model.py
@@ -22,6 +22,10 @@ def modulate(x, scale):
# Core NextDiT Model #
#############################################################################
+def clamp_fp16(x):
+ if x.dtype == torch.float16:
+ return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
+ return x
class JointAttention(nn.Module):
"""Multi-head attention module."""
@@ -169,7 +173,7 @@ class FeedForward(nn.Module):
# @torch.compile
def _forward_silu_gating(self, x1, x3):
- return F.silu(x1) * x3
+ return clamp_fp16(F.silu(x1) * x3)
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
@@ -273,27 +277,27 @@ class JointTransformerBlock(nn.Module):
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
- self.attention(
+ clamp_fp16(self.attention(
modulate(self.attention_norm1(x), scale_msa),
x_mask,
freqs_cis,
transformer_options=transformer_options,
- )
+ ))
)
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
- self.feed_forward(
+ clamp_fp16(self.feed_forward(
modulate(self.ffn_norm1(x), scale_mlp),
- )
+ ))
)
else:
assert adaln_input is None
x = x + self.attention_norm2(
- self.attention(
+ clamp_fp16(self.attention(
self.attention_norm1(x),
x_mask,
freqs_cis,
transformer_options=transformer_options,
- )
+ ))
)
x = x + self.ffn_norm2(
self.feed_forward(
@@ -564,7 +568,7 @@ class NextDiT(nn.Module):
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
# def forward(self, x, t, cap_feats, cap_mask):
- def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
+ def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
@@ -581,16 +585,24 @@ class NextDiT(nn.Module):
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
+ patches = transformer_options.get("patches", {})
transformer_options = kwargs.get("transformer_options", {})
x_is_tensor = isinstance(x, torch.Tensor)
- x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
- freqs_cis = freqs_cis.to(x.device)
+ img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
+ freqs_cis = freqs_cis.to(img.device)
- for layer in self.layers:
- x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
+ for i, layer in enumerate(self.layers):
+ img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
+ if "double_block" in patches:
+ for p in patches["double_block"]:
+ out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
+ if "img" in out:
+ img[:, cap_size[0]:] = out["img"]
+ if "txt" in out:
+ img[:, :cap_size[0]] = out["txt"]
- x = self.final_layer(x, adaln_input)
- x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
+ img = self.final_layer(img, adaln_input)
+ img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
- return -x
+ return -img
diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py
index d51e49da2..d23a753f9 100644
--- a/comfy/ldm/modules/attention.py
+++ b/comfy/ldm/modules/attention.py
@@ -529,6 +529,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
@wrap_attn
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
+ exception_fallback = False
if skip_reshape:
b, _, _, dim_head = q.shape
tensor_layout = "HND"
@@ -553,6 +554,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
+ exception_fallback = True
+ if exception_fallback:
if tensor_layout == "NHD":
q, k, v = map(
lambda t: t.transpose(1, 2),
diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py
index 4245eedca..681a55db5 100644
--- a/comfy/ldm/modules/diffusionmodules/model.py
+++ b/comfy/ldm/modules/diffusionmodules/model.py
@@ -13,6 +13,12 @@ if model_management.xformers_enabled_vae():
import xformers
import xformers.ops
+def torch_cat_if_needed(xl, dim):
+ if len(xl) > 1:
+ return torch.cat(xl, dim)
+ else:
+ return xl[0]
+
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
@@ -43,6 +49,37 @@ def Normalize(in_channels, num_groups=32):
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+class CarriedConv3d(nn.Module):
+ def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
+ super().__init__()
+ self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
+
+ x = xl[0]
+ xl.clear()
+
+ if isinstance(op, CarriedConv3d):
+ if conv_carry_in is None:
+ x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
+ else:
+ carry_len = conv_carry_in[0].shape[2]
+ x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
+ x = torch.cat([conv_carry_in.pop(0), x], dim=2)
+
+ if conv_carry_out is not None:
+ to_push = x[:, :, -2:, :, :].clone()
+ conv_carry_out.append(to_push)
+
+ out = op(x)
+
+ return out
+
+
class VideoConv3d(nn.Module):
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
super().__init__()
@@ -89,29 +126,24 @@ class Upsample(nn.Module):
stride=1,
padding=1)
- def forward(self, x):
+ def forward(self, x, conv_carry_in=None, conv_carry_out=None):
scale_factor = self.scale_factor
if isinstance(scale_factor, (int, float)):
scale_factor = (scale_factor,) * (x.ndim - 2)
if x.ndim == 5 and scale_factor[0] > 1.0:
- t = x.shape[2]
- if t > 1:
- a, b = x.split((1, t - 1), dim=2)
- del x
- b = interpolate_up(b, scale_factor)
- else:
- a = x
-
- a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
- if t > 1:
- x = torch.cat((a, b), dim=2)
- else:
- x = a
+ results = []
+ if conv_carry_in is None:
+ first = x[:, :, :1, :, :]
+ results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2))
+ x = x[:, :, 1:, :, :]
+ if x.shape[2] > 0:
+ results.append(interpolate_up(x, scale_factor))
+ x = torch_cat_if_needed(results, dim=2)
else:
x = interpolate_up(x, scale_factor)
if self.with_conv:
- x = self.conv(x)
+ x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
return x
@@ -127,17 +159,20 @@ class Downsample(nn.Module):
stride=stride,
padding=0)
- def forward(self, x):
+ def forward(self, x, conv_carry_in=None, conv_carry_out=None):
if self.with_conv:
- if x.ndim == 4:
+ if isinstance(self.conv, CarriedConv3d):
+ x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
+ elif x.ndim == 4:
pad = (0, 1, 0, 1)
mode = "constant"
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
+ x = self.conv(x)
elif x.ndim == 5:
pad = (1, 1, 1, 1, 2, 0)
mode = "replicate"
x = torch.nn.functional.pad(x, pad, mode=mode)
- x = self.conv(x)
+ x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
@@ -183,23 +218,23 @@ class ResnetBlock(nn.Module):
stride=1,
padding=0)
- def forward(self, x, temb=None):
+ def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None):
h = x
h = self.norm1(h)
- h = self.swish(h)
- h = self.conv1(h)
+ h = [ self.swish(h) ]
+ h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if temb is not None:
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
h = self.norm2(h)
h = self.swish(h)
- h = self.dropout(h)
- h = self.conv2(h)
+ h = [ self.dropout(h) ]
+ h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
- x = self.conv_shortcut(x)
+ x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
else:
x = self.nin_shortcut(x)
@@ -279,6 +314,7 @@ def pytorch_attention(q, k, v):
orig_shape = q.shape
B = orig_shape[0]
C = orig_shape[1]
+ oom_fallback = False
q, k, v = map(
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
(q, k, v),
@@ -289,6 +325,8 @@ def pytorch_attention(q, k, v):
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
+ oom_fallback = True
+ if oom_fallback:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
return out
@@ -517,9 +555,14 @@ class Encoder(nn.Module):
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
+ self.carried = False
if conv3d:
- conv_op = VideoConv3d
+ if not attn_resolutions:
+ conv_op = CarriedConv3d
+ self.carried = True
+ else:
+ conv_op = VideoConv3d
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
@@ -532,6 +575,7 @@ class Encoder(nn.Module):
stride=1,
padding=1)
+ self.time_compress = 1
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.in_ch_mult = in_ch_mult
@@ -558,10 +602,15 @@ class Encoder(nn.Module):
if time_compress is not None:
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
stride = (1, 2, 2)
+ else:
+ self.time_compress *= 2
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
curr_res = curr_res // 2
self.down.append(down)
+ if time_compress is not None:
+ self.time_compress = time_compress
+
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
@@ -587,15 +636,42 @@ class Encoder(nn.Module):
def forward(self, x):
# timestep embedding
temb = None
- # downsampling
- h = self.conv_in(x)
- for i_level in range(self.num_resolutions):
- for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](h, temb)
- if len(self.down[i_level].attn) > 0:
- h = self.down[i_level].attn[i_block](h)
- if i_level != self.num_resolutions-1:
- h = self.down[i_level].downsample(h)
+
+ if self.carried:
+ xl = [x[:, :, :1, :, :]]
+ if x.shape[2] > self.time_compress:
+ tc = self.time_compress
+ xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2)
+ x = xl
+ else:
+ x = [x]
+ out = []
+
+ conv_carry_in = None
+
+ for i, x1 in enumerate(x):
+ conv_carry_out = []
+ if i == len(x) - 1:
+ conv_carry_out = None
+
+ # downsampling
+ x1 = [ x1 ]
+ h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
+
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out)
+ if len(self.down[i_level].attn) > 0:
+ assert i == 0 #carried should not happen if attn exists
+ h1 = self.down[i_level].attn[i_block](h1)
+ if i_level != self.num_resolutions-1:
+ h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out)
+
+ out.append(h1)
+ conv_carry_in = conv_carry_out
+
+ h = torch_cat_if_needed(out, dim=2)
+ del out
# middle
h = self.mid.block_1(h, temb)
@@ -604,15 +680,15 @@ class Encoder(nn.Module):
# end
h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h)
+ h = [ nonlinearity(h) ]
+ h = conv_carry_causal_3d(h, self.conv_out)
return h
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
- resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
+ resolution, z_channels, tanh_out=False, use_linear_attn=False,
conv_out_op=ops.Conv2d,
resnet_op=ResnetBlock,
attn_op=AttnBlock,
@@ -626,12 +702,18 @@ class Decoder(nn.Module):
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
- self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
+ self.carried = False
if conv3d:
- conv_op = VideoConv3d
- conv_out_op = VideoConv3d
+ if not attn_resolutions and resnet_op == ResnetBlock:
+ conv_op = CarriedConv3d
+ conv_out_op = CarriedConv3d
+ self.carried = True
+ else:
+ conv_op = VideoConv3d
+ conv_out_op = VideoConv3d
+
mid_attn_conv_op = ops.Conv3d
else:
conv_op = ops.Conv2d
@@ -706,29 +788,43 @@ class Decoder(nn.Module):
temb = None
# z to block_in
- h = self.conv_in(z)
+ h = conv_carry_causal_3d([z], self.conv_in)
# middle
h = self.mid.block_1(h, temb, **kwargs)
h = self.mid.attn_1(h, **kwargs)
h = self.mid.block_2(h, temb, **kwargs)
+ if self.carried:
+ h = torch.split(h, 2, dim=2)
+ else:
+ h = [ h ]
+ out = []
+
+ conv_carry_in = None
+
# upsampling
- for i_level in reversed(range(self.num_resolutions)):
- for i_block in range(self.num_res_blocks+1):
- h = self.up[i_level].block[i_block](h, temb, **kwargs)
- if len(self.up[i_level].attn) > 0:
- h = self.up[i_level].attn[i_block](h, **kwargs)
- if i_level != 0:
- h = self.up[i_level].upsample(h)
+ for i, h1 in enumerate(h):
+ conv_carry_out = []
+ if i == len(h) - 1:
+ conv_carry_out = None
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks+1):
+ h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs)
+ if len(self.up[i_level].attn) > 0:
+ assert i == 0 #carried should not happen if attn exists
+ h1 = self.up[i_level].attn[i_block](h1, **kwargs)
+ if i_level != 0:
+ h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out)
- # end
- if self.give_pre_end:
- return h
+ h1 = self.norm_out(h1)
+ h1 = [ nonlinearity(h1) ]
+ h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out)
+ if self.tanh_out:
+ h1 = torch.tanh(h1)
+ out.append(h1)
+ conv_carry_in = conv_carry_out
- h = self.norm_out(h)
- h = nonlinearity(h)
- h = self.conv_out(h, **kwargs)
- if self.tanh_out:
- h = torch.tanh(h)
- return h
+ out = torch_cat_if_needed(out, dim=2)
+
+ return out
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 7afe4a798..7d0517e61 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -208,12 +208,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["theta"] = 2000
dit_config["out_channels"] = 128
dit_config["global_modulation"] = True
- dit_config["vec_in_dim"] = None
dit_config["mlp_silu_act"] = True
dit_config["qkv_bias"] = False
dit_config["ops_bias"] = False
dit_config["default_ref_method"] = "index"
dit_config["ref_index_scale"] = 10.0
+ dit_config["txt_ids_dims"] = [3]
patch_size = 1
else:
dit_config["image_model"] = "flux"
@@ -223,6 +223,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["theta"] = 10000
dit_config["out_channels"] = 16
dit_config["qkv_bias"] = True
+ dit_config["txt_ids_dims"] = []
patch_size = 2
dit_config["in_channels"] = 16
@@ -245,6 +246,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
if vec_in_key in state_dict_keys:
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
+ else:
+ dit_config["vec_in_dim"] = None
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
@@ -270,6 +273,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_embedder_dtype"] = torch.float32
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
+ dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
+ dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
+ if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
+ dit_config["txt_ids_dims"] = [1, 2]
+
return dit_config
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index 3eac77275..3dcac3eef 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -699,12 +699,12 @@ class ModelPatcher:
offloaded = []
offload_buffer = 0
loading.sort(reverse=True)
- for x in loading:
+ for i, x in enumerate(loading):
module_offload_mem, module_mem, n, m, params = x
lowvram_weight = False
- potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
+ potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]]))
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
weight_key = "{}.weight".format(n)
@@ -876,14 +876,18 @@ class ModelPatcher:
patch_counter = 0
unload_list = self._load_list()
unload_list.sort()
+
offload_buffer = self.model.model_offload_buffer_memory
+ if len(unload_list) > 0:
+ NS = comfy.model_management.NUM_STREAMS
+ offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
for unload in unload_list:
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
break
module_offload_mem, module_mem, n, m, params = unload
- potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem
+ potential_offload = module_offload_mem + sum(offload_weight_factor)
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
@@ -935,6 +939,8 @@ class ModelPatcher:
m.comfy_patched_weights = False
memory_freed += module_mem
offload_buffer = max(offload_buffer, potential_offload)
+ offload_weight_factor.append(module_mem)
+ offload_weight_factor.pop(0)
logging.debug("freed {}".format(n))
for param in params:
diff --git a/comfy/ops.py b/comfy/ops.py
index 61a2f0754..eae434e68 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -111,22 +111,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if s.bias is not None:
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
- if bias_has_function:
- with wf_context:
- for f in s.bias_function:
- bias = f(bias)
+ comfy.model_management.sync_stream(device, offload_stream)
+
+ bias_a = bias
+ weight_a = weight
+
+ if s.bias is not None:
+ for f in s.bias_function:
+ bias = f(bias)
if weight_has_function or weight.dtype != dtype:
- with wf_context:
- weight = weight.to(dtype=dtype)
- if isinstance(weight, QuantizedTensor):
- weight = weight.dequantize()
- for f in s.weight_function:
- weight = f(weight)
+ weight = weight.to(dtype=dtype)
+ if isinstance(weight, QuantizedTensor):
+ weight = weight.dequantize()
+ for f in s.weight_function:
+ weight = f(weight)
- comfy.model_management.sync_stream(device, offload_stream)
if offloadable:
- return weight, bias, offload_stream
+ return weight, bias, (offload_stream, weight_a, bias_a)
else:
#Legacy function signature
return weight, bias
@@ -135,13 +137,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
def uncast_bias_weight(s, weight, bias, offload_stream):
if offload_stream is None:
return
- if weight is not None:
- device = weight.device
+ os, weight_a, bias_a = offload_stream
+ if os is None:
+ return
+ if weight_a is not None:
+ device = weight_a.device
else:
- if bias is None:
+ if bias_a is None:
return
- device = bias.device
- offload_stream.wait_stream(comfy.model_management.current_stream(device))
+ device = bias_a.device
+ os.wait_stream(comfy.model_management.current_stream(device))
class CastWeightBiasOp:
diff --git a/comfy/sd.py b/comfy/sd.py
index 9eeb0c45a..03bdb33d5 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -53,6 +53,7 @@ import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
+import comfy.text_encoders.ovis
import comfy.model_patcher
import comfy.lora
@@ -192,6 +193,7 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
+ self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
all_hooks.reset()
self.patcher.patch_hooks(None)
if show_pbar:
@@ -239,6 +241,7 @@ class CLIP:
self.cond_stage_model.set_clip_options({"projected_pooled": False})
self.load_model()
+ self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
o = self.cond_stage_model.encode_token_weights(tokens)
cond, pooled = o[:2]
if return_dict:
@@ -468,7 +471,7 @@ class VAE:
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
- self.memory_used_decode = lambda shape, dtype: (2800 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
+ self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
@@ -480,8 +483,10 @@ class VAE:
self.latent_dim = 3
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
- self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
- self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
+ #This is likely to significantly over-estimate with single image or low frame counts as the
+ #implementation is able to completely skip caching. Rework if used as an image only VAE
+ self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
+ self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
elif "decoder.unpatcher3d.wavelets" in sd:
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
@@ -956,6 +961,7 @@ class CLIPType(Enum):
QWEN_IMAGE = 18
HUNYUAN_IMAGE = 19
HUNYUAN_VIDEO_15 = 20
+ OVIS = 21
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -987,6 +993,7 @@ class TEModel(Enum):
MISTRAL3_24B = 14
MISTRAL3_24B_PRUNED_FLUX2 = 15
QWEN3_4B = 16
+ QWEN3_2B = 17
def detect_te_model(sd):
@@ -1020,9 +1027,12 @@ def detect_te_model(sd):
if weight.shape[0] == 512:
return TEModel.QWEN25_7B
if "model.layers.0.post_attention_layernorm.weight" in sd:
- if 'model.layers.0.self_attn.q_norm.weight' in sd:
- return TEModel.QWEN3_4B
weight = sd['model.layers.0.post_attention_layernorm.weight']
+ if 'model.layers.0.self_attn.q_norm.weight' in sd:
+ if weight.shape[0] == 2560:
+ return TEModel.QWEN3_4B
+ elif weight.shape[0] == 2048:
+ return TEModel.QWEN3_2B
if weight.shape[0] == 5120:
if "model.layers.39.post_attention_layernorm.weight" in sd:
return TEModel.MISTRAL3_24B
@@ -1150,6 +1160,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif te_model == TEModel.QWEN3_4B:
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
+ elif te_model == TEModel.QWEN3_2B:
+ clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
else:
# clip_l
if clip_type == CLIPType.SD3:
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index 0fc9ab3db..503a51843 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -147,6 +147,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer_norm_hidden_state = layer_norm_hidden_state
self.return_projected_pooled = return_projected_pooled
self.return_attention_masks = return_attention_masks
+ self.execution_device = None
if layer == "hidden":
assert layer_idx is not None
@@ -163,6 +164,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
+ self.execution_device = options.get("execution_device", self.execution_device)
if isinstance(self.layer, list) or self.layer == "all":
pass
elif layer_idx is None or abs(layer_idx) > self.num_layers:
@@ -175,6 +177,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer = self.options_default[0]
self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2]
+ self.execution_device = None
def process_tokens(self, tokens, device):
end_token = self.special_tokens.get("end", None)
@@ -258,7 +261,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
def forward(self, tokens):
- device = self.transformer.get_input_embeddings().weight.device
+ if self.execution_device is None:
+ device = self.transformer.get_input_embeddings().weight.device
+ else:
+ device = self.execution_device
+
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
attention_mask_model = None
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index af8120400..afd97160b 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -1027,6 +1027,8 @@ class ZImage(Lumina2):
memory_usage_factor = 1.7
+ supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
+
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py
index cd4b5f76c..0d07ac8c6 100644
--- a/comfy/text_encoders/llama.py
+++ b/comfy/text_encoders/llama.py
@@ -100,6 +100,28 @@ class Qwen3_4BConfig:
rope_scale = None
final_norm: bool = True
+@dataclass
+class Ovis25_2BConfig:
+ vocab_size: int = 151936
+ hidden_size: int = 2048
+ intermediate_size: int = 6144
+ num_hidden_layers: int = 28
+ num_attention_heads: int = 16
+ num_key_value_heads: int = 8
+ max_position_embeddings: int = 40960
+ rms_norm_eps: float = 1e-6
+ rope_theta: float = 1000000.0
+ transformer_type: str = "llama"
+ head_dim = 128
+ rms_norm_add = False
+ mlp_activation = "silu"
+ qkv_bias = False
+ rope_dims = None
+ q_norm = "gemma3"
+ k_norm = "gemma3"
+ rope_scale = None
+ final_norm: bool = True
+
@dataclass
class Qwen25_7BVLI_Config:
vocab_size: int = 152064
@@ -542,6 +564,15 @@ class Qwen3_4B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
+class Ovis25_2B(BaseLlama, torch.nn.Module):
+ def __init__(self, config_dict, dtype, device, operations):
+ super().__init__()
+ config = Ovis25_2BConfig(**config_dict)
+ self.num_layers = config.num_hidden_layers
+
+ self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
+ self.dtype = dtype
+
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py
new file mode 100644
index 000000000..81c9bd51c
--- /dev/null
+++ b/comfy/text_encoders/ovis.py
@@ -0,0 +1,69 @@
+from transformers import Qwen2Tokenizer
+import comfy.text_encoders.llama
+from comfy import sd1_clip
+import os
+import torch
+import numbers
+
+class Qwen3Tokenizer(sd1_clip.SDTokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
+ super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen3_2b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=284, pad_token=151643, tokenizer_data=tokenizer_data)
+
+
+class OvisTokenizer(sd1_clip.SD1Tokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_2b", tokenizer=Qwen3Tokenizer)
+ self.llama_template = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: {}<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"
+
+ def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
+ if llama_template is None:
+ llama_text = self.llama_template.format(text)
+ else:
+ llama_text = llama_template.format(text)
+
+ tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
+ return tokens
+
+class Ovis25_2BModel(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options)
+
+
+class OvisTEModel(sd1_clip.SD1ClipModel):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ super().__init__(device=device, dtype=dtype, name="qwen3_2b", clip_model=Ovis25_2BModel, model_options=model_options)
+
+ def encode_token_weights(self, token_weight_pairs, template_end=-1):
+ out, pooled = super().encode_token_weights(token_weight_pairs)
+ tok_pairs = token_weight_pairs["qwen3_2b"][0]
+ count_im_start = 0
+ if template_end == -1:
+ for i, v in enumerate(tok_pairs):
+ elem = v[0]
+ if not torch.is_tensor(elem):
+ if isinstance(elem, numbers.Integral):
+ if elem == 4004 and count_im_start < 1:
+ template_end = i
+ count_im_start += 1
+
+ if out.shape[1] > (template_end + 1):
+ if tok_pairs[template_end + 1][0] == 25:
+ template_end += 1
+
+ out = out[:, template_end:]
+ return out, pooled, {}
+
+
+def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None):
+ class OvisTEModel_(OvisTEModel):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
+ model_options = model_options.copy()
+ model_options["scaled_fp8"] = llama_scaled_fp8
+ if dtype_llama is not None:
+ dtype = dtype_llama
+ if llama_quantization_metadata is not None:
+ model_options["quantization_metadata"] = llama_quantization_metadata
+ super().__init__(device=device, dtype=dtype, model_options=model_options)
+ return OvisTEModel_
diff --git a/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json b/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json
index 67688e82c..df5b5d7fe 100644
--- a/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json
+++ b/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json
@@ -179,36 +179,36 @@
"special": false
},
"151665": {
- "content": "<|img|>",
+ "content": "",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
- "special": true
+ "special": false
},
"151666": {
- "content": "<|endofimg|>",
+ "content": "",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
- "special": true
+ "special": false
},
"151667": {
- "content": "<|meta|>",
+ "content": "",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
- "special": true
+ "special": false
},
"151668": {
- "content": "<|endofmeta|>",
+ "content": "",
"lstrip": false,
"normalized": false,
"rstrip": false,
"single_word": false,
- "special": true
+ "special": false
}
},
"additional_special_tokens": [
diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py
index 0d4389a6e..bfb77eb5f 100644
--- a/comfy_api/feature_flags.py
+++ b/comfy_api/feature_flags.py
@@ -13,6 +13,7 @@ from comfy.cli_args import args
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
+ "extension": {"manager": {"supports_v4": True}},
}
diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py
index 176ae36e0..0fa01d1e7 100644
--- a/comfy_api/latest/__init__.py
+++ b/comfy_api/latest/__init__.py
@@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
-from . import _io as io
-from . import _ui as ui
+from . import _io_public as io
+from . import _ui_public as ui
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
from comfy_execution.utils import get_executing_context
from comfy_execution.progress import get_progress_state, PreviewImageTuple
diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py
index bde37f90a..a4cd3737d 100644
--- a/comfy_api/latest/_input_impl/video_types.py
+++ b/comfy_api/latest/_input_impl/video_types.py
@@ -336,7 +336,10 @@ class VideoFromComponents(VideoInput):
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
raise ValueError("Only H264 codec is supported for now")
- with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
+ extra_kwargs = {}
+ if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
+ extra_kwargs["format"] = format.value
+ with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
# Add metadata before writing any streams
if metadata is not None:
for key, value in metadata.items():
diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py
index 79c0722a9..866c3e0eb 100644
--- a/comfy_api/latest/_io.py
+++ b/comfy_api/latest/_io.py
@@ -4,7 +4,8 @@ import copy
import inspect
from abc import ABC, abstractmethod
from collections import Counter
-from dataclasses import asdict, dataclass
+from collections.abc import Iterable
+from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING
from typing_extensions import NotRequired, final
@@ -150,6 +151,9 @@ class _IO_V3:
def __init__(self):
pass
+ def validate(self):
+ pass
+
@property
def io_type(self):
return self.Parent.io_type
@@ -182,6 +186,9 @@ class Input(_IO_V3):
def get_io_type(self):
return _StringIOType(self.io_type)
+ def get_all(self) -> list[Input]:
+ return [self]
+
class WidgetInput(Input):
'''
Base class for a V3 Input with widget.
@@ -814,13 +821,61 @@ class MultiType:
else:
return super().as_dict()
+@comfytype(io_type="COMFY_MATCHTYPE_V3")
+class MatchType(ComfyTypeIO):
+ class Template:
+ def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType):
+ self.template_id = template_id
+ # account for syntactic sugar
+ if not isinstance(allowed_types, Iterable):
+ allowed_types = [allowed_types]
+ for t in allowed_types:
+ if not isinstance(t, type):
+ if not isinstance(t, _ComfyType):
+ raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__class__.__name__}")
+ else:
+ if not issubclass(t, _ComfyType):
+ raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__name__}")
+ self.allowed_types = allowed_types
+
+ def as_dict(self):
+ return {
+ "template_id": self.template_id,
+ "allowed_types": ",".join([t.io_type for t in self.allowed_types]),
+ }
+
+ class Input(Input):
+ def __init__(self, id: str, template: MatchType.Template,
+ display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
+ super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
+ self.template = template
+
+ def as_dict(self):
+ return super().as_dict() | prune_dict({
+ "template": self.template.as_dict(),
+ })
+
+ class Output(Output):
+ def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
+ is_output_list=False):
+ super().__init__(id, display_name, tooltip, is_output_list)
+ self.template = template
+
+ def as_dict(self):
+ return super().as_dict() | prune_dict({
+ "template": self.template.as_dict(),
+ })
+
class DynamicInput(Input, ABC):
'''
Abstract class for dynamic input registration.
'''
- @abstractmethod
def get_dynamic(self) -> list[Input]:
- ...
+ return []
+
+ def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
+ pass
+
class DynamicOutput(Output, ABC):
'''
@@ -830,99 +885,223 @@ class DynamicOutput(Output, ABC):
is_output_list=False):
super().__init__(id, display_name, tooltip, is_output_list)
- @abstractmethod
def get_dynamic(self) -> list[Output]:
- ...
+ return []
@comfytype(io_type="COMFY_AUTOGROW_V3")
-class AutogrowDynamic(ComfyTypeI):
- Type = list[Any]
- class Input(DynamicInput):
- def __init__(self, id: str, template_input: Input, min: int=1, max: int=None,
- display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
- super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
- self.template_input = template_input
- if min is not None:
- assert(min >= 1)
- if max is not None:
- assert(max >= 1)
+class Autogrow(ComfyTypeI):
+ Type = dict[str, Any]
+ _MaxNames = 100 # NOTE: max 100 names for sanity
+
+ class _AutogrowTemplate:
+ def __init__(self, input: Input):
+ # dynamic inputs are not allowed as the template input
+ assert(not isinstance(input, DynamicInput))
+ self.input = copy.copy(input)
+ if isinstance(self.input, WidgetInput):
+ self.input.force_input = True
+ self.names: list[str] = []
+ self.cached_inputs = {}
+
+ def _create_input(self, input: Input, name: str):
+ new_input = copy.copy(self.input)
+ new_input.id = name
+ return new_input
+
+ def _create_cached_inputs(self):
+ for name in self.names:
+ self.cached_inputs[name] = self._create_input(self.input, name)
+
+ def get_all(self) -> list[Input]:
+ return list(self.cached_inputs.values())
+
+ def as_dict(self):
+ return prune_dict({
+ "input": create_input_dict_v1([self.input]),
+ })
+
+ def validate(self):
+ self.input.validate()
+
+ def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
+ real_inputs = []
+ for name, input in self.cached_inputs.items():
+ if name in live_inputs:
+ real_inputs.append(input)
+ add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix)
+ add_dynamic_id_mapping(d, real_inputs, curr_prefix)
+
+ class TemplatePrefix(_AutogrowTemplate):
+ def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
+ super().__init__(input)
+ self.prefix = prefix
+ assert(min >= 0)
+ assert(max >= 1)
+ assert(max <= Autogrow._MaxNames)
self.min = min
self.max = max
+ self.names = [f"{self.prefix}{i}" for i in range(self.max)]
+ self._create_cached_inputs()
+
+ def as_dict(self):
+ return super().as_dict() | prune_dict({
+ "prefix": self.prefix,
+ "min": self.min,
+ "max": self.max,
+ })
+
+ class TemplateNames(_AutogrowTemplate):
+ def __init__(self, input: Input, names: list[str], min: int=1):
+ super().__init__(input)
+ self.names = names[:Autogrow._MaxNames]
+ assert(min >= 0)
+ self.min = min
+ self._create_cached_inputs()
+
+ def as_dict(self):
+ return super().as_dict() | prune_dict({
+ "names": self.names,
+ "min": self.min,
+ })
+
+ class Input(DynamicInput):
+ def __init__(self, id: str, template: Autogrow.TemplatePrefix | Autogrow.TemplateNames,
+ display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
+ super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
+ self.template = template
+
+ def as_dict(self):
+ return super().as_dict() | prune_dict({
+ "template": self.template.as_dict(),
+ })
def get_dynamic(self) -> list[Input]:
- curr_count = 1
- new_inputs = []
- for i in range(self.min):
- new_input = copy.copy(self.template_input)
- new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
- if new_input.display_name is not None:
- new_input.display_name = f"{new_input.display_name}{curr_count}"
- new_input.optional = self.optional or new_input.optional
- if isinstance(self.template_input, WidgetInput):
- new_input.force_input = True
- new_inputs.append(new_input)
- curr_count += 1
- # pretend to expand up to max
- for i in range(curr_count-1, self.max):
- new_input = copy.copy(self.template_input)
- new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
- if new_input.display_name is not None:
- new_input.display_name = f"{new_input.display_name}{curr_count}"
- new_input.optional = True
- if isinstance(self.template_input, WidgetInput):
- new_input.force_input = True
- new_inputs.append(new_input)
- curr_count += 1
- return new_inputs
+ return self.template.get_all()
-@comfytype(io_type="COMFY_COMBODYNAMIC_V3")
-class ComboDynamic(ComfyTypeI):
- class Input(DynamicInput):
- def __init__(self, id: str):
- pass
+ def get_all(self) -> list[Input]:
+ return [self] + self.template.get_all()
-@comfytype(io_type="COMFY_MATCHTYPE_V3")
-class MatchType(ComfyTypeIO):
- class Template:
- def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]):
- self.template_id = template_id
- self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types
+ def validate(self):
+ self.template.validate()
+
+ def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
+ curr_prefix = f"{curr_prefix}{self.id}."
+ # need to remove self from expected inputs dictionary; replaced by template inputs in frontend
+ for inner_dict in d.values():
+ if self.id in inner_dict:
+ del inner_dict[self.id]
+ self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
+
+@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
+class DynamicCombo(ComfyTypeI):
+ Type = dict[str, Any]
+
+ class Option:
+ def __init__(self, key: str, inputs: list[Input]):
+ self.key = key
+ self.inputs = inputs
def as_dict(self):
return {
- "template_id": self.template_id,
- "allowed_types": "".join(t.io_type for t in self.allowed_types),
+ "key": self.key,
+ "inputs": create_input_dict_v1(self.inputs),
}
class Input(DynamicInput):
- def __init__(self, id: str, template: MatchType.Template,
+ def __init__(self, id: str, options: list[DynamicCombo.Option],
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
- self.template = template
+ self.options = options
+
+ def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
+ # check if dynamic input's id is in live_inputs
+ if self.id in live_inputs:
+ curr_prefix = f"{curr_prefix}{self.id}."
+ key = live_inputs[self.id]
+ selected_option = None
+ for option in self.options:
+ if option.key == key:
+ selected_option = option
+ break
+ if selected_option is not None:
+ add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix)
+ add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self)
def get_dynamic(self) -> list[Input]:
- return [self]
+ return [input for option in self.options for input in option.inputs]
+
+ def get_all(self) -> list[Input]:
+ return [self] + [input for option in self.options for input in option.inputs]
def as_dict(self):
return super().as_dict() | prune_dict({
- "template": self.template.as_dict(),
+ "options": [o.as_dict() for o in self.options],
})
- class Output(DynamicOutput):
- def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None,
- is_output_list=False):
- super().__init__(id, display_name, tooltip, is_output_list)
- self.template = template
+ def validate(self):
+ # make sure all nested inputs are validated
+ for option in self.options:
+ for input in option.inputs:
+ input.validate()
- def get_dynamic(self) -> list[Output]:
- return [self]
+@comfytype(io_type="COMFY_DYNAMICSLOT_V3")
+class DynamicSlot(ComfyTypeI):
+ Type = dict[str, Any]
+
+ class Input(DynamicInput):
+ def __init__(self, slot: Input, inputs: list[Input],
+ display_name: str=None, tooltip: str=None, lazy: bool=None, extra_dict=None):
+ assert(not isinstance(slot, DynamicInput))
+ self.slot = copy.copy(slot)
+ self.slot.display_name = slot.display_name if slot.display_name is not None else display_name
+ optional = True
+ self.slot.tooltip = slot.tooltip if slot.tooltip is not None else tooltip
+ self.slot.lazy = slot.lazy if slot.lazy is not None else lazy
+ self.slot.extra_dict = slot.extra_dict if slot.extra_dict is not None else extra_dict
+ super().__init__(slot.id, self.slot.display_name, optional, self.slot.tooltip, self.slot.lazy, self.slot.extra_dict)
+ self.inputs = inputs
+ self.force_input = None
+ # force widget inputs to have no widgets, otherwise this would be awkward
+ if isinstance(self.slot, WidgetInput):
+ self.force_input = True
+ self.slot.force_input = True
+
+ def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
+ if self.id in live_inputs:
+ curr_prefix = f"{curr_prefix}{self.id}."
+ add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix)
+ add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix)
+
+ def get_dynamic(self) -> list[Input]:
+ return [self.slot] + self.inputs
+
+ def get_all(self) -> list[Input]:
+ return [self] + [self.slot] + self.inputs
def as_dict(self):
return super().as_dict() | prune_dict({
- "template": self.template.as_dict(),
+ "slotType": str(self.slot.get_io_type()),
+ "inputs": create_input_dict_v1(self.inputs),
+ "forceInput": self.force_input,
})
+ def validate(self):
+ self.slot.validate()
+ for input in self.inputs:
+ input.validate()
+
+def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None):
+ dynamic = d.setdefault("dynamic_paths", {})
+ if self is not None:
+ dynamic[self.id] = f"{curr_prefix}{self.id}"
+ for i in inputs:
+ if not isinstance(i, DynamicInput):
+ dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}"
+
+class V3Data(TypedDict):
+ hidden_inputs: dict[str, Any]
+ dynamic_paths: dict[str, Any]
class HiddenHolder:
def __init__(self, unique_id: str, prompt: Any,
@@ -984,6 +1163,7 @@ class NodeInfoV1:
output_is_list: list[bool]=None
output_name: list[str]=None
output_tooltips: list[str]=None
+ output_matchtypes: list[str]=None
name: str=None
display_name: str=None
description: str=None
@@ -1019,9 +1199,9 @@ class Schema:
"""Display name of node."""
category: str = "sd"
"""The category of the node, as per the "Add Node" menu."""
- inputs: list[Input]=None
- outputs: list[Output]=None
- hidden: list[Hidden]=None
+ inputs: list[Input] = field(default_factory=list)
+ outputs: list[Output] = field(default_factory=list)
+ hidden: list[Hidden] = field(default_factory=list)
description: str=""
"""Node description, shown as a tooltip when hovering over the node."""
is_input_list: bool = False
@@ -1061,7 +1241,11 @@ class Schema:
'''Validate the schema:
- verify ids on inputs and outputs are unique - both internally and in relation to each other
'''
- input_ids = [i.id for i in self.inputs] if self.inputs is not None else []
+ nested_inputs: list[Input] = []
+ if self.inputs is not None:
+ for input in self.inputs:
+ nested_inputs.extend(input.get_all())
+ input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else []
output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
input_set = set(input_ids)
output_set = set(output_ids)
@@ -1077,6 +1261,13 @@ class Schema:
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
if len(issues) > 0:
raise ValueError("\n".join(issues))
+ # validate inputs and outputs
+ if self.inputs is not None:
+ for input in self.inputs:
+ input.validate()
+ if self.outputs is not None:
+ for output in self.outputs:
+ output.validate()
def finalize(self):
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
@@ -1102,19 +1293,10 @@ class Schema:
if output.id is None:
output.id = f"_{i}_{output.io_type}_"
- def get_v1_info(self, cls) -> NodeInfoV1:
+ def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1:
+ # NOTE: live_inputs will not be used anymore very soon and this will be done another way
# get V1 inputs
- input = {
- "required": {}
- }
- if self.inputs:
- for i in self.inputs:
- if isinstance(i, DynamicInput):
- dynamic_inputs = i.get_dynamic()
- for d in dynamic_inputs:
- add_to_dict_v1(d, input)
- else:
- add_to_dict_v1(i, input)
+ input = create_input_dict_v1(self.inputs, live_inputs)
if self.hidden:
for hidden in self.hidden:
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
@@ -1123,12 +1305,24 @@ class Schema:
output_is_list = []
output_name = []
output_tooltips = []
+ output_matchtypes = []
+ any_matchtypes = False
if self.outputs:
for o in self.outputs:
output.append(o.io_type)
output_is_list.append(o.is_output_list)
output_name.append(o.display_name if o.display_name else o.io_type)
output_tooltips.append(o.tooltip if o.tooltip else None)
+ # special handling for MatchType
+ if isinstance(o, MatchType.Output):
+ output_matchtypes.append(o.template.template_id)
+ any_matchtypes = True
+ else:
+ output_matchtypes.append(None)
+
+ # clear out lists that are all None
+ if not any_matchtypes:
+ output_matchtypes = None
info = NodeInfoV1(
input=input,
@@ -1137,6 +1331,7 @@ class Schema:
output_is_list=output_is_list,
output_name=output_name,
output_tooltips=output_tooltips,
+ output_matchtypes=output_matchtypes,
name=self.node_id,
display_name=self.display_name,
category=self.category,
@@ -1182,16 +1377,57 @@ class Schema:
return info
-def add_to_dict_v1(i: Input, input: dict):
+def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict:
+ input = {
+ "required": {}
+ }
+ add_to_input_dict_v1(input, inputs, live_inputs)
+ return input
+
+def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''):
+ for i in inputs:
+ if isinstance(i, DynamicInput):
+ add_to_dict_v1(i, d)
+ if live_inputs is not None:
+ i.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
+ else:
+ add_to_dict_v1(i, d)
+
+def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
key = "optional" if i.optional else "required"
as_dict = i.as_dict()
# for v1, we don't want to include the optional key
as_dict.pop("optional", None)
- input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
+ if dynamic_dict is None:
+ value = (i.get_io_type(), as_dict)
+ else:
+ value = (i.get_io_type(), as_dict, dynamic_dict)
+ d.setdefault(key, {})[i.id] = value
def add_to_dict_v3(io: Input | Output, d: dict):
d[io.id] = (io.get_io_type(), io.as_dict())
+def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
+ paths = v3_data.get("dynamic_paths", None)
+ if paths is None:
+ return values
+ values = values.copy()
+ result = {}
+
+ for key, path in paths.items():
+ parts = path.split(".")
+ current = result
+
+ for i, p in enumerate(parts):
+ is_last = (i == len(parts) - 1)
+
+ if is_last:
+ current[p] = values.pop(key, None)
+ else:
+ current = current.setdefault(p, {})
+
+ values.update(result)
+ return values
class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@@ -1311,12 +1547,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final
@classmethod
- def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]:
+ def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
"""Creates clone of real node class to prevent monkey-patching."""
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
# set hidden
- type_clone.hidden = HiddenHolder.from_dict(hidden_inputs)
+ type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
return type_clone
@final
@@ -1433,14 +1669,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
@final
@classmethod
- def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]:
+ def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]:
schema = cls.FINALIZE_SCHEMA()
- info = schema.get_v1_info(cls)
+ info = schema.get_v1_info(cls, live_inputs)
input = info.input
if not include_hidden:
input.pop("hidden", None)
if return_schema:
- return input, schema
+ v3_data: V3Data = {}
+ dynamic = input.pop("dynamic_paths", None)
+ if dynamic is not None:
+ v3_data["dynamic_paths"] = dynamic
+ return input, schema, v3_data
return input
@final
@@ -1513,7 +1753,7 @@ class ComfyNode(_ComfyNodeBaseInternal):
raise NotImplementedError
@classmethod
- def validate_inputs(cls, **kwargs) -> bool:
+ def validate_inputs(cls, **kwargs) -> bool | str:
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
raise NotImplementedError
@@ -1628,6 +1868,7 @@ __all__ = [
"StyleModel",
"Gligen",
"UpscaleModel",
+ "LatentUpscaleModel",
"Audio",
"Video",
"SVG",
@@ -1651,6 +1892,10 @@ __all__ = [
"SEGS",
"AnyType",
"MultiType",
+ # Dynamic Types
+ "MatchType",
+ # "DynamicCombo",
+ # "Autogrow",
# Other classes
"HiddenHolder",
"Hidden",
@@ -1661,4 +1906,5 @@ __all__ = [
"NodeOutput",
"add_to_dict_v1",
"add_to_dict_v3",
+ "V3Data",
]
diff --git a/comfy_api/latest/_io_public.py b/comfy_api/latest/_io_public.py
new file mode 100644
index 000000000..43c7680f3
--- /dev/null
+++ b/comfy_api/latest/_io_public.py
@@ -0,0 +1 @@
+from ._io import * # noqa: F403
diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py
index b0bbabe2a..5a75a3aae 100644
--- a/comfy_api/latest/_ui.py
+++ b/comfy_api/latest/_ui.py
@@ -3,6 +3,7 @@ from __future__ import annotations
import json
import os
import random
+import uuid
from io import BytesIO
from typing import Type
@@ -318,9 +319,10 @@ class AudioSaveHelper:
for key, value in metadata.items():
output_container.metadata[key] = value
+ layout = "mono" if waveform.shape[0] == 1 else "stereo"
# Set up the output stream with appropriate properties
if format == "opus":
- out_stream = output_container.add_stream("libopus", rate=sample_rate)
+ out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
@@ -332,7 +334,7 @@ class AudioSaveHelper:
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
- out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
+ out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
if quality == "V0":
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
@@ -341,12 +343,12 @@ class AudioSaveHelper:
elif quality == "320k":
out_stream.bit_rate = 320000
else: # format == "flac":
- out_stream = output_container.add_stream("flac", rate=sample_rate)
+ out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
frame = av.AudioFrame.from_ndarray(
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
format="flt",
- layout="mono" if waveform.shape[0] == 1 else "stereo",
+ layout=layout,
)
frame.sample_rate = sample_rate
frame.pts = 0
@@ -436,9 +438,19 @@ class PreviewUI3D(_UIOutput):
def __init__(self, model_file, camera_info, **kwargs):
self.model_file = model_file
self.camera_info = camera_info
+ self.bg_image_path = None
+ bg_image = kwargs.get("bg_image", None)
+ if bg_image is not None:
+ img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
+ img = PILImage.fromarray(img_array)
+ temp_dir = folder_paths.get_temp_directory()
+ filename = f"bg_{uuid.uuid4().hex}.png"
+ bg_image_path = os.path.join(temp_dir, filename)
+ img.save(bg_image_path, compress_level=1)
+ self.bg_image_path = f"temp/{filename}"
def as_dict(self):
- return {"result": [self.model_file, self.camera_info]}
+ return {"result": [self.model_file, self.camera_info, self.bg_image_path]}
class PreviewText(_UIOutput):
diff --git a/comfy_api/latest/_ui_public.py b/comfy_api/latest/_ui_public.py
new file mode 100644
index 000000000..85b11d78b
--- /dev/null
+++ b/comfy_api/latest/_ui_public.py
@@ -0,0 +1 @@
+from ._ui import * # noqa: F403
diff --git a/comfy_api/v0_0_2/__init__.py b/comfy_api/v0_0_2/__init__.py
index de0f95001..c4fa1d971 100644
--- a/comfy_api/v0_0_2/__init__.py
+++ b/comfy_api/v0_0_2/__init__.py
@@ -6,7 +6,7 @@ from comfy_api.latest import (
)
from typing import Type, TYPE_CHECKING
from comfy_api.internal.async_to_sync import create_sync_class
-from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
+from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
@@ -42,4 +42,8 @@ __all__ = [
"InputImpl",
"Types",
"ComfyExtension",
+ "io",
+ "IO",
+ "ui",
+ "UI",
]
diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py
new file mode 100644
index 000000000..d8949f8ac
--- /dev/null
+++ b/comfy_api_nodes/apis/kling_api.py
@@ -0,0 +1,86 @@
+from pydantic import BaseModel, Field
+
+
+class OmniProText2VideoRequest(BaseModel):
+ model_name: str = Field(..., description="kling-video-o1")
+ aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
+ duration: str = Field(..., description="'5' or '10'")
+ prompt: str = Field(...)
+ mode: str = Field("pro")
+
+
+class OmniParamImage(BaseModel):
+ image_url: str = Field(...)
+ type: str | None = Field(None, description="Can be 'first_frame' or 'end_frame'")
+
+
+class OmniParamVideo(BaseModel):
+ video_url: str = Field(...)
+ refer_type: str | None = Field(..., description="Can be 'base' or 'feature'")
+ keep_original_sound: str = Field(..., description="'yes' or 'no'")
+
+
+class OmniProFirstLastFrameRequest(BaseModel):
+ model_name: str = Field(..., description="kling-video-o1")
+ image_list: list[OmniParamImage] = Field(..., min_length=1, max_length=7)
+ duration: str = Field(..., description="'5' or '10'")
+ prompt: str = Field(...)
+ mode: str = Field("pro")
+
+
+class OmniProReferences2VideoRequest(BaseModel):
+ model_name: str = Field(..., description="kling-video-o1")
+ aspect_ratio: str | None = Field(..., description="'16:9', '9:16' or '1:1'")
+ image_list: list[OmniParamImage] | None = Field(
+ None, max_length=7, description="Max length 4 when video is present."
+ )
+ video_list: list[OmniParamVideo] | None = Field(None, max_length=1)
+ duration: str | None = Field(..., description="From 3 to 10.")
+ prompt: str = Field(...)
+ mode: str = Field("pro")
+
+
+class TaskStatusVideoResult(BaseModel):
+ duration: str | None = Field(None, description="Total video duration")
+ id: str | None = Field(None, description="Generated video ID")
+ url: str | None = Field(None, description="URL for generated video")
+
+
+class TaskStatusImageResult(BaseModel):
+ index: int = Field(..., description="Image Number,0-9")
+ url: str = Field(..., description="URL for generated image")
+
+
+class OmniTaskStatusResults(BaseModel):
+ videos: list[TaskStatusVideoResult] | None = Field(None)
+ images: list[TaskStatusImageResult] | None = Field(None)
+
+
+class OmniTaskStatusResponseData(BaseModel):
+ created_at: int | None = Field(None, description="Task creation time")
+ updated_at: int | None = Field(None, description="Task update time")
+ task_status: str | None = None
+ task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
+ task_id: str | None = Field(None, description="Task ID")
+ task_result: OmniTaskStatusResults | None = Field(None)
+
+
+class OmniTaskStatusResponse(BaseModel):
+ code: int | None = Field(None, description="Error code")
+ message: str | None = Field(None, description="Error message")
+ request_id: str | None = Field(None, description="Request ID")
+ data: OmniTaskStatusResponseData | None = Field(None)
+
+
+class OmniImageParamImage(BaseModel):
+ image: str = Field(...)
+
+
+class OmniProImageRequest(BaseModel):
+ model_name: str = Field(..., description="kling-image-o1")
+ resolution: str = Field(..., description="'1k' or '2k'")
+ aspect_ratio: str | None = Field(...)
+ prompt: str = Field(...)
+ mode: str = Field("pro")
+ n: int | None = Field(1, le=9)
+ image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py
index 23a7f55f1..6c840dc47 100644
--- a/comfy_api_nodes/nodes_kling.py
+++ b/comfy_api_nodes/nodes_kling.py
@@ -4,13 +4,14 @@ For source of truth on the allowed permutations of request fields, please refere
- [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap)
"""
-import math
import logging
-
-from typing_extensions import override
+import math
+import re
import torch
+from typing_extensions import override
+from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis import (
KlingCameraControl,
KlingCameraConfig,
@@ -48,23 +49,33 @@ from comfy_api_nodes.apis import (
KlingCharacterEffectModelName,
KlingSingleImageEffectModelName,
)
+from comfy_api_nodes.apis.kling_api import (
+ OmniImageParamImage,
+ OmniParamImage,
+ OmniParamVideo,
+ OmniProFirstLastFrameRequest,
+ OmniProImageRequest,
+ OmniProReferences2VideoRequest,
+ OmniProText2VideoRequest,
+ OmniTaskStatusResponse,
+)
from comfy_api_nodes.util import (
- validate_image_dimensions,
+ ApiEndpoint,
+ download_url_to_image_tensor,
+ download_url_to_video_output,
+ get_number_of_images,
+ poll_op,
+ sync_op,
+ tensor_to_base64_string,
+ upload_audio_to_comfyapi,
+ upload_images_to_comfyapi,
+ upload_video_to_comfyapi,
validate_image_aspect_ratio,
+ validate_image_dimensions,
+ validate_string,
validate_video_dimensions,
validate_video_duration,
- tensor_to_base64_string,
- validate_string,
- upload_audio_to_comfyapi,
- download_url_to_image_tensor,
- upload_video_to_comfyapi,
- download_url_to_video_output,
- sync_op,
- ApiEndpoint,
- poll_op,
)
-from comfy_api.input_impl import VideoFromFile
-from comfy_api.latest import ComfyExtension, IO, Input
KLING_API_VERSION = "v1"
PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video"
@@ -202,6 +213,50 @@ VOICES_CONFIG = {
}
+def normalize_omni_prompt_references(prompt: str) -> str:
+ """
+ Rewrites Kling Omni-style placeholders used in the app, like:
+
+ @image, @image1, @image2, ... @imageN
+ @video, @video1, @video2, ... @videoN
+
+ into the API-compatible form:
+
+ <<>>, <<>>, ...
+ <<>>, <<>>, ...
+
+ This is a UX shim for ComfyUI so users can type the same syntax as in the Kling app.
+ """
+ if not prompt:
+ return prompt
+
+ def _image_repl(match):
+ return f"<<>>"
+
+ def _video_repl(match):
+ return f"<<>>"
+
+ # (? and not @imageFoo
+ prompt = re.sub(r"(?\d*)(?!\w)", _image_repl, prompt)
+ return re.sub(r"(?\d*)(?!\w)", _video_repl, prompt)
+
+
+async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput:
+ if response.code:
+ raise RuntimeError(
+ f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
+ )
+ final_response = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
+ response_model=OmniTaskStatusResponse,
+ status_extractor=lambda r: (r.data.task_status if r.data else None),
+ max_poll_attempts=160,
+ )
+ return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
+
+
def is_valid_camera_control_configs(configs: list[float]) -> bool:
"""Verifies that at least one camera control configuration is non-zero."""
return any(not math.isclose(value, 0.0) for value in configs)
@@ -449,7 +504,7 @@ async def execute_video_effect(
image_1: torch.Tensor,
image_2: torch.Tensor | None = None,
model_mode: KlingVideoGenMode | None = None,
-) -> tuple[VideoFromFile, str, str]:
+) -> tuple[InputImpl.VideoFromFile, str, str]:
if dual_character:
request_input_field = KlingDualCharacterEffectInput(
model_name=model_name,
@@ -736,6 +791,474 @@ class KlingTextToVideoNode(IO.ComfyNode):
)
+class OmniProTextToVideoNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingOmniProTextToVideoNode",
+ display_name="Kling Omni Text to Video (Pro)",
+ category="api node/video/Kling",
+ description="Use text prompts to generate videos with the latest Kling model.",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="A text prompt describing the video content. "
+ "This can include both positive and negative descriptions.",
+ ),
+ IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
+ IO.Combo.Input("duration", options=[5, 10]),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ aspect_ratio: str,
+ duration: int,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, min_length=1, max_length=2500)
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
+ response_model=OmniTaskStatusResponse,
+ data=OmniProText2VideoRequest(
+ model_name=model_name,
+ prompt=prompt,
+ aspect_ratio=aspect_ratio,
+ duration=str(duration),
+ ),
+ )
+ return await finish_omni_video_task(cls, response)
+
+
+class OmniProFirstLastFrameNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingOmniProFirstLastFrameNode",
+ display_name="Kling Omni First-Last-Frame to Video (Pro)",
+ category="api node/video/Kling",
+ description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="A text prompt describing the video content. "
+ "This can include both positive and negative descriptions.",
+ ),
+ IO.Combo.Input("duration", options=["5", "10"]),
+ IO.Image.Input("first_frame"),
+ IO.Image.Input(
+ "end_frame",
+ optional=True,
+ tooltip="An optional end frame for the video. "
+ "This cannot be used simultaneously with 'reference_images'.",
+ ),
+ IO.Image.Input(
+ "reference_images",
+ optional=True,
+ tooltip="Up to 6 additional reference images.",
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ duration: int,
+ first_frame: Input.Image,
+ end_frame: Input.Image | None = None,
+ reference_images: Input.Image | None = None,
+ ) -> IO.NodeOutput:
+ prompt = normalize_omni_prompt_references(prompt)
+ validate_string(prompt, min_length=1, max_length=2500)
+ if end_frame is not None and reference_images is not None:
+ raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
+ validate_image_dimensions(first_frame, min_width=300, min_height=300)
+ validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
+ image_list: list[OmniParamImage] = [
+ OmniParamImage(
+ image_url=(await upload_images_to_comfyapi(cls, first_frame, wait_label="Uploading first frame"))[0],
+ type="first_frame",
+ )
+ ]
+ if end_frame is not None:
+ validate_image_dimensions(end_frame, min_width=300, min_height=300)
+ validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1))
+ image_list.append(
+ OmniParamImage(
+ image_url=(await upload_images_to_comfyapi(cls, end_frame, wait_label="Uploading end frame"))[0],
+ type="end_frame",
+ )
+ )
+ if reference_images is not None:
+ if get_number_of_images(reference_images) > 6:
+ raise ValueError("The maximum number of reference images allowed is 6.")
+ for i in reference_images:
+ validate_image_dimensions(i, min_width=300, min_height=300)
+ validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
+ for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"):
+ image_list.append(OmniParamImage(image_url=i))
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
+ response_model=OmniTaskStatusResponse,
+ data=OmniProFirstLastFrameRequest(
+ model_name=model_name,
+ prompt=prompt,
+ duration=str(duration),
+ image_list=image_list,
+ ),
+ )
+ return await finish_omni_video_task(cls, response)
+
+
+class OmniProImageToVideoNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingOmniProImageToVideoNode",
+ display_name="Kling Omni Image to Video (Pro)",
+ category="api node/video/Kling",
+ description="Use up to 7 reference images to generate a video with the latest Kling model.",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="A text prompt describing the video content. "
+ "This can include both positive and negative descriptions.",
+ ),
+ IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
+ IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
+ IO.Image.Input(
+ "reference_images",
+ tooltip="Up to 7 reference images.",
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ aspect_ratio: str,
+ duration: int,
+ reference_images: Input.Image,
+ ) -> IO.NodeOutput:
+ prompt = normalize_omni_prompt_references(prompt)
+ validate_string(prompt, min_length=1, max_length=2500)
+ if get_number_of_images(reference_images) > 7:
+ raise ValueError("The maximum number of reference images is 7.")
+ for i in reference_images:
+ validate_image_dimensions(i, min_width=300, min_height=300)
+ validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
+ image_list: list[OmniParamImage] = []
+ for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
+ image_list.append(OmniParamImage(image_url=i))
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
+ response_model=OmniTaskStatusResponse,
+ data=OmniProReferences2VideoRequest(
+ model_name=model_name,
+ prompt=prompt,
+ aspect_ratio=aspect_ratio,
+ duration=str(duration),
+ image_list=image_list,
+ ),
+ )
+ return await finish_omni_video_task(cls, response)
+
+
+class OmniProVideoToVideoNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingOmniProVideoToVideoNode",
+ display_name="Kling Omni Video to Video (Pro)",
+ category="api node/video/Kling",
+ description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="A text prompt describing the video content. "
+ "This can include both positive and negative descriptions.",
+ ),
+ IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
+ IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
+ IO.Video.Input("reference_video", tooltip="Video to use as a reference."),
+ IO.Boolean.Input("keep_original_sound", default=True),
+ IO.Image.Input(
+ "reference_images",
+ tooltip="Up to 4 additional reference images.",
+ optional=True,
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ aspect_ratio: str,
+ duration: int,
+ reference_video: Input.Video,
+ keep_original_sound: bool,
+ reference_images: Input.Image | None = None,
+ ) -> IO.NodeOutput:
+ prompt = normalize_omni_prompt_references(prompt)
+ validate_string(prompt, min_length=1, max_length=2500)
+ validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05)
+ validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160)
+ image_list: list[OmniParamImage] = []
+ if reference_images is not None:
+ if get_number_of_images(reference_images) > 4:
+ raise ValueError("The maximum number of reference images allowed with a video input is 4.")
+ for i in reference_images:
+ validate_image_dimensions(i, min_width=300, min_height=300)
+ validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
+ for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
+ image_list.append(OmniParamImage(image_url=i))
+ video_list = [
+ OmniParamVideo(
+ video_url=await upload_video_to_comfyapi(cls, reference_video, wait_label="Uploading reference video"),
+ refer_type="feature",
+ keep_original_sound="yes" if keep_original_sound else "no",
+ )
+ ]
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
+ response_model=OmniTaskStatusResponse,
+ data=OmniProReferences2VideoRequest(
+ model_name=model_name,
+ prompt=prompt,
+ aspect_ratio=aspect_ratio,
+ duration=str(duration),
+ image_list=image_list if image_list else None,
+ video_list=video_list,
+ ),
+ )
+ return await finish_omni_video_task(cls, response)
+
+
+class OmniProEditVideoNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingOmniProEditVideoNode",
+ display_name="Kling Omni Edit Video (Pro)",
+ category="api node/video/Kling",
+ description="Edit an existing video with the latest model from Kling.",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="A text prompt describing the video content. "
+ "This can include both positive and negative descriptions.",
+ ),
+ IO.Video.Input("video", tooltip="Video for editing. The output video length will be the same."),
+ IO.Boolean.Input("keep_original_sound", default=True),
+ IO.Image.Input(
+ "reference_images",
+ tooltip="Up to 4 additional reference images.",
+ optional=True,
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ video: Input.Video,
+ keep_original_sound: bool,
+ reference_images: Input.Image | None = None,
+ ) -> IO.NodeOutput:
+ prompt = normalize_omni_prompt_references(prompt)
+ validate_string(prompt, min_length=1, max_length=2500)
+ validate_video_duration(video, min_duration=3.0, max_duration=10.05)
+ validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160)
+ image_list: list[OmniParamImage] = []
+ if reference_images is not None:
+ if get_number_of_images(reference_images) > 4:
+ raise ValueError("The maximum number of reference images allowed with a video input is 4.")
+ for i in reference_images:
+ validate_image_dimensions(i, min_width=300, min_height=300)
+ validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
+ for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
+ image_list.append(OmniParamImage(image_url=i))
+ video_list = [
+ OmniParamVideo(
+ video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading base video"),
+ refer_type="base",
+ keep_original_sound="yes" if keep_original_sound else "no",
+ )
+ ]
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
+ response_model=OmniTaskStatusResponse,
+ data=OmniProReferences2VideoRequest(
+ model_name=model_name,
+ prompt=prompt,
+ aspect_ratio=None,
+ duration=None,
+ image_list=image_list if image_list else None,
+ video_list=video_list,
+ ),
+ )
+ return await finish_omni_video_task(cls, response)
+
+
+class OmniProImageNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingOmniProImageNode",
+ display_name="Kling Omni Image (Pro)",
+ category="api node/image/Kling",
+ description="Create or edit images with the latest model from Kling.",
+ inputs=[
+ IO.Combo.Input("model_name", options=["kling-image-o1"]),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="A text prompt describing the image content. "
+ "This can include both positive and negative descriptions.",
+ ),
+ IO.Combo.Input("resolution", options=["1K", "2K"]),
+ IO.Combo.Input(
+ "aspect_ratio",
+ options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"],
+ ),
+ IO.Image.Input(
+ "reference_images",
+ tooltip="Up to 10 additional reference images.",
+ optional=True,
+ ),
+ ],
+ outputs=[
+ IO.Image.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_name: str,
+ prompt: str,
+ resolution: str,
+ aspect_ratio: str,
+ reference_images: Input.Image | None = None,
+ ) -> IO.NodeOutput:
+ prompt = normalize_omni_prompt_references(prompt)
+ validate_string(prompt, min_length=1, max_length=2500)
+ image_list: list[OmniImageParamImage] = []
+ if reference_images is not None:
+ if get_number_of_images(reference_images) > 10:
+ raise ValueError("The maximum number of reference images is 10.")
+ for i in reference_images:
+ validate_image_dimensions(i, min_width=300, min_height=300)
+ validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
+ for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
+ image_list.append(OmniImageParamImage(image=i))
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
+ response_model=OmniTaskStatusResponse,
+ data=OmniProImageRequest(
+ model_name=model_name,
+ prompt=prompt,
+ resolution=resolution.lower(),
+ aspect_ratio=aspect_ratio,
+ image_list=image_list if image_list else None,
+ ),
+ )
+ if response.code:
+ raise RuntimeError(
+ f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
+ )
+ final_response = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
+ response_model=OmniTaskStatusResponse,
+ status_extractor=lambda r: (r.data.task_status if r.data else None),
+ )
+ return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
+
+
class KlingCameraControlT2VNode(IO.ComfyNode):
"""
Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera.
@@ -1162,7 +1685,10 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
category="api node/video/Kling",
description="Achieve different special effects when generating a video based on the effect_scene.",
inputs=[
- IO.Image.Input("image", tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1"),
+ IO.Image.Input(
+ "image",
+ tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1",
+ ),
IO.Combo.Input(
"effect_scene",
options=[i.value for i in KlingSingleImageEffectsScene],
@@ -1525,6 +2051,12 @@ class KlingExtension(ComfyExtension):
KlingImageGenerationNode,
KlingSingleImageVideoEffectNode,
KlingDualCharacterVideoEffectNode,
+ OmniProTextToVideoNode,
+ OmniProFirstLastFrameNode,
+ OmniProImageToVideoNode,
+ OmniProVideoToVideoNode,
+ OmniProEditVideoNode,
+ # OmniProImageNode, # need support from backend
]
diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py
index 80292fb3c..4cc22abfb 100644
--- a/comfy_api_nodes/util/__init__.py
+++ b/comfy_api_nodes/util/__init__.py
@@ -47,6 +47,7 @@ from .validation_utils import (
validate_string,
validate_video_dimensions,
validate_video_duration,
+ validate_video_frame_count,
)
__all__ = [
@@ -94,6 +95,7 @@ __all__ = [
"validate_string",
"validate_video_dimensions",
"validate_video_duration",
+ "validate_video_frame_count",
# Misc functions
"get_fs_object_size",
]
diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py
index 328fe5227..491e6b6a8 100644
--- a/comfy_api_nodes/util/_helpers.py
+++ b/comfy_api_nodes/util/_helpers.py
@@ -2,8 +2,8 @@ import asyncio
import contextlib
import os
import time
+from collections.abc import Callable
from io import BytesIO
-from typing import Callable, Optional, Union
from comfy.cli_args import args
from comfy.model_management import processing_interrupted
@@ -35,12 +35,12 @@ def default_base_url() -> str:
async def sleep_with_interrupt(
seconds: float,
- node_cls: Optional[type[IO.ComfyNode]],
- label: Optional[str] = None,
- start_ts: Optional[float] = None,
- estimated_total: Optional[int] = None,
+ node_cls: type[IO.ComfyNode] | None,
+ label: str | None = None,
+ start_ts: float | None = None,
+ estimated_total: int | None = None,
*,
- display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None,
+ display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None,
):
"""
Sleep in 1s slices while:
@@ -65,7 +65,7 @@ def mimetype_to_extension(mime_type: str) -> str:
return mime_type.split("/")[-1].lower()
-def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int:
+def get_fs_object_size(path_or_object: str | BytesIO) -> int:
if isinstance(path_or_object, str):
return os.path.getsize(path_or_object)
return len(path_or_object.getvalue())
diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py
index bf01d7d36..bf37cba5f 100644
--- a/comfy_api_nodes/util/client.py
+++ b/comfy_api_nodes/util/client.py
@@ -4,10 +4,11 @@ import json
import logging
import time
import uuid
+from collections.abc import Callable, Iterable
from dataclasses import dataclass
from enum import Enum
from io import BytesIO
-from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
+from typing import Any, Literal, TypeVar
from urllib.parse import urljoin, urlparse
import aiohttp
@@ -37,8 +38,8 @@ class ApiEndpoint:
path: str,
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
*,
- query_params: Optional[dict[str, Any]] = None,
- headers: Optional[dict[str, str]] = None,
+ query_params: dict[str, Any] | None = None,
+ headers: dict[str, str] | None = None,
):
self.path = path
self.method = method
@@ -52,18 +53,18 @@ class _RequestConfig:
endpoint: ApiEndpoint
timeout: float
content_type: str
- data: Optional[dict[str, Any]]
- files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
- multipart_parser: Optional[Callable]
+ data: dict[str, Any] | None
+ files: dict[str, Any] | list[tuple[str, Any]] | None
+ multipart_parser: Callable | None
max_retries: int
retry_delay: float
retry_backoff: float
wait_label: str = "Waiting"
monitor_progress: bool = True
- estimated_total: Optional[int] = None
- final_label_on_success: Optional[str] = "Completed"
- progress_origin_ts: Optional[float] = None
- price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None
+ estimated_total: int | None = None
+ final_label_on_success: str | None = "Completed"
+ progress_origin_ts: float | None = None
+ price_extractor: Callable[[dict[str, Any]], float | None] | None = None
@dataclass
@@ -71,10 +72,10 @@ class _PollUIState:
started: float
status_label: str = "Queued"
is_queued: bool = True
- price: Optional[float] = None
- estimated_duration: Optional[int] = None
+ price: float | None = None
+ estimated_duration: int | None = None
base_processing_elapsed: float = 0.0 # sum of completed active intervals
- active_since: Optional[float] = None # start time of current active interval (None if queued)
+ active_since: float | None = None # start time of current active interval (None if queued)
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
@@ -87,20 +88,20 @@ async def sync_op(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
*,
- response_model: Type[M],
- price_extractor: Optional[Callable[[M], Optional[float]]] = None,
- data: Optional[BaseModel] = None,
- files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
+ response_model: type[M],
+ price_extractor: Callable[[M | Any], float | None] | None = None,
+ data: BaseModel | None = None,
+ files: dict[str, Any] | list[tuple[str, Any]] | None = None,
content_type: str = "application/json",
timeout: float = 3600.0,
- multipart_parser: Optional[Callable] = None,
+ multipart_parser: Callable | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str = "Waiting for server",
- estimated_duration: Optional[int] = None,
- final_label_on_success: Optional[str] = "Completed",
- progress_origin_ts: Optional[float] = None,
+ estimated_duration: int | None = None,
+ final_label_on_success: str | None = "Completed",
+ progress_origin_ts: float | None = None,
monitor_progress: bool = True,
) -> M:
raw = await sync_op_raw(
@@ -131,22 +132,22 @@ async def poll_op(
cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint,
*,
- response_model: Type[M],
- status_extractor: Callable[[M], Optional[Union[str, int]]],
- progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
- price_extractor: Optional[Callable[[M], Optional[float]]] = None,
- completed_statuses: Optional[list[Union[str, int]]] = None,
- failed_statuses: Optional[list[Union[str, int]]] = None,
- queued_statuses: Optional[list[Union[str, int]]] = None,
- data: Optional[BaseModel] = None,
+ response_model: type[M],
+ status_extractor: Callable[[M | Any], str | int | None],
+ progress_extractor: Callable[[M | Any], int | None] | None = None,
+ price_extractor: Callable[[M | Any], float | None] | None = None,
+ completed_statuses: list[str | int] | None = None,
+ failed_statuses: list[str | int] | None = None,
+ queued_statuses: list[str | int] | None = None,
+ data: BaseModel | None = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
- estimated_duration: Optional[int] = None,
- cancel_endpoint: Optional[ApiEndpoint] = None,
+ estimated_duration: int | None = None,
+ cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
) -> M:
raw = await poll_op_raw(
@@ -178,22 +179,22 @@ async def sync_op_raw(
cls: type[IO.ComfyNode],
endpoint: ApiEndpoint,
*,
- price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
- data: Optional[Union[dict[str, Any], BaseModel]] = None,
- files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
+ price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
+ data: dict[str, Any] | BaseModel | None = None,
+ files: dict[str, Any] | list[tuple[str, Any]] | None = None,
content_type: str = "application/json",
timeout: float = 3600.0,
- multipart_parser: Optional[Callable] = None,
+ multipart_parser: Callable | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
wait_label: str = "Waiting for server",
- estimated_duration: Optional[int] = None,
+ estimated_duration: int | None = None,
as_binary: bool = False,
- final_label_on_success: Optional[str] = "Completed",
- progress_origin_ts: Optional[float] = None,
+ final_label_on_success: str | None = "Completed",
+ progress_origin_ts: float | None = None,
monitor_progress: bool = True,
-) -> Union[dict[str, Any], bytes]:
+) -> dict[str, Any] | bytes:
"""
Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON).
@@ -229,21 +230,21 @@ async def poll_op_raw(
cls: type[IO.ComfyNode],
poll_endpoint: ApiEndpoint,
*,
- status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
- progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
- price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
- completed_statuses: Optional[list[Union[str, int]]] = None,
- failed_statuses: Optional[list[Union[str, int]]] = None,
- queued_statuses: Optional[list[Union[str, int]]] = None,
- data: Optional[Union[dict[str, Any], BaseModel]] = None,
+ status_extractor: Callable[[dict[str, Any]], str | int | None],
+ progress_extractor: Callable[[dict[str, Any]], int | None] | None = None,
+ price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
+ completed_statuses: list[str | int] | None = None,
+ failed_statuses: list[str | int] | None = None,
+ queued_statuses: list[str | int] | None = None,
+ data: dict[str, Any] | BaseModel | None = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120,
timeout_per_poll: float = 120.0,
max_retries_per_poll: int = 3,
retry_delay_per_poll: float = 1.0,
retry_backoff_per_poll: float = 2.0,
- estimated_duration: Optional[int] = None,
- cancel_endpoint: Optional[ApiEndpoint] = None,
+ estimated_duration: int | None = None,
+ cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
) -> dict[str, Any]:
"""
@@ -261,7 +262,7 @@ async def poll_op_raw(
consumed_attempts = 0 # counts only non-queued polls
progress_bar = utils.ProgressBar(100) if progress_extractor else None
- last_progress: Optional[int] = None
+ last_progress: int | None = None
state = _PollUIState(started=started, estimated_duration=estimated_duration)
stop_ticker = asyncio.Event()
@@ -420,10 +421,10 @@ async def poll_op_raw(
def _display_text(
node_cls: type[IO.ComfyNode],
- text: Optional[str],
+ text: str | None,
*,
- status: Optional[Union[str, int]] = None,
- price: Optional[float] = None,
+ status: str | int | None = None,
+ price: float | None = None,
) -> None:
display_lines: list[str] = []
if status:
@@ -440,13 +441,13 @@ def _display_text(
def _display_time_progress(
node_cls: type[IO.ComfyNode],
- status: Optional[Union[str, int]],
+ status: str | int | None,
elapsed_seconds: int,
- estimated_total: Optional[int] = None,
+ estimated_total: int | None = None,
*,
- price: Optional[float] = None,
- is_queued: Optional[bool] = None,
- processing_elapsed_seconds: Optional[int] = None,
+ price: float | None = None,
+ is_queued: bool | None = None,
+ processing_elapsed_seconds: int | None = None,
) -> None:
if estimated_total is not None and estimated_total > 0 and is_queued is False:
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
@@ -488,7 +489,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
raise ValueError("files tuple must be (filename, file[, content_type])")
-def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
+def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
params = dict(endpoint_params or {})
if method.upper() == "GET" and data:
for k, v in data.items():
@@ -534,9 +535,9 @@ def _generate_operation_id(method: str, path: str, attempt: int) -> str:
def _snapshot_request_body_for_logging(
content_type: str,
method: str,
- data: Optional[dict[str, Any]],
- files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
-) -> Optional[Union[dict[str, Any], str]]:
+ data: dict[str, Any] | None,
+ files: dict[str, Any] | list[tuple[str, Any]] | None,
+) -> dict[str, Any] | str | None:
if method.upper() == "GET":
return None
if content_type == "multipart/form-data":
@@ -586,13 +587,13 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
attempt = 0
delay = cfg.retry_delay
operation_succeeded: bool = False
- final_elapsed_seconds: Optional[int] = None
- extracted_price: Optional[float] = None
+ final_elapsed_seconds: int | None = None
+ extracted_price: float | None = None
while True:
attempt += 1
stop_event = asyncio.Event()
- monitor_task: Optional[asyncio.Task] = None
- sess: Optional[aiohttp.ClientSession] = None
+ monitor_task: asyncio.Task | None = None
+ sess: aiohttp.ClientSession | None = None
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
@@ -887,7 +888,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
)
-def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
+def _validate_or_raise(response_model: type[M], payload: Any) -> M:
try:
return response_model.model_validate(payload)
except Exception as e:
@@ -902,9 +903,9 @@ def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
def _wrap_model_extractor(
- response_model: Type[M],
- extractor: Optional[Callable[[M], Any]],
-) -> Optional[Callable[[dict[str, Any]], Any]]:
+ response_model: type[M],
+ extractor: Callable[[M], Any] | None,
+) -> Callable[[dict[str, Any]], Any] | None:
"""Wrap a typed extractor so it can be used by the dict-based poller.
Validates the dict into `response_model` before invoking `extractor`.
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
@@ -929,10 +930,10 @@ def _wrap_model_extractor(
return _wrapped
-def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
+def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]:
if not values:
return set()
- out: set[Union[str, int]] = set()
+ out: set[str | int] = set()
for v in values:
nv = _normalize_status_value(v)
if nv is not None:
@@ -940,7 +941,7 @@ def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Unio
return out
-def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
+def _normalize_status_value(val: str | int | None) -> str | int | None:
if isinstance(val, str):
return val.strip().lower()
return val
diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py
index 971dc57de..c57457580 100644
--- a/comfy_api_nodes/util/conversions.py
+++ b/comfy_api_nodes/util/conversions.py
@@ -4,7 +4,6 @@ import math
import mimetypes
import uuid
from io import BytesIO
-from typing import Optional
import av
import numpy as np
@@ -12,8 +11,7 @@ import torch
from PIL import Image
from comfy.utils import common_upscale
-from comfy_api.latest import Input, InputImpl
-from comfy_api.util import VideoCodec, VideoContainer
+from comfy_api.latest import Input, InputImpl, Types
from ._helpers import mimetype_to_extension
@@ -57,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to
def tensor_to_bytesio(
image: torch.Tensor,
- name: Optional[str] = None,
+ name: str | None = None,
total_pixels: int = 2048 * 2048,
mime_type: str = "image/png",
) -> BytesIO:
@@ -177,8 +175,8 @@ def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", co
def video_to_base64_string(
video: Input.Video,
- container_format: VideoContainer = None,
- codec: VideoCodec = None
+ container_format: Types.VideoContainer | None = None,
+ codec: Types.VideoCodec | None = None,
) -> str:
"""
Converts a video input to a base64 string.
@@ -189,12 +187,11 @@ def video_to_base64_string(
codec: Optional codec to use (defaults to video.codec if available)
"""
video_bytes_io = BytesIO()
-
- # Use provided format/codec if specified, otherwise use video's own if available
- format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
- codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
-
- video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
+ video.save_to(
+ video_bytes_io,
+ format=container_format or getattr(video, "container", Types.VideoContainer.MP4),
+ codec=codec or getattr(video, "codec", Types.VideoCodec.H264),
+ )
video_bytes_io.seek(0)
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py
index 14207dc68..3e0d0352d 100644
--- a/comfy_api_nodes/util/download_helpers.py
+++ b/comfy_api_nodes/util/download_helpers.py
@@ -3,15 +3,15 @@ import contextlib
import uuid
from io import BytesIO
from pathlib import Path
-from typing import IO, Optional, Union
+from typing import IO
from urllib.parse import urljoin, urlparse
import aiohttp
import torch
from aiohttp.client_exceptions import ClientError, ContentTypeError
-from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO as COMFY_IO
+from comfy_api.latest import InputImpl
from . import request_logger
from ._helpers import (
@@ -29,9 +29,9 @@ _RETRY_STATUS = {408, 429, 500, 502, 503, 504}
async def download_url_to_bytesio(
url: str,
- dest: Optional[Union[BytesIO, IO[bytes], str, Path]],
+ dest: BytesIO | IO[bytes] | str | Path | None,
*,
- timeout: Optional[float] = None,
+ timeout: float | None = None,
max_retries: int = 5,
retry_delay: float = 1.0,
retry_backoff: float = 2.0,
@@ -71,10 +71,10 @@ async def download_url_to_bytesio(
is_path_sink = isinstance(dest, (str, Path))
fhandle = None
- session: Optional[aiohttp.ClientSession] = None
- stop_evt: Optional[asyncio.Event] = None
- monitor_task: Optional[asyncio.Task] = None
- req_task: Optional[asyncio.Task] = None
+ session: aiohttp.ClientSession | None = None
+ stop_evt: asyncio.Event | None = None
+ monitor_task: asyncio.Task | None = None
+ req_task: asyncio.Task | None = None
try:
with contextlib.suppress(Exception):
@@ -234,11 +234,11 @@ async def download_url_to_video_output(
timeout: float = None,
max_retries: int = 5,
cls: type[COMFY_IO.ComfyNode] = None,
-) -> VideoFromFile:
+) -> InputImpl.VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output."""
result = BytesIO()
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
- return VideoFromFile(result)
+ return InputImpl.VideoFromFile(result)
async def download_url_as_bytesio(
diff --git a/comfy_api_nodes/util/request_logger.py b/comfy_api_nodes/util/request_logger.py
index ac52e2eab..e0cb4428d 100644
--- a/comfy_api_nodes/util/request_logger.py
+++ b/comfy_api_nodes/util/request_logger.py
@@ -1,5 +1,3 @@
-from __future__ import annotations
-
import datetime
import hashlib
import json
diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py
index b9019841f..b8d33f4d1 100644
--- a/comfy_api_nodes/util/upload_helpers.py
+++ b/comfy_api_nodes/util/upload_helpers.py
@@ -4,15 +4,13 @@ import logging
import time
import uuid
from io import BytesIO
-from typing import Optional
from urllib.parse import urlparse
import aiohttp
import torch
from pydantic import BaseModel, Field
-from comfy_api.latest import IO, Input
-from comfy_api.util import VideoCodec, VideoContainer
+from comfy_api.latest import IO, Input, Types
from . import request_logger
from ._helpers import is_processing_interrupted, sleep_with_interrupt
@@ -32,7 +30,7 @@ from .conversions import (
class UploadRequest(BaseModel):
file_name: str = Field(..., description="Filename to upload")
- content_type: Optional[str] = Field(
+ content_type: str | None = Field(
None,
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
)
@@ -56,7 +54,7 @@ async def upload_images_to_comfyapi(
Uploads images to ComfyUI API and returns download URLs.
To upload multiple images, stack them in the batch dimension first.
"""
- # if batch, try to upload each file if max_images is greater than 0
+ # if batched, try to upload each file if max_images is greater than 0
download_urls: list[str] = []
is_batch = len(image.shape) > 3
batch_len = image.shape[0] if is_batch else 1
@@ -100,9 +98,10 @@ async def upload_video_to_comfyapi(
cls: type[IO.ComfyNode],
video: Input.Video,
*,
- container: VideoContainer = VideoContainer.MP4,
- codec: VideoCodec = VideoCodec.H264,
- max_duration: Optional[int] = None,
+ container: Types.VideoContainer = Types.VideoContainer.MP4,
+ codec: Types.VideoCodec = Types.VideoCodec.H264,
+ max_duration: int | None = None,
+ wait_label: str | None = "Uploading",
) -> str:
"""
Uploads a single video to ComfyUI API and returns its download URL.
@@ -127,7 +126,7 @@ async def upload_video_to_comfyapi(
video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0)
- return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type)
+ return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
async def upload_file_to_comfyapi(
@@ -219,7 +218,7 @@ async def upload_file(
return
monitor_task = asyncio.create_task(_monitor())
- sess: Optional[aiohttp.ClientSession] = None
+ sess: aiohttp.ClientSession | None = None
try:
try:
request_logger.log_request_response(
diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py
index ec7006aed..f01edea96 100644
--- a/comfy_api_nodes/util/validation_utils.py
+++ b/comfy_api_nodes/util/validation_utils.py
@@ -1,9 +1,7 @@
import logging
-from typing import Optional
import torch
-from comfy_api.input.video_types import VideoInput
from comfy_api.latest import Input
@@ -18,10 +16,10 @@ def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
def validate_image_dimensions(
image: torch.Tensor,
- min_width: Optional[int] = None,
- max_width: Optional[int] = None,
- min_height: Optional[int] = None,
- max_height: Optional[int] = None,
+ min_width: int | None = None,
+ max_width: int | None = None,
+ min_height: int | None = None,
+ max_height: int | None = None,
):
height, width = get_image_dimensions(image)
@@ -37,8 +35,8 @@ def validate_image_dimensions(
def validate_image_aspect_ratio(
image: torch.Tensor,
- min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
- max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
+ min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
+ max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
*,
strict: bool = True, # True -> (min, max); False -> [min, max]
) -> float:
@@ -54,8 +52,8 @@ def validate_image_aspect_ratio(
def validate_images_aspect_ratio_closeness(
first_image: torch.Tensor,
second_image: torch.Tensor,
- min_rel: float, # e.g. 0.8
- max_rel: float, # e.g. 1.25
+ min_rel: float, # e.g. 0.8
+ max_rel: float, # e.g. 1.25
*,
strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float:
@@ -84,8 +82,8 @@ def validate_images_aspect_ratio_closeness(
def validate_aspect_ratio_string(
aspect_ratio: str,
- min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
- max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
+ min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
+ max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
*,
strict: bool = False, # True -> (min, max); False -> [min, max]
) -> float:
@@ -97,10 +95,10 @@ def validate_aspect_ratio_string(
def validate_video_dimensions(
video: Input.Video,
- min_width: Optional[int] = None,
- max_width: Optional[int] = None,
- min_height: Optional[int] = None,
- max_height: Optional[int] = None,
+ min_width: int | None = None,
+ max_width: int | None = None,
+ min_height: int | None = None,
+ max_height: int | None = None,
):
try:
width, height = video.get_dimensions()
@@ -120,8 +118,8 @@ def validate_video_dimensions(
def validate_video_duration(
video: Input.Video,
- min_duration: Optional[float] = None,
- max_duration: Optional[float] = None,
+ min_duration: float | None = None,
+ max_duration: float | None = None,
):
try:
duration = video.get_duration()
@@ -136,6 +134,23 @@ def validate_video_duration(
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
+def validate_video_frame_count(
+ video: Input.Video,
+ min_frame_count: int | None = None,
+ max_frame_count: int | None = None,
+):
+ try:
+ frame_count = video.get_frame_count()
+ except Exception as e:
+ logging.error("Error getting frame count of video: %s", e)
+ return
+
+ if min_frame_count is not None and min_frame_count > frame_count:
+ raise ValueError(f"Video frame count must be at least {min_frame_count}, got {frame_count}")
+ if max_frame_count is not None and frame_count > max_frame_count:
+ raise ValueError(f"Video frame count must be at most {max_frame_count}, got {frame_count}")
+
+
def get_number_of_images(images):
if isinstance(images, torch.Tensor):
return images.shape[0] if images.ndim >= 4 else 1
@@ -144,8 +159,8 @@ def get_number_of_images(images):
def validate_audio_duration(
audio: Input.Audio,
- min_duration: Optional[float] = None,
- max_duration: Optional[float] = None,
+ min_duration: float | None = None,
+ max_duration: float | None = None,
) -> None:
sr = int(audio["sample_rate"])
dur = int(audio["waveform"].shape[-1]) / sr
@@ -177,7 +192,7 @@ def validate_string(
)
-def validate_container_format_is_mp4(video: VideoInput) -> None:
+def validate_container_format_is_mp4(video: Input.Video) -> None:
"""Validates video container format is MP4."""
container_format = video.get_container_format()
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
@@ -194,8 +209,8 @@ def _ratio_from_tuple(r: tuple[float, float]) -> float:
def _assert_ratio_bounds(
ar: float,
*,
- min_ratio: Optional[tuple[float, float]] = None,
- max_ratio: Optional[tuple[float, float]] = None,
+ min_ratio: tuple[float, float] | None = None,
+ max_ratio: tuple[float, float] | None = None,
strict: bool = True,
) -> None:
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py
index cec105fc9..24c0b4ed7 100644
--- a/comfy_execution/validation.py
+++ b/comfy_execution/validation.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+from comfy_api.latest import IO
def validate_node_input(
@@ -23,6 +24,11 @@ def validate_node_input(
if not received_type != input_type:
return True
+ # If the received type or input_type is a MatchType, we can return True immediately;
+ # validation for this is handled by the frontend
+ if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
+ return True
+
# Not equal, and not strings
if not isinstance(received_type, str) or not isinstance(input_type, str):
return False
diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py
index 2ed7e0b22..812301fb7 100644
--- a/comfy_extras/nodes_audio.py
+++ b/comfy_extras/nodes_audio.py
@@ -6,65 +6,80 @@ import torch
import comfy.model_management
import folder_paths
import os
-import io
-import json
-import random
import hashlib
import node_helpers
import logging
-from comfy.cli_args import args
-from comfy.comfy_types import FileLocator
+from typing_extensions import override
+from comfy_api.latest import ComfyExtension, IO, UI
-class EmptyLatentAudio:
- def __init__(self):
- self.device = comfy.model_management.intermediate_device()
+class EmptyLatentAudio(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="EmptyLatentAudio",
+ display_name="Empty Latent Audio",
+ category="latent/audio",
+ inputs=[
+ IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
+ IO.Int.Input(
+ "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
+ ),
+ ],
+ outputs=[IO.Latent.Output()],
+ )
@classmethod
- def INPUT_TYPES(s):
- return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
- "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
- }}
- RETURN_TYPES = ("LATENT",)
- FUNCTION = "generate"
-
- CATEGORY = "latent/audio"
-
- def generate(self, seconds, batch_size):
+ def execute(cls, seconds, batch_size) -> IO.NodeOutput:
length = round((seconds * 44100 / 2048) / 2) * 2
- latent = torch.zeros([batch_size, 64, length], device=self.device)
- return ({"samples":latent, "type": "audio"}, )
+ latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
+ return IO.NodeOutput({"samples":latent, "type": "audio"})
-class ConditioningStableAudio:
+ generate = execute # TODO: remove
+
+
+class ConditioningStableAudio(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {"positive": ("CONDITIONING", ),
- "negative": ("CONDITIONING", ),
- "seconds_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
- "seconds_total": ("FLOAT", {"default": 47.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
- }}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ConditioningStableAudio",
+ category="conditioning",
+ inputs=[
+ IO.Conditioning.Input("positive"),
+ IO.Conditioning.Input("negative"),
+ IO.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1),
+ IO.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1),
+ ],
+ outputs=[
+ IO.Conditioning.Output(display_name="positive"),
+ IO.Conditioning.Output(display_name="negative"),
+ ],
+ )
- RETURN_TYPES = ("CONDITIONING","CONDITIONING")
- RETURN_NAMES = ("positive", "negative")
-
- FUNCTION = "append"
-
- CATEGORY = "conditioning"
-
- def append(self, positive, negative, seconds_start, seconds_total):
+ @classmethod
+ def execute(cls, positive, negative, seconds_start, seconds_total) -> IO.NodeOutput:
positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total})
negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total})
- return (positive, negative)
+ return IO.NodeOutput(positive, negative)
-class VAEEncodeAudio:
+ append = execute # TODO: remove
+
+
+class VAEEncodeAudio(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}}
- RETURN_TYPES = ("LATENT",)
- FUNCTION = "encode"
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="VAEEncodeAudio",
+ display_name="VAE Encode Audio",
+ category="latent/audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.Vae.Input("vae"),
+ ],
+ outputs=[IO.Latent.Output()],
+ )
- CATEGORY = "latent/audio"
-
- def encode(self, vae, audio):
+ @classmethod
+ def execute(cls, vae, audio) -> IO.NodeOutput:
sample_rate = audio["sample_rate"]
if 44100 != sample_rate:
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
@@ -72,213 +87,134 @@ class VAEEncodeAudio:
waveform = audio["waveform"]
t = vae.encode(waveform.movedim(1, -1))
- return ({"samples":t}, )
+ return IO.NodeOutput({"samples":t})
-class VAEDecodeAudio:
+ encode = execute # TODO: remove
+
+
+class VAEDecodeAudio(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
- RETURN_TYPES = ("AUDIO",)
- FUNCTION = "decode"
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="VAEDecodeAudio",
+ display_name="VAE Decode Audio",
+ category="latent/audio",
+ inputs=[
+ IO.Latent.Input("samples"),
+ IO.Vae.Input("vae"),
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- CATEGORY = "latent/audio"
-
- def decode(self, vae, samples):
+ @classmethod
+ def execute(cls, vae, samples) -> IO.NodeOutput:
audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
- return ({"waveform": audio, "sample_rate": 44100}, )
+ return IO.NodeOutput({"waveform": audio, "sample_rate": 44100})
+
+ decode = execute # TODO: remove
-def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
-
- filename_prefix += self.prefix_append
- full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
- results: list[FileLocator] = []
-
- # Prepare metadata dictionary
- metadata = {}
- if not args.disable_metadata:
- if prompt is not None:
- metadata["prompt"] = json.dumps(prompt)
- if extra_pnginfo is not None:
- for x in extra_pnginfo:
- metadata[x] = json.dumps(extra_pnginfo[x])
-
- # Opus supported sample rates
- OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
-
- for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
- filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
- file = f"{filename_with_batch_num}_{counter:05}_.{format}"
- output_path = os.path.join(full_output_folder, file)
-
- # Use original sample rate initially
- sample_rate = audio["sample_rate"]
-
- # Handle Opus sample rate requirements
- if format == "opus":
- if sample_rate > 48000:
- sample_rate = 48000
- elif sample_rate not in OPUS_RATES:
- # Find the next highest supported rate
- for rate in sorted(OPUS_RATES):
- if rate > sample_rate:
- sample_rate = rate
- break
- if sample_rate not in OPUS_RATES: # Fallback if still not supported
- sample_rate = 48000
-
- # Resample if necessary
- if sample_rate != audio["sample_rate"]:
- waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
-
- # Create output with specified format
- output_buffer = io.BytesIO()
- output_container = av.open(output_buffer, mode='w', format=format)
-
- # Set metadata on the container
- for key, value in metadata.items():
- output_container.metadata[key] = value
-
- layout = 'mono' if waveform.shape[0] == 1 else 'stereo'
- # Set up the output stream with appropriate properties
- if format == "opus":
- out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
- if quality == "64k":
- out_stream.bit_rate = 64000
- elif quality == "96k":
- out_stream.bit_rate = 96000
- elif quality == "128k":
- out_stream.bit_rate = 128000
- elif quality == "192k":
- out_stream.bit_rate = 192000
- elif quality == "320k":
- out_stream.bit_rate = 320000
- elif format == "mp3":
- out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
- if quality == "V0":
- #TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
- out_stream.codec_context.qscale = 1
- elif quality == "128k":
- out_stream.bit_rate = 128000
- elif quality == "320k":
- out_stream.bit_rate = 320000
- else: #format == "flac":
- out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
-
- frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout)
- frame.sample_rate = sample_rate
- frame.pts = 0
- output_container.mux(out_stream.encode(frame))
-
- # Flush encoder
- output_container.mux(out_stream.encode(None))
-
- # Close containers
- output_container.close()
-
- # Write the output to file
- output_buffer.seek(0)
- with open(output_path, 'wb') as f:
- f.write(output_buffer.getbuffer())
-
- results.append({
- "filename": file,
- "subfolder": subfolder,
- "type": self.type
- })
- counter += 1
-
- return { "ui": { "audio": results } }
-
-class SaveAudio:
- def __init__(self):
- self.output_dir = folder_paths.get_output_directory()
- self.type = "output"
- self.prefix_append = ""
+class SaveAudio(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="SaveAudio",
+ display_name="Save Audio (FLAC)",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.String.Input("filename_prefix", default="audio/ComfyUI"),
+ ],
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
+ is_output_node=True,
+ )
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "audio": ("AUDIO", ),
- "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
- },
- "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
- }
+ def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput:
+ return IO.NodeOutput(
+ ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
+ )
- RETURN_TYPES = ()
- FUNCTION = "save_flac"
+ save_flac = execute # TODO: remove
- OUTPUT_NODE = True
- CATEGORY = "audio"
-
- def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
- return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
-
-class SaveAudioMP3:
- def __init__(self):
- self.output_dir = folder_paths.get_output_directory()
- self.type = "output"
- self.prefix_append = ""
+class SaveAudioMP3(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="SaveAudioMP3",
+ display_name="Save Audio (MP3)",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.String.Input("filename_prefix", default="audio/ComfyUI"),
+ IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
+ ],
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
+ is_output_node=True,
+ )
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "audio": ("AUDIO", ),
- "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
- "quality": (["V0", "128k", "320k"], {"default": "V0"}),
- },
- "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
- }
+ def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput:
+ return IO.NodeOutput(
+ ui=UI.AudioSaveHelper.get_save_audio_ui(
+ audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
+ )
+ )
- RETURN_TYPES = ()
- FUNCTION = "save_mp3"
+ save_mp3 = execute # TODO: remove
- OUTPUT_NODE = True
- CATEGORY = "audio"
-
- def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
- return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
-
-class SaveAudioOpus:
- def __init__(self):
- self.output_dir = folder_paths.get_output_directory()
- self.type = "output"
- self.prefix_append = ""
+class SaveAudioOpus(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="SaveAudioOpus",
+ display_name="Save Audio (Opus)",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.String.Input("filename_prefix", default="audio/ComfyUI"),
+ IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
+ ],
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
+ is_output_node=True,
+ )
@classmethod
- def INPUT_TYPES(s):
- return {"required": { "audio": ("AUDIO", ),
- "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
- "quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
- },
- "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
- }
+ def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput:
+ return IO.NodeOutput(
+ ui=UI.AudioSaveHelper.get_save_audio_ui(
+ audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
+ )
+ )
- RETURN_TYPES = ()
- FUNCTION = "save_opus"
+ save_opus = execute # TODO: remove
- OUTPUT_NODE = True
- CATEGORY = "audio"
-
- def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
- return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
-
-class PreviewAudio(SaveAudio):
- def __init__(self):
- self.output_dir = folder_paths.get_temp_directory()
- self.type = "temp"
- self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
+class PreviewAudio(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="PreviewAudio",
+ display_name="Preview Audio",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ ],
+ hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
+ is_output_node=True,
+ )
@classmethod
- def INPUT_TYPES(s):
- return {"required":
- {"audio": ("AUDIO", ), },
- "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
- }
+ def execute(cls, audio) -> IO.NodeOutput:
+ return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
+
+ save_flac = execute # TODO: remove
+
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format."""
@@ -316,26 +252,30 @@ def load(filepath: str) -> tuple[torch.Tensor, int]:
wav = f32_pcm(wav)
return wav, sr
-class LoadAudio:
+class LoadAudio(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
+ def define_schema(cls):
input_dir = folder_paths.get_input_directory()
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
- return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
+ return IO.Schema(
+ node_id="LoadAudio",
+ display_name="Load Audio",
+ category="audio",
+ inputs=[
+ IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)),
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- CATEGORY = "audio"
-
- RETURN_TYPES = ("AUDIO", )
- FUNCTION = "load"
-
- def load(self, audio):
+ @classmethod
+ def execute(cls, audio) -> IO.NodeOutput:
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
- return (audio, )
+ return IO.NodeOutput(audio)
@classmethod
- def IS_CHANGED(s, audio):
+ def fingerprint_inputs(cls, audio):
image_path = folder_paths.get_annotated_filepath(audio)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
@@ -343,46 +283,69 @@ class LoadAudio:
return m.digest().hex()
@classmethod
- def VALIDATE_INPUTS(s, audio):
+ def validate_inputs(cls, audio):
if not folder_paths.exists_annotated_filepath(audio):
return "Invalid audio file: {}".format(audio)
return True
-class RecordAudio:
+ load = execute # TODO: remove
+
+
+class RecordAudio(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {"audio": ("AUDIO_RECORD", {})}}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="RecordAudio",
+ display_name="Record Audio",
+ category="audio",
+ inputs=[
+ IO.Custom("AUDIO_RECORD").Input("audio"),
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- CATEGORY = "audio"
-
- RETURN_TYPES = ("AUDIO", )
- FUNCTION = "load"
-
- def load(self, audio):
+ @classmethod
+ def execute(cls, audio) -> IO.NodeOutput:
audio_path = folder_paths.get_annotated_filepath(audio)
waveform, sample_rate = load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
- return (audio, )
+ return IO.NodeOutput(audio)
+
+ load = execute # TODO: remove
-class TrimAudioDuration:
+class TrimAudioDuration(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "audio": ("AUDIO",),
- "start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}),
- "duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}),
- },
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="TrimAudioDuration",
+ display_name="Trim Audio Duration",
+ description="Trim audio tensor into chosen time range.",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.Float.Input(
+ "start_index",
+ default=0.0,
+ min=-0xffffffffffffffff,
+ max=0xffffffffffffffff,
+ step=0.01,
+ tooltip="Start time in seconds, can be negative to count from the end (supports sub-seconds).",
+ ),
+ IO.Float.Input(
+ "duration",
+ default=60.0,
+ min=0.0,
+ step=0.01,
+ tooltip="Duration in seconds",
+ ),
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- FUNCTION = "trim"
- RETURN_TYPES = ("AUDIO",)
- CATEGORY = "audio"
- DESCRIPTION = "Trim audio tensor into chosen time range."
-
- def trim(self, audio, start_index, duration):
+ @classmethod
+ def execute(cls, audio, start_index, duration) -> IO.NodeOutput:
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
audio_length = waveform.shape[-1]
@@ -399,23 +362,30 @@ class TrimAudioDuration:
if start_frame >= end_frame:
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
- return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},)
+ return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate})
+
+ trim = execute # TODO: remove
-class SplitAudioChannels:
+class SplitAudioChannels(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "audio": ("AUDIO",),
- }}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="SplitAudioChannels",
+ display_name="Split Audio Channels",
+ description="Separates the audio into left and right channels.",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ ],
+ outputs=[
+ IO.Audio.Output(display_name="left"),
+ IO.Audio.Output(display_name="right"),
+ ],
+ )
- RETURN_TYPES = ("AUDIO", "AUDIO")
- RETURN_NAMES = ("left", "right")
- FUNCTION = "separate"
- CATEGORY = "audio"
- DESCRIPTION = "Separates the audio into left and right channels."
-
- def separate(self, audio):
+ @classmethod
+ def execute(cls, audio) -> IO.NodeOutput:
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
@@ -425,7 +395,9 @@ class SplitAudioChannels:
left_channel = waveform[..., 0:1, :]
right_channel = waveform[..., 1:2, :]
- return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
+ return IO.NodeOutput({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
+
+ separate = execute # TODO: remove
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
@@ -443,21 +415,29 @@ def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_
return waveform_1, waveform_2, output_sample_rate
-class AudioConcat:
+class AudioConcat(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "audio1": ("AUDIO",),
- "audio2": ("AUDIO",),
- "direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}),
- }}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="AudioConcat",
+ display_name="Audio Concat",
+ description="Concatenates the audio1 to audio2 in the specified direction.",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio1"),
+ IO.Audio.Input("audio2"),
+ IO.Combo.Input(
+ "direction",
+ options=['after', 'before'],
+ default="after",
+ tooltip="Whether to append audio2 after or before audio1.",
+ )
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- RETURN_TYPES = ("AUDIO",)
- FUNCTION = "concat"
- CATEGORY = "audio"
- DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction."
-
- def concat(self, audio1, audio2, direction):
+ @classmethod
+ def execute(cls, audio1, audio2, direction) -> IO.NodeOutput:
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
@@ -477,26 +457,33 @@ class AudioConcat:
elif direction == 'before':
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
- return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},)
+ return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate})
+
+ concat = execute # TODO: remove
-class AudioMerge:
+class AudioMerge(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(cls):
- return {
- "required": {
- "audio1": ("AUDIO",),
- "audio2": ("AUDIO",),
- "merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}),
- },
- }
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="AudioMerge",
+ display_name="Audio Merge",
+ description="Combine two audio tracks by overlaying their waveforms.",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio1"),
+ IO.Audio.Input("audio2"),
+ IO.Combo.Input(
+ "merge_method",
+ options=["add", "mean", "subtract", "multiply"],
+ tooltip="The method used to combine the audio waveforms.",
+ )
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- FUNCTION = "merge"
- RETURN_TYPES = ("AUDIO",)
- CATEGORY = "audio"
- DESCRIPTION = "Combine two audio tracks by overlaying their waveforms."
-
- def merge(self, audio1, audio2, merge_method):
+ @classmethod
+ def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput:
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
@@ -530,85 +517,108 @@ class AudioMerge:
if max_val > 1.0:
waveform = waveform / max_val
- return ({"waveform": waveform, "sample_rate": output_sample_rate},)
+ return IO.NodeOutput({"waveform": waveform, "sample_rate": output_sample_rate})
+
+ merge = execute # TODO: remove
-class AudioAdjustVolume:
+class AudioAdjustVolume(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "audio": ("AUDIO",),
- "volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}),
- }}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="AudioAdjustVolume",
+ display_name="Audio Adjust Volume",
+ category="audio",
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.Int.Input(
+ "volume",
+ default=1,
+ min=-100,
+ max=100,
+ tooltip="Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc",
+ )
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- RETURN_TYPES = ("AUDIO",)
- FUNCTION = "adjust_volume"
- CATEGORY = "audio"
-
- def adjust_volume(self, audio, volume):
+ @classmethod
+ def execute(cls, audio, volume) -> IO.NodeOutput:
if volume == 0:
- return (audio,)
+ return IO.NodeOutput(audio)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
gain = 10 ** (volume / 20)
waveform = waveform * gain
- return ({"waveform": waveform, "sample_rate": sample_rate},)
+ return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
+
+ adjust_volume = execute # TODO: remove
-class EmptyAudio:
+class EmptyAudio(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}),
- "sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}),
- "channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}),
- }}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="EmptyAudio",
+ display_name="Empty Audio",
+ category="audio",
+ inputs=[
+ IO.Float.Input(
+ "duration",
+ default=60.0,
+ min=0.0,
+ max=0xffffffffffffffff,
+ step=0.01,
+ tooltip="Duration of the empty audio clip in seconds",
+ ),
+ IO.Float.Input(
+ "sample_rate",
+ default=44100,
+ tooltip="Sample rate of the empty audio clip.",
+ ),
+ IO.Float.Input(
+ "channels",
+ default=2,
+ min=1,
+ max=2,
+ tooltip="Number of audio channels (1 for mono, 2 for stereo).",
+ ),
+ ],
+ outputs=[IO.Audio.Output()],
+ )
- RETURN_TYPES = ("AUDIO",)
- FUNCTION = "create_empty_audio"
- CATEGORY = "audio"
-
- def create_empty_audio(self, duration, sample_rate, channels):
+ @classmethod
+ def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput:
num_samples = int(round(duration * sample_rate))
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
- return ({"waveform": waveform, "sample_rate": sample_rate},)
+ return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
+
+ create_empty_audio = execute # TODO: remove
-NODE_CLASS_MAPPINGS = {
- "EmptyLatentAudio": EmptyLatentAudio,
- "VAEEncodeAudio": VAEEncodeAudio,
- "VAEDecodeAudio": VAEDecodeAudio,
- "SaveAudio": SaveAudio,
- "SaveAudioMP3": SaveAudioMP3,
- "SaveAudioOpus": SaveAudioOpus,
- "LoadAudio": LoadAudio,
- "PreviewAudio": PreviewAudio,
- "ConditioningStableAudio": ConditioningStableAudio,
- "RecordAudio": RecordAudio,
- "TrimAudioDuration": TrimAudioDuration,
- "SplitAudioChannels": SplitAudioChannels,
- "AudioConcat": AudioConcat,
- "AudioMerge": AudioMerge,
- "AudioAdjustVolume": AudioAdjustVolume,
- "EmptyAudio": EmptyAudio,
-}
+class AudioExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
+ return [
+ EmptyLatentAudio,
+ VAEEncodeAudio,
+ VAEDecodeAudio,
+ SaveAudio,
+ SaveAudioMP3,
+ SaveAudioOpus,
+ LoadAudio,
+ PreviewAudio,
+ ConditioningStableAudio,
+ RecordAudio,
+ TrimAudioDuration,
+ SplitAudioChannels,
+ AudioConcat,
+ AudioMerge,
+ AudioAdjustVolume,
+ EmptyAudio,
+ ]
-NODE_DISPLAY_NAME_MAPPINGS = {
- "EmptyLatentAudio": "Empty Latent Audio",
- "VAEEncodeAudio": "VAE Encode Audio",
- "VAEDecodeAudio": "VAE Decode Audio",
- "PreviewAudio": "Preview Audio",
- "LoadAudio": "Load Audio",
- "SaveAudio": "Save Audio (FLAC)",
- "SaveAudioMP3": "Save Audio (MP3)",
- "SaveAudioOpus": "Save Audio (Opus)",
- "RecordAudio": "Record Audio",
- "TrimAudioDuration": "Trim Audio Duration",
- "SplitAudioChannels": "Split Audio Channels",
- "AudioConcat": "Audio Concat",
- "AudioMerge": "Audio Merge",
- "AudioAdjustVolume": "Audio Adjust Volume",
- "EmptyAudio": "Empty Audio",
-}
+async def comfy_entrypoint() -> AudioExtension:
+ return AudioExtension()
diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py
index 54c66ef68..545588ef8 100644
--- a/comfy_extras/nodes_load_3d.py
+++ b/comfy_extras/nodes_load_3d.py
@@ -2,22 +2,18 @@ import nodes
import folder_paths
import os
-from comfy.comfy_types import IO
-from comfy_api.input_impl import VideoFromFile
+from typing_extensions import override
+from comfy_api.latest import IO, ComfyExtension, InputImpl, UI
from pathlib import Path
-from PIL import Image
-import numpy as np
-
-import uuid
def normalize_path(path):
return path.replace('\\', '/')
-class Load3D():
+class Load3D(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
+ def define_schema(cls):
input_dir = os.path.join(folder_paths.get_input_directory(), "3d")
os.makedirs(input_dir, exist_ok=True)
@@ -30,23 +26,29 @@ class Load3D():
for file_path in input_path.rglob("*")
if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'}
]
+ return IO.Schema(
+ node_id="Load3D",
+ display_name="Load 3D & Animation",
+ category="3d",
+ is_experimental=True,
+ inputs=[
+ IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model),
+ IO.Load3D.Input("image"),
+ IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
+ IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
+ ],
+ outputs=[
+ IO.Image.Output(display_name="image"),
+ IO.Mask.Output(display_name="mask"),
+ IO.String.Output(display_name="mesh_path"),
+ IO.Image.Output(display_name="normal"),
+ IO.Load3DCamera.Output(display_name="camera_info"),
+ IO.Video.Output(display_name="recording_video"),
+ ],
+ )
- return {"required": {
- "model_file": (sorted(files), {"file_upload": True}),
- "image": ("LOAD_3D", {}),
- "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
- "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
- }}
-
- RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
- RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
-
- FUNCTION = "process"
- EXPERIMENTAL = True
-
- CATEGORY = "3d"
-
- def process(self, model_file, image, **kwargs):
+ @classmethod
+ def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput:
image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask'])
normal_path = folder_paths.get_annotated_filepath(image['normal'])
@@ -61,58 +63,47 @@ class Load3D():
if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
- video = VideoFromFile(recording_video_path)
+ video = InputImpl.VideoFromFile(recording_video_path)
- return output_image, output_mask, model_file, normal_image, image['camera_info'], video
+ return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video)
-class Preview3D():
+ process = execute # TODO: remove
+
+
+class Preview3D(IO.ComfyNode):
@classmethod
- def INPUT_TYPES(s):
- return {"required": {
- "model_file": ("STRING", {"default": "", "multiline": False}),
- },
- "optional": {
- "camera_info": ("LOAD3D_CAMERA", {}),
- "bg_image": ("IMAGE", {})
- }}
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="Preview3D",
+ display_name="Preview 3D & Animation",
+ category="3d",
+ is_experimental=True,
+ is_output_node=True,
+ inputs=[
+ IO.String.Input("model_file", default="", multiline=False),
+ IO.Load3DCamera.Input("camera_info", optional=True),
+ IO.Image.Input("bg_image", optional=True),
+ ],
+ outputs=[],
+ )
- OUTPUT_NODE = True
- RETURN_TYPES = ()
-
- CATEGORY = "3d"
-
- FUNCTION = "process"
- EXPERIMENTAL = True
-
- def process(self, model_file, **kwargs):
+ @classmethod
+ def execute(cls, model_file, **kwargs) -> IO.NodeOutput:
camera_info = kwargs.get("camera_info", None)
bg_image = kwargs.get("bg_image", None)
+ return IO.NodeOutput(ui=UI.PreviewUI3D(model_file, camera_info, bg_image=bg_image))
- bg_image_path = None
- if bg_image is not None:
+ process = execute # TODO: remove
- img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
- img = Image.fromarray(img_array)
- temp_dir = folder_paths.get_temp_directory()
- filename = f"bg_{uuid.uuid4().hex}.png"
- bg_image_path = os.path.join(temp_dir, filename)
- img.save(bg_image_path, compress_level=1)
+class Load3DExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
+ return [
+ Load3D,
+ Preview3D,
+ ]
- bg_image_path = f"temp/{filename}"
- return {
- "ui": {
- "result": [model_file, camera_info, bg_image_path]
- }
- }
-
-NODE_CLASS_MAPPINGS = {
- "Load3D": Load3D,
- "Preview3D": Preview3D,
-}
-
-NODE_DISPLAY_NAME_MAPPINGS = {
- "Load3D": "Load 3D & Animation",
- "Preview3D": "Preview 3D & Animation",
-}
+async def comfy_entrypoint() -> Load3DExtension:
+ return Load3DExtension()
diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py
new file mode 100644
index 000000000..95a6ba788
--- /dev/null
+++ b/comfy_extras/nodes_logic.py
@@ -0,0 +1,155 @@
+from typing import TypedDict
+from typing_extensions import override
+from comfy_api.latest import ComfyExtension, io
+from comfy_api.latest import _io
+
+
+
+class SwitchNode(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ template = io.MatchType.Template("switch")
+ return io.Schema(
+ node_id="ComfySwitchNode",
+ display_name="Switch",
+ category="logic",
+ is_experimental=True,
+ inputs=[
+ io.Boolean.Input("switch"),
+ io.MatchType.Input("on_false", template=template, lazy=True, optional=True),
+ io.MatchType.Input("on_true", template=template, lazy=True, optional=True),
+ ],
+ outputs=[
+ io.MatchType.Output(template=template, display_name="output"),
+ ],
+ )
+
+ @classmethod
+ def check_lazy_status(cls, switch, on_false=..., on_true=...):
+ # We use ... instead of None, as None is passed for connected-but-unevaluated inputs.
+ # This trick allows us to ignore the value of the switch and still be able to run execute().
+
+ # One of the inputs may be missing, in which case we need to evaluate the other input
+ if on_false is ...:
+ return ["on_true"]
+ if on_true is ...:
+ return ["on_false"]
+ # Normal lazy switch operation
+ if switch and on_true is None:
+ return ["on_true"]
+ if not switch and on_false is None:
+ return ["on_false"]
+
+ @classmethod
+ def validate_inputs(cls, switch, on_false=..., on_true=...):
+ # This check happens before check_lazy_status(), so we can eliminate the case where
+ # both inputs are missing.
+ if on_false is ... and on_true is ...:
+ return "At least one of on_false or on_true must be connected to Switch node"
+ return True
+
+ @classmethod
+ def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput:
+ if on_true is ...:
+ return io.NodeOutput(on_false)
+ if on_false is ...:
+ return io.NodeOutput(on_true)
+ return io.NodeOutput(on_true if switch else on_false)
+
+
+class DCTestNode(io.ComfyNode):
+ class DCValues(TypedDict):
+ combo: str
+ string: str
+ integer: int
+ image: io.Image.Type
+ subcombo: dict[str]
+
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="DCTestNode",
+ display_name="DCTest",
+ category="logic",
+ is_output_node=True,
+ inputs=[_io.DynamicCombo.Input("combo", options=[
+ _io.DynamicCombo.Option("option1", [io.String.Input("string")]),
+ _io.DynamicCombo.Option("option2", [io.Int.Input("integer")]),
+ _io.DynamicCombo.Option("option3", [io.Image.Input("image")]),
+ _io.DynamicCombo.Option("option4", [
+ _io.DynamicCombo.Input("subcombo", options=[
+ _io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]),
+ _io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]),
+ ])
+ ])]
+ )],
+ outputs=[io.AnyType.Output()],
+ )
+
+ @classmethod
+ def execute(cls, combo: DCValues) -> io.NodeOutput:
+ combo_val = combo["combo"]
+ if combo_val == "option1":
+ return io.NodeOutput(combo["string"])
+ elif combo_val == "option2":
+ return io.NodeOutput(combo["integer"])
+ elif combo_val == "option3":
+ return io.NodeOutput(combo["image"])
+ elif combo_val == "option4":
+ return io.NodeOutput(f"{combo['subcombo']}")
+ else:
+ raise ValueError(f"Invalid combo: {combo_val}")
+
+
+class AutogrowNamesTestNode(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ template = _io.Autogrow.TemplateNames(input=io.Float.Input("float"), names=["a", "b", "c"])
+ return io.Schema(
+ node_id="AutogrowNamesTestNode",
+ display_name="AutogrowNamesTest",
+ category="logic",
+ inputs=[
+ _io.Autogrow.Input("autogrow", template=template)
+ ],
+ outputs=[io.String.Output()],
+ )
+
+ @classmethod
+ def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
+ vals = list(autogrow.values())
+ combined = ",".join([str(x) for x in vals])
+ return io.NodeOutput(combined)
+
+class AutogrowPrefixTestNode(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ template = _io.Autogrow.TemplatePrefix(input=io.Float.Input("float"), prefix="float", min=1, max=10)
+ return io.Schema(
+ node_id="AutogrowPrefixTestNode",
+ display_name="AutogrowPrefixTest",
+ category="logic",
+ inputs=[
+ _io.Autogrow.Input("autogrow", template=template)
+ ],
+ outputs=[io.String.Output()],
+ )
+
+ @classmethod
+ def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput:
+ vals = list(autogrow.values())
+ combined = ",".join([str(x) for x in vals])
+ return io.NodeOutput(combined)
+
+class LogicExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
+ return [
+ # SwitchNode,
+ # DCTestNode,
+ # AutogrowNamesTestNode,
+ # AutogrowPrefixTestNode,
+ ]
+
+async def comfy_entrypoint() -> LogicExtension:
+ return LogicExtension()
diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py
index 783c59b6b..c61810dbf 100644
--- a/comfy_extras/nodes_model_patch.py
+++ b/comfy_extras/nodes_model_patch.py
@@ -6,6 +6,7 @@ import comfy.ops
import comfy.model_management
import comfy.ldm.common_dit
import comfy.latent_formats
+import comfy.ldm.lumina.controlnet
class BlockWiseControlBlock(torch.nn.Module):
@@ -189,6 +190,35 @@ class SigLIPMultiFeatProjModel(torch.nn.Module):
return embedding
+def z_image_convert(sd):
+ replace_keys = {".attention.to_out.0.bias": ".attention.out.bias",
+ ".attention.norm_k.weight": ".attention.k_norm.weight",
+ ".attention.norm_q.weight": ".attention.q_norm.weight",
+ ".attention.to_out.0.weight": ".attention.out.weight"
+ }
+
+ out_sd = {}
+ for k in sorted(sd.keys()):
+ w = sd[k]
+
+ k_out = k
+ if k_out.endswith(".attention.to_k.weight"):
+ cc = [w]
+ continue
+ if k_out.endswith(".attention.to_q.weight"):
+ cc = [w] + cc
+ continue
+ if k_out.endswith(".attention.to_v.weight"):
+ cc = cc + [w]
+ w = torch.cat(cc, dim=0)
+ k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight")
+
+ for r, rr in replace_keys.items():
+ k_out = k_out.replace(r, rr)
+ out_sd[k_out] = w
+
+ return out_sd
+
class ModelPatchLoader:
@classmethod
def INPUT_TYPES(s):
@@ -211,6 +241,9 @@ class ModelPatchLoader:
elif 'feature_embedder.mid_layer_norm.bias' in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True)
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
+ elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
+ sd = z_image_convert(sd)
+ model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
model.load_state_dict(sd)
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
@@ -263,6 +296,69 @@ class DiffSynthCnetPatch:
def models(self):
return [self.model_patch]
+class ZImageControlPatch:
+ def __init__(self, model_patch, vae, image, strength):
+ self.model_patch = model_patch
+ self.vae = vae
+ self.image = image
+ self.strength = strength
+ self.encoded_image = self.encode_latent_cond(image)
+ self.encoded_image_size = (image.shape[1], image.shape[2])
+ self.temp_data = None
+
+ def encode_latent_cond(self, image):
+ latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
+ return latent_image
+
+ def __call__(self, kwargs):
+ x = kwargs.get("x")
+ img = kwargs.get("img")
+ txt = kwargs.get("txt")
+ pe = kwargs.get("pe")
+ vec = kwargs.get("vec")
+ block_index = kwargs.get("block_index")
+ spacial_compression = self.vae.spacial_compression_encode()
+ if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
+ image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
+ loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
+ self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
+ self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
+ comfy.model_management.load_models_gpu(loaded_models)
+
+ cnet_index = (block_index // 5)
+ cnet_index_float = (block_index / 5)
+
+ kwargs.pop("img") # we do ops in place
+ kwargs.pop("txt")
+
+ cnet_blocks = self.model_patch.model.n_control_layers
+ if cnet_index_float > (cnet_blocks - 1):
+ self.temp_data = None
+ return kwargs
+
+ if self.temp_data is None or self.temp_data[0] > cnet_index:
+ self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
+
+ while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
+ next_layer = self.temp_data[0] + 1
+ self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
+
+ if cnet_index_float == self.temp_data[0]:
+ img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
+ if cnet_blocks == self.temp_data[0] + 1:
+ self.temp_data = None
+
+ return kwargs
+
+ def to(self, device_or_dtype):
+ if isinstance(device_or_dtype, torch.device):
+ self.encoded_image = self.encoded_image.to(device_or_dtype)
+ self.temp_data = None
+ return self
+
+ def models(self):
+ return [self.model_patch]
+
class QwenImageDiffsynthControlnet:
@classmethod
def INPUT_TYPES(s):
@@ -289,7 +385,10 @@ class QwenImageDiffsynthControlnet:
mask = mask.unsqueeze(2)
mask = 1.0 - mask
- model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
+ if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
+ model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
+ else:
+ model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
return (model_patched,)
diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py
index cb24ab709..19b8baaf4 100644
--- a/comfy_extras/nodes_train.py
+++ b/comfy_extras/nodes_train.py
@@ -623,7 +623,7 @@ class TrainLoraNode(io.ComfyNode):
noise = comfy_extras.nodes_custom_sampler.Noise_RandomNoise(seed)
if multi_res:
# use first latent as dummy latent if multi_res
- latents = latents[0].repeat(num_images, 1, 1, 1)
+ latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1)))
guider.sample(
noise.generate_noise({"samples": latents}),
latents,
diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py
index 69fabb12e..6cf6e39bf 100644
--- a/comfy_extras/nodes_video.py
+++ b/comfy_extras/nodes_video.py
@@ -88,7 +88,7 @@ class SaveVideo(io.ComfyNode):
)
@classmethod
- def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput:
+ def execute(cls, video: VideoInput, filename_prefix, format: str, codec) -> io.NodeOutput:
width, height = video.get_dimensions()
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(
filename_prefix,
@@ -108,7 +108,7 @@ class SaveVideo(io.ComfyNode):
file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}"
video.save_to(
os.path.join(full_output_folder, file),
- format=format,
+ format=VideoContainer(format),
codec=codec,
metadata=saved_metadata
)
diff --git a/comfyui_version.py b/comfyui_version.py
index fa4b4f4b0..4b039356e 100644
--- a/comfyui_version.py
+++ b/comfyui_version.py
@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
-__version__ = "0.3.75"
+__version__ = "0.3.76"
diff --git a/execution.py b/execution.py
index 17c77beab..c2186ac98 100644
--- a/execution.py
+++ b/execution.py
@@ -34,7 +34,7 @@ from comfy_execution.validation import validate_node_input
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
-from comfy_api.latest import io
+from comfy_api.latest import io, _io
class ExecutionResult(Enum):
@@ -76,7 +76,7 @@ class IsChangedCache:
return self.is_changed[node_id]
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
- input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None)
+ input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None)
try:
is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name)
is_changed = await resolve_map_node_over_list_results(is_changed)
@@ -146,8 +146,9 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org")
def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}):
is_v3 = issubclass(class_def, _ComfyNodeInternal)
+ v3_data: io.V3Data = {}
if is_v3:
- valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True)
+ valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
else:
valid_inputs = class_def.INPUT_TYPES()
input_data_all = {}
@@ -207,7 +208,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
if h[x] == "API_KEY_COMFY_ORG":
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
- return input_data_all, missing_keys, hidden_inputs_v3
+ v3_data["hidden_inputs"] = hidden_inputs_v3
+ return input_data_all, missing_keys, v3_data
map_node_over_list = None #Don't hook this please
@@ -223,7 +225,7 @@ async def resolve_map_node_over_list_results(results):
raise exc
return [x.result() if isinstance(x, asyncio.Task) else x for x in results]
-async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
+async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
# check if node wants the lists
input_is_list = getattr(obj, "INPUT_IS_LIST", False)
@@ -259,13 +261,16 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f
if is_class(obj):
type_obj = obj
obj.VALIDATE_CLASS()
- class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs)
+ class_clone = obj.PREPARE_CLASS_CLONE(v3_data)
# otherwise, use class instance to populate/reuse some fields
else:
type_obj = type(obj)
type_obj.VALIDATE_CLASS()
- class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs)
+ class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data)
f = make_locked_method_func(type_obj, func, class_clone)
+ # in case of dynamic inputs, restructure inputs to expected nested dict
+ if v3_data is not None:
+ inputs = _io.build_nested_inputs(inputs, v3_data)
# V1
else:
f = getattr(obj, func)
@@ -320,8 +325,8 @@ def merge_result_data(results, obj):
output.append([o[i] for o in results])
return output
-async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None):
- return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
+async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None):
+ return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values)
if has_pending_task:
return return_values, {}, False, has_pending_task
@@ -460,7 +465,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
has_subgraph = False
else:
get_progress_state().start_progress(unique_id)
- input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
+ input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data)
if server.client_id is not None:
server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
@@ -475,7 +480,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
else:
lazy_status_present = getattr(obj, "check_lazy_status", None) is not None
if lazy_status_present:
- required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs)
+ required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data)
required_inputs = await resolve_map_node_over_list_results(required_inputs)
required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], []))
required_inputs = [x for x in required_inputs if isinstance(x,str) and (
@@ -507,7 +512,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
def pre_execute_cb(call_index):
# TODO - How to handle this with async functions without contextvars (which requires Python 3.12)?
GraphBuilder.set_default_prefix(unique_id, call_index, 0)
- output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs)
+ output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
unblock = execution_list.add_external_block(unique_id)
@@ -745,18 +750,17 @@ async def validate_inputs(prompt_id, prompt, item, validated):
class_type = prompt[unique_id]['class_type']
obj_class = nodes.NODE_CLASS_MAPPINGS[class_type]
- class_inputs = obj_class.INPUT_TYPES()
- valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
-
errors = []
valid = True
validate_function_inputs = []
validate_has_kwargs = False
if issubclass(obj_class, _ComfyNodeInternal):
+ class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs)
validate_function_name = "validate_inputs"
validate_function = first_real_override(obj_class, validate_function_name)
else:
+ class_inputs = obj_class.INPUT_TYPES()
validate_function_name = "VALIDATE_INPUTS"
validate_function = getattr(obj_class, validate_function_name, None)
if validate_function is not None:
@@ -765,6 +769,8 @@ async def validate_inputs(prompt_id, prompt, item, validated):
validate_has_kwargs = argspec.varkw is not None
received_types = {}
+ valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{})))
+
for x in valid_inputs:
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
assert extra_info is not None
@@ -935,7 +941,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
continue
if len(validate_function_inputs) > 0 or validate_has_kwargs:
- input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id)
+ input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id)
input_filtered = {}
for x in input_data_all:
if x in validate_function_inputs or validate_has_kwargs:
@@ -943,7 +949,7 @@ async def validate_inputs(prompt_id, prompt, item, validated):
if 'input_types' in validate_function_inputs:
input_filtered['input_types'] = [received_types]
- ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs)
+ ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data)
ret = await resolve_map_node_over_list_results(ret)
for x in input_filtered:
for i, r in enumerate(ret):
diff --git a/main.py b/main.py
index e1b0f1620..0cd815d9e 100644
--- a/main.py
+++ b/main.py
@@ -15,6 +15,7 @@ from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
from comfy_api import feature_flags
+
if __name__ == "__main__":
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
@@ -22,6 +23,23 @@ if __name__ == "__main__":
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
+
+def handle_comfyui_manager_unavailable():
+ if not args.windows_standalone_build:
+ logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n")
+ args.enable_manager = False
+
+
+if args.enable_manager:
+ if importlib.util.find_spec("comfyui_manager"):
+ import comfyui_manager
+
+ if not comfyui_manager.__file__ or not comfyui_manager.__file__.endswith('__init__.py'):
+ handle_comfyui_manager_unavailable()
+ else:
+ handle_comfyui_manager_unavailable()
+
+
def apply_custom_paths():
# extra model paths
extra_model_paths_config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "extra_model_paths.yaml")
@@ -79,6 +97,11 @@ def execute_prestartup_script():
for possible_module in possible_modules:
module_path = os.path.join(custom_node_path, possible_module)
+
+ if args.enable_manager:
+ if comfyui_manager.should_be_disabled(module_path):
+ continue
+
if os.path.isfile(module_path) or module_path.endswith(".disabled") or module_path == "__pycache__":
continue
@@ -101,6 +124,10 @@ def execute_prestartup_script():
logging.info("")
apply_custom_paths()
+
+if args.enable_manager:
+ comfyui_manager.prestartup()
+
execute_prestartup_script()
@@ -323,6 +350,9 @@ def start_comfyui(asyncio_loop=None):
asyncio.set_event_loop(asyncio_loop)
prompt_server = server.PromptServer(asyncio_loop)
+ if args.enable_manager and not args.disable_manager_ui:
+ comfyui_manager.start()
+
hook_breaker_ac10a0.save_functions()
asyncio_loop.run_until_complete(nodes.init_extra_nodes(
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
diff --git a/manager_requirements.txt b/manager_requirements.txt
new file mode 100644
index 000000000..52cc5389c
--- /dev/null
+++ b/manager_requirements.txt
@@ -0,0 +1 @@
+comfyui_manager==4.0.3b3
diff --git a/nodes.py b/nodes.py
index 495dec806..356aa63df 100644
--- a/nodes.py
+++ b/nodes.py
@@ -43,6 +43,9 @@ import folder_paths
import latent_preview
import node_helpers
+if args.enable_manager:
+ import comfyui_manager
+
def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted()
@@ -939,7 +942,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
- "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2"], ),
+ "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"], ),
},
"optional": {
"device": (["default", "cpu"], {"advanced": True}),
@@ -2243,6 +2246,12 @@ async def init_external_custom_nodes():
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
continue
+
+ if args.enable_manager:
+ if comfyui_manager.should_be_disabled(module_path):
+ logging.info(f"Blocked by policy: {module_path}")
+ continue
+
time_before = time.perf_counter()
success = await load_custom_node(module_path, base_node_names, module_parent="custom_nodes")
node_import_times.append((time.perf_counter() - time_before, module_path, success))
@@ -2346,6 +2355,7 @@ async def init_builtin_extra_nodes():
"nodes_easycache.py",
"nodes_audio_encoder.py",
"nodes_rope.py",
+ "nodes_logic.py",
"nodes_nop.py",
]
diff --git a/pyproject.toml b/pyproject.toml
index 9009e65fe..02b94a0ce 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
-version = "0.3.75"
+version = "0.3.76"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"
diff --git a/requirements.txt b/requirements.txt
index 386477808..f98848e20 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-comfyui-frontend-package==1.32.9
+comfyui-frontend-package==1.33.10
comfyui-workflow-templates==0.7.25
comfyui-embedded-docs==0.3.1
torch
diff --git a/server.py b/server.py
index fca5050bd..ac4f42222 100644
--- a/server.py
+++ b/server.py
@@ -44,6 +44,9 @@ from protocol import BinaryEventTypes
# Import cache control middleware
from middleware.cache_middleware import cache_control
+if args.enable_manager:
+ import comfyui_manager
+
async def send_socket_catch_exception(function, message):
try:
await function(message)
@@ -95,7 +98,7 @@ def create_cors_middleware(allowed_origin: str):
response = await handler(request)
response.headers['Access-Control-Allow-Origin'] = allowed_origin
- response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS'
+ response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS, PATCH'
response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization'
response.headers['Access-Control-Allow-Credentials'] = 'true'
return response
@@ -212,6 +215,9 @@ class PromptServer():
if args.disable_api_nodes:
middlewares.append(create_block_external_middleware())
+ if args.enable_manager:
+ middlewares.append(comfyui_manager.create_middleware())
+
max_upload_size = round(args.max_upload_size * 1024 * 1024)
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
self.sockets = dict()
@@ -599,7 +605,7 @@ class PromptServer():
system_stats = {
"system": {
- "os": os.name,
+ "os": sys.platform,
"ram_total": ram_total,
"ram_free": ram_free,
"comfyui_version": __version__,