Fix VOID last-frame glitch by enforcing even latent_t.

This commit is contained in:
Talmaj Marinc 2026-04-16 19:30:20 +02:00
parent fe8906144c
commit 6c4b1cebd1

View File

@ -1,3 +1,5 @@
import logging
import nodes
import node_helpers
import torch
@ -9,6 +11,29 @@ from comfy.utils import model_trange as trange
from comfy_api.latest import io, ComfyExtension
from typing_extensions import override
TEMPORAL_COMPRESSION = 4
PATCH_SIZE_T = 2
def _valid_void_length(length: int) -> int:
"""Round ``length`` down to a value that produces an even latent_t.
VOID / CogVideoX-Fun-V1.5 uses patch_size_t=2, so the VAE-encoded latent
must have an even temporal dimension. If latent_t is odd, the transformer
pad_to_patch_size circular-wraps an extra latent frame onto the end; after
the post-transformer crop the last real latent frame has been influenced
by the wrapped phantom frame, producing visible jitter and "disappearing"
subjects near the end of the decoded video. Rounding down fixes this.
"""
latent_t = ((length - 1) // TEMPORAL_COMPRESSION) + 1
if latent_t % PATCH_SIZE_T == 0:
return length
# Round latent_t down to the nearest multiple of PATCH_SIZE_T, then invert
# the ((length - 1) // TEMPORAL_COMPRESSION) + 1 formula. Floor at 1 frame
# so we never return a non-positive length.
target_latent_t = max(PATCH_SIZE_T, (latent_t // PATCH_SIZE_T) * PATCH_SIZE_T)
return (target_latent_t - 1) * TEMPORAL_COMPRESSION + 1
class VOIDQuadmaskPreprocess(io.ComfyNode):
"""Preprocess a quadmask video for VOID inpainting.
@ -88,8 +113,10 @@ class VOIDInpaintConditioning(io.ComfyNode):
io.Mask.Input("quadmask", tooltip="Preprocessed quadmask from VOIDQuadmaskPreprocess [T, H, W]"),
io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("length", default=49, min=1, max=nodes.MAX_RESOLUTION, step=1,
tooltip="Number of pixel frames to process"),
io.Int.Input("length", default=45, min=1, max=nodes.MAX_RESOLUTION, step=1,
tooltip="Number of pixel frames to process. For CogVideoX-Fun-V1.5 "
"(patch_size_t=2), latent_t must be even — lengths that "
"produce odd latent_t are rounded down (e.g. 49 → 45)."),
io.Int.Input("batch_size", default=1, min=1, max=64),
],
outputs=[
@ -103,8 +130,17 @@ class VOIDInpaintConditioning(io.ComfyNode):
def execute(cls, positive, negative, vae, video, quadmask,
width, height, length, batch_size) -> io.NodeOutput:
temporal_compression = 4
latent_t = ((length - 1) // temporal_compression) + 1
adjusted_length = _valid_void_length(length)
if adjusted_length != length:
logging.warning(
"VOIDInpaintConditioning: rounding length %d down to %d so that "
"latent_t is even (required by CogVideoX-Fun-V1.5 patch_size_t=2). "
"Using odd latent_t causes the last frame to be corrupted by "
"circular padding.", length, adjusted_length,
)
length = adjusted_length
latent_t = ((length - 1) // TEMPORAL_COMPRESSION) + 1
latent_h = height // 8
latent_w = width // 8
@ -188,8 +224,9 @@ class VOIDWarpedNoise(io.ComfyNode):
io.Image.Input("video", tooltip="Pass 1 output video frames [T, H, W, 3]"),
io.Int.Input("width", default=672, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("height", default=384, min=16, max=nodes.MAX_RESOLUTION, step=8),
io.Int.Input("length", default=49, min=1, max=nodes.MAX_RESOLUTION, step=1,
tooltip="Number of pixel frames"),
io.Int.Input("length", default=45, min=1, max=nodes.MAX_RESOLUTION, step=1,
tooltip="Number of pixel frames. Rounded down to make latent_t "
"even (patch_size_t=2 requirement), e.g. 49 → 45."),
io.Int.Input("batch_size", default=1, min=1, max=64),
],
outputs=[
@ -211,8 +248,16 @@ class VOIDWarpedNoise(io.ComfyNode):
"VOIDWarpedNoise requires the 'rp' package. Install with: pip install rp"
)
temporal_compression = 4
latent_t = ((length - 1) // temporal_compression) + 1
adjusted_length = _valid_void_length(length)
if adjusted_length != length:
logging.warning(
"VOIDWarpedNoise: rounding length %d down to %d so that "
"latent_t is even (required by CogVideoX-Fun-V1.5 patch_size_t=2).",
length, adjusted_length,
)
length = adjusted_length
latent_t = ((length - 1) // TEMPORAL_COMPRESSION) + 1
latent_h = height // 8
latent_w = width // 8