mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-29 21:13:33 +08:00
Add better error handling for a custom ar_video sampler.
This commit is contained in:
parent
e9cf4659d2
commit
2841684700
@ -1817,21 +1817,37 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
"""
|
"""
|
||||||
Autoregressive video sampler: block-by-block denoising with KV cache
|
Autoregressive video sampler: block-by-block denoising with KV cache
|
||||||
and flow-match re-noising for Causal Forcing / Self-Forcing models.
|
and flow-match re-noising for Causal Forcing / Self-Forcing models.
|
||||||
|
|
||||||
|
Requires a Causal-WAN compatible model (diffusion_model must expose
|
||||||
|
init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W].
|
||||||
"""
|
"""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
model_options = extra_args.get("model_options", {})
|
model_options = extra_args.get("model_options", {})
|
||||||
transformer_options = model_options.get("transformer_options", {})
|
transformer_options = model_options.get("transformer_options", {})
|
||||||
ar_config = transformer_options.get("ar_config", {})
|
ar_config = transformer_options.get("ar_config", {})
|
||||||
|
|
||||||
|
if x.ndim != 5:
|
||||||
|
raise ValueError(
|
||||||
|
f"ar_video sampler requires 5-D video latents [B,C,T,H,W], got {x.ndim}-D tensor with shape {x.shape}. "
|
||||||
|
"This sampler is only compatible with autoregressive video models (e.g. Causal-WAN)."
|
||||||
|
)
|
||||||
|
|
||||||
|
inner_model = model.inner_model.inner_model
|
||||||
|
causal_model = inner_model.diffusion_model
|
||||||
|
|
||||||
|
if not (hasattr(causal_model, "init_kv_caches") and hasattr(causal_model, "init_crossattn_caches")):
|
||||||
|
raise TypeError(
|
||||||
|
"ar_video sampler requires a Causal-WAN compatible model whose diffusion_model "
|
||||||
|
"exposes init_kv_caches() and init_crossattn_caches(). The loaded checkpoint "
|
||||||
|
"does not support this interface — choose a different sampler."
|
||||||
|
)
|
||||||
|
|
||||||
num_frame_per_block = ar_config.get("num_frame_per_block", 1)
|
num_frame_per_block = ar_config.get("num_frame_per_block", 1)
|
||||||
seed = extra_args.get("seed", 0)
|
seed = extra_args.get("seed", 0)
|
||||||
|
|
||||||
bs, c, lat_t, lat_h, lat_w = x.shape
|
bs, c, lat_t, lat_h, lat_w = x.shape
|
||||||
frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division
|
frame_seq_len = -(-lat_h // 2) * -(-lat_w // 2) # ceiling division
|
||||||
num_blocks = -(-lat_t // num_frame_per_block) # ceiling division
|
num_blocks = -(-lat_t // num_frame_per_block) # ceiling division
|
||||||
|
|
||||||
inner_model = model.inner_model.inner_model
|
|
||||||
causal_model = inner_model.diffusion_model
|
|
||||||
device = x.device
|
device = x.device
|
||||||
model_dtype = inner_model.get_dtype()
|
model_dtype = inner_model.get_dtype()
|
||||||
|
|
||||||
|
|||||||
@ -719,6 +719,9 @@ class Sampler:
|
|||||||
sigma = float(sigmas[0])
|
sigma = float(sigmas[0])
|
||||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||||
|
|
||||||
|
# "ar_video" is model-specific (requires Causal-WAN KV-cache interface + 5-D latents)
|
||||||
|
# but is kept here so it appears in standard sampler dropdowns; sample_ar_video
|
||||||
|
# validates at runtime and raises a clear error for incompatible checkpoints.
|
||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user