Compare commits

...

19 Commits

Author SHA1 Message Date
Talmaj
dd414f7a32
Merge b8688389e5 into b138133ffa 2026-05-03 15:13:47 -04:00
Silver
b138133ffa
Enable triton comfy kitchen via cli-arg (#12730) 2026-05-03 14:07:21 -04:00
Jukka Seppänen
025e6792ee
Batch broadcasting in JoinImageWithAlpha node (#13686)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Generate Pydantic Stubs from api.comfy.org / generate-models (push) Has been cancelled
* Batch broadcasting in JoinImageWithAlpha node
2026-05-03 16:30:00 +03:00
Luke Mino-Altherr
867b8d2408
fix: gracefully handle port-in-use error on server startup (#13001)
Catch EADDRINUSE OSError when binding the TCP site and exit with a clear error message instead of an unhandled traceback.
2026-05-03 20:44:20 +08:00
Alexis Rolland
d0f0b15cf5
Update ComfyUI screenshot in README (#13683)
Update ComfyUI screenshot to showcase a more modern workflow
2026-05-03 18:48:58 +08:00
Alexis Rolland
b5bb83c964
Fix issue blend images with alpha (#13615)
Make ImageBlend and ImageCompositeMasked nodes handle images with different channel counts
2026-05-03 18:17:08 +08:00
Talmaj Marinc
b8688389e5 Create a dedicated node for ar_sampler. 2026-05-01 22:21:46 +02:00
Talmaj Marinc
4a63ef01ed Add better error handling for a custom ar_video sampler. 2026-05-01 11:37:22 +02:00
Talmaj Marinc
b5d1cdb2a8 Base frame_seq_len on the padded token grid. 2026-05-01 11:37:22 +02:00
Talmaj Marinc
df0a0dae36 Move KV cache end counter to Python int to avoid per-step host synchronization in AR sampling loops. 2026-05-01 11:37:22 +02:00
Talmaj Marinc
6f99c8086f Remove ar_convert, now present in hg repackaged model repo. 2026-05-01 11:37:22 +02:00
Talmaj Marinc
586647afdb Fix 'Process the tail block instead of truncating it', fix 'Don't mutate the patcher's shared transformer_options in place'. 2026-05-01 11:37:22 +02:00
Talmaj Marinc
6739b76a20 Remove dedicated ARLoader. 2026-05-01 11:37:19 +02:00
Talmaj Marinc
7caf1e801c Refactor CausalWanModel to inherit from WanModel. 2026-05-01 11:34:07 +02:00
Talmaj Marinc
3bd42bfaa4 Rewrite causual forcing using custom sampler with KSampler node. 2026-05-01 11:34:07 +02:00
Talmaj Marinc
09e46f8509 Apply ruff. 2026-05-01 11:34:07 +02:00
Talmaj Marinc
9d40dfbab8 Rename causual forcing to using more general auto regressive naming convention. 2026-05-01 11:34:07 +02:00
Talmaj Marinc
0edbe6d8dc Fix CausalForcingSampler. 2026-05-01 11:34:07 +02:00
Talmaj Marinc
2cdc4389d7 Initial commit causual_forcing. 2026-05-01 11:34:07 +02:00
12 changed files with 516 additions and 12 deletions

View File

@ -31,7 +31,8 @@
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
<img width="1590" height="795" alt="ComfyUI Screenshot" src="https://github.com/user-attachments/assets/4aab0bef-b413-4595-9766-a2c134676d27" />
<img width="1590" height="795" alt="ComfyUI Screenshot" src="https://github.com/user-attachments/assets/36e065e0-bfae-4456-8c7f-8369d5ea48a2" />
<br>
</div>
ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...

View File

@ -91,6 +91,7 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE"
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
parser.add_argument("--enable-triton-backend", action="store_true", help="ComfyUI will enable the use of Triton backend in comfy-kitchen. Is disabled at launch by default.")
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"

View File

@ -1810,3 +1810,102 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
"""Stochastic Adams Solver with PECE (PredictEvaluateCorrectEvaluate) mode (NeurIPS 2023)."""
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
@torch.no_grad()
def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None,
num_frame_per_block=1):
"""
Autoregressive video sampler: block-by-block denoising with KV cache
and flow-match re-noising for Causal Forcing / Self-Forcing models.
Requires a Causal-WAN compatible model (diffusion_model must expose
init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W].
All AR-loop parameters are passed via the SamplerARVideo node, not read
from the checkpoint or transformer_options.
"""
extra_args = {} if extra_args is None else extra_args
model_options = extra_args.get("model_options", {})
transformer_options = model_options.get("transformer_options", {})
if x.ndim != 5:
raise ValueError(
f"ar_video sampler requires 5-D video latents [B,C,T,H,W], got {x.ndim}-D tensor with shape {x.shape}. "
"This sampler is only compatible with autoregressive video models (e.g. Causal-WAN)."
)
inner_model = model.inner_model.inner_model
causal_model = inner_model.diffusion_model
if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")):
raise TypeError(
"ar_video sampler requires a Causal-WAN compatible model whose diffusion_model "
"exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint "
"does not support this interface — choose a different sampler."
)
seed = extra_args.get("seed", 0)
bs, c, lat_t, lat_h, lat_w = x.shape
frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division
num_blocks = -(-lat_t // num_frame_per_block) # ceiling division
device = x.device
model_dtype = inner_model.get_dtype()
kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype)
crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype)
output = torch.zeros_like(x)
s_in = x.new_ones([x.shape[0]])
current_start_frame = 0
num_sigma_steps = len(sigmas) - 1
total_real_steps = num_blocks * num_sigma_steps
step_count = 0
try:
for block_idx in trange(num_blocks, disable=disable):
bf = min(num_frame_per_block, lat_t - current_start_frame)
fs, fe = current_start_frame, current_start_frame + bf
noisy_input = x[:, :, fs:fe]
ar_state = {
"start_frame": current_start_frame,
"kv_caches": kv_caches,
"crossattn_caches": crossattn_caches,
}
transformer_options["ar_state"] = ar_state
for i in range(num_sigma_steps):
denoised = model(noisy_input, sigmas[i] * s_in, **extra_args)
if callback is not None:
scaled_i = step_count * num_sigma_steps // total_real_steps
callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i],
"sigma_hat": sigmas[i], "denoised": denoised})
if sigmas[i + 1] == 0:
noisy_input = denoised
else:
sigma_next = sigmas[i + 1]
torch.manual_seed(seed + block_idx * 1000 + i)
fresh_noise = torch.randn_like(denoised)
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
for cache in kv_caches:
cache["end"] -= bf * frame_seq_len
step_count += 1
output[:, :, fs:fe] = noisy_input
for cache in kv_caches:
cache["end"] -= bf * frame_seq_len
zero_sigma = sigmas.new_zeros([1])
_ = model(noisy_input, zero_sigma * s_in, **extra_args)
current_start_frame += bf
finally:
transformer_options.pop("ar_state", None)
return output

276
comfy/ldm/wan/ar_model.py Normal file
View File

@ -0,0 +1,276 @@
"""
CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for
autoregressive (frame-by-frame) video generation via Causal Forcing.
Weight-compatible with the standard WanModel -- same layer names, same shapes.
The difference is purely in the forward pass: this model processes one temporal
block at a time and maintains a KV cache across blocks.
Reference: https://github.com/thu-ml/Causal-Forcing
"""
import torch
import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.wan.model import (
sinusoidal_embedding_1d,
repeat_e,
WanModel,
WanAttentionBlock,
)
import comfy.ldm.common_dit
import comfy.model_management
class CausalWanSelfAttention(nn.Module):
"""Self-attention with KV cache support for autoregressive inference."""
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True,
eps=1e-6, operation_settings={}):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qk_norm = qk_norm
self.eps = eps
ops = operation_settings.get("operations")
device = operation_settings.get("device")
dtype = operation_settings.get("dtype")
self.q = ops.Linear(dim, dim, device=device, dtype=dtype)
self.k = ops.Linear(dim, dim, device=device, dtype=dtype)
self.v = ops.Linear(dim, dim, device=device, dtype=dtype)
self.o = ops.Linear(dim, dim, device=device, dtype=dtype)
self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
def forward(self, x, freqs, kv_cache=None, transformer_options={}):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs)
k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs)
v = self.v(x).view(b, s, n, d)
if kv_cache is None:
x = optimized_attention(
q.view(b, s, n * d),
k.view(b, s, n * d),
v.view(b, s, n * d),
heads=self.num_heads,
transformer_options=transformer_options,
)
else:
end = kv_cache["end"]
new_end = end + s
# Roped K and plain V go into cache
kv_cache["k"][:, end:new_end] = k
kv_cache["v"][:, end:new_end] = v
kv_cache["end"] = new_end
x = optimized_attention(
q.view(b, s, n * d),
kv_cache["k"][:, :new_end].view(b, new_end, n * d),
kv_cache["v"][:, :new_end].view(b, new_end, n * d),
heads=self.num_heads,
transformer_options=transformer_options,
)
x = self.o(x)
return x
class CausalWanAttentionBlock(WanAttentionBlock):
"""Transformer block with KV-cached self-attention and cross-attention caching."""
def __init__(self, cross_attn_type, dim, ffn_dim, num_heads,
window_size=(-1, -1), qk_norm=True, cross_attn_norm=False,
eps=1e-6, operation_settings={}):
super().__init__(cross_attn_type, dim, ffn_dim, num_heads,
window_size, qk_norm, cross_attn_norm, eps,
operation_settings=operation_settings)
self.self_attn = CausalWanSelfAttention(
dim, num_heads, window_size, qk_norm, eps,
operation_settings=operation_settings)
def forward(self, x, e, freqs, context, context_img_len=257,
kv_cache=None, crossattn_cache=None, transformer_options={}):
if e.ndim < 4:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
else:
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
# Self-attention with optional KV cache
x = x.contiguous()
y = self.self_attn(
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
freqs, kv_cache=kv_cache, transformer_options=transformer_options)
x = torch.addcmul(x, y, repeat_e(e[2], x))
del y
# Cross-attention with optional caching
if crossattn_cache is not None and crossattn_cache.get("is_init"):
q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x)))
x_ca = optimized_attention(
q, crossattn_cache["k"], crossattn_cache["v"],
heads=self.num_heads, transformer_options=transformer_options)
x = x + self.cross_attn.o(x_ca)
else:
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
if crossattn_cache is not None:
crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context))
crossattn_cache["v"] = self.cross_attn.v(context)
crossattn_cache["is_init"] = True
# FFN
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
x = torch.addcmul(x, y, repeat_e(e[5], x))
return x
class CausalWanModel(WanModel):
"""
Wan 2.1 diffusion backbone with causal KV-cache support.
Same weight structure as WanModel -- loads identical state dicts.
Adds forward_block() for frame-by-frame autoregressive inference.
"""
def __init__(self,
model_type='t2v',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
image_model=None,
device=None,
dtype=None,
operations=None):
super().__init__(
model_type=model_type, patch_size=patch_size, text_len=text_len,
in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim,
text_dim=text_dim, out_dim=out_dim, num_heads=num_heads,
num_layers=num_layers, window_size=window_size, qk_norm=qk_norm,
cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model,
wan_attn_block_class=CausalWanAttentionBlock,
device=device, dtype=dtype, operations=operations)
def forward_block(self, x, timestep, context, start_frame,
kv_caches, crossattn_caches, clip_fea=None):
"""
Forward one temporal block for autoregressive inference.
Args:
x: [B, C, block_frames, H, W] input latent for the current block
timestep: [B, block_frames] per-frame timesteps
context: [B, L, text_dim] raw text embeddings (pre-text_embedding)
start_frame: temporal frame index for RoPE offset
kv_caches: list of per-layer KV cache dicts
crossattn_caches: list of per-layer cross-attention cache dicts
clip_fea: optional CLIP features for I2V
Returns:
flow_pred: [B, C_out, block_frames, H, W] flow prediction
"""
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
bs, c, t, h, w = x.shape
x = self.patch_embedding(x.float()).to(x.dtype)
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# Per-frame time embedding
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype))
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
# Text embedding (reuses crossattn_cache after first block)
context = self.text_embedding(context)
context_img_len = None
if clip_fea is not None and self.img_emb is not None:
context_clip = self.img_emb(clip_fea)
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
# RoPE for current block's temporal position
freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype)
# Transformer blocks
for i, block in enumerate(self.blocks):
x = block(x, e=e0, freqs=freqs, context=context,
context_img_len=context_img_len,
kv_cache=kv_caches[i],
crossattn_cache=crossattn_caches[i])
# Head
x = self.head(x, e)
# Unpatchify
x = self.unpatchify(x, grid_sizes)
return x[:, :, :t, :h, :w]
def init_kv_caches(self, batch_size, max_seq_len, device, dtype):
"""Create fresh KV caches for all layers."""
caches = []
for _ in range(self.num_layers):
caches.append({
"k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
"v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
"end": 0,
})
return caches
def init_crossattn_caches(self, batch_size, device, dtype):
"""Create fresh cross-attention caches for all layers."""
caches = []
for _ in range(self.num_layers):
caches.append({"is_init": False})
return caches
def reset_kv_caches(self, kv_caches):
"""Reset KV caches to empty (reuse allocated memory)."""
for cache in kv_caches:
cache["end"] = 0
def reset_crossattn_caches(self, crossattn_caches):
"""Reset cross-attention caches."""
for cache in crossattn_caches:
cache["is_init"] = False
@property
def head_dim(self):
return self.dim // self.num_heads
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
ar_state = transformer_options.get("ar_state")
if ar_state is not None:
bs = x.shape[0]
block_frames = x.shape[2]
t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames)
return self.forward_block(
x=x, timestep=t_per_frame, context=context,
start_frame=ar_state["start_frame"],
kv_caches=ar_state["kv_caches"],
crossattn_caches=ar_state["crossattn_caches"],
clip_fea=clip_fea,
)
return super().forward(x, timestep, context, clip_fea=clip_fea,
time_dim_concat=time_dim_concat,
transformer_options=transformer_options, **kwargs)

View File

@ -42,6 +42,7 @@ import comfy.ldm.cosmos.predict2
import comfy.ldm.lumina.model
import comfy.ldm.wan.model
import comfy.ldm.wan.model_animate
import comfy.ldm.wan.ar_model
import comfy.ldm.hunyuan3d.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
@ -1365,6 +1366,13 @@ class WAN21(BaseModel):
return out
class WAN21_CausalAR(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device,
unet_model=comfy.ldm.wan.ar_model.CausalWanModel)
self.image_to_video = False
class WAN21_Vace(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)

View File

@ -1,6 +1,8 @@
import torch
import logging
from comfy.cli_args import args
try:
import comfy_kitchen as ck
from comfy_kitchen.tensor import (
@ -21,7 +23,15 @@ try:
ck.registry.disable("cuda")
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
ck.registry.disable("triton")
if args.enable_triton_backend:
try:
import triton
logging.info("Found triton %s. Enabling comfy-kitchen triton backend.", triton.__version__)
except ImportError as e:
logging.error(f"Failed to import triton, Error: {e}, the comfy-kitchen triton backend will not be available.")
ck.registry.disable("triton")
else:
ck.registry.disable("triton")
for k, v in ck.list_backends().items():
logging.info(f"Found comfy_kitchen backend {k}: {v}")
except ImportError as e:

View File

@ -1167,6 +1167,25 @@ class WAN21_T2V(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
class WAN21_CausalAR_T2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "t2v",
"causal_ar": True,
}
sampling_settings = {
"shift": 5.0,
}
def __init__(self, unet_config):
super().__init__(unet_config)
self.unet_config.pop("causal_ar", None)
def get_model(self, state_dict, prefix="", device=None):
return model_base.WAN21_CausalAR(self, device=device)
class WAN21_I2V(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@ -1929,6 +1948,7 @@ models = [
ZImage,
Lumina2,
WAN22_T2V,
WAN21_CausalAR_T2V,
WAN21_T2V,
WAN21_I2V,
WAN21_FunControl2V,

View File

@ -0,0 +1,84 @@
"""
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
- SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop
"""
import torch
from typing_extensions import override
import comfy.model_management
import comfy.samplers
from comfy_api.latest import ComfyExtension, io
class EmptyARVideoLatent(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="EmptyARVideoLatent",
category="latent/video",
inputs=[
io.Int.Input("width", default=832, min=16, max=8192, step=16),
io.Int.Input("height", default=480, min=16, max=8192, step=16),
io.Int.Input("length", default=81, min=1, max=1024, step=4),
io.Int.Input("batch_size", default=1, min=1, max=64),
],
outputs=[
io.Latent.Output(display_name="LATENT"),
],
)
@classmethod
def execute(cls, width, height, length, batch_size) -> io.NodeOutput:
lat_t = ((length - 1) // 4) + 1
latent = torch.zeros(
[batch_size, 16, lat_t, height // 8, width // 8],
device=comfy.model_management.intermediate_device(),
)
return io.NodeOutput({"samples": latent})
class SamplerARVideo(io.ComfyNode):
"""Sampler for autoregressive video models (Causal Forcing, Self-Forcing).
All AR-loop parameters are owned by this node so they live in the workflow.
Add new widgets here as the AR sampler grows new options.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SamplerARVideo",
display_name="Sampler AR Video",
category="sampling/custom_sampling/samplers",
inputs=[
io.Int.Input(
"num_frame_per_block",
default=1, min=1, max=64,
tooltip="Frames per autoregressive block. 1 = framewise, "
"3 = chunkwise. Must match the checkpoint's training mode.",
),
],
outputs=[io.Sampler.Output()],
)
@classmethod
def execute(cls, num_frame_per_block) -> io.NodeOutput:
extra_options = {
"num_frame_per_block": num_frame_per_block,
}
return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options))
class ARVideoExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
EmptyARVideoLatent,
SamplerARVideo,
]
async def comfy_entrypoint() -> ARVideoExtension:
return ARVideoExtension()

View File

@ -202,14 +202,11 @@ class JoinImageWithAlpha(io.ComfyNode):
@classmethod
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
batch_size = min(len(image), len(alpha))
out_images = []
batch_size = max(len(image), len(alpha))
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
return io.NodeOutput(torch.stack(out_images))
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
image = comfy.utils.repeat_to_batch_size(image, batch_size)
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))
class CompositingExtension(ComfyExtension):

View File

@ -86,6 +86,6 @@ def image_alpha_fix(destination, source):
if destination.shape[-1] < source.shape[-1]:
source = source[...,:destination.shape[-1]]
elif destination.shape[-1] > source.shape[-1]:
destination = torch.nn.functional.pad(destination, (0, 1))
destination[..., -1] = 1.0
source = torch.nn.functional.pad(source, (0, 1))
source[..., -1] = 1.0
return destination, source

View File

@ -2419,6 +2419,7 @@ async def init_builtin_extra_nodes():
"nodes_nop.py",
"nodes_kandinsky5.py",
"nodes_wanmove.py",
"nodes_ar_video.py",
"nodes_image_compare.py",
"nodes_zimage.py",
"nodes_glsl.py",

View File

@ -1,3 +1,4 @@
import errno
import os
import sys
import asyncio
@ -1245,7 +1246,13 @@ class PromptServer():
address = addr[0]
port = addr[1]
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
await site.start()
try:
await site.start()
except OSError as e:
if e.errno == errno.EADDRINUSE:
logging.error(f"Port {port} is already in use on address {address}. Please close the other application or use a different port with --port.")
raise SystemExit(1)
raise
if not hasattr(self, 'address'):
self.address = address #TODO: remove this