mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-29 21:13:33 +08:00
Merge 2841684700 into fc1fdf3389
This commit is contained in:
commit
de28e26214
@ -1810,3 +1810,100 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
||||
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
|
||||
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
|
||||
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||
"""
|
||||
Autoregressive video sampler: block-by-block denoising with KV cache
|
||||
and flow-match re-noising for Causal Forcing / Self-Forcing models.
|
||||
|
||||
Requires a Causal-WAN compatible model (diffusion_model must expose
|
||||
init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W].
|
||||
"""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
model_options = extra_args.get("model_options", {})
|
||||
transformer_options = model_options.get("transformer_options", {})
|
||||
ar_config = transformer_options.get("ar_config", {})
|
||||
|
||||
if x.ndim != 5:
|
||||
raise ValueError(
|
||||
f"ar_video sampler requires 5-D video latents [B,C,T,H,W], got {x.ndim}-D tensor with shape {x.shape}. "
|
||||
"This sampler is only compatible with autoregressive video models (e.g. Causal-WAN)."
|
||||
)
|
||||
|
||||
inner_model = model.inner_model.inner_model
|
||||
causal_model = inner_model.diffusion_model
|
||||
|
||||
if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")):
|
||||
raise TypeError(
|
||||
"ar_video sampler requires a Causal-WAN compatible model whose diffusion_model "
|
||||
"exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint "
|
||||
"does not support this interface — choose a different sampler."
|
||||
)
|
||||
|
||||
num_frame_per_block = ar_config.get("num_frame_per_block", 1)
|
||||
seed = extra_args.get("seed", 0)
|
||||
|
||||
bs, c, lat_t, lat_h, lat_w = x.shape
|
||||
frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division
|
||||
num_blocks = -(-lat_t // num_frame_per_block) # ceiling division
|
||||
device = x.device
|
||||
model_dtype = inner_model.get_dtype()
|
||||
|
||||
kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype)
|
||||
crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype)
|
||||
|
||||
output = torch.zeros_like(x)
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
current_start_frame = 0
|
||||
num_sigma_steps = len(sigmas) - 1
|
||||
total_real_steps = num_blocks * num_sigma_steps
|
||||
step_count = 0
|
||||
|
||||
try:
|
||||
for block_idx in trange(num_blocks, disable=disable):
|
||||
bf = min(num_frame_per_block, lat_t - current_start_frame)
|
||||
fs, fe = current_start_frame, current_start_frame + bf
|
||||
noisy_input = x[:, :, fs:fe]
|
||||
|
||||
ar_state = {
|
||||
"start_frame": current_start_frame,
|
||||
"kv_caches": kv_caches,
|
||||
"crossattn_caches": crossattn_caches,
|
||||
}
|
||||
transformer_options["ar_state"] = ar_state
|
||||
|
||||
for i in range(num_sigma_steps):
|
||||
denoised = model(noisy_input, sigmas[i] * s_in, **extra_args)
|
||||
|
||||
if callback is not None:
|
||||
scaled_i = step_count * num_sigma_steps // total_real_steps
|
||||
callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i],
|
||||
"sigma_hat": sigmas[i], "denoised": denoised})
|
||||
|
||||
if sigmas[i + 1] == 0:
|
||||
noisy_input = denoised
|
||||
else:
|
||||
sigma_next = sigmas[i + 1]
|
||||
torch.manual_seed(seed + block_idx * 1000 + i)
|
||||
fresh_noise = torch.randn_like(denoised)
|
||||
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
|
||||
|
||||
for cache in kv_caches:
|
||||
cache["end"] -= bf * frame_seq_len
|
||||
|
||||
step_count += 1
|
||||
|
||||
output[:, :, fs:fe] = noisy_input
|
||||
|
||||
for cache in kv_caches:
|
||||
cache["end"] -= bf * frame_seq_len
|
||||
zero_sigma = sigmas.new_zeros([1])
|
||||
_ = model(noisy_input, zero_sigma * s_in, **extra_args)
|
||||
|
||||
current_start_frame += bf
|
||||
finally:
|
||||
transformer_options.pop("ar_state", None)
|
||||
|
||||
return output
|
||||
|
||||
276
comfy/ldm/wan/ar_model.py
Normal file
276
comfy/ldm/wan/ar_model.py
Normal file
@ -0,0 +1,276 @@
|
||||
"""
|
||||
CausalWanModel: Wan 2.1 backbone with KV-cached causal self-attention for
|
||||
autoregressive (frame-by-frame) video generation via Causal Forcing.
|
||||
|
||||
Weight-compatible with the standard WanModel -- same layer names, same shapes.
|
||||
The difference is purely in the forward pass: this model processes one temporal
|
||||
block at a time and maintains a KV cache across blocks.
|
||||
|
||||
Reference: https://github.com/thu-ml/Causal-Forcing
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.wan.model import (
|
||||
sinusoidal_embedding_1d,
|
||||
repeat_e,
|
||||
WanModel,
|
||||
WanAttentionBlock,
|
||||
)
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class CausalWanSelfAttention(nn.Module):
|
||||
"""Self-attention with KV cache support for autoregressive inference."""
|
||||
|
||||
def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True,
|
||||
eps=1e-6, operation_settings={}):
|
||||
assert dim % num_heads == 0
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qk_norm = qk_norm
|
||||
self.eps = eps
|
||||
|
||||
ops = operation_settings.get("operations")
|
||||
device = operation_settings.get("device")
|
||||
dtype = operation_settings.get("dtype")
|
||||
|
||||
self.q = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.k = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.v = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.o = ops.Linear(dim, dim, device=device, dtype=dtype)
|
||||
self.norm_q = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
|
||||
self.norm_k = ops.RMSNorm(dim, eps=eps, elementwise_affine=True, device=device, dtype=dtype) if qk_norm else nn.Identity()
|
||||
|
||||
def forward(self, x, freqs, kv_cache=None, transformer_options={}):
|
||||
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
||||
|
||||
q = apply_rope1(self.norm_q(self.q(x)).view(b, s, n, d), freqs)
|
||||
k = apply_rope1(self.norm_k(self.k(x)).view(b, s, n, d), freqs)
|
||||
v = self.v(x).view(b, s, n, d)
|
||||
|
||||
if kv_cache is None:
|
||||
x = optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
k.view(b, s, n * d),
|
||||
v.view(b, s, n * d),
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
else:
|
||||
end = kv_cache["end"]
|
||||
new_end = end + s
|
||||
|
||||
# Roped K and plain V go into cache
|
||||
kv_cache["k"][:, end:new_end] = k
|
||||
kv_cache["v"][:, end:new_end] = v
|
||||
kv_cache["end"] = new_end
|
||||
|
||||
x = optimized_attention(
|
||||
q.view(b, s, n * d),
|
||||
kv_cache["k"][:, :new_end].view(b, new_end, n * d),
|
||||
kv_cache["v"][:, :new_end].view(b, new_end, n * d),
|
||||
heads=self.num_heads,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
|
||||
x = self.o(x)
|
||||
return x
|
||||
|
||||
|
||||
class CausalWanAttentionBlock(WanAttentionBlock):
|
||||
"""Transformer block with KV-cached self-attention and cross-attention caching."""
|
||||
|
||||
def __init__(self, cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size=(-1, -1), qk_norm=True, cross_attn_norm=False,
|
||||
eps=1e-6, operation_settings={}):
|
||||
super().__init__(cross_attn_type, dim, ffn_dim, num_heads,
|
||||
window_size, qk_norm, cross_attn_norm, eps,
|
||||
operation_settings=operation_settings)
|
||||
self.self_attn = CausalWanSelfAttention(
|
||||
dim, num_heads, window_size, qk_norm, eps,
|
||||
operation_settings=operation_settings)
|
||||
|
||||
def forward(self, x, e, freqs, context, context_img_len=257,
|
||||
kv_cache=None, crossattn_cache=None, transformer_options={}):
|
||||
if e.ndim < 4:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||
else:
|
||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e).unbind(2)
|
||||
|
||||
# Self-attention with optional KV cache
|
||||
x = x.contiguous()
|
||||
y = self.self_attn(
|
||||
torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)),
|
||||
freqs, kv_cache=kv_cache, transformer_options=transformer_options)
|
||||
x = torch.addcmul(x, y, repeat_e(e[2], x))
|
||||
del y
|
||||
|
||||
# Cross-attention with optional caching
|
||||
if crossattn_cache is not None and crossattn_cache.get("is_init"):
|
||||
q = self.cross_attn.norm_q(self.cross_attn.q(self.norm3(x)))
|
||||
x_ca = optimized_attention(
|
||||
q, crossattn_cache["k"], crossattn_cache["v"],
|
||||
heads=self.num_heads, transformer_options=transformer_options)
|
||||
x = x + self.cross_attn.o(x_ca)
|
||||
else:
|
||||
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len, transformer_options=transformer_options)
|
||||
if crossattn_cache is not None:
|
||||
crossattn_cache["k"] = self.cross_attn.norm_k(self.cross_attn.k(context))
|
||||
crossattn_cache["v"] = self.cross_attn.v(context)
|
||||
crossattn_cache["is_init"] = True
|
||||
|
||||
# FFN
|
||||
y = self.ffn(torch.addcmul(repeat_e(e[3], x), self.norm2(x), 1 + repeat_e(e[4], x)))
|
||||
x = torch.addcmul(x, y, repeat_e(e[5], x))
|
||||
return x
|
||||
|
||||
|
||||
class CausalWanModel(WanModel):
|
||||
"""
|
||||
Wan 2.1 diffusion backbone with causal KV-cache support.
|
||||
|
||||
Same weight structure as WanModel -- loads identical state dicts.
|
||||
Adds forward_block() for frame-by-frame autoregressive inference.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_type='t2v',
|
||||
patch_size=(1, 2, 2),
|
||||
text_len=512,
|
||||
in_dim=16,
|
||||
dim=2048,
|
||||
ffn_dim=8192,
|
||||
freq_dim=256,
|
||||
text_dim=4096,
|
||||
out_dim=16,
|
||||
num_heads=16,
|
||||
num_layers=32,
|
||||
window_size=(-1, -1),
|
||||
qk_norm=True,
|
||||
cross_attn_norm=True,
|
||||
eps=1e-6,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
operations=None):
|
||||
super().__init__(
|
||||
model_type=model_type, patch_size=patch_size, text_len=text_len,
|
||||
in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim,
|
||||
text_dim=text_dim, out_dim=out_dim, num_heads=num_heads,
|
||||
num_layers=num_layers, window_size=window_size, qk_norm=qk_norm,
|
||||
cross_attn_norm=cross_attn_norm, eps=eps, image_model=image_model,
|
||||
wan_attn_block_class=CausalWanAttentionBlock,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
|
||||
def forward_block(self, x, timestep, context, start_frame,
|
||||
kv_caches, crossattn_caches, clip_fea=None):
|
||||
"""
|
||||
Forward one temporal block for autoregressive inference.
|
||||
|
||||
Args:
|
||||
x: [B, C, block_frames, H, W] input latent for the current block
|
||||
timestep: [B, block_frames] per-frame timesteps
|
||||
context: [B, L, text_dim] raw text embeddings (pre-text_embedding)
|
||||
start_frame: temporal frame index for RoPE offset
|
||||
kv_caches: list of per-layer KV cache dicts
|
||||
crossattn_caches: list of per-layer cross-attention cache dicts
|
||||
clip_fea: optional CLIP features for I2V
|
||||
|
||||
Returns:
|
||||
flow_pred: [B, C_out, block_frames, H, W] flow prediction
|
||||
"""
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
bs, c, t, h, w = x.shape
|
||||
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
# Per-frame time embedding
|
||||
e = self.time_embedding(
|
||||
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype))
|
||||
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
|
||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||
|
||||
# Text embedding (reuses crossattn_cache after first block)
|
||||
context = self.text_embedding(context)
|
||||
|
||||
context_img_len = None
|
||||
if clip_fea is not None and self.img_emb is not None:
|
||||
context_clip = self.img_emb(clip_fea)
|
||||
context = torch.concat([context_clip, context], dim=1)
|
||||
context_img_len = clip_fea.shape[-2]
|
||||
|
||||
# RoPE for current block's temporal position
|
||||
freqs = self.rope_encode(t, h, w, t_start=start_frame, device=x.device, dtype=x.dtype)
|
||||
|
||||
# Transformer blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
x = block(x, e=e0, freqs=freqs, context=context,
|
||||
context_img_len=context_img_len,
|
||||
kv_cache=kv_caches[i],
|
||||
crossattn_cache=crossattn_caches[i])
|
||||
|
||||
# Head
|
||||
x = self.head(x, e)
|
||||
|
||||
# Unpatchify
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x[:, :, :t, :h, :w]
|
||||
|
||||
def init_kv_caches(self, batch_size, max_seq_len, device, dtype):
|
||||
"""Create fresh KV caches for all layers."""
|
||||
caches = []
|
||||
for _ in range(self.num_layers):
|
||||
caches.append({
|
||||
"k": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
|
||||
"v": torch.zeros(batch_size, max_seq_len, self.num_heads, self.head_dim, device=device, dtype=dtype),
|
||||
"end": 0,
|
||||
})
|
||||
return caches
|
||||
|
||||
def init_crossattn_caches(self, batch_size, device, dtype):
|
||||
"""Create fresh cross-attention caches for all layers."""
|
||||
caches = []
|
||||
for _ in range(self.num_layers):
|
||||
caches.append({"is_init": False})
|
||||
return caches
|
||||
|
||||
def reset_kv_caches(self, kv_caches):
|
||||
"""Reset KV caches to empty (reuse allocated memory)."""
|
||||
for cache in kv_caches:
|
||||
cache["end"] = 0
|
||||
|
||||
def reset_crossattn_caches(self, crossattn_caches):
|
||||
"""Reset cross-attention caches."""
|
||||
for cache in crossattn_caches:
|
||||
cache["is_init"] = False
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.dim // self.num_heads
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
ar_state = transformer_options.get("ar_state")
|
||||
if ar_state is not None:
|
||||
bs = x.shape[0]
|
||||
block_frames = x.shape[2]
|
||||
t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames)
|
||||
return self.forward_block(
|
||||
x=x, timestep=t_per_frame, context=context,
|
||||
start_frame=ar_state["start_frame"],
|
||||
kv_caches=ar_state["kv_caches"],
|
||||
crossattn_caches=ar_state["crossattn_caches"],
|
||||
clip_fea=clip_fea,
|
||||
)
|
||||
|
||||
return super().forward(x, timestep, context, clip_fea=clip_fea,
|
||||
time_dim_concat=time_dim_concat,
|
||||
transformer_options=transformer_options, **kwargs)
|
||||
@ -42,6 +42,7 @@ import comfy.ldm.cosmos.predict2
|
||||
import comfy.ldm.lumina.model
|
||||
import comfy.ldm.wan.model
|
||||
import comfy.ldm.wan.model_animate
|
||||
import comfy.ldm.wan.ar_model
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
@ -1353,6 +1354,13 @@ class WAN21(BaseModel):
|
||||
return out
|
||||
|
||||
|
||||
class WAN21_CausalAR(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.wan.ar_model.CausalWanModel)
|
||||
self.image_to_video = False
|
||||
|
||||
|
||||
class WAN21_Vace(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
|
||||
|
||||
@ -719,11 +719,15 @@ class Sampler:
|
||||
sigma = float(sigmas[0])
|
||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||
|
||||
# "ar_video" is model-specific (requires Causal-WAN KV-cache interface + 5-D latents)
|
||||
# but is kept here so it appears in standard sampler dropdowns; sample_ar_video
|
||||
# validates at runtime and raises a clear error for incompatible checkpoints.
|
||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
|
||||
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece",
|
||||
"ar_video"]
|
||||
|
||||
class KSAMPLER(Sampler):
|
||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||
|
||||
@ -1165,6 +1165,25 @@ class WAN21_T2V(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
|
||||
|
||||
class WAN21_CausalAR_T2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "t2v",
|
||||
"causal_ar": True,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 5.0,
|
||||
}
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.unet_config.pop("causal_ar", None)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.WAN21_CausalAR(self, device=device)
|
||||
|
||||
|
||||
class WAN21_I2V(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -1734,6 +1753,6 @@ class LongCatImage(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_CausalAR_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
49
comfy_extras/nodes_ar_video.py
Normal file
49
comfy_extras/nodes_ar_video.py
Normal file
@ -0,0 +1,49 @@
|
||||
"""
|
||||
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
|
||||
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
|
||||
"""
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
class EmptyARVideoLatent(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="EmptyARVideoLatent",
|
||||
category="latent/video",
|
||||
inputs=[
|
||||
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=1024, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=64),
|
||||
],
|
||||
outputs=[
|
||||
io.Latent.Output(display_name="LATENT"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, width, height, length, batch_size) -> io.NodeOutput:
|
||||
lat_t = ((length - 1) // 4) + 1
|
||||
latent = torch.zeros(
|
||||
[batch_size, 16, lat_t, height // 8, width // 8],
|
||||
device=comfy.model_management.intermediate_device(),
|
||||
)
|
||||
return io.NodeOutput({"samples": latent})
|
||||
|
||||
|
||||
class ARVideoExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
EmptyARVideoLatent,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ARVideoExtension:
|
||||
return ARVideoExtension()
|
||||
Loading…
Reference in New Issue
Block a user