mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-23 23:47:25 +08:00
Create a dedicated node for ar_sampler.
This commit is contained in:
parent
8ad8d101a1
commit
fc303cb2cf
@ -1813,18 +1813,21 @@ def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disa
|
|||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None,
|
||||||
|
num_frame_per_block=1):
|
||||||
"""
|
"""
|
||||||
Autoregressive video sampler: block-by-block denoising with KV cache
|
Autoregressive video sampler: block-by-block denoising with KV cache
|
||||||
and flow-match re-noising for Causal Forcing / Self-Forcing models.
|
and flow-match re-noising for Causal Forcing / Self-Forcing models.
|
||||||
|
|
||||||
Requires a Causal-WAN compatible model (diffusion_model must expose
|
Requires a Causal-WAN compatible model (diffusion_model must expose
|
||||||
init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W].
|
init_kv_caches / init_crossattn_caches) and 5-D latents [B,C,T,H,W].
|
||||||
|
|
||||||
|
All AR-loop parameters are passed via the SamplerARVideo node, not read
|
||||||
|
from the checkpoint or transformer_options.
|
||||||
"""
|
"""
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
model_options = extra_args.get("model_options", {})
|
model_options = extra_args.get("model_options", {})
|
||||||
transformer_options = model_options.get("transformer_options", {})
|
transformer_options = model_options.get("transformer_options", {})
|
||||||
ar_config = transformer_options.get("ar_config", {})
|
|
||||||
|
|
||||||
if x.ndim != 5:
|
if x.ndim != 5:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -1842,7 +1845,6 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
"does not support this interface — choose a different sampler."
|
"does not support this interface — choose a different sampler."
|
||||||
)
|
)
|
||||||
|
|
||||||
num_frame_per_block = ar_config.get("num_frame_per_block", 1)
|
|
||||||
seed = extra_args.get("seed", 0)
|
seed = extra_args.get("seed", 0)
|
||||||
|
|
||||||
bs, c, lat_t, lat_h, lat_w = x.shape
|
bs, c, lat_t, lat_h, lat_w = x.shape
|
||||||
|
|||||||
@ -719,15 +719,11 @@ class Sampler:
|
|||||||
sigma = float(sigmas[0])
|
sigma = float(sigmas[0])
|
||||||
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma
|
||||||
|
|
||||||
# "ar_video" is model-specific (requires Causal-WAN KV-cache interface + 5-D latents)
|
|
||||||
# but is kept here so it appears in standard sampler dropdowns; sample_ar_video
|
|
||||||
# validates at runtime and raises a clear error for incompatible checkpoints.
|
|
||||||
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral",
|
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2", "exp_heun_2_x0", "exp_heun_2_x0_sde", "dpm_2", "dpm_2_ancestral",
|
||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||||
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece",
|
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
|
||||||
"ar_video"]
|
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
|
|||||||
@ -1,12 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
|
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
|
||||||
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
|
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
|
||||||
|
- SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.samplers
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
@ -37,11 +39,44 @@ class EmptyARVideoLatent(io.ComfyNode):
|
|||||||
return io.NodeOutput({"samples": latent})
|
return io.NodeOutput({"samples": latent})
|
||||||
|
|
||||||
|
|
||||||
|
class SamplerARVideo(io.ComfyNode):
|
||||||
|
"""Sampler for autoregressive video models (Causal Forcing, Self-Forcing).
|
||||||
|
|
||||||
|
All AR-loop parameters are owned by this node so they live in the workflow.
|
||||||
|
Add new widgets here as the AR sampler grows new options.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SamplerARVideo",
|
||||||
|
display_name="Sampler AR Video",
|
||||||
|
category="sampling/custom_sampling/samplers",
|
||||||
|
inputs=[
|
||||||
|
io.Int.Input(
|
||||||
|
"num_frame_per_block",
|
||||||
|
default=1, min=1, max=64,
|
||||||
|
tooltip="Frames per autoregressive block. 1 = framewise, "
|
||||||
|
"3 = chunkwise. Must match the checkpoint's training mode.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[io.Sampler.Output()],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, num_frame_per_block) -> io.NodeOutput:
|
||||||
|
extra_options = {
|
||||||
|
"num_frame_per_block": num_frame_per_block,
|
||||||
|
}
|
||||||
|
return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options))
|
||||||
|
|
||||||
|
|
||||||
class ARVideoExtension(ComfyExtension):
|
class ARVideoExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
EmptyARVideoLatent,
|
EmptyARVideoLatent,
|
||||||
|
SamplerARVideo,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user