mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-13 18:47:29 +08:00
Add I2V for causal forcing model. (#13719)
This commit is contained in:
parent
8dc3f3f209
commit
ef8f25601a
@ -1859,6 +1859,23 @@ def sample_ar_video(model, x, sigmas, extra_args=None, callback=None, disable=No
|
|||||||
output = torch.zeros_like(x)
|
output = torch.zeros_like(x)
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
current_start_frame = 0
|
current_start_frame = 0
|
||||||
|
|
||||||
|
# I2V: seed KV cache with the initial image latent before the denoising loop
|
||||||
|
initial_latent = transformer_options.get("ar_config", {}).get("initial_latent", None)
|
||||||
|
if initial_latent is not None:
|
||||||
|
initial_latent = inner_model.process_latent_in(initial_latent).to(device=device, dtype=model_dtype)
|
||||||
|
n_init = initial_latent.shape[2]
|
||||||
|
output[:, :, :n_init] = initial_latent
|
||||||
|
|
||||||
|
ar_state = {"start_frame": 0, "kv_caches": kv_caches, "crossattn_caches": crossattn_caches}
|
||||||
|
transformer_options["ar_state"] = ar_state
|
||||||
|
zero_sigma = sigmas.new_zeros([1])
|
||||||
|
_ = model(initial_latent, zero_sigma * s_in, **extra_args)
|
||||||
|
|
||||||
|
current_start_frame = n_init
|
||||||
|
remaining = lat_t - n_init
|
||||||
|
num_blocks = -(-remaining // num_frame_per_block)
|
||||||
|
|
||||||
num_sigma_steps = len(sigmas) - 1
|
num_sigma_steps = len(sigmas) - 1
|
||||||
total_real_steps = num_blocks * num_sigma_steps
|
total_real_steps = num_blocks * num_sigma_steps
|
||||||
step_count = 0
|
step_count = 0
|
||||||
|
|||||||
@ -2,6 +2,7 @@
|
|||||||
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
|
ComfyUI nodes for autoregressive video generation (Causal Forcing, Self-Forcing, etc.).
|
||||||
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
|
- EmptyARVideoLatent: create 5D [B, C, T, H, W] video latent tensors
|
||||||
- SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop
|
- SamplerARVideo: SAMPLER for the block-by-block autoregressive denoising loop
|
||||||
|
- ARVideoI2V: image-to-video conditioning for AR models (seeds KV cache with start image)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -9,6 +10,7 @@ from typing_extensions import override
|
|||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.samplers
|
import comfy.samplers
|
||||||
|
import comfy.utils
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
|
||||||
|
|
||||||
@ -71,12 +73,62 @@ class SamplerARVideo(io.ComfyNode):
|
|||||||
return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options))
|
return io.NodeOutput(comfy.samplers.ksampler("ar_video", extra_options))
|
||||||
|
|
||||||
|
|
||||||
|
class ARVideoI2V(io.ComfyNode):
|
||||||
|
"""Image-to-video setup for AR video models (Causal Forcing, Self-Forcing).
|
||||||
|
|
||||||
|
VAE-encodes the start image and stores it in the model's transformer_options
|
||||||
|
so that sample_ar_video can seed the KV cache before denoising.
|
||||||
|
Uses the same T2V model checkpoint -- no separate I2V architecture needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ARVideoI2V",
|
||||||
|
category="conditioning/video_models",
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Vae.Input("vae"),
|
||||||
|
io.Image.Input("start_image"),
|
||||||
|
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
||||||
|
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
||||||
|
io.Int.Input("length", default=81, min=1, max=1024, step=4),
|
||||||
|
io.Int.Input("batch_size", default=1, min=1, max=64),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(display_name="MODEL"),
|
||||||
|
io.Latent.Output(display_name="LATENT"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, vae, start_image, width, height, length, batch_size) -> io.NodeOutput:
|
||||||
|
start_image = comfy.utils.common_upscale(
|
||||||
|
start_image[:1].movedim(-1, 1), width, height, "bilinear", "center"
|
||||||
|
).movedim(1, -1)
|
||||||
|
|
||||||
|
initial_latent = vae.encode(start_image[:, :, :, :3])
|
||||||
|
|
||||||
|
m = model.clone()
|
||||||
|
to = m.model_options.setdefault("transformer_options", {})
|
||||||
|
ar_cfg = to.setdefault("ar_config", {})
|
||||||
|
ar_cfg["initial_latent"] = initial_latent
|
||||||
|
|
||||||
|
lat_t = ((length - 1) // 4) + 1
|
||||||
|
latent = torch.zeros(
|
||||||
|
[batch_size, 16, lat_t, height // 8, width // 8],
|
||||||
|
device=comfy.model_management.intermediate_device(),
|
||||||
|
)
|
||||||
|
return io.NodeOutput(m, {"samples": latent})
|
||||||
|
|
||||||
|
|
||||||
class ARVideoExtension(ComfyExtension):
|
class ARVideoExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
EmptyARVideoLatent,
|
EmptyARVideoLatent,
|
||||||
SamplerARVideo,
|
SamplerARVideo,
|
||||||
|
ARVideoI2V,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user