mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-13 18:47:29 +08:00
Merge branch 'Comfy-Org:master' into master
This commit is contained in:
commit
fc06fb6cd9
1
.gitignore
vendored
1
.gitignore
vendored
@ -23,3 +23,4 @@ web_custom_versions/
|
|||||||
.DS_Store
|
.DS_Store
|
||||||
filtered-openapi.yaml
|
filtered-openapi.yaml
|
||||||
uv.lock
|
uv.lock
|
||||||
|
.comfy_environment
|
||||||
|
|||||||
@ -133,7 +133,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
|||||||
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
ComfyUI follows a weekly release cycle targeting Monday but this regularly changes because of model releases or large changes to the codebase. There are three interconnected repositories:
|
||||||
|
|
||||||
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
|
||||||
- Releases a new stable version (e.g., v0.7.0) roughly every week.
|
- Releases a new major stable version (e.g., v0.7.0) roughly every 2 weeks.
|
||||||
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
|
- Starting from v0.4.0 patch versions will be used for fixes backported onto the current stable release.
|
||||||
- Minor versions will be used for releases off the master branch.
|
- Minor versions will be used for releases off the master branch.
|
||||||
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
|
- Patch versions may still be used for releases on the master branch in cases where a backport would not make sense.
|
||||||
|
|||||||
@ -238,6 +238,8 @@ database_default_path = os.path.abspath(
|
|||||||
)
|
)
|
||||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||||
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
|
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
|
||||||
|
parser.add_argument("--feature-flag", type=str, action='append', default=[], metavar="KEY[=VALUE]", help="Set a server feature flag. Use KEY=VALUE to set an explicit value, or bare KEY to set it to true. Can be specified multiple times. Boolean values (true/false) and numbers are auto-converted. Examples: --feature-flag show_signin_button=true or --feature-flag show_signin_button")
|
||||||
|
parser.add_argument("--list-feature-flags", action="store_true", help="Print the registry of known CLI-settable feature flags as JSON and exit.")
|
||||||
|
|
||||||
if comfy.options.args_parsing:
|
if comfy.options.args_parsing:
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|||||||
34
comfy/deploy_environment.py
Normal file
34
comfy/deploy_environment.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_DEFAULT_DEPLOY_ENV = "local-git"
|
||||||
|
_ENV_FILENAME = ".comfy_environment"
|
||||||
|
|
||||||
|
# Resolve the ComfyUI install directory (the parent of this `comfy/` package).
|
||||||
|
# We deliberately avoid `folder_paths.base_path` here because that is overridden
|
||||||
|
# by the `--base-directory` CLI arg to a user-supplied path, whereas the
|
||||||
|
# `.comfy_environment` marker is written by launchers/installers next to the
|
||||||
|
# ComfyUI install itself.
|
||||||
|
_COMFY_INSTALL_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def get_deploy_environment() -> str:
|
||||||
|
env_file = os.path.join(_COMFY_INSTALL_DIR, _ENV_FILENAME)
|
||||||
|
try:
|
||||||
|
with open(env_file, encoding="utf-8") as f:
|
||||||
|
# Cap the read so a malformed or maliciously crafted file (e.g.
|
||||||
|
# a single huge line with no newline) can't blow up memory.
|
||||||
|
first_line = f.readline(128).strip()
|
||||||
|
value = "".join(c for c in first_line if 32 <= ord(c) < 127)
|
||||||
|
if value:
|
||||||
|
return value
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to read %s: %s", env_file, e)
|
||||||
|
|
||||||
|
return _DEFAULT_DEPLOY_ENV
|
||||||
@ -1810,3 +1810,102 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
|||||||
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
|
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
|
||||||
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
|
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
|
||||||
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
|
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||||
|
num_frame_per_block=1):
|
||||||
|
"""
|
||||||
|
Autoregressive video sampler: block-by-block denoising with KV cache
|
||||||
|
and flow-match re-noising for Causal Forcing / Self-Forcing models.
|
||||||
|
|
||||||
|
Requires a Causal-WAN compatible model (diffusion_model must expose
|
||||||
|
init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W].
|
||||||
|
|
||||||
|
All AR-loop parameters are passed via the SamplerARVideo node, not read
|
||||||
|
from the checkpoint or transformer_options.
|
||||||
|
"""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
model_options = extra_args.get("model_options", {})
|
||||||
|
transformer_options = model_options.get("transformer_options", {})
|
||||||
|
|
||||||
|
if x.ndim != 5:
|
||||||
|
raise ValueError(
|
||||||
|
f"ar_video sampler requires 5-D video latents [B,C,T,H,W], got {x.ndim}-D tensor with shape {x.shape}. "
|
||||||
|
"This sampler is only compatible with autoregressive video models (e.g. Causal-WAN)."
|
||||||
|
)
|
||||||
|
|
||||||
|
inner_model = model.inner_model.inner_model
|
||||||
|
causal_model = inner_model.diffusion_model
|
||||||
|
|
||||||
|
if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")):
|
||||||
|
raise TypeError(
|
||||||
|
"ar_video sampler requires a Causal-WAN compatible model whose diffusion_model "
|
||||||
|
"exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint "
|
||||||
|
"does not support this interface — choose a different sampler."
|
||||||
|
)
|
||||||
|
|
||||||
|
seed = extra_args.get("seed", 0)
|
||||||
|
|
||||||
|
bs, c, lat_t, lat_h, lat_w = x.shape
|
||||||
|
frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division
|
||||||
|
num_blocks = -(-lat_t // num_frame_per_block) # ceiling division
|
||||||
|
device = x.device
|
||||||
|
model_dtype = inner_model.get_dtype()
|
||||||
|
|
||||||
|
kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype)
|
||||||
|
crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype)
|
||||||
|
|
||||||
|
output = torch.zeros_like(x)
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
current_start_frame = 0
|
||||||
|
num_sigma_steps = len(sigmas) - 1
|
||||||
|
total_real_steps = num_blocks * num_sigma_steps
|
||||||
|
step_count = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
for block_idx in trange(num_blocks, disable=disable):
|
||||||
|
bf = min(num_frame_per_block, lat_t - current_start_frame)
|
||||||
|
fs, fe = current_start_frame, current_start_frame + bf
|
||||||
|
noisy_input = x[:, :, fs:fe]
|
||||||
|
|
||||||
|
ar_state = {
|
||||||
|
"start_frame": current_start_frame,
|
||||||
|
"kv_caches": kv_caches,
|
||||||
|
"crossattn_caches": crossattn_caches,
|
||||||
|
}
|
||||||
|
transformer_options["ar_state"] = ar_state
|
||||||
|
|
||||||
|
for i in range(num_sigma_steps):
|
||||||
|
denoised = model(noisy_input, sigmas[i] * s_in, **extra_args)
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
scaled_i = step_count * num_sigma_steps // total_real_steps
|
||||||
|
callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i],
|
||||||
|
"sigma_hat": sigmas[i], "denoised": denoised})
|
||||||
|
|
||||||
|
if sigmas[i + 1] == 0:
|
||||||
|
noisy_input = denoised
|
||||||
|
else:
|
||||||
|
sigma_next = sigmas[i + 1]
|
||||||
|
torch.manual_seed(seed + block_idx * 1000 + i)
|
||||||
|
fresh_noise = torch.randn_like(denoised)
|
||||||
|
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
|
||||||
|
|
||||||
|
for cache in kv_caches:
|
||||||
|
cache["end"] -= bf * frame_seq_len
|
||||||
|
|
||||||
|
step_count += 1
|
||||||
|
|
||||||
|
output[:, :, fs:fe] = noisy_input
|
||||||
|
|
||||||
|
for cache in kv_caches:
|
||||||
|
cache["end"] -= bf * frame_seq_len
|
||||||
|
zero_sigma = sigmas.new_zeros([1])
|
||||||
|
_ = model(noisy_input, zero_sigma * s_in, **extra_args)
|
||||||
|
|
||||||
|
current_start_frame += bf
|
||||||
|
finally:
|
||||||
|
transformer_options.pop("ar_state", None)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|||||||
276
comfy/ldm/wan/ar_model.py
Normal file
276
comfy/ldm/wan/ar_model.py
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
"""
|
||||||
|
CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for
|
||||||
|
autoregressive (frame-by-frame) video generation via Causal Forcing.
|
||||||
|
|
||||||
|
Weight-compatible with the standard WanModel -- same layer names, same shapes.
|
||||||
|
The difference is purely in the forward pass: this model processes one temporal
|
||||||
|
block at a time and maintains a KV cache across blocks.
|
||||||
|
|
||||||
|
Reference: https://github.com/thu-ml/Causal-Forcing
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
from comfy.ldm.flux.math import apply_rope1
|
||||||
|
from comfy.ldm.wan.model import (
|
||||||
|
sinusoidal_embedding_1d,
|
||||||
|
repeat_e,
|
||||||
|
WanModel,
|
||||||
|
WanAttentionBlock,
|
||||||
|
)
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
import comfy.model_management
|
||||||
|
|
||||||
|
|
||||||
|
class CausalWanSelfAttention(nn.Module):
|
||||||
|
"""Self-attention with KV cache support for autoregressive inference."""
|
||||||
|
|
||||||
|
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True,
|
||||||
|
eps=1e-6, operation_settings={}):
|
||||||
|
assert dim % num_heads == 0
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.qk_norm = qk_norm
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
ops = operation_settings.get("operations")
|
||||||
|
device = operation_settings.get("device")
|
||||||
|
dtype = operation_settings.get("dtype")
|
||||||
|
|
||||||
|
self.q = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||||
|
self.k = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||||
|
self.v = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||||
|
self.o = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||||
|
self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
|
||||||
|
self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, freqs, kv_cache=None, transformer_options={}):
|
||||||
|
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||||
|
|
||||||
|
q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs)
|
||||||
|
k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs)
|
||||||
|
v = self.v(x).view(b, s, n, d)
|
||||||
|
|
||||||
|
if kv_cache is None:
|
||||||
|
x = optimized_attention(
|
||||||
|
q.view(b, s, n * d),
|
||||||
|
k.view(b, s, n * d),
|
||||||
|
v.view(b, s, n * d),
|
||||||
|
heads=self.num_heads,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
end = kv_cache["end"]
|
||||||
|
new_end = end + s
|
||||||
|
|
||||||
|
# Roped K and plain V go into cache
|
||||||
|
kv_cache["k"][:, end:new_end] = k
|
||||||
|
kv_cache["v"][:, end:new_end] = v
|
||||||
|
kv_cache["end"] = new_end
|
||||||
|
|
||||||
|
x = optimized_attention(
|
||||||
|
q.view(b, s, n * d),
|
||||||
|
kv_cache["k"][:, :new_end].view(b, new_end, n * d),
|
||||||
|
kv_cache["v"][:, :new_end].view(b, new_end, n * d),
|
||||||
|
heads=self.num_heads,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = self.o(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CausalWanAttentionBlock(WanAttentionBlock):
|
||||||
|
"""Transformer block with KV-cached self-attention and cross-attention caching."""
|
||||||
|
|
||||||
|
def __init__(self, cross_attn_type, dim, ffn_dim, num_heads,
|
||||||
|
window_size=(-1, -1), qk_norm=True, cross_attn_norm=False,
|
||||||
|
eps=1e-6, operation_settings={}):
|
||||||
|
super().__init__(cross_attn_type, dim, ffn_dim, num_heads,
|
||||||
|
window_size, qk_norm, cross_attn_norm, eps,
|
||||||
|
operation_settings=operation_settings)
|
||||||
|
self.self_attn = CausalWanSelfAttention(
|
||||||
|
dim, num_heads, window_size, qk_norm, eps,
|
||||||
|
operation_settings=operation_settings)
|
||||||
|
|
||||||
|
def forward(self, x, e, freqs, context, context_img_len=257,
|
||||||
|
kv_cache=None, crossattn_cache=None, transformer_options={}):
|
||||||
|
if e.ndim < 4:
|
||||||
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||||
|
else:
|
||||||
|
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
|
||||||
|
|
||||||
|
# Self-attention with optional KV cache
|
||||||
|
x = x.contiguous()
|
||||||
|
y = self.self_attn(
|
||||||
|
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||||
|
freqs, kv_cache=kv_cache, transformer_options=transformer_options)
|
||||||
|
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||||
|
del y
|
||||||
|
|
||||||
|
# Cross-attention with optional caching
|
||||||
|
if crossattn_cache is not None and crossattn_cache.get("is_init"):
|
||||||
|
q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x)))
|
||||||
|
x_ca = optimized_attention(
|
||||||
|
q, crossattn_cache["k"], crossattn_cache["v"],
|
||||||
|
heads=self.num_heads, transformer_options=transformer_options)
|
||||||
|
x = x + self.cross_attn.o(x_ca)
|
||||||
|
else:
|
||||||
|
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||||
|
if crossattn_cache is not None:
|
||||||
|
crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context))
|
||||||
|
crossattn_cache["v"] = self.cross_attn.v(context)
|
||||||
|
crossattn_cache["is_init"] = True
|
||||||
|
|
||||||
|
# FFN
|
||||||
|
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||||
|
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CausalWanModel(WanModel):
|
||||||
|
"""
|
||||||
|
Wan 2.1 diffusion backbone with causal KV-cache support.
|
||||||
|
|
||||||
|
Same weight structure as WanModel -- loads identical state dicts.
|
||||||
|
Adds forward_block() for frame-by-frame autoregressive inference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_type='t2v',
|
||||||
|
patch_size=(1, 2, 2),
|
||||||
|
text_len=512,
|
||||||
|
in_dim=16,
|
||||||
|
dim=2048,
|
||||||
|
ffn_dim=8192,
|
||||||
|
freq_dim=256,
|
||||||
|
text_dim=4096,
|
||||||
|
out_dim=16,
|
||||||
|
num_heads=16,
|
||||||
|
num_layers=32,
|
||||||
|
window_size=(-1, -1),
|
||||||
|
qk_norm=True,
|
||||||
|
cross_attn_norm=True,
|
||||||
|
eps=1e-6,
|
||||||
|
image_model=None,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
operations=None):
|
||||||
|
super().__init__(
|
||||||
|
model_type=model_type, patch_size=patch_size, text_len=text_len,
|
||||||
|
in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim,
|
||||||
|
text_dim=text_dim, out_dim=out_dim, num_heads=num_heads,
|
||||||
|
num_layers=num_layers, window_size=window_size, qk_norm=qk_norm,
|
||||||
|
cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model,
|
||||||
|
wan_attn_block_class=CausalWanAttentionBlock,
|
||||||
|
device=device, dtype=dtype, operations=operations)
|
||||||
|
|
||||||
|
def forward_block(self, x, timestep, context, start_frame,
|
||||||
|
kv_caches, crossattn_caches, clip_fea=None):
|
||||||
|
"""
|
||||||
|
Forward one temporal block for autoregressive inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: [B, C, block_frames, H, W] input latent for the current block
|
||||||
|
timestep: [B, block_frames] per-frame timesteps
|
||||||
|
context: [B, L, text_dim] raw text embeddings (pre-text_embedding)
|
||||||
|
start_frame: temporal frame index for RoPE offset
|
||||||
|
kv_caches: list of per-layer KV cache dicts
|
||||||
|
crossattn_caches: list of per-layer cross-attention cache dicts
|
||||||
|
clip_fea: optional CLIP features for I2V
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
flow_pred: [B, C_out, block_frames, H, W] flow prediction
|
||||||
|
"""
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
|
bs, c, t, h, w = x.shape
|
||||||
|
|
||||||
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
|
grid_sizes = x.shape[2:]
|
||||||
|
x = x.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
# Per-frame time embedding
|
||||||
|
e = self.time_embedding(
|
||||||
|
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype))
|
||||||
|
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
|
||||||
|
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||||
|
|
||||||
|
# Text embedding (reuses crossattn_cache after first block)
|
||||||
|
context = self.text_embedding(context)
|
||||||
|
|
||||||
|
context_img_len = None
|
||||||
|
if clip_fea is not None and self.img_emb is not None:
|
||||||
|
context_clip = self.img_emb(clip_fea)
|
||||||
|
context = torch.concat([context_clip, context], dim=1)
|
||||||
|
context_img_len = clip_fea.shape[-2]
|
||||||
|
|
||||||
|
# RoPE for current block's temporal position
|
||||||
|
freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
# Transformer blocks
|
||||||
|
for i, block in enumerate(self.blocks):
|
||||||
|
x = block(x, e=e0, freqs=freqs, context=context,
|
||||||
|
context_img_len=context_img_len,
|
||||||
|
kv_cache=kv_caches[i],
|
||||||
|
crossattn_cache=crossattn_caches[i])
|
||||||
|
|
||||||
|
# Head
|
||||||
|
x = self.head(x, e)
|
||||||
|
|
||||||
|
# Unpatchify
|
||||||
|
x = self.unpatchify(x, grid_sizes)
|
||||||
|
return x[:, :, :t, :h, :w]
|
||||||
|
|
||||||
|
def init_kv_caches(self, batch_size, max_seq_len, device, dtype):
|
||||||
|
"""Create fresh KV caches for all layers."""
|
||||||
|
caches = []
|
||||||
|
for _ in range(self.num_layers):
|
||||||
|
caches.append({
|
||||||
|
"k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
|
||||||
|
"v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
|
||||||
|
"end": 0,
|
||||||
|
})
|
||||||
|
return caches
|
||||||
|
|
||||||
|
def init_crossattn_caches(self, batch_size, device, dtype):
|
||||||
|
"""Create fresh cross-attention caches for all layers."""
|
||||||
|
caches = []
|
||||||
|
for _ in range(self.num_layers):
|
||||||
|
caches.append({"is_init": False})
|
||||||
|
return caches
|
||||||
|
|
||||||
|
def reset_kv_caches(self, kv_caches):
|
||||||
|
"""Reset KV caches to empty (reuse allocated memory)."""
|
||||||
|
for cache in kv_caches:
|
||||||
|
cache["end"] = 0
|
||||||
|
|
||||||
|
def reset_crossattn_caches(self, crossattn_caches):
|
||||||
|
"""Reset cross-attention caches."""
|
||||||
|
for cache in crossattn_caches:
|
||||||
|
cache["is_init"] = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def head_dim(self):
|
||||||
|
return self.dim // self.num_heads
|
||||||
|
|
||||||
|
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||||
|
ar_state = transformer_options.get("ar_state")
|
||||||
|
if ar_state is not None:
|
||||||
|
bs = x.shape[0]
|
||||||
|
block_frames = x.shape[2]
|
||||||
|
t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames)
|
||||||
|
return self.forward_block(
|
||||||
|
x=x, timestep=t_per_frame, context=context,
|
||||||
|
start_frame=ar_state["start_frame"],
|
||||||
|
kv_caches=ar_state["kv_caches"],
|
||||||
|
crossattn_caches=ar_state["crossattn_caches"],
|
||||||
|
clip_fea=clip_fea,
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().forward(x, timestep, context, clip_fea=clip_fea,
|
||||||
|
time_dim_concat=time_dim_concat,
|
||||||
|
transformer_options=transformer_options, **kwargs)
|
||||||
@ -43,6 +43,7 @@ import comfy.ldm.lumina.model
|
|||||||
import comfy.ldm.twinflow.model
|
import comfy.ldm.twinflow.model
|
||||||
import comfy.ldm.wan.model
|
import comfy.ldm.wan.model
|
||||||
import comfy.ldm.wan.model_animate
|
import comfy.ldm.wan.model_animate
|
||||||
|
import comfy.ldm.wan.ar_model
|
||||||
import comfy.ldm.hunyuan3d.model
|
import comfy.ldm.hunyuan3d.model
|
||||||
import comfy.ldm.hidream.model
|
import comfy.ldm.hidream.model
|
||||||
import comfy.ldm.chroma.model
|
import comfy.ldm.chroma.model
|
||||||
@ -1371,6 +1372,13 @@ class WAN21(BaseModel):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class WAN21_CausalAR(WAN21):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super(WAN21, self).__init__(model_config, model_type, device=device,
|
||||||
|
unet_model=comfy.ldm.wan.ar_model.CausalWanModel)
|
||||||
|
self.image_to_video = False
|
||||||
|
|
||||||
|
|
||||||
class WAN21_Vace(WAN21):
|
class WAN21_Vace(WAN21):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
|
||||||
|
|||||||
@ -721,13 +721,15 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
else:
|
else:
|
||||||
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
||||||
|
|
||||||
models_temp = set()
|
# Order-preserving dedup. A plain set() would randomize iteration order across runs
|
||||||
|
models_temp = {}
|
||||||
for m in models:
|
for m in models:
|
||||||
models_temp.add(m)
|
models_temp[m] = None
|
||||||
for mm in m.model_patches_models():
|
for mm in m.model_patches_models():
|
||||||
models_temp.add(mm)
|
models_temp[mm] = None
|
||||||
|
|
||||||
models = models_temp
|
models = list(models_temp)
|
||||||
|
models.reverse()
|
||||||
|
|
||||||
models_to_load = []
|
models_to_load = []
|
||||||
|
|
||||||
|
|||||||
@ -253,6 +253,9 @@ def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, w
|
|||||||
if bias is not None:
|
if bias is not None:
|
||||||
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
|
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
|
||||||
|
|
||||||
|
if prefetch["signature"] is not None:
|
||||||
|
prefetch["resident"] = True
|
||||||
|
|
||||||
return weight, bias
|
return weight, bias
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -89,7 +89,8 @@ def get_additional_models(conds, dtype):
|
|||||||
gligen += get_models_from_cond(conds[k], "gligen")
|
gligen += get_models_from_cond(conds[k], "gligen")
|
||||||
add_models += get_models_from_cond(conds[k], "additional_models")
|
add_models += get_models_from_cond(conds[k], "additional_models")
|
||||||
|
|
||||||
control_nets = set(cnets)
|
# Order-preserving dedup. A plain set() would randomize iteration order across runs
|
||||||
|
control_nets = list(dict.fromkeys(cnets))
|
||||||
|
|
||||||
inference_memory = 0
|
inference_memory = 0
|
||||||
control_models = []
|
control_models = []
|
||||||
|
|||||||
@ -1176,6 +1176,25 @@ class WAN21_T2V(supported_models_base.BASE):
|
|||||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
|
||||||
|
|
||||||
|
class WAN21_CausalAR_T2V(WAN21_T2V):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "t2v",
|
||||||
|
"causal_ar": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 5.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, unet_config):
|
||||||
|
super().__init__(unet_config)
|
||||||
|
self.unet_config.pop("causal_ar", None)
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
return model_base.WAN21_CausalAR(self, device=device)
|
||||||
|
|
||||||
|
|
||||||
class WAN21_I2V(WAN21_T2V):
|
class WAN21_I2V(WAN21_T2V):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "wan2.1",
|
"image_model": "wan2.1",
|
||||||
@ -1939,6 +1958,7 @@ models = [
|
|||||||
TwinFlow_Z_Image,
|
TwinFlow_Z_Image,
|
||||||
Lumina2,
|
Lumina2,
|
||||||
WAN22_T2V,
|
WAN22_T2V,
|
||||||
|
WAN21_CausalAR_T2V,
|
||||||
WAN21_T2V,
|
WAN21_T2V,
|
||||||
WAN21_I2V,
|
WAN21_I2V,
|
||||||
WAN21_FunControl2V,
|
WAN21_FunControl2V,
|
||||||
|
|||||||
@ -5,12 +5,95 @@ This module handles capability negotiation between frontend and backend,
|
|||||||
allowing graceful protocol evolution while maintaining backward compatibility.
|
allowing graceful protocol evolution while maintaining backward compatibility.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any
|
import logging
|
||||||
|
from typing import Any, TypedDict
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureFlagInfo(TypedDict):
|
||||||
|
type: str
|
||||||
|
default: Any
|
||||||
|
description: str
|
||||||
|
|
||||||
|
|
||||||
|
# Registry of known CLI-settable feature flags.
|
||||||
|
# Launchers can query this via --list-feature-flags to discover valid flags.
|
||||||
|
CLI_FEATURE_FLAG_REGISTRY: dict[str, FeatureFlagInfo] = {
|
||||||
|
"show_signin_button": {
|
||||||
|
"type": "bool",
|
||||||
|
"default": False,
|
||||||
|
"description": "Show the sign-in button in the frontend even when not signed in",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_bool(v: str) -> bool:
|
||||||
|
"""Strict bool coercion: only 'true'/'false' (case-insensitive).
|
||||||
|
|
||||||
|
Anything else raises ValueError so the caller can warn and drop the flag,
|
||||||
|
rather than silently treating typos like 'ture' or 'yes' as False.
|
||||||
|
"""
|
||||||
|
lower = v.lower()
|
||||||
|
if lower == "true":
|
||||||
|
return True
|
||||||
|
if lower == "false":
|
||||||
|
return False
|
||||||
|
raise ValueError(f"expected 'true' or 'false', got {v!r}")
|
||||||
|
|
||||||
|
|
||||||
|
_COERCE_FNS: dict[str, Any] = {
|
||||||
|
"bool": _coerce_bool,
|
||||||
|
"int": lambda v: int(v),
|
||||||
|
"float": lambda v: float(v),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_flag_value(key: str, raw_value: str) -> Any:
|
||||||
|
"""Coerce a raw string value using the registry type, or keep as string.
|
||||||
|
|
||||||
|
Returns the raw string if the key is unregistered or the type is unknown.
|
||||||
|
Raises ValueError/TypeError if the key is registered with a known type but
|
||||||
|
the value cannot be coerced; callers are expected to warn and drop the flag.
|
||||||
|
"""
|
||||||
|
info = CLI_FEATURE_FLAG_REGISTRY.get(key)
|
||||||
|
if info is None:
|
||||||
|
return raw_value
|
||||||
|
coerce = _COERCE_FNS.get(info["type"])
|
||||||
|
if coerce is None:
|
||||||
|
return raw_value
|
||||||
|
return coerce(raw_value)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_cli_feature_flags() -> dict[str, Any]:
|
||||||
|
"""Parse --feature-flag key=value pairs from CLI args into a dict.
|
||||||
|
|
||||||
|
Items without '=' default to the value 'true' (bare flag form).
|
||||||
|
Flags whose value cannot be coerced to the registered type are dropped
|
||||||
|
with a warning, so a typo like '--feature-flag some_bool=ture' does not
|
||||||
|
silently take effect as the wrong value.
|
||||||
|
"""
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
for item in getattr(args, "feature_flag", []):
|
||||||
|
key, sep, raw_value = item.partition("=")
|
||||||
|
key = key.strip()
|
||||||
|
if not key:
|
||||||
|
continue
|
||||||
|
if not sep:
|
||||||
|
raw_value = "true"
|
||||||
|
try:
|
||||||
|
result[key] = _coerce_flag_value(key, raw_value.strip())
|
||||||
|
except (ValueError, TypeError) as e:
|
||||||
|
info = CLI_FEATURE_FLAG_REGISTRY.get(key, {})
|
||||||
|
logging.warning(
|
||||||
|
"Could not coerce --feature-flag %s=%r to %s (%s); dropping flag.",
|
||||||
|
key, raw_value.strip(), info.get("type", "?"), e,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
# Default server capabilities
|
# Default server capabilities
|
||||||
SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
_CORE_FEATURE_FLAGS: dict[str, Any] = {
|
||||||
"supports_preview_metadata": True,
|
"supports_preview_metadata": True,
|
||||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||||
"extension": {"manager": {"supports_v4": True}},
|
"extension": {"manager": {"supports_v4": True}},
|
||||||
@ -18,6 +101,11 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
|||||||
"assets": args.enable_assets,
|
"assets": args.enable_assets,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# CLI-provided flags cannot overwrite core flags
|
||||||
|
_cli_flags = {k: v for k, v in _parse_cli_feature_flags().items() if k not in _CORE_FEATURE_FLAGS}
|
||||||
|
|
||||||
|
SERVER_FEATURE_FLAGS: dict[str, Any] = {**_CORE_FEATURE_FLAGS, **_cli_flags}
|
||||||
|
|
||||||
|
|
||||||
def get_connection_feature(
|
def get_connection_feature(
|
||||||
sockets_metadata: dict[str, dict[str, Any]],
|
sockets_metadata: dict[str, dict[str, Any]],
|
||||||
|
|||||||
@ -33,7 +33,7 @@ class OpenAIVideoSora2(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="OpenAIVideoSora2",
|
node_id="OpenAIVideoSora2",
|
||||||
display_name="OpenAI Sora - Video (Deprecated)",
|
display_name="OpenAI Sora - Video (DEPRECATED)",
|
||||||
category="api node/video/Sora",
|
category="api node/video/Sora",
|
||||||
description=(
|
description=(
|
||||||
"OpenAI video and audio generation.\n\n"
|
"OpenAI video and audio generation.\n\n"
|
||||||
|
|||||||
@ -19,6 +19,8 @@ from comfy import utils
|
|||||||
from comfy_api.latest import IO
|
from comfy_api.latest import IO
|
||||||
from server import PromptServer
|
from server import PromptServer
|
||||||
|
|
||||||
|
from comfy.deploy_environment import get_deploy_environment
|
||||||
|
|
||||||
from . import request_logger
|
from . import request_logger
|
||||||
from ._helpers import (
|
from ._helpers import (
|
||||||
default_base_url,
|
default_base_url,
|
||||||
@ -624,6 +626,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
|||||||
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||||
|
payload_headers["Comfy-Env"] = get_deploy_environment()
|
||||||
if cfg.endpoint.headers:
|
if cfg.endpoint.headers:
|
||||||
payload_headers.update(cfg.endpoint.headers)
|
payload_headers.update(cfg.endpoint.headers)
|
||||||
|
|
||||||
|
|||||||
84
comfy_extras/nodes_ar_video.py
Normal file
84
comfy_extras/nodes_ar_video.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
"""
|
||||||
|
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
|
||||||
|
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
|
||||||
|
- SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.samplers
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyARVideoLatent(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="EmptyARVideoLatent",
|
||||||
|
category="latent/video",
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
||||||
|
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
||||||
|
io.Int.Input("length", default=81, min=1, max=1024, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=64),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Latent.Output(display_name="LATENT"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, width, height, length, batch_size) -> io.NodeOutput:
|
||||||
|
lat_t = ((length - 1) // 4) + 1
|
||||||
|
latent = torch.zeros(
|
||||||
|
[batch_size, 16, lat_t, height // 8, width // 8],
|
||||||
|
device=comfy.model_management.intermediate_device(),
|
||||||
|
)
|
||||||
|
return io.NodeOutput({"samples": latent})
|
||||||
|
|
||||||
|
|
||||||
|
class SamplerARVideo(io.ComfyNode):
|
||||||
|
"""Sampler for autoregressive video models (Causal Forcing, Self-Forcing).
|
||||||
|
|
||||||
|
All AR-loop parameters are owned by this node so they live in the workflow.
|
||||||
|
Add new widgets here as the AR sampler grows new options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SamplerARVideo",
|
||||||
|
display_name="Sampler AR Video",
|
||||||
|
category="sampling/custom_sampling/samplers",
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input(
|
||||||
|
"num_frame_per_block",
|
||||||
|
default=1, min=1, max=64,
|
||||||
|
tooltip="Frames per autoregressive block. 1 = framewise, "
|
||||||
|
"3 = chunkwise. Must match the checkpoint's training mode.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[io.Sampler.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, num_frame_per_block) -> io.NodeOutput:
|
||||||
|
extra_options = {
|
||||||
|
"num_frame_per_block": num_frame_per_block,
|
||||||
|
}
|
||||||
|
return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options))
|
||||||
|
|
||||||
|
|
||||||
|
class ARVideoExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
EmptyARVideoLatent,
|
||||||
|
SamplerARVideo,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> ARVideoExtension:
|
||||||
|
return ARVideoExtension()
|
||||||
@ -78,7 +78,7 @@ class FrameInterpolate(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="FrameInterpolate",
|
node_id="FrameInterpolate",
|
||||||
display_name="Frame Interpolate",
|
display_name="Frame Interpolate",
|
||||||
category="image/video",
|
category="video",
|
||||||
search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"],
|
search_aliases=["rife", "film", "frame interpolation", "slow motion", "interpolate frames", "vfi"],
|
||||||
inputs=[
|
inputs=[
|
||||||
FrameInterpolationModel.Input("interp_model"),
|
FrameInterpolationModel.Input("interp_model"),
|
||||||
|
|||||||
@ -11,7 +11,7 @@ class ImageCompare(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="ImageCompare",
|
node_id="ImageCompare",
|
||||||
display_name="Image Compare",
|
display_name="Compare Images",
|
||||||
description="Compares two images side by side with a slider.",
|
description="Compares two images side by side with a slider.",
|
||||||
category="image",
|
category="image",
|
||||||
essentials_category="Image Tools",
|
essentials_category="Image Tools",
|
||||||
|
|||||||
@ -24,7 +24,7 @@ class ImageCrop(IO.ComfyNode):
|
|||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="ImageCrop",
|
node_id="ImageCrop",
|
||||||
search_aliases=["trim"],
|
search_aliases=["trim"],
|
||||||
display_name="Image Crop (Deprecated)",
|
display_name="Crop Image (DEPRECATED)",
|
||||||
category="image/transform",
|
category="image/transform",
|
||||||
is_deprecated=True,
|
is_deprecated=True,
|
||||||
essentials_category="Image Tools",
|
essentials_category="Image Tools",
|
||||||
@ -56,7 +56,7 @@ class ImageCropV2(IO.ComfyNode):
|
|||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="ImageCropV2",
|
node_id="ImageCropV2",
|
||||||
search_aliases=["trim"],
|
search_aliases=["trim"],
|
||||||
display_name="Image Crop",
|
display_name="Crop Image",
|
||||||
category="image/transform",
|
category="image/transform",
|
||||||
essentials_category="Image Tools",
|
essentials_category="Image Tools",
|
||||||
has_intermediate_output=True,
|
has_intermediate_output=True,
|
||||||
@ -109,6 +109,7 @@ class RepeatImageBatch(IO.ComfyNode):
|
|||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="RepeatImageBatch",
|
node_id="RepeatImageBatch",
|
||||||
search_aliases=["duplicate image", "clone image"],
|
search_aliases=["duplicate image", "clone image"],
|
||||||
|
display_name="Repeat Image Batch",
|
||||||
category="image/batch",
|
category="image/batch",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
@ -131,6 +132,7 @@ class ImageFromBatch(IO.ComfyNode):
|
|||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="ImageFromBatch",
|
node_id="ImageFromBatch",
|
||||||
search_aliases=["select image", "pick from batch", "extract image"],
|
search_aliases=["select image", "pick from batch", "extract image"],
|
||||||
|
display_name="Get Image from Batch",
|
||||||
category="image/batch",
|
category="image/batch",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
@ -157,7 +159,8 @@ class ImageAddNoise(IO.ComfyNode):
|
|||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="ImageAddNoise",
|
node_id="ImageAddNoise",
|
||||||
search_aliases=["film grain"],
|
search_aliases=["film grain"],
|
||||||
category="image",
|
display_name="Add Noise to Image",
|
||||||
|
category="image/postprocessing",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
@ -259,7 +262,7 @@ class ImageStitch(IO.ComfyNode):
|
|||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="ImageStitch",
|
node_id="ImageStitch",
|
||||||
search_aliases=["combine images", "join images", "concatenate images", "side by side"],
|
search_aliases=["combine images", "join images", "concatenate images", "side by side"],
|
||||||
display_name="Image Stitch",
|
display_name="Stitch Images",
|
||||||
description="Stitches image2 to image1 in the specified direction.\n"
|
description="Stitches image2 to image1 in the specified direction.\n"
|
||||||
"If image2 is not provided, returns image1 unchanged.\n"
|
"If image2 is not provided, returns image1 unchanged.\n"
|
||||||
"Optional spacing can be added between images.",
|
"Optional spacing can be added between images.",
|
||||||
@ -434,6 +437,7 @@ class ResizeAndPadImage(IO.ComfyNode):
|
|||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="ResizeAndPadImage",
|
node_id="ResizeAndPadImage",
|
||||||
search_aliases=["fit to size"],
|
search_aliases=["fit to size"],
|
||||||
|
display_name="Resize And Pad Image",
|
||||||
category="image/transform",
|
category="image/transform",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
@ -485,6 +489,7 @@ class SaveSVGNode(IO.ComfyNode):
|
|||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="SaveSVGNode",
|
node_id="SaveSVGNode",
|
||||||
search_aliases=["export vector", "save vector graphics"],
|
search_aliases=["export vector", "save vector graphics"],
|
||||||
|
display_name="Save SVG",
|
||||||
description="Save SVG files on disk.",
|
description="Save SVG files on disk.",
|
||||||
category="image/save",
|
category="image/save",
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -591,7 +596,7 @@ class ImageRotate(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="ImageRotate",
|
node_id="ImageRotate",
|
||||||
display_name="Image Rotate",
|
display_name="Rotate Image",
|
||||||
search_aliases=["turn", "flip orientation"],
|
search_aliases=["turn", "flip orientation"],
|
||||||
category="image/transform",
|
category="image/transform",
|
||||||
essentials_category="Image Tools",
|
essentials_category="Image Tools",
|
||||||
@ -624,6 +629,7 @@ class ImageFlip(IO.ComfyNode):
|
|||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="ImageFlip",
|
node_id="ImageFlip",
|
||||||
search_aliases=["mirror", "reflect"],
|
search_aliases=["mirror", "reflect"],
|
||||||
|
display_name="Flip Image",
|
||||||
category="image/transform",
|
category="image/transform",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
@ -650,6 +656,7 @@ class ImageScaleToMaxDimension(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="ImageScaleToMaxDimension",
|
node_id="ImageScaleToMaxDimension",
|
||||||
|
display_name="Scale Image to Max Dimension",
|
||||||
category="image/upscaling",
|
category="image/upscaling",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
|
|||||||
@ -80,7 +80,8 @@ class ImageCompositeMasked(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="ImageCompositeMasked",
|
node_id="ImageCompositeMasked",
|
||||||
search_aliases=["paste image", "overlay", "layer"],
|
search_aliases=["overlay", "layer", "paste image", "images composition"],
|
||||||
|
display_name="Image Composite Masked",
|
||||||
category="image",
|
category="image",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("destination"),
|
IO.Image.Input("destination"),
|
||||||
@ -201,6 +202,7 @@ class InvertMask(IO.ComfyNode):
|
|||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="InvertMask",
|
node_id="InvertMask",
|
||||||
search_aliases=["reverse mask", "flip mask"],
|
search_aliases=["reverse mask", "flip mask"],
|
||||||
|
display_name="Invert Mask",
|
||||||
category="mask",
|
category="mask",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Mask.Input("mask"),
|
IO.Mask.Input("mask"),
|
||||||
@ -222,6 +224,7 @@ class CropMask(IO.ComfyNode):
|
|||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="CropMask",
|
node_id="CropMask",
|
||||||
search_aliases=["cut mask", "extract mask region", "mask slice"],
|
search_aliases=["cut mask", "extract mask region", "mask slice"],
|
||||||
|
display_name="Crop Mask",
|
||||||
category="mask",
|
category="mask",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Mask.Input("mask"),
|
IO.Mask.Input("mask"),
|
||||||
@ -247,7 +250,8 @@ class MaskComposite(IO.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="MaskComposite",
|
node_id="MaskComposite",
|
||||||
search_aliases=["combine masks", "blend masks", "layer masks"],
|
search_aliases=["combine masks", "blend masks", "layer masks", "masks composition"],
|
||||||
|
display_name="Combine Masks",
|
||||||
category="mask",
|
category="mask",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Mask.Input("destination"),
|
IO.Mask.Input("destination"),
|
||||||
@ -298,6 +302,7 @@ class FeatherMask(IO.ComfyNode):
|
|||||||
return IO.Schema(
|
return IO.Schema(
|
||||||
node_id="FeatherMask",
|
node_id="FeatherMask",
|
||||||
search_aliases=["soft edge mask", "blur mask edges", "gradient mask edge"],
|
search_aliases=["soft edge mask", "blur mask edges", "gradient mask edge"],
|
||||||
|
display_name="Feather Mask",
|
||||||
category="mask",
|
category="mask",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Mask.Input("mask"),
|
IO.Mask.Input("mask"),
|
||||||
|
|||||||
@ -59,7 +59,8 @@ class ImageRGBToYUV(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="ImageRGBToYUV",
|
node_id="ImageRGBToYUV",
|
||||||
search_aliases=["color space conversion"],
|
search_aliases=["color space conversion"],
|
||||||
category="image/batch",
|
display_name="Image RGB to YUV",
|
||||||
|
category="image/color",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
],
|
],
|
||||||
@ -81,7 +82,8 @@ class ImageYUVToRGB(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="ImageYUVToRGB",
|
node_id="ImageYUVToRGB",
|
||||||
search_aliases=["color space conversion"],
|
search_aliases=["color space conversion"],
|
||||||
category="image/batch",
|
display_name="Image YUV to RGB",
|
||||||
|
category="image/color",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("Y"),
|
io.Image.Input("Y"),
|
||||||
io.Image.Input("U"),
|
io.Image.Input("U"),
|
||||||
|
|||||||
@ -20,7 +20,8 @@ class Blend(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="ImageBlend",
|
node_id="ImageBlend",
|
||||||
display_name="Image Blend",
|
search_aliases=["mix images"],
|
||||||
|
display_name="Blend Images",
|
||||||
category="image/postprocessing",
|
category="image/postprocessing",
|
||||||
essentials_category="Image Tools",
|
essentials_category="Image Tools",
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -224,6 +225,7 @@ class ImageScaleToTotalPixels(io.ComfyNode):
|
|||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="ImageScaleToTotalPixels",
|
node_id="ImageScaleToTotalPixels",
|
||||||
|
display_name="Scale Image to Total Pixels",
|
||||||
category="image/upscaling",
|
category="image/upscaling",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("image"),
|
io.Image.Input("image"),
|
||||||
@ -568,7 +570,7 @@ class BatchImagesNode(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="BatchImagesNode",
|
node_id="BatchImagesNode",
|
||||||
display_name="Batch Images",
|
display_name="Batch Images",
|
||||||
category="image",
|
category="image/batch",
|
||||||
essentials_category="Image Tools",
|
essentials_category="Image Tools",
|
||||||
search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"],
|
search_aliases=["batch", "image batch", "batch images", "combine images", "merge images", "stack images"],
|
||||||
inputs=[
|
inputs=[
|
||||||
|
|||||||
@ -17,7 +17,8 @@ class SaveWEBM(io.ComfyNode):
|
|||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="SaveWEBM",
|
node_id="SaveWEBM",
|
||||||
search_aliases=["export webm"],
|
search_aliases=["export webm"],
|
||||||
category="image/video",
|
display_name="Save WEBM",
|
||||||
|
category="video",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("images"),
|
io.Image.Input("images"),
|
||||||
@ -72,7 +73,7 @@ class SaveVideo(io.ComfyNode):
|
|||||||
node_id="SaveVideo",
|
node_id="SaveVideo",
|
||||||
search_aliases=["export video"],
|
search_aliases=["export video"],
|
||||||
display_name="Save Video",
|
display_name="Save Video",
|
||||||
category="image/video",
|
category="video",
|
||||||
essentials_category="Basics",
|
essentials_category="Basics",
|
||||||
description="Saves the input images to your ComfyUI output directory.",
|
description="Saves the input images to your ComfyUI output directory.",
|
||||||
inputs=[
|
inputs=[
|
||||||
@ -121,7 +122,7 @@ class CreateVideo(io.ComfyNode):
|
|||||||
node_id="CreateVideo",
|
node_id="CreateVideo",
|
||||||
search_aliases=["images to video"],
|
search_aliases=["images to video"],
|
||||||
display_name="Create Video",
|
display_name="Create Video",
|
||||||
category="image/video",
|
category="video",
|
||||||
description="Create a video from images.",
|
description="Create a video from images.",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("images", tooltip="The images to create a video from."),
|
io.Image.Input("images", tooltip="The images to create a video from."),
|
||||||
@ -146,7 +147,7 @@ class GetVideoComponents(io.ComfyNode):
|
|||||||
node_id="GetVideoComponents",
|
node_id="GetVideoComponents",
|
||||||
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
||||||
display_name="Get Video Components",
|
display_name="Get Video Components",
|
||||||
category="image/video",
|
category="video",
|
||||||
description="Extracts all components from a video: frames, audio, and framerate.",
|
description="Extracts all components from a video: frames, audio, and framerate.",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Video.Input("video", tooltip="The video to extract components from."),
|
io.Video.Input("video", tooltip="The video to extract components from."),
|
||||||
@ -174,7 +175,7 @@ class LoadVideo(io.ComfyNode):
|
|||||||
node_id="LoadVideo",
|
node_id="LoadVideo",
|
||||||
search_aliases=["import video", "open video", "video file"],
|
search_aliases=["import video", "open video", "video file"],
|
||||||
display_name="Load Video",
|
display_name="Load Video",
|
||||||
category="image/video",
|
category="video",
|
||||||
essentials_category="Basics",
|
essentials_category="Basics",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video),
|
io.Combo.Input("file", options=sorted(files), upload=io.UploadType.video),
|
||||||
@ -216,7 +217,7 @@ class VideoSlice(io.ComfyNode):
|
|||||||
"frame load cap",
|
"frame load cap",
|
||||||
"start time",
|
"start time",
|
||||||
],
|
],
|
||||||
category="image/video",
|
category="video",
|
||||||
essentials_category="Video Tools",
|
essentials_category="Video Tools",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Video.Input("video"),
|
io.Video.Input("video"),
|
||||||
|
|||||||
10
main.py
10
main.py
@ -1,13 +1,21 @@
|
|||||||
import comfy.options
|
import comfy.options
|
||||||
comfy.options.enable_args_parsing()
|
comfy.options.enable_args_parsing()
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
if args.list_feature_flags:
|
||||||
|
import json
|
||||||
|
from comfy_api.feature_flags import CLI_FEATURE_FLAG_REGISTRY
|
||||||
|
print(json.dumps(CLI_FEATURE_FLAG_REGISTRY, indent=2)) # noqa: T201
|
||||||
|
raise SystemExit(0)
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import importlib.util
|
import importlib.util
|
||||||
import shutil
|
import shutil
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import time
|
import time
|
||||||
from comfy.cli_args import args, enables_dynamic_vram
|
from comfy.cli_args import enables_dynamic_vram
|
||||||
from app.logger import setup_logger
|
from app.logger import setup_logger
|
||||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||||
|
|
||||||
|
|||||||
26
nodes.py
26
nodes.py
@ -1887,7 +1887,7 @@ class ImageInvert:
|
|||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "invert"
|
FUNCTION = "invert"
|
||||||
|
|
||||||
CATEGORY = "image"
|
CATEGORY = "image/color"
|
||||||
|
|
||||||
def invert(self, image):
|
def invert(self, image):
|
||||||
s = 1.0 - image
|
s = 1.0 - image
|
||||||
@ -1903,7 +1903,7 @@ class ImageBatch:
|
|||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "batch"
|
FUNCTION = "batch"
|
||||||
|
|
||||||
CATEGORY = "image"
|
CATEGORY = "image/batch"
|
||||||
DEPRECATED = True
|
DEPRECATED = True
|
||||||
|
|
||||||
def batch(self, image1, image2):
|
def batch(self, image1, image2):
|
||||||
@ -1960,7 +1960,7 @@ class ImagePadForOutpaint:
|
|||||||
RETURN_TYPES = ("IMAGE", "MASK")
|
RETURN_TYPES = ("IMAGE", "MASK")
|
||||||
FUNCTION = "expand_image"
|
FUNCTION = "expand_image"
|
||||||
|
|
||||||
CATEGORY = "image"
|
CATEGORY = "image/transform"
|
||||||
|
|
||||||
def expand_image(self, image, left, top, right, bottom, feathering):
|
def expand_image(self, image, left, top, right, bottom, feathering):
|
||||||
d1, d2, d3, d4 = image.size()
|
d1, d2, d3, d4 = image.size()
|
||||||
@ -2103,7 +2103,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"ConditioningSetArea": "Conditioning (Set Area)",
|
"ConditioningSetArea": "Conditioning (Set Area)",
|
||||||
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
|
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
|
||||||
"ConditioningSetMask": "Conditioning (Set Mask)",
|
"ConditioningSetMask": "Conditioning (Set Mask)",
|
||||||
"ControlNetApply": "Apply ControlNet (OLD)",
|
"ControlNetApply": "Apply ControlNet (DEPRECATED)",
|
||||||
"ControlNetApplyAdvanced": "Apply ControlNet",
|
"ControlNetApplyAdvanced": "Apply ControlNet",
|
||||||
# Latent
|
# Latent
|
||||||
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
|
||||||
@ -2121,6 +2121,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"LatentFromBatch" : "Latent From Batch",
|
"LatentFromBatch" : "Latent From Batch",
|
||||||
"RepeatLatentBatch": "Repeat Latent Batch",
|
"RepeatLatentBatch": "Repeat Latent Batch",
|
||||||
# Image
|
# Image
|
||||||
|
"EmptyImage": "Empty Image",
|
||||||
"SaveImage": "Save Image",
|
"SaveImage": "Save Image",
|
||||||
"PreviewImage": "Preview Image",
|
"PreviewImage": "Preview Image",
|
||||||
"LoadImage": "Load Image",
|
"LoadImage": "Load Image",
|
||||||
@ -2128,15 +2129,15 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"LoadImageOutput": "Load Image (from Outputs)",
|
"LoadImageOutput": "Load Image (from Outputs)",
|
||||||
"ImageScale": "Upscale Image",
|
"ImageScale": "Upscale Image",
|
||||||
"ImageScaleBy": "Upscale Image By",
|
"ImageScaleBy": "Upscale Image By",
|
||||||
"ImageInvert": "Invert Image",
|
"ImageInvert": "Invert Image Colors",
|
||||||
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
"ImagePadForOutpaint": "Pad Image for Outpainting",
|
||||||
"ImageBatch": "Batch Images",
|
"ImageBatch": "Batch Images (DEPRECATED)",
|
||||||
"ImageCrop": "Image Crop",
|
"ImageCrop": "Crop Image",
|
||||||
"ImageStitch": "Image Stitch",
|
"ImageStitch": "Stitch Images",
|
||||||
"ImageBlend": "Image Blend",
|
"ImageBlend": "Blend Images",
|
||||||
"ImageBlur": "Image Blur",
|
"ImageBlur": "Blur Image",
|
||||||
"ImageQuantize": "Image Quantize",
|
"ImageQuantize": "Quantize Image",
|
||||||
"ImageSharpen": "Image Sharpen",
|
"ImageSharpen": "Sharpen Image",
|
||||||
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
"ImageScaleToTotalPixels": "Scale Image to Total Pixels",
|
||||||
"GetImageSize": "Get Image Size",
|
"GetImageSize": "Get Image Size",
|
||||||
# _for_testing
|
# _for_testing
|
||||||
@ -2411,6 +2412,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_nop.py",
|
"nodes_nop.py",
|
||||||
"nodes_kandinsky5.py",
|
"nodes_kandinsky5.py",
|
||||||
"nodes_wanmove.py",
|
"nodes_wanmove.py",
|
||||||
|
"nodes_ar_video.py",
|
||||||
"nodes_image_compare.py",
|
"nodes_image_compare.py",
|
||||||
"nodes_zimage.py",
|
"nodes_zimage.py",
|
||||||
"nodes_glsl.py",
|
"nodes_glsl.py",
|
||||||
|
|||||||
40
openapi.yaml
40
openapi.yaml
@ -1999,6 +1999,26 @@ components:
|
|||||||
items:
|
items:
|
||||||
type: string
|
type: string
|
||||||
description: List of node IDs to execute (partial graph execution)
|
description: List of node IDs to execute (partial graph execution)
|
||||||
|
workflow_id:
|
||||||
|
type: string
|
||||||
|
format: uuid
|
||||||
|
nullable: true
|
||||||
|
x-runtime: [cloud]
|
||||||
|
description: |
|
||||||
|
UUID identifying a hosted-cloud workflow entity to associate with this
|
||||||
|
job. Local ComfyUI doesn't track workflow entities and returns `null`
|
||||||
|
(or omits the field). The `x-runtime: [cloud]` extension marks this
|
||||||
|
as populated only by the hosted-cloud runtime; absence of the tag
|
||||||
|
means a field is populated by all runtimes.
|
||||||
|
workflow_version_id:
|
||||||
|
type: string
|
||||||
|
format: uuid
|
||||||
|
nullable: true
|
||||||
|
x-runtime: [cloud]
|
||||||
|
description: |
|
||||||
|
UUID identifying a hosted-cloud workflow version to associate with
|
||||||
|
this job. Local ComfyUI returns `null` (or omits the field). See
|
||||||
|
`workflow_id` above for `x-runtime` semantics.
|
||||||
|
|
||||||
PromptResponse:
|
PromptResponse:
|
||||||
type: object
|
type: object
|
||||||
@ -2347,7 +2367,12 @@ components:
|
|||||||
description: Device type (cuda, mps, cpu, etc.)
|
description: Device type (cuda, mps, cpu, etc.)
|
||||||
index:
|
index:
|
||||||
type: number
|
type: number
|
||||||
description: Device index
|
nullable: true
|
||||||
|
description: |
|
||||||
|
Device index within its type (e.g. CUDA ordinal for `cuda:0`,
|
||||||
|
`cuda:1`). `null` for devices with no index, including the CPU
|
||||||
|
device returned in `--cpu` mode (PyTorch's `torch.device('cpu').index`
|
||||||
|
is `None`).
|
||||||
vram_total:
|
vram_total:
|
||||||
type: number
|
type: number
|
||||||
description: Total VRAM in bytes
|
description: Total VRAM in bytes
|
||||||
@ -2503,7 +2528,18 @@ components:
|
|||||||
description: Alternative search terms for finding this node
|
description: Alternative search terms for finding this node
|
||||||
essentials_category:
|
essentials_category:
|
||||||
type: string
|
type: string
|
||||||
description: Category override used by the essentials pack
|
nullable: true
|
||||||
|
description: |
|
||||||
|
Category override used by the essentials pack. The
|
||||||
|
`essentials_category` key may be present with a string value,
|
||||||
|
present and `null`, or absent entirely:
|
||||||
|
|
||||||
|
- V1 nodes: `essentials_category` is **omitted** when the node
|
||||||
|
class doesn't define an `ESSENTIALS_CATEGORY` attribute, and
|
||||||
|
**`null`** if the attribute is explicitly set to `None`.
|
||||||
|
- V3 nodes (`comfy_api.latest.io`): `essentials_category` is
|
||||||
|
**always present**, and **`null`** for nodes whose `Schema`
|
||||||
|
doesn't populate it.
|
||||||
|
|
||||||
# -------------------------------------------------------------------
|
# -------------------------------------------------------------------
|
||||||
# Models
|
# Models
|
||||||
|
|||||||
109
tests-unit/deploy_environment_test.py
Normal file
109
tests-unit/deploy_environment_test.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
"""Tests for comfy.deploy_environment."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from comfy import deploy_environment
|
||||||
|
from comfy.deploy_environment import get_deploy_environment
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_cache_and_install_dir(tmp_path, monkeypatch):
|
||||||
|
"""Reset the functools cache and point the ComfyUI install dir at a tmp dir for each test."""
|
||||||
|
get_deploy_environment.cache_clear()
|
||||||
|
monkeypatch.setattr(deploy_environment, "_COMFY_INSTALL_DIR", str(tmp_path))
|
||||||
|
yield
|
||||||
|
get_deploy_environment.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
|
def _write_env_file(tmp_path, content: str) -> str:
|
||||||
|
"""Write the env file with exact content (no newline translation).
|
||||||
|
|
||||||
|
`newline=""` disables Python's text-mode newline translation so the bytes
|
||||||
|
on disk match the literal string passed in, regardless of host OS.
|
||||||
|
Newline-style tests (CRLF, lone CR) rely on this.
|
||||||
|
"""
|
||||||
|
path = os.path.join(str(tmp_path), ".comfy_environment")
|
||||||
|
with open(path, "w", encoding="utf-8", newline="") as f:
|
||||||
|
f.write(content)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetDeployEnvironment:
|
||||||
|
def test_returns_local_git_when_file_missing(self):
|
||||||
|
assert get_deploy_environment() == "local-git"
|
||||||
|
|
||||||
|
def test_reads_value_from_file(self, tmp_path):
|
||||||
|
_write_env_file(tmp_path, "local-desktop2-standalone\n")
|
||||||
|
assert get_deploy_environment() == "local-desktop2-standalone"
|
||||||
|
|
||||||
|
def test_strips_trailing_whitespace_and_newline(self, tmp_path):
|
||||||
|
_write_env_file(tmp_path, " local-desktop2-standalone \n")
|
||||||
|
assert get_deploy_environment() == "local-desktop2-standalone"
|
||||||
|
|
||||||
|
def test_only_first_line_is_used(self, tmp_path):
|
||||||
|
_write_env_file(tmp_path, "first-line\nsecond-line\n")
|
||||||
|
assert get_deploy_environment() == "first-line"
|
||||||
|
|
||||||
|
def test_crlf_line_ending(self, tmp_path):
|
||||||
|
# Windows editors often save text files with CRLF line endings.
|
||||||
|
# The CR must not end up in the returned value.
|
||||||
|
_write_env_file(tmp_path, "local-desktop2-standalone\r\n")
|
||||||
|
assert get_deploy_environment() == "local-desktop2-standalone"
|
||||||
|
|
||||||
|
def test_crlf_multiline_only_first_line_used(self, tmp_path):
|
||||||
|
_write_env_file(tmp_path, "first-line\r\nsecond-line\r\n")
|
||||||
|
assert get_deploy_environment() == "first-line"
|
||||||
|
|
||||||
|
def test_crlf_with_surrounding_whitespace(self, tmp_path):
|
||||||
|
_write_env_file(tmp_path, " local-desktop2-standalone \r\n")
|
||||||
|
assert get_deploy_environment() == "local-desktop2-standalone"
|
||||||
|
|
||||||
|
def test_lone_cr_line_ending(self, tmp_path):
|
||||||
|
# Classic-Mac / some legacy editors use a bare CR.
|
||||||
|
# Universal-newlines decoding treats it as a line terminator too.
|
||||||
|
_write_env_file(tmp_path, "local-desktop2-standalone\r")
|
||||||
|
assert get_deploy_environment() == "local-desktop2-standalone"
|
||||||
|
|
||||||
|
def test_empty_file_falls_back_to_default(self, tmp_path):
|
||||||
|
_write_env_file(tmp_path, "")
|
||||||
|
assert get_deploy_environment() == "local-git"
|
||||||
|
|
||||||
|
def test_empty_after_whitespace_strip_falls_back_to_default(self, tmp_path):
|
||||||
|
_write_env_file(tmp_path, " \n")
|
||||||
|
assert get_deploy_environment() == "local-git"
|
||||||
|
|
||||||
|
def test_strips_control_chars_within_first_line(self, tmp_path):
|
||||||
|
# Embedded NUL/control chars in the value should be stripped
|
||||||
|
# (header-injection / smuggling protection).
|
||||||
|
_write_env_file(tmp_path, "abc\x00\x07xyz\n")
|
||||||
|
assert get_deploy_environment() == "abcxyz"
|
||||||
|
|
||||||
|
def test_strips_non_ascii_characters(self, tmp_path):
|
||||||
|
_write_env_file(tmp_path, "café-é\n")
|
||||||
|
assert get_deploy_environment() == "caf-"
|
||||||
|
|
||||||
|
def test_caps_read_at_128_bytes(self, tmp_path):
|
||||||
|
# A single huge line with no newline must not be fully read into memory.
|
||||||
|
huge = "x" * 10_000
|
||||||
|
_write_env_file(tmp_path, huge)
|
||||||
|
result = get_deploy_environment()
|
||||||
|
assert result == "x" * 128
|
||||||
|
|
||||||
|
def test_result_is_cached_across_calls(self, tmp_path):
|
||||||
|
path = _write_env_file(tmp_path, "first_value\n")
|
||||||
|
assert get_deploy_environment() == "first_value"
|
||||||
|
# Overwrite the file — cached value should still be returned.
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
f.write("second_value\n")
|
||||||
|
assert get_deploy_environment() == "first_value"
|
||||||
|
|
||||||
|
def test_unreadable_file_falls_back_to_default(self, tmp_path, monkeypatch):
|
||||||
|
_write_env_file(tmp_path, "should_not_be_used\n")
|
||||||
|
|
||||||
|
def _boom(*args, **kwargs):
|
||||||
|
raise OSError("simulated read failure")
|
||||||
|
|
||||||
|
monkeypatch.setattr("builtins.open", _boom)
|
||||||
|
assert get_deploy_environment() == "local-git"
|
||||||
@ -1,10 +1,15 @@
|
|||||||
"""Tests for feature flags functionality."""
|
"""Tests for feature flags functionality."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from comfy_api.feature_flags import (
|
from comfy_api.feature_flags import (
|
||||||
get_connection_feature,
|
get_connection_feature,
|
||||||
supports_feature,
|
supports_feature,
|
||||||
get_server_features,
|
get_server_features,
|
||||||
|
CLI_FEATURE_FLAG_REGISTRY,
|
||||||
SERVER_FEATURE_FLAGS,
|
SERVER_FEATURE_FLAGS,
|
||||||
|
_coerce_flag_value,
|
||||||
|
_parse_cli_feature_flags,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -96,3 +101,83 @@ class TestFeatureFlags:
|
|||||||
result = get_connection_feature(sockets_metadata, "sid1", "any_feature")
|
result = get_connection_feature(sockets_metadata, "sid1", "any_feature")
|
||||||
assert result is False
|
assert result is False
|
||||||
assert supports_feature(sockets_metadata, "sid1", "any_feature") is False
|
assert supports_feature(sockets_metadata, "sid1", "any_feature") is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestCoerceFlagValue:
|
||||||
|
"""Test suite for _coerce_flag_value."""
|
||||||
|
|
||||||
|
def test_registered_bool_true(self):
|
||||||
|
assert _coerce_flag_value("show_signin_button", "true") is True
|
||||||
|
assert _coerce_flag_value("show_signin_button", "True") is True
|
||||||
|
|
||||||
|
def test_registered_bool_false(self):
|
||||||
|
assert _coerce_flag_value("show_signin_button", "false") is False
|
||||||
|
assert _coerce_flag_value("show_signin_button", "FALSE") is False
|
||||||
|
|
||||||
|
def test_unregistered_key_stays_string(self):
|
||||||
|
assert _coerce_flag_value("unknown_flag", "true") == "true"
|
||||||
|
assert _coerce_flag_value("unknown_flag", "42") == "42"
|
||||||
|
|
||||||
|
def test_bool_typo_raises(self):
|
||||||
|
"""Strict bool: typos like 'ture' or 'yes' must raise so the flag can be dropped."""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_coerce_flag_value("show_signin_button", "ture")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_coerce_flag_value("show_signin_button", "yes")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_coerce_flag_value("show_signin_button", "1")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_coerce_flag_value("show_signin_button", "")
|
||||||
|
|
||||||
|
def test_failed_int_coercion_raises(self, monkeypatch):
|
||||||
|
"""Malformed values for typed flags must raise; caller decides what to do."""
|
||||||
|
monkeypatch.setitem(
|
||||||
|
CLI_FEATURE_FLAG_REGISTRY,
|
||||||
|
"test_int_flag",
|
||||||
|
{"type": "int", "default": 0, "description": "test"},
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_coerce_flag_value("test_int_flag", "not_a_number")
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseCliFeatureFlags:
|
||||||
|
"""Test suite for _parse_cli_feature_flags."""
|
||||||
|
|
||||||
|
def test_single_flag(self, monkeypatch):
|
||||||
|
monkeypatch.setattr("comfy_api.feature_flags.args", type("Args", (), {"feature_flag": ["show_signin_button=true"]})())
|
||||||
|
result = _parse_cli_feature_flags()
|
||||||
|
assert result == {"show_signin_button": True}
|
||||||
|
|
||||||
|
def test_missing_equals_defaults_to_true(self, monkeypatch):
|
||||||
|
"""Bare flag without '=' is treated as the string 'true' (and coerced if registered)."""
|
||||||
|
monkeypatch.setattr("comfy_api.feature_flags.args", type("Args", (), {"feature_flag": ["show_signin_button", "valid=1"]})())
|
||||||
|
result = _parse_cli_feature_flags()
|
||||||
|
assert result == {"show_signin_button": True, "valid": "1"}
|
||||||
|
|
||||||
|
def test_empty_key_skipped(self, monkeypatch):
|
||||||
|
monkeypatch.setattr("comfy_api.feature_flags.args", type("Args", (), {"feature_flag": ["=value", "valid=1"]})())
|
||||||
|
result = _parse_cli_feature_flags()
|
||||||
|
assert result == {"valid": "1"}
|
||||||
|
|
||||||
|
def test_invalid_bool_value_dropped(self, monkeypatch, caplog):
|
||||||
|
"""A typo'd bool value must be dropped entirely, not silently set to False
|
||||||
|
and not stored as a raw string. A warning must be logged."""
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"comfy_api.feature_flags.args",
|
||||||
|
type("Args", (), {"feature_flag": ["show_signin_button=ture", "valid=1"]})(),
|
||||||
|
)
|
||||||
|
with caplog.at_level("WARNING"):
|
||||||
|
result = _parse_cli_feature_flags()
|
||||||
|
assert result == {"valid": "1"}
|
||||||
|
assert "show_signin_button" not in result
|
||||||
|
assert any("show_signin_button" in r.message and "drop" in r.message.lower() for r in caplog.records)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCliFeatureFlagRegistry:
|
||||||
|
"""Test suite for the CLI feature flag registry."""
|
||||||
|
|
||||||
|
def test_registry_entries_have_required_fields(self):
|
||||||
|
for key, info in CLI_FEATURE_FLAG_REGISTRY.items():
|
||||||
|
assert "type" in info, f"{key} missing 'type'"
|
||||||
|
assert "default" in info, f"{key} missing 'default'"
|
||||||
|
assert "description" in info, f"{key} missing 'description'"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user