mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 21:43:43 +08:00
Rewrite causual forcing using custom sampler with KSampler node.
This commit is contained in:
parent
6f9af338ae
commit
3a9547192e
@ -1810,3 +1810,84 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
|
|||||||
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
|
def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False):
|
||||||
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
|
"""Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023)."""
|
||||||
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
|
return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None):
|
||||||
|
"""
|
||||||
|
Autoregressive video sampler: block-by-block denoising with KV cache
|
||||||
|
and flow-match re-noising for Causal Forcing / Self-Forcing models.
|
||||||
|
"""
|
||||||
|
extra_args = {} if extra_args is None else extra_args
|
||||||
|
model_options = extra_args.get("model_options", {})
|
||||||
|
transformer_options = model_options.get("transformer_options", {})
|
||||||
|
ar_config = transformer_options.get("ar_config", {})
|
||||||
|
|
||||||
|
num_frame_per_block = ar_config.get("num_frame_per_block", 1)
|
||||||
|
seed = extra_args.get("seed", 0)
|
||||||
|
|
||||||
|
bs, c, lat_t, lat_h, lat_w = x.shape
|
||||||
|
frame_seq_len = (lat_h // 2) * (lat_w // 2)
|
||||||
|
num_blocks = lat_t // num_frame_per_block
|
||||||
|
|
||||||
|
inner_model = model.inner_model.inner_model
|
||||||
|
causal_model = inner_model.diffusion_model
|
||||||
|
device = x.device
|
||||||
|
model_dtype = inner_model.get_dtype()
|
||||||
|
|
||||||
|
kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype)
|
||||||
|
crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype)
|
||||||
|
|
||||||
|
output = torch.zeros_like(x)
|
||||||
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
current_start_frame = 0
|
||||||
|
num_sigma_steps = len(sigmas) - 1
|
||||||
|
total_real_steps = num_blocks * num_sigma_steps
|
||||||
|
step_count = 0
|
||||||
|
|
||||||
|
for block_idx in trange(num_blocks, disable=disable):
|
||||||
|
bf = num_frame_per_block
|
||||||
|
fs, fe = current_start_frame, current_start_frame + bf
|
||||||
|
noisy_input = x[:, :, fs:fe]
|
||||||
|
|
||||||
|
ar_state = {
|
||||||
|
"start_frame": current_start_frame,
|
||||||
|
"kv_caches": kv_caches,
|
||||||
|
"crossattn_caches": crossattn_caches,
|
||||||
|
}
|
||||||
|
transformer_options["ar_state"] = ar_state
|
||||||
|
|
||||||
|
for i in range(num_sigma_steps):
|
||||||
|
denoised = model(noisy_input, sigmas[i] * s_in, **extra_args)
|
||||||
|
|
||||||
|
if callback is not None:
|
||||||
|
# Scale step_count to [0, num_sigma_steps) so the progress bar fills gradually
|
||||||
|
scaled_i = step_count * num_sigma_steps // total_real_steps
|
||||||
|
callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i],
|
||||||
|
"sigma_hat": sigmas[i], "denoised": denoised})
|
||||||
|
|
||||||
|
if sigmas[i + 1] == 0:
|
||||||
|
noisy_input = denoised
|
||||||
|
else:
|
||||||
|
sigma_next = sigmas[i + 1]
|
||||||
|
torch.manual_seed(seed + block_idx * 1000 + i)
|
||||||
|
fresh_noise = torch.randn_like(denoised)
|
||||||
|
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
|
||||||
|
|
||||||
|
for cache in kv_caches:
|
||||||
|
cache["end"].fill_(cache["end"].item() - bf * frame_seq_len)
|
||||||
|
|
||||||
|
step_count += 1
|
||||||
|
|
||||||
|
output[:, :, fs:fe] = noisy_input
|
||||||
|
|
||||||
|
# Cache update: run model at t=0 with clean output to fill KV cache
|
||||||
|
for cache in kv_caches:
|
||||||
|
cache["end"].fill_(cache["end"].item() - bf * frame_seq_len)
|
||||||
|
zero_sigma = sigmas.new_zeros([1])
|
||||||
|
_ = model(noisy_input, zero_sigma * s_in, **extra_args)
|
||||||
|
|
||||||
|
current_start_frame += bf
|
||||||
|
|
||||||
|
transformer_options.pop("ar_state", None)
|
||||||
|
return output
|
||||||
|
|||||||
@ -281,7 +281,7 @@ class CausalWanModel(torch.nn.Module):
|
|||||||
|
|
||||||
# Per-frame time embedding → [B, block_frames, 6, dim]
|
# Per-frame time embedding → [B, block_frames, 6, dim]
|
||||||
e = self.time_embedding(
|
e = self.time_embedding(
|
||||||
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()))
|
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype))
|
||||||
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
|
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
|
||||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||||
|
|
||||||
@ -351,8 +351,20 @@ class CausalWanModel(torch.nn.Module):
|
|||||||
def head_dim(self):
|
def head_dim(self):
|
||||||
return self.dim // self.num_heads
|
return self.dim // self.num_heads
|
||||||
|
|
||||||
# Standard forward for non-causal use (compatibility with ComfyUI infrastructure)
|
|
||||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||||
|
ar_state = transformer_options.get("ar_state")
|
||||||
|
if ar_state is not None:
|
||||||
|
bs = x.shape[0]
|
||||||
|
block_frames = x.shape[2]
|
||||||
|
t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames)
|
||||||
|
return self.forward_block(
|
||||||
|
x=x, timestep=t_per_frame, context=context,
|
||||||
|
start_frame=ar_state["start_frame"],
|
||||||
|
kv_caches=ar_state["kv_caches"],
|
||||||
|
crossattn_caches=ar_state["crossattn_caches"],
|
||||||
|
clip_fea=clip_fea,
|
||||||
|
)
|
||||||
|
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||||
|
|
||||||
@ -369,7 +381,7 @@ class CausalWanModel(torch.nn.Module):
|
|||||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
|
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
e = self.time_embedding(
|
e = self.time_embedding(
|
||||||
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()))
|
sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype))
|
||||||
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
|
e = e.reshape(timestep.shape[0], -1, e.shape[-1])
|
||||||
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
e0 = self.time_projection(e).unflatten(2, (6, self.dim))
|
||||||
|
|
||||||
|
|||||||
@ -42,6 +42,7 @@ import comfy.ldm.cosmos.predict2
|
|||||||
import comfy.ldm.lumina.model
|
import comfy.ldm.lumina.model
|
||||||
import comfy.ldm.wan.model
|
import comfy.ldm.wan.model
|
||||||
import comfy.ldm.wan.model_animate
|
import comfy.ldm.wan.model_animate
|
||||||
|
import comfy.ldm.wan.ar_model
|
||||||
import comfy.ldm.hunyuan3d.model
|
import comfy.ldm.hunyuan3d.model
|
||||||
import comfy.ldm.hidream.model
|
import comfy.ldm.hidream.model
|
||||||
import comfy.ldm.chroma.model
|
import comfy.ldm.chroma.model
|
||||||
@ -1353,6 +1354,13 @@ class WAN21(BaseModel):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class WAN21_CausalAR(WAN21):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super(WAN21, self).__init__(model_config, model_type, device=device,
|
||||||
|
unet_model=comfy.ldm.wan.ar_model.CausalWanModel)
|
||||||
|
self.image_to_video = False
|
||||||
|
|
||||||
|
|
||||||
class WAN21_Vace(WAN21):
|
class WAN21_Vace(WAN21):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
|
||||||
|
|||||||
@ -723,7 +723,8 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c
|
|||||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||||
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
|
||||||
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
"ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
|
||||||
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"]
|
"gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece",
|
||||||
|
"ar_video"]
|
||||||
|
|
||||||
class KSAMPLER(Sampler):
|
class KSAMPLER(Sampler):
|
||||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||||
|
|||||||
@ -1165,6 +1165,15 @@ class WAN21_T2V(supported_models_base.BASE):
|
|||||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
|
||||||
|
|
||||||
|
class WAN21_CausalAR_T2V(WAN21_T2V):
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 5.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
return model_base.WAN21_CausalAR(self, device=device)
|
||||||
|
|
||||||
|
|
||||||
class WAN21_I2V(WAN21_T2V):
|
class WAN21_I2V(WAN21_T2V):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "wan2.1",
|
"image_model": "wan2.1",
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
"""
|
"""
|
||||||
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
|
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
|
||||||
- LoadARVideoModel: load original HF/training or pre-converted checkpoints
|
- LoadARVideoModel: load original HF/training or pre-converted checkpoints
|
||||||
(auto-detects format and converts state dict at runtime)
|
via the standard BaseModel + ModelPatcher pipeline
|
||||||
- ARVideoSampler: autoregressive frame-by-frame sampling with KV cache
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -13,10 +12,9 @@ from typing_extensions import override
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
import comfy.latent_formats
|
import comfy.model_patcher
|
||||||
from comfy.model_patcher import ModelPatcher
|
|
||||||
from comfy.ldm.wan.ar_model import CausalWanModel
|
|
||||||
from comfy.ldm.wan.ar_convert import extract_state_dict
|
from comfy.ldm.wan.ar_convert import extract_state_dict
|
||||||
|
from comfy.supported_models import WAN21_CausalAR_T2V
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
# ── Model size presets derived from Wan 2.1 configs ──────────────────────────
|
# ── Model size presets derived from Wan 2.1 configs ──────────────────────────
|
||||||
@ -36,6 +34,7 @@ class LoadARVideoModel(io.ComfyNode):
|
|||||||
category="loaders/video_models",
|
category="loaders/video_models",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("diffusion_models")),
|
io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("diffusion_models")),
|
||||||
|
io.Int.Input("num_frame_per_block", default=1, min=1, max=21),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(display_name="MODEL"),
|
io.Model.Output(display_name="MODEL"),
|
||||||
@ -43,21 +42,21 @@ class LoadARVideoModel(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, ckpt_name) -> io.NodeOutput:
|
def execute(cls, ckpt_name, num_frame_per_block) -> io.NodeOutput:
|
||||||
ckpt_path = folder_paths.get_full_path_or_raise("diffusion_models", ckpt_name)
|
ckpt_path = folder_paths.get_full_path_or_raise("diffusion_models", ckpt_name)
|
||||||
raw = comfy.utils.load_torch_file(ckpt_path)
|
raw = comfy.utils.load_torch_file(ckpt_path)
|
||||||
sd = extract_state_dict(raw, use_ema=True)
|
sd = extract_state_dict(raw, use_ema=True)
|
||||||
del raw
|
del raw
|
||||||
|
|
||||||
dim = sd["head.modulation"].shape[-1]
|
dim = sd["head.modulation"].shape[-1]
|
||||||
out_dim = sd["head.head.weight"].shape[0] // 4 # prod(patch_size) * out_dim
|
out_dim = sd["head.head.weight"].shape[0] // 4
|
||||||
in_dim = sd["patch_embedding.weight"].shape[1]
|
in_dim = sd["patch_embedding.weight"].shape[1]
|
||||||
num_layers = 0
|
num_layers = 0
|
||||||
while f"blocks.{num_layers}.self_attn.q.weight" in sd:
|
while f"blocks.{num_layers}.self_attn.q.weight" in sd:
|
||||||
num_layers += 1
|
num_layers += 1
|
||||||
|
|
||||||
if dim in WAN_CONFIGS:
|
if dim in WAN_CONFIGS:
|
||||||
ffn_dim, num_heads, expected_layers, text_dim = WAN_CONFIGS[dim]
|
ffn_dim, num_heads, _, text_dim = WAN_CONFIGS[dim]
|
||||||
else:
|
else:
|
||||||
num_heads = dim // 128
|
num_heads = dim // 128
|
||||||
ffn_dim = sd["blocks.0.ffn.0.weight"].shape[0]
|
ffn_dim = sd["blocks.0.ffn.0.weight"].shape[0]
|
||||||
@ -66,57 +65,60 @@ class LoadARVideoModel(io.ComfyNode):
|
|||||||
|
|
||||||
cross_attn_norm = "blocks.0.norm3.weight" in sd
|
cross_attn_norm = "blocks.0.norm3.weight" in sd
|
||||||
|
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "wan2.1",
|
||||||
|
"model_type": "t2v",
|
||||||
|
"dim": dim,
|
||||||
|
"ffn_dim": ffn_dim,
|
||||||
|
"num_heads": num_heads,
|
||||||
|
"num_layers": num_layers,
|
||||||
|
"in_dim": in_dim,
|
||||||
|
"out_dim": out_dim,
|
||||||
|
"text_dim": text_dim,
|
||||||
|
"cross_attn_norm": cross_attn_norm,
|
||||||
|
}
|
||||||
|
|
||||||
|
model_config = WAN21_CausalAR_T2V(unet_config)
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype(
|
||||||
|
model_params=comfy.utils.calculate_parameters(sd),
|
||||||
|
supported_dtypes=model_config.supported_inference_dtypes,
|
||||||
|
)
|
||||||
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(
|
||||||
|
unet_dtype,
|
||||||
|
comfy.model_management.get_torch_device(),
|
||||||
|
model_config.supported_inference_dtypes,
|
||||||
|
)
|
||||||
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
|
|
||||||
|
model = model_config.get_model(sd, "")
|
||||||
load_device = comfy.model_management.get_torch_device()
|
load_device = comfy.model_management.get_torch_device()
|
||||||
offload_device = comfy.model_management.unet_offload_device()
|
offload_device = comfy.model_management.unet_offload_device()
|
||||||
ops = comfy.ops.disable_weight_init
|
|
||||||
|
|
||||||
model = CausalWanModel(
|
model_patcher = comfy.model_patcher.ModelPatcher(
|
||||||
model_type='t2v',
|
model, load_device=load_device, offload_device=offload_device,
|
||||||
patch_size=(1, 2, 2),
|
|
||||||
text_len=512,
|
|
||||||
in_dim=in_dim,
|
|
||||||
dim=dim,
|
|
||||||
ffn_dim=ffn_dim,
|
|
||||||
freq_dim=256,
|
|
||||||
text_dim=text_dim,
|
|
||||||
out_dim=out_dim,
|
|
||||||
num_heads=num_heads,
|
|
||||||
num_layers=num_layers,
|
|
||||||
window_size=(-1, -1),
|
|
||||||
qk_norm=True,
|
|
||||||
cross_attn_norm=cross_attn_norm,
|
|
||||||
eps=1e-6,
|
|
||||||
device=offload_device,
|
|
||||||
dtype=torch.bfloat16,
|
|
||||||
operations=ops,
|
|
||||||
)
|
)
|
||||||
|
if not comfy.model_management.is_device_cpu(offload_device):
|
||||||
|
model.to(offload_device)
|
||||||
|
model.load_model_weights(sd, "")
|
||||||
|
|
||||||
model.load_state_dict(sd, strict=False)
|
model_patcher.model_options.setdefault("transformer_options", {})["ar_config"] = {
|
||||||
model.eval()
|
"num_frame_per_block": num_frame_per_block,
|
||||||
|
}
|
||||||
|
|
||||||
model_size = comfy.model_management.module_size(model)
|
return io.NodeOutput(model_patcher)
|
||||||
patcher = ModelPatcher(model, load_device=load_device,
|
|
||||||
offload_device=offload_device, size=model_size)
|
|
||||||
patcher.model.latent_format = comfy.latent_formats.Wan21()
|
|
||||||
return io.NodeOutput(patcher)
|
|
||||||
|
|
||||||
|
|
||||||
class ARVideoSampler(io.ComfyNode):
|
class EmptyARVideoLatent(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
return io.Schema(
|
return io.Schema(
|
||||||
node_id="ARVideoSampler",
|
node_id="EmptyARVideoLatent",
|
||||||
category="sampling",
|
category="latent/video",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Model.Input("model"),
|
|
||||||
io.Conditioning.Input("positive"),
|
|
||||||
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True),
|
|
||||||
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
||||||
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
||||||
io.Int.Input("num_frames", default=81, min=1, max=1024, step=4),
|
io.Int.Input("length", default=81, min=1, max=1024, step=4),
|
||||||
io.Int.Input("num_frame_per_block", default=1, min=1, max=21),
|
io.Int.Input("batch_size", default=1, min=1, max=64),
|
||||||
io.Float.Input("timestep_shift", default=5.0, min=0.1, max=20.0, step=0.1),
|
|
||||||
io.String.Input("denoising_steps", default="1000,750,500,250"),
|
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Latent.Output(display_name="LATENT"),
|
io.Latent.Output(display_name="LATENT"),
|
||||||
@ -124,138 +126,13 @@ class ARVideoSampler(io.ComfyNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, model, positive, seed, width, height,
|
def execute(cls, width, height, length, batch_size) -> io.NodeOutput:
|
||||||
num_frames, num_frame_per_block, timestep_shift,
|
lat_t = ((length - 1) // 4) + 1
|
||||||
denoising_steps) -> io.NodeOutput:
|
latent = torch.zeros(
|
||||||
|
[batch_size, 16, lat_t, height // 8, width // 8],
|
||||||
device = comfy.model_management.get_torch_device()
|
device=comfy.model_management.intermediate_device(),
|
||||||
|
)
|
||||||
# Parse denoising steps
|
return io.NodeOutput({"samples": latent})
|
||||||
step_values = [int(s.strip()) for s in denoising_steps.split(",")]
|
|
||||||
|
|
||||||
# Build scheduler sigmas (FlowMatch with shift)
|
|
||||||
num_train_timesteps = 1000
|
|
||||||
raw_sigmas = torch.linspace(1.0, 0.003 / 1.002, num_train_timesteps + 1)[:-1]
|
|
||||||
sigmas = timestep_shift * raw_sigmas / (1.0 + (timestep_shift - 1.0) * raw_sigmas)
|
|
||||||
timesteps = sigmas * num_train_timesteps
|
|
||||||
|
|
||||||
# Warp denoising step indices to actual timestep values
|
|
||||||
all_timesteps = torch.cat([timesteps, torch.tensor([0.0])])
|
|
||||||
warped_steps = all_timesteps[num_train_timesteps - torch.tensor(step_values, dtype=torch.long)]
|
|
||||||
|
|
||||||
# Get the CausalWanModel from the patcher
|
|
||||||
comfy.model_management.load_model_gpu(model)
|
|
||||||
causal_model = model.model
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
|
|
||||||
# Extract text embeddings from conditioning
|
|
||||||
cond = positive[0][0].to(device=device, dtype=dtype)
|
|
||||||
if cond.ndim == 2:
|
|
||||||
cond = cond.unsqueeze(0)
|
|
||||||
|
|
||||||
# Latent dimensions
|
|
||||||
lat_h = height // 8
|
|
||||||
lat_w = width // 8
|
|
||||||
lat_t = ((num_frames - 1) // 4) + 1 # Wan VAE temporal compression
|
|
||||||
in_channels = 16
|
|
||||||
|
|
||||||
# Generate noise
|
|
||||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
|
||||||
noise = torch.randn(1, in_channels, lat_t, lat_h, lat_w,
|
|
||||||
generator=generator, device="cpu").to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
assert lat_t % num_frame_per_block == 0, \
|
|
||||||
f"Latent frames ({lat_t}) must be divisible by num_frame_per_block ({num_frame_per_block})"
|
|
||||||
num_blocks = lat_t // num_frame_per_block
|
|
||||||
|
|
||||||
# Tokens per frame: (H/patch_h) * (W/patch_w) per temporal patch
|
|
||||||
frame_seq_len = (lat_h // 2) * (lat_w // 2) # patch_size = (1,2,2)
|
|
||||||
max_seq_len = lat_t * frame_seq_len
|
|
||||||
|
|
||||||
# Initialize caches
|
|
||||||
kv_caches = causal_model.init_kv_caches(1, max_seq_len, device, dtype)
|
|
||||||
crossattn_caches = causal_model.init_crossattn_caches(1, device, dtype)
|
|
||||||
|
|
||||||
output = torch.zeros_like(noise)
|
|
||||||
pbar = comfy.utils.ProgressBar(num_blocks * len(warped_steps) + num_blocks)
|
|
||||||
|
|
||||||
current_start_frame = 0
|
|
||||||
for block_idx in range(num_blocks):
|
|
||||||
block_frames = num_frame_per_block
|
|
||||||
frame_start = current_start_frame
|
|
||||||
frame_end = current_start_frame + block_frames
|
|
||||||
|
|
||||||
# Noise slice for this block: [B, C, block_frames, H, W]
|
|
||||||
noisy_input = noise[:, :, frame_start:frame_end]
|
|
||||||
|
|
||||||
# Denoising loop (e.g. 4 steps)
|
|
||||||
for step_idx, current_timestep in enumerate(warped_steps):
|
|
||||||
t_val = current_timestep.item()
|
|
||||||
|
|
||||||
# Per-frame timestep tensor [B, block_frames]
|
|
||||||
timestep_tensor = torch.full(
|
|
||||||
(1, block_frames), t_val, device=device, dtype=dtype)
|
|
||||||
|
|
||||||
# Model forward
|
|
||||||
flow_pred = causal_model.forward_block(
|
|
||||||
x=noisy_input,
|
|
||||||
timestep=timestep_tensor,
|
|
||||||
context=cond,
|
|
||||||
start_frame=current_start_frame,
|
|
||||||
kv_caches=kv_caches,
|
|
||||||
crossattn_caches=crossattn_caches,
|
|
||||||
)
|
|
||||||
|
|
||||||
# x0 = input - sigma * flow_pred
|
|
||||||
sigma_t = _lookup_sigma(sigmas, timesteps, t_val)
|
|
||||||
denoised = noisy_input - sigma_t * flow_pred
|
|
||||||
|
|
||||||
if step_idx < len(warped_steps) - 1:
|
|
||||||
# Add noise for next step
|
|
||||||
next_t = warped_steps[step_idx + 1].item()
|
|
||||||
sigma_next = _lookup_sigma(sigmas, timesteps, next_t)
|
|
||||||
fresh_noise = torch.randn_like(denoised)
|
|
||||||
noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise
|
|
||||||
|
|
||||||
# Roll back KV cache end pointer so next step re-writes same positions
|
|
||||||
for cache in kv_caches:
|
|
||||||
cache["end"].fill_(cache["end"].item() - block_frames * frame_seq_len)
|
|
||||||
else:
|
|
||||||
noisy_input = denoised
|
|
||||||
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
output[:, :, frame_start:frame_end] = noisy_input
|
|
||||||
|
|
||||||
# Cache update: forward at t=0 with clean output to fill KV cache
|
|
||||||
with torch.no_grad():
|
|
||||||
# Reset cache end to before this block so the t=0 pass writes clean K/V
|
|
||||||
for cache in kv_caches:
|
|
||||||
cache["end"].fill_(cache["end"].item() - block_frames * frame_seq_len)
|
|
||||||
|
|
||||||
t_zero = torch.zeros(1, block_frames, device=device, dtype=dtype)
|
|
||||||
causal_model.forward_block(
|
|
||||||
x=noisy_input,
|
|
||||||
timestep=t_zero,
|
|
||||||
context=cond,
|
|
||||||
start_frame=current_start_frame,
|
|
||||||
kv_caches=kv_caches,
|
|
||||||
crossattn_caches=crossattn_caches,
|
|
||||||
)
|
|
||||||
|
|
||||||
pbar.update(1)
|
|
||||||
current_start_frame += block_frames
|
|
||||||
|
|
||||||
# Denormalize latents because VAEDecode expects raw latents.
|
|
||||||
latent_format = comfy.latent_formats.Wan21()
|
|
||||||
output_denorm = latent_format.process_out(output.float().cpu())
|
|
||||||
return io.NodeOutput({"samples": output_denorm})
|
|
||||||
|
|
||||||
|
|
||||||
def _lookup_sigma(sigmas, timesteps, t_val):
|
|
||||||
"""Find the sigma corresponding to a timestep value."""
|
|
||||||
idx = torch.argmin((timesteps - t_val).abs()).item()
|
|
||||||
return sigmas[idx]
|
|
||||||
|
|
||||||
|
|
||||||
class ARVideoExtension(ComfyExtension):
|
class ARVideoExtension(ComfyExtension):
|
||||||
@ -263,7 +140,7 @@ class ARVideoExtension(ComfyExtension):
|
|||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
LoadARVideoModel,
|
LoadARVideoModel,
|
||||||
ARVideoSampler,
|
EmptyARVideoLatent,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user