diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 6978eb717..870bff369 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1810,3 +1810,84 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F def sample_sa_solver_pece(model, x, sigmas, extra_args=None, callback=None, disable=False, tau_func=None, s_noise=1.0, noise_sampler=None, predictor_order=3, corrector_order=4, simple_order_2=False): """Stochastic Adams Solver with PECE (Predict–Evaluate–Correct–Evaluate) mode (NeurIPS 2023).""" return sample_sa_solver(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, tau_func=tau_func, s_noise=s_noise, noise_sampler=noise_sampler, predictor_order=predictor_order, corrector_order=corrector_order, use_pece=True, simple_order_2=simple_order_2) + + +@torch.no_grad() +def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=None): + """ + Autoregressive video sampler: block-by-block denoising with KV cache + and flow-match re-noising for Causal Forcing / Self-Forcing models. + """ + extra_args = {} if extra_args is None else extra_args + model_options = extra_args.get("model_options", {}) + transformer_options = model_options.get("transformer_options", {}) + ar_config = transformer_options.get("ar_config", {}) + + num_frame_per_block = ar_config.get("num_frame_per_block", 1) + seed = extra_args.get("seed", 0) + + bs, c, lat_t, lat_h, lat_w = x.shape + frame_seq_len = (lat_h // 2) * (lat_w // 2) + num_blocks = lat_t // num_frame_per_block + + inner_model = model.inner_model.inner_model + causal_model = inner_model.diffusion_model + device = x.device + model_dtype = inner_model.get_dtype() + + kv_caches = causal_model.init_kv_caches(bs, lat_t * frame_seq_len, device, model_dtype) + crossattn_caches = causal_model.init_crossattn_caches(bs, device, model_dtype) + + output = torch.zeros_like(x) + s_in = x.new_ones([x.shape[0]]) + current_start_frame = 0 + num_sigma_steps = len(sigmas) - 1 + total_real_steps = num_blocks * num_sigma_steps + step_count = 0 + + for block_idx in trange(num_blocks, disable=disable): + bf = num_frame_per_block + fs, fe = current_start_frame, current_start_frame + bf + noisy_input = x[:, :, fs:fe] + + ar_state = { + "start_frame": current_start_frame, + "kv_caches": kv_caches, + "crossattn_caches": crossattn_caches, + } + transformer_options["ar_state"] = ar_state + + for i in range(num_sigma_steps): + denoised = model(noisy_input, sigmas[i] * s_in, **extra_args) + + if callback is not None: + # Scale step_count to [0, num_sigma_steps) so the progress bar fills gradually + scaled_i = step_count * num_sigma_steps // total_real_steps + callback({"x": noisy_input, "i": scaled_i, "sigma": sigmas[i], + "sigma_hat": sigmas[i], "denoised": denoised}) + + if sigmas[i + 1] == 0: + noisy_input = denoised + else: + sigma_next = sigmas[i + 1] + torch.manual_seed(seed + block_idx * 1000 + i) + fresh_noise = torch.randn_like(denoised) + noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise + + for cache in kv_caches: + cache["end"].fill_(cache["end"].item() - bf * frame_seq_len) + + step_count += 1 + + output[:, :, fs:fe] = noisy_input + + # Cache update: run model at t=0 with clean output to fill KV cache + for cache in kv_caches: + cache["end"].fill_(cache["end"].item() - bf * frame_seq_len) + zero_sigma = sigmas.new_zeros([1]) + _ = model(noisy_input, zero_sigma * s_in, **extra_args) + + current_start_frame += bf + + transformer_options.pop("ar_state", None) + return output diff --git a/comfy/ldm/wan/ar_model.py b/comfy/ldm/wan/ar_model.py index 775d675b7..0fe2a585c 100644 --- a/comfy/ldm/wan/ar_model.py +++ b/comfy/ldm/wan/ar_model.py @@ -281,7 +281,7 @@ class CausalWanModel(torch.nn.Module): # Per-frame time embedding → [B, block_frames, 6, dim] e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, timestep.flatten())) + sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype)) e = e.reshape(timestep.shape[0], -1, e.shape[-1]) e0 = self.time_projection(e).unflatten(2, (6, self.dim)) @@ -351,8 +351,20 @@ class CausalWanModel(torch.nn.Module): def head_dim(self): return self.dim // self.num_heads - # Standard forward for non-causal use (compatibility with ComfyUI infrastructure) def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs): + ar_state = transformer_options.get("ar_state") + if ar_state is not None: + bs = x.shape[0] + block_frames = x.shape[2] + t_per_frame = timestep.unsqueeze(1).expand(bs, block_frames) + return self.forward_block( + x=x, timestep=t_per_frame, context=context, + start_frame=ar_state["start_frame"], + kv_caches=ar_state["kv_caches"], + crossattn_caches=ar_state["crossattn_caches"], + clip_fea=clip_fea, + ) + bs, c, t, h, w = x.shape x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) @@ -369,7 +381,7 @@ class CausalWanModel(torch.nn.Module): freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype) e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, timestep.flatten())) + sinusoidal_embedding_1d(self.freq_dim, timestep.flatten()).to(dtype=x.dtype)) e = e.reshape(timestep.shape[0], -1, e.shape[-1]) e0 = self.time_projection(e).unflatten(2, (6, self.dim)) diff --git a/comfy/model_base.py b/comfy/model_base.py index 70aff886e..0bdfefba5 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -42,6 +42,7 @@ import comfy.ldm.cosmos.predict2 import comfy.ldm.lumina.model import comfy.ldm.wan.model import comfy.ldm.wan.model_animate +import comfy.ldm.wan.ar_model import comfy.ldm.hunyuan3d.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model @@ -1353,6 +1354,13 @@ class WAN21(BaseModel): return out +class WAN21_CausalAR(WAN21): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super(WAN21, self).__init__(model_config, model_type, device=device, + unet_model=comfy.ldm.wan.ar_model.CausalWanModel) + self.image_to_video = False + + class WAN21_Vace(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel) diff --git a/comfy/samplers.py b/comfy/samplers.py index 0a4d062db..03a07ec68 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -723,7 +723,8 @@ KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_c "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_2m_sde_heun", "dpmpp_2m_sde_heun_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp", - "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece"] + "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3", "sa_solver", "sa_solver_pece", + "ar_video"] class KSAMPLER(Sampler): def __init__(self, sampler_function, extra_options={}, inpaint_options={}): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 07feb31b3..4c5159fbe 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1165,6 +1165,15 @@ class WAN21_T2V(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect)) +class WAN21_CausalAR_T2V(WAN21_T2V): + sampling_settings = { + "shift": 5.0, + } + + def get_model(self, state_dict, prefix="", device=None): + return model_base.WAN21_CausalAR(self, device=device) + + class WAN21_I2V(WAN21_T2V): unet_config = { "image_model": "wan2.1", diff --git a/comfy_extras/nodes_ar_video.py b/comfy_extras/nodes_ar_video.py index 08010a6ac..41bed9414 100644 --- a/comfy_extras/nodes_ar_video.py +++ b/comfy_extras/nodes_ar_video.py @@ -1,8 +1,7 @@ """ ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.). - LoadARVideoModel: load original HF/training or pre-converted checkpoints - (auto-detects format and converts state dict at runtime) - - ARVideoSampler: autoregressive frame-by-frame sampling with KV cache + via the standard BaseModel + ModelPatcher pipeline """ import torch @@ -13,10 +12,9 @@ from typing_extensions import override import comfy.model_management import comfy.utils import comfy.ops -import comfy.latent_formats -from comfy.model_patcher import ModelPatcher -from comfy.ldm.wan.ar_model import CausalWanModel +import comfy.model_patcher from comfy.ldm.wan.ar_convert import extract_state_dict +from comfy.supported_models import WAN21_CausalAR_T2V from comfy_api.latest import ComfyExtension, io # ── Model size presets derived from Wan 2.1 configs ────────────────────────── @@ -36,6 +34,7 @@ class LoadARVideoModel(io.ComfyNode): category="loaders/video_models", inputs=[ io.Combo.Input("ckpt_name", options=folder_paths.get_filename_list("diffusion_models")), + io.Int.Input("num_frame_per_block", default=1, min=1, max=21), ], outputs=[ io.Model.Output(display_name="MODEL"), @@ -43,21 +42,21 @@ class LoadARVideoModel(io.ComfyNode): ) @classmethod - def execute(cls, ckpt_name) -> io.NodeOutput: + def execute(cls, ckpt_name, num_frame_per_block) -> io.NodeOutput: ckpt_path = folder_paths.get_full_path_or_raise("diffusion_models", ckpt_name) raw = comfy.utils.load_torch_file(ckpt_path) sd = extract_state_dict(raw, use_ema=True) del raw dim = sd["head.modulation"].shape[-1] - out_dim = sd["head.head.weight"].shape[0] // 4 # prod(patch_size) * out_dim + out_dim = sd["head.head.weight"].shape[0] // 4 in_dim = sd["patch_embedding.weight"].shape[1] num_layers = 0 while f"blocks.{num_layers}.self_attn.q.weight" in sd: num_layers += 1 if dim in WAN_CONFIGS: - ffn_dim, num_heads, expected_layers, text_dim = WAN_CONFIGS[dim] + ffn_dim, num_heads, _, text_dim = WAN_CONFIGS[dim] else: num_heads = dim // 128 ffn_dim = sd["blocks.0.ffn.0.weight"].shape[0] @@ -66,57 +65,60 @@ class LoadARVideoModel(io.ComfyNode): cross_attn_norm = "blocks.0.norm3.weight" in sd + unet_config = { + "image_model": "wan2.1", + "model_type": "t2v", + "dim": dim, + "ffn_dim": ffn_dim, + "num_heads": num_heads, + "num_layers": num_layers, + "in_dim": in_dim, + "out_dim": out_dim, + "text_dim": text_dim, + "cross_attn_norm": cross_attn_norm, + } + + model_config = WAN21_CausalAR_T2V(unet_config) + unet_dtype = comfy.model_management.unet_dtype( + model_params=comfy.utils.calculate_parameters(sd), + supported_dtypes=model_config.supported_inference_dtypes, + ) + manual_cast_dtype = comfy.model_management.unet_manual_cast( + unet_dtype, + comfy.model_management.get_torch_device(), + model_config.supported_inference_dtypes, + ) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) + + model = model_config.get_model(sd, "") load_device = comfy.model_management.get_torch_device() offload_device = comfy.model_management.unet_offload_device() - ops = comfy.ops.disable_weight_init - model = CausalWanModel( - model_type='t2v', - patch_size=(1, 2, 2), - text_len=512, - in_dim=in_dim, - dim=dim, - ffn_dim=ffn_dim, - freq_dim=256, - text_dim=text_dim, - out_dim=out_dim, - num_heads=num_heads, - num_layers=num_layers, - window_size=(-1, -1), - qk_norm=True, - cross_attn_norm=cross_attn_norm, - eps=1e-6, - device=offload_device, - dtype=torch.bfloat16, - operations=ops, + model_patcher = comfy.model_patcher.ModelPatcher( + model, load_device=load_device, offload_device=offload_device, ) + if not comfy.model_management.is_device_cpu(offload_device): + model.to(offload_device) + model.load_model_weights(sd, "") - model.load_state_dict(sd, strict=False) - model.eval() + model_patcher.model_options.setdefault("transformer_options", {})["ar_config"] = { + "num_frame_per_block": num_frame_per_block, + } - model_size = comfy.model_management.module_size(model) - patcher = ModelPatcher(model, load_device=load_device, - offload_device=offload_device, size=model_size) - patcher.model.latent_format = comfy.latent_formats.Wan21() - return io.NodeOutput(patcher) + return io.NodeOutput(model_patcher) -class ARVideoSampler(io.ComfyNode): +class EmptyARVideoLatent(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( - node_id="ARVideoSampler", - category="sampling", + node_id="EmptyARVideoLatent", + category="latent/video", inputs=[ - io.Model.Input("model"), - io.Conditioning.Input("positive"), - io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), io.Int.Input("width", default=832, min=16, max=8192, step=16), io.Int.Input("height", default=480, min=16, max=8192, step=16), - io.Int.Input("num_frames", default=81, min=1, max=1024, step=4), - io.Int.Input("num_frame_per_block", default=1, min=1, max=21), - io.Float.Input("timestep_shift", default=5.0, min=0.1, max=20.0, step=0.1), - io.String.Input("denoising_steps", default="1000,750,500,250"), + io.Int.Input("length", default=81, min=1, max=1024, step=4), + io.Int.Input("batch_size", default=1, min=1, max=64), ], outputs=[ io.Latent.Output(display_name="LATENT"), @@ -124,138 +126,13 @@ class ARVideoSampler(io.ComfyNode): ) @classmethod - def execute(cls, model, positive, seed, width, height, - num_frames, num_frame_per_block, timestep_shift, - denoising_steps) -> io.NodeOutput: - - device = comfy.model_management.get_torch_device() - - # Parse denoising steps - step_values = [int(s.strip()) for s in denoising_steps.split(",")] - - # Build scheduler sigmas (FlowMatch with shift) - num_train_timesteps = 1000 - raw_sigmas = torch.linspace(1.0, 0.003 / 1.002, num_train_timesteps + 1)[:-1] - sigmas = timestep_shift * raw_sigmas / (1.0 + (timestep_shift - 1.0) * raw_sigmas) - timesteps = sigmas * num_train_timesteps - - # Warp denoising step indices to actual timestep values - all_timesteps = torch.cat([timesteps, torch.tensor([0.0])]) - warped_steps = all_timesteps[num_train_timesteps - torch.tensor(step_values, dtype=torch.long)] - - # Get the CausalWanModel from the patcher - comfy.model_management.load_model_gpu(model) - causal_model = model.model - dtype = torch.bfloat16 - - # Extract text embeddings from conditioning - cond = positive[0][0].to(device=device, dtype=dtype) - if cond.ndim == 2: - cond = cond.unsqueeze(0) - - # Latent dimensions - lat_h = height // 8 - lat_w = width // 8 - lat_t = ((num_frames - 1) // 4) + 1 # Wan VAE temporal compression - in_channels = 16 - - # Generate noise - generator = torch.Generator(device="cpu").manual_seed(seed) - noise = torch.randn(1, in_channels, lat_t, lat_h, lat_w, - generator=generator, device="cpu").to(device=device, dtype=dtype) - - assert lat_t % num_frame_per_block == 0, \ - f"Latent frames ({lat_t}) must be divisible by num_frame_per_block ({num_frame_per_block})" - num_blocks = lat_t // num_frame_per_block - - # Tokens per frame: (H/patch_h) * (W/patch_w) per temporal patch - frame_seq_len = (lat_h // 2) * (lat_w // 2) # patch_size = (1,2,2) - max_seq_len = lat_t * frame_seq_len - - # Initialize caches - kv_caches = causal_model.init_kv_caches(1, max_seq_len, device, dtype) - crossattn_caches = causal_model.init_crossattn_caches(1, device, dtype) - - output = torch.zeros_like(noise) - pbar = comfy.utils.ProgressBar(num_blocks * len(warped_steps) + num_blocks) - - current_start_frame = 0 - for block_idx in range(num_blocks): - block_frames = num_frame_per_block - frame_start = current_start_frame - frame_end = current_start_frame + block_frames - - # Noise slice for this block: [B, C, block_frames, H, W] - noisy_input = noise[:, :, frame_start:frame_end] - - # Denoising loop (e.g. 4 steps) - for step_idx, current_timestep in enumerate(warped_steps): - t_val = current_timestep.item() - - # Per-frame timestep tensor [B, block_frames] - timestep_tensor = torch.full( - (1, block_frames), t_val, device=device, dtype=dtype) - - # Model forward - flow_pred = causal_model.forward_block( - x=noisy_input, - timestep=timestep_tensor, - context=cond, - start_frame=current_start_frame, - kv_caches=kv_caches, - crossattn_caches=crossattn_caches, - ) - - # x0 = input - sigma * flow_pred - sigma_t = _lookup_sigma(sigmas, timesteps, t_val) - denoised = noisy_input - sigma_t * flow_pred - - if step_idx < len(warped_steps) - 1: - # Add noise for next step - next_t = warped_steps[step_idx + 1].item() - sigma_next = _lookup_sigma(sigmas, timesteps, next_t) - fresh_noise = torch.randn_like(denoised) - noisy_input = (1.0 - sigma_next) * denoised + sigma_next * fresh_noise - - # Roll back KV cache end pointer so next step re-writes same positions - for cache in kv_caches: - cache["end"].fill_(cache["end"].item() - block_frames * frame_seq_len) - else: - noisy_input = denoised - - pbar.update(1) - - output[:, :, frame_start:frame_end] = noisy_input - - # Cache update: forward at t=0 with clean output to fill KV cache - with torch.no_grad(): - # Reset cache end to before this block so the t=0 pass writes clean K/V - for cache in kv_caches: - cache["end"].fill_(cache["end"].item() - block_frames * frame_seq_len) - - t_zero = torch.zeros(1, block_frames, device=device, dtype=dtype) - causal_model.forward_block( - x=noisy_input, - timestep=t_zero, - context=cond, - start_frame=current_start_frame, - kv_caches=kv_caches, - crossattn_caches=crossattn_caches, - ) - - pbar.update(1) - current_start_frame += block_frames - - # Denormalize latents because VAEDecode expects raw latents. - latent_format = comfy.latent_formats.Wan21() - output_denorm = latent_format.process_out(output.float().cpu()) - return io.NodeOutput({"samples": output_denorm}) - - -def _lookup_sigma(sigmas, timesteps, t_val): - """Find the sigma corresponding to a timestep value.""" - idx = torch.argmin((timesteps - t_val).abs()).item() - return sigmas[idx] + def execute(cls, width, height, length, batch_size) -> io.NodeOutput: + lat_t = ((length - 1) // 4) + 1 + latent = torch.zeros( + [batch_size, 16, lat_t, height // 8, width // 8], + device=comfy.model_management.intermediate_device(), + ) + return io.NodeOutput({"samples": latent}) class ARVideoExtension(ComfyExtension): @@ -263,7 +140,7 @@ class ARVideoExtension(ComfyExtension): async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ LoadARVideoModel, - ARVideoSampler, + EmptyARVideoLatent, ]