ComfyUI/comfy_extras/nodes_seedvr.py
2026-05-26 00:27:57 -05:00

1165 lines
49 KiB
Python

from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import torch
import math
import logging
from einops import rearrange
import gc
import comfy.model_management
import comfy.sample
import comfy.samplers
from comfy.ldm.seedvr.vae import (
adain_color_transfer,
lab_color_transfer,
wavelet_color_transfer,
)
from torchvision.transforms import functional as TVF
from torchvision.transforms import Lambda
from torchvision.transforms.functional import InterpolationMode
_SEEDVR2_INVALID_MODEL_MSG_PREFIX = (
"SeedVR2Conditioning: model object does not match expected SeedVR2 structure"
)
LAB_SCALE_MULTIPLIER = 13
WAVELET_SCALE_MULTIPLIER = 10
ADAIN_SCALE_MULTIPLIER = 6
COLOR_CORRECTION_MEMORY_HEADROOM = 0.75
# Private sentinel for getattr default: distinguishes "attribute missing"
# from "attribute present but None" so the failure message is accurate.
_ATTR_MISSING = object()
def _seedvr2_auto_chunk_attempts(t_latent, t_pixel, frames_per_chunk):
"""Return stricter 4n+1 frame chunk sizes for auto OOM retries."""
attempts = [frames_per_chunk]
current_chunk_latent = (
t_latent if t_pixel <= frames_per_chunk
else (frames_per_chunk - 1) // 4 + 1
)
current_chunk_count = max(1, math.ceil(t_latent / current_chunk_latent))
seen = {frames_per_chunk}
for target_chunks in range(max(2, current_chunk_count + 1), t_latent + 1):
chunk_latent = max(1, math.ceil(t_latent / target_chunks))
candidate = 4 * (chunk_latent - 1) + 1
if candidate in seen:
continue
if candidate >= attempts[-1]:
continue
attempts.append(candidate)
seen.add(candidate)
return attempts
def _resolve_seedvr2_diffusion_model(model):
"""Resolve the inner SeedVR2 diffusion-model module from a ComfyUI model
patcher object. Fails loud with a ``RuntimeError`` whose message begins
with ``_SEEDVR2_INVALID_MODEL_MSG_PREFIX`` when the expected wrapper
shape (``model.model.diffusion_model``) is absent.
Distinguishes four failure modes via the ``_ATTR_MISSING`` sentinel:
``model.model`` missing, ``model.model is None``,
``model.model.diffusion_model`` missing, ``model.model.diffusion_model
is None``. Each mode produces an accurate error message rather than
conflating "attribute missing" with "attribute is None".
"""
inner = getattr(model, "model", _ATTR_MISSING)
if inner is _ATTR_MISSING:
raise RuntimeError(
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input has no 'model' attribute "
f"(got type {type(model).__name__})."
)
if inner is None:
raise RuntimeError(
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input.model is None "
f"(input type {type(model).__name__})."
)
diffusion_model = getattr(inner, "diffusion_model", _ATTR_MISSING)
if diffusion_model is _ATTR_MISSING:
raise RuntimeError(
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model' has no "
f"'diffusion_model' attribute (got type {type(inner).__name__})."
)
if diffusion_model is None:
raise RuntimeError(
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model.diffusion_model' "
f"is None (model.model type {type(inner).__name__})."
)
return diffusion_model
def _apply_rope_freqs_float32_cast(diffusion_model):
"""Cast every nested module's ``rope.freqs`` parameter data to ``float32``
when it is not already in float32. Idempotency is per-tensor by dtype
check, NOT a per-instance sentinel attribute — a sentinel would survive
Comfy's dynamic model unload/reload cycle while ``rope.freqs`` itself
is restored from the archived dtype, leaving RoPE running in fp16/bf16
on subsequent calls. The dtype check makes the cast self-correcting
against weight-restore lifecycle events. Iteration cost is one walk of
the diffusion-model module tree per ``execute()`` call (microseconds).
"""
for module in diffusion_model.modules():
if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'):
if module.rope.freqs.data.dtype != torch.float32:
module.rope.freqs.data = module.rope.freqs.data.to(torch.float32)
def clear_vae_memory(vae_model):
for module in vae_model.modules():
if hasattr(module, "memory"):
module.memory = None
gc.collect()
comfy.model_management.soft_empty_cache()
def expand_dims(tensor, ndim):
shape = tensor.shape + (1,) * (ndim - tensor.ndim)
return tensor.reshape(shape)
def get_conditions(latent, latent_blur):
t, h, w, c = latent.shape
cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
cond[:, ..., :-1] = latent_blur[:]
cond[:, ..., -1:] = 1.0
return cond
def timestep_transform(timesteps, latents_shapes):
vt = 4
vs = 8
frames = (latents_shapes[:, 0] - 1) * vt + 1
heights = latents_shapes[:, 1] * vs
widths = latents_shapes[:, 2] * vs
# Compute shift factor.
def get_lin_function(x1, y1, x2, y2):
m = (y2 - y1) / (x2 - x1)
b = y1 - m * x1
return lambda x: m * x + b
img_shift_fn = get_lin_function(x1=256 * 256, y1=1.0, x2=1024 * 1024, y2=3.2)
vid_shift_fn = get_lin_function(x1=256 * 256 * 37, y1=1.0, x2=1280 * 720 * 145, y2=5.0)
shift = torch.where(
frames > 1,
vid_shift_fn(heights * widths * frames),
img_shift_fn(heights * widths),
).to(timesteps.device)
# Shift timesteps.
T = 1000.0
timesteps = timesteps / T
timesteps = shift * timesteps / (1 + (shift - 1) * timesteps)
timesteps = timesteps * T
return timesteps
def inter(x_0, x_T, t):
t = expand_dims(t, x_0.ndim)
T = 1000.0
B = lambda t: t / T
A = lambda t: 1 - (t / T)
return A(t) * x_0 + B(t) * x_T
def area_resize(image, max_area):
height, width = image.shape[-2:]
scale = math.sqrt(max_area / (height * width))
resized_height, resized_width = round(height * scale), round(width * scale)
return TVF.resize(
image,
size=(resized_height, resized_width),
interpolation=InterpolationMode.BICUBIC,
)
def div_pad(image, factor):
height_factor, width_factor = factor
height, width = image.shape[-2:]
pad_height = (height_factor - (height % height_factor)) % height_factor
pad_width = (width_factor - (width % width_factor)) % width_factor
if pad_height == 0 and pad_width == 0:
return image
if isinstance(image, torch.Tensor):
padding = (0, pad_width, 0, pad_height)
image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0)
return image
def cut_videos(videos):
t = videos.size(1)
if t == 1:
return videos
if t <= 4 :
padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1)
padding = torch.cat(padding, dim=1)
videos = torch.cat([videos, padding], dim=1)
return videos
if (t - 1) % (4) == 0:
return videos
else:
padding = [videos[:, -1].unsqueeze(1)] * (
4 - ((t - 1) % (4))
)
padding = torch.cat(padding, dim=1)
videos = torch.cat([videos, padding], dim=1)
assert (videos.size(1) - 1) % (4) == 0
return videos
def side_resize(image, size):
antialias = not (isinstance(image, torch.Tensor) and image.device.type == 'mps')
resized = TVF.resize(image, size, InterpolationMode.BICUBIC, antialias=antialias)
return resized
def _seedvr2_input_shorter_edge(images, node_name):
if images.dim() == 4:
return min(images.shape[1], images.shape[2])
if images.dim() == 5:
return min(images.shape[2], images.shape[3])
raise ValueError(
f"{node_name}: expected 4-D or 5-D IMAGE tensor, "
f"got shape {tuple(images.shape)}"
)
def _seedvr2_resize_and_pad(images, upscaled_shorter_edge, node_name):
if upscaled_shorter_edge < 2:
raise ValueError(
f"{node_name}: resolved upscaled_shorter_edge must be at least 2 pixels; "
f"got {upscaled_shorter_edge}."
)
original_image = images
if images.dim() == 4:
# Comfy video components arrive as a 4-D IMAGE frame sequence:
# (frames, H, W, C). SeedVR2 consumes that as one video.
images = images.unsqueeze(0)
elif images.dim() != 5:
raise ValueError(
f"{node_name}: expected 4-D or 5-D IMAGE tensor, "
f"got shape {tuple(images.shape)}"
)
images = images.permute(0, 1, 4, 2, 3)
b, t, c, h, w = images.shape
images = images.reshape(b * t, c, h, w)
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
images = side_resize(images, upscaled_shorter_edge)
images = clip(images)
images = div_pad(images, (16, 16))
_, _, new_h, new_w = images.shape
images = images.reshape(b, t, c, new_h, new_w)
images = cut_videos(images)
images_bthwc = rearrange(images, "b t c h w -> b t h w c")
return io.NodeOutput(images_bthwc, original_image, upscaled_shorter_edge)
class SeedVR2Resize(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SeedVR2Resize",
category="image/video",
inputs=[
io.Image.Input("images"),
io.Float.Input("multiplier", default=4.0, min=0.01),
],
outputs=[
io.Image.Output("input_pixels"),
io.Image.Output("original_image"),
io.Int.Output("upscaled_shorter_edge"),
]
)
@classmethod
def execute(cls, images, multiplier=4.0):
if multiplier <= 0:
raise ValueError(
f"SeedVR2Resize: multiplier must be > 0; got {multiplier}."
)
shorter_edge = _seedvr2_input_shorter_edge(images, "SeedVR2Resize")
upscaled_shorter_edge = int(round(shorter_edge * multiplier))
if upscaled_shorter_edge < 2:
raise ValueError(
"SeedVR2Resize: multiplier resolved upscaled_shorter_edge "
f"to {upscaled_shorter_edge}; use a multiplier that resolves "
"to at least 2 pixels."
)
return _seedvr2_resize_and_pad(
images, upscaled_shorter_edge, "SeedVR2Resize",
)
class SeedVR2ResizeAdvanced(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SeedVR2ResizeAdvanced",
category="image/video",
inputs=[
io.Image.Input("images"),
io.Int.Input("shorter_edge", default=1280, min=2),
],
outputs=[
io.Image.Output("input_pixels"),
io.Image.Output("original_image"),
io.Int.Output("upscaled_shorter_edge"),
]
)
@classmethod
def execute(cls, images, shorter_edge):
return _seedvr2_resize_and_pad(
images, shorter_edge, "SeedVR2ResizeAdvanced",
)
class SeedVR2PostProcessing(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SeedVR2PostProcessing",
category="image/video",
inputs=[
io.Image.Input("decoded"),
io.Image.Input("original_image"),
io.Int.Input("upscaled_shorter_edge", min=2, force_input=True),
io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab"),
],
outputs=[io.Image.Output()],
)
@classmethod
def execute(cls, decoded, original_image, upscaled_shorter_edge, color_correction_method):
cls._validate_upscaled_shorter_edge(upscaled_shorter_edge)
decoded_5d, decoded_was_4d = cls._as_bthwc(decoded)
original_5d, _ = cls._as_bthwc(original_image)
decoded_5d = cls._restore_reference_batch_time(decoded_5d, original_5d)
b = min(decoded_5d.shape[0], original_5d.shape[0])
t = min(decoded_5d.shape[1], original_5d.shape[1])
reference_h, reference_w = cls._resized_shorter_edge_dims(
original_5d.shape[2], original_5d.shape[3], upscaled_shorter_edge,
)
decoded_5d = decoded_5d[:b, :t, :, :, :]
target_h = min(decoded_5d.shape[2], reference_h)
target_w = min(decoded_5d.shape[3], reference_w)
decoded_5d = decoded_5d[:, :, :target_h, :target_w, :]
if color_correction_method in ("lab", "wavelet", "adain"):
reference_5d = cls._resize_original_reference(original_image, upscaled_shorter_edge)
reference_5d = reference_5d[:b, :t, :, :, :]
reference_5d = cls._resize_reference(reference_5d, target_h, target_w)
output_device = decoded_5d.device
decoded_raw = cls._to_seedvr2_raw(decoded_5d)
reference_raw = cls._to_seedvr2_raw(reference_5d)
decoded_flat = rearrange(decoded_raw, "b t h w c -> (b t) c h w")
reference_flat = rearrange(reference_raw, "b t h w c -> (b t) c h w")
output = cls._color_transfer_chunked(
decoded_flat, reference_flat, output_device, color_correction_method,
)
output = rearrange(output, "(b t) c h w -> b t h w c", b=b, t=t)
output = output.add(1.0).div(2.0).clamp(0.0, 1.0)
elif color_correction_method == "none":
output = decoded_5d
else:
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
h2 = output.shape[-3] - (output.shape[-3] % 2)
w2 = output.shape[-2] - (output.shape[-2] % 2)
output = output[:, :, :h2, :w2, :]
if decoded_was_4d:
output = output.reshape(-1, output.shape[-3], output.shape[-2], output.shape[-1])
return io.NodeOutput(output)
@staticmethod
def _as_bthwc(images):
if images.ndim == 4:
return images.unsqueeze(0), True
if images.ndim == 5:
return images, False
raise ValueError(
f"SeedVR2PostProcessing: expected 4-D or 5-D IMAGE tensor, got shape {tuple(images.shape)}"
)
@staticmethod
def _restore_reference_batch_time(decoded, reference):
if decoded.shape[0] != 1:
return decoded
ref_b, ref_t = reference.shape[:2]
if ref_b < 1 or decoded.shape[1] % ref_b != 0:
return decoded
decoded_t = decoded.shape[1] // ref_b
if decoded_t < ref_t:
return decoded
return decoded.reshape(ref_b, decoded_t, decoded.shape[2], decoded.shape[3], decoded.shape[4])
@staticmethod
def _to_seedvr2_raw(images):
return images.mul(2.0).sub(1.0)
@staticmethod
def _validate_upscaled_shorter_edge(upscaled_shorter_edge):
if not isinstance(upscaled_shorter_edge, int) or upscaled_shorter_edge < 2:
raise ValueError(
"SeedVR2PostProcessing: upscaled_shorter_edge must be an integer "
f"of at least 2 pixels; got {upscaled_shorter_edge!r}."
)
@staticmethod
def _resized_shorter_edge_dims(height, width, upscaled_shorter_edge):
if height <= width:
return upscaled_shorter_edge, int(upscaled_shorter_edge * width / height)
return int(upscaled_shorter_edge * height / width), upscaled_shorter_edge
@classmethod
def _resize_original_reference(cls, original, upscaled_shorter_edge):
original_5d, _ = cls._as_bthwc(original)
b, t = original_5d.shape[:2]
original_flat = rearrange(original_5d, "b t h w c -> (b t) c h w")
resized_flat = side_resize(original_flat, upscaled_shorter_edge).clamp(0.0, 1.0)
return rearrange(resized_flat, "(b t) c h w -> b t h w c", b=b, t=t)
@staticmethod
def _color_transfer_on_vae_device(decoded_flat, reference_flat, output_device, transfer_fn):
color_device = comfy.model_management.vae_device()
decoded_flat = decoded_flat.to(device=color_device)
reference_flat = reference_flat.to(device=color_device)
output = transfer_fn(decoded_flat, reference_flat)
return output.to(device=output_device)
@staticmethod
def _lab_color_transfer_on_vae_device(decoded_flat, reference_flat, output_device):
color_device = comfy.model_management.vae_device()
result = None
for start in range(decoded_flat.shape[0]):
decoded_frame = decoded_flat[start:start + 1].to(device=color_device).clone()
reference_frame = reference_flat[start:start + 1].to(device=color_device).clone()
output = lab_color_transfer(decoded_frame, reference_frame).to(device=output_device)
if result is None:
result = torch.empty(
(decoded_flat.shape[0],) + tuple(output.shape[1:]),
device=output_device,
dtype=output.dtype,
)
result[start:start + 1].copy_(output)
if result is None:
raise ValueError("SeedVR2PostProcessing: LAB color correction requires at least one frame.")
return result
@classmethod
def _color_transfer_chunked(cls, decoded_flat, reference_flat, output_device, color_correction_method):
chunk_size = cls._estimate_color_correction_chunk_size(decoded_flat, color_correction_method)
while True:
next_chunk_size = None
try:
return cls._run_color_transfer_chunks(
decoded_flat, reference_flat, output_device, color_correction_method, chunk_size,
)
except Exception as e:
comfy.model_management.raise_non_oom(e)
if chunk_size <= 1:
raise RuntimeError(
"SeedVR2PostProcessing: color correction OOM at one frame; "
f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}."
) from e
next_chunk_size = max(1, chunk_size // 2)
comfy.model_management.soft_empty_cache()
chunk_size = next_chunk_size
@classmethod
def _run_color_transfer_chunks(cls, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size):
result = None
for start in range(0, decoded_flat.shape[0], chunk_size):
end = min(start + chunk_size, decoded_flat.shape[0])
decoded_chunk = decoded_flat[start:end]
reference_chunk = reference_flat[start:end]
if color_correction_method == "lab":
output = cls._lab_color_transfer_on_vae_device(decoded_chunk, reference_chunk, output_device)
elif color_correction_method == "wavelet":
output = cls._color_transfer_on_vae_device(
decoded_chunk, reference_chunk, output_device, wavelet_color_transfer,
)
else:
output = cls._color_transfer_on_vae_device(
decoded_chunk, reference_chunk, output_device, adain_color_transfer,
)
if result is None:
result = torch.empty(
(decoded_flat.shape[0],) + tuple(output.shape[1:]),
device=output_device,
dtype=output.dtype,
)
result[start:end].copy_(output)
if result is None:
raise ValueError("SeedVR2PostProcessing: color correction requires at least one frame.")
return result
@classmethod
def _estimate_color_correction_chunk_size(cls, decoded_flat, color_correction_method):
multiplier = cls._color_correction_memory_multiplier(color_correction_method)
frames = decoded_flat.shape[0]
_, channels, height, width = decoded_flat.shape
dtype_bytes = max(decoded_flat.element_size(), 4)
bytes_per_frame = height * width * channels * dtype_bytes * multiplier
if bytes_per_frame <= 0:
return frames
color_device = comfy.model_management.vae_device()
free_memory = comfy.model_management.get_free_memory(color_device)
chunk_size = int((free_memory * COLOR_CORRECTION_MEMORY_HEADROOM) // bytes_per_frame)
return max(1, min(frames, chunk_size))
@staticmethod
def _color_correction_memory_multiplier(color_correction_method):
if color_correction_method == "lab":
return LAB_SCALE_MULTIPLIER
if color_correction_method == "wavelet":
return WAVELET_SCALE_MULTIPLIER
if color_correction_method == "adain":
return ADAIN_SCALE_MULTIPLIER
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
@staticmethod
def _resize_reference(reference, height, width):
if reference.shape[2] == height and reference.shape[3] == width:
return reference
b, t = reference.shape[:2]
reference_flat = rearrange(reference, "b t h w c -> (b t) c h w")
resized = TVF.resize(
reference_flat,
size=(height, width),
interpolation=InterpolationMode.BICUBIC,
antialias=not (isinstance(reference_flat, torch.Tensor) and reference_flat.device.type == "mps"),
)
return rearrange(resized, "(b t) c h w -> b t h w c", b=b, t=t)
class SeedVR2Conditioning(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SeedVR2Conditioning",
category="image/video",
inputs=[
io.Model.Input("model"),
io.Latent.Input("vae_conditioning", display_name="LATENT"),
],
outputs=[
io.Model.Output(display_name = "model"),
io.Conditioning.Output(display_name = "positive"),
io.Conditioning.Output(display_name = "negative"),
io.Latent.Output(display_name = "latent"),
],
)
@classmethod
def execute(cls, model, vae_conditioning) -> io.NodeOutput:
vae_conditioning = vae_conditioning["samples"]
if vae_conditioning.ndim != 5:
raise ValueError(
"SeedVR2Conditioning expects a 5-D VAE latent in Comfy "
f"channel-first layout; got shape {tuple(vae_conditioning.shape)}."
)
if vae_conditioning.shape[-1] == _SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != _SEEDVR2_LATENT_CHANNELS:
raise ValueError(
"SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy "
f"channel-first layout (B, {_SEEDVR2_LATENT_CHANNELS}, T, H, W); "
f"got channel-last shape {tuple(vae_conditioning.shape)}."
)
vae_conditioning = vae_conditioning.movedim(1, -1).contiguous()
model_patcher = model
model = _resolve_seedvr2_diffusion_model(model_patcher)
pos_cond = model.positive_conditioning
neg_cond = model.negative_conditioning
# Fail-loud guard against silently-wrong output when a numz-format
# DiT-only ``.safetensors`` (no ``positive_conditioning`` /
# ``negative_conditioning`` keys) is loaded via ``UNETLoader``.
# ``NaDiT.__init__`` zero-fills the buffers via ``torch.zeros`` (see
# ``comfy/ldm/seedvr/model.py``); ``load_state_dict(strict=False)``
# leaves them at zero when the keys are absent. Detect that state
# here rather than at ``BaseModel.extra_conds`` (per sampling step,
# wasteful) or at the resolver helper (mixes structural shape with
# semantic content). Both buffers must be checked together — partial
# bake regressions could populate one but not the other.
if (
pos_cond.float().abs().sum().item() == 0
and neg_cond.float().abs().sum().item() == 0
):
raise RuntimeError(
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: positive_conditioning "
f"and negative_conditioning buffers are zero-valued — model "
f"file appears to be a numz-format DiT-only export missing "
f"the SeedVR2 conditioning tensors. "
f"Re-bake the file with ``positive_conditioning`` (58, 5120) "
f"and ``negative_conditioning`` (64, 5120) keys at top level, "
f"or load via CheckpointLoaderSimple from a bundled "
f"checkpoint."
)
_apply_rope_freqs_float32_cast(model)
condition = torch.stack([get_conditions(c, c) for c in vae_conditioning])
condition = condition.movedim(-1, 1)
latent = vae_conditioning.movedim(-1, 1)
latent = rearrange(latent, "b c t h w -> b (c t) h w")
condition = rearrange(condition, "b c t h w -> b (c t) h w")
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
return io.NodeOutput(model_patcher, positive, negative, {"samples": latent})
# SeedVR2 latent / conditioning channel constants. The SeedVR2 conditioning
# stage collapses ``(B, C, T, H, W) -> (B, C*T, H, W)`` for both the latent
# (C=16) and the per-frame condition tensor (C=17 = 16 latent + 1 mask), as
# required by ``NaDiT.forward`` which un-collapses via
# ``view(B, 16, -1, H, W)`` and ``view(B, 17, -1, H, W)`` respectively.
_SEEDVR2_LATENT_CHANNELS = 16
_SEEDVR2_CONDITION_CHANNELS = 17
def _slice_collapsed_4d_along_t(tensor_4d: torch.Tensor, t_start: int,
t_end: int, channels: int) -> torch.Tensor:
"""Slice a SeedVR2-style collapsed 4D tensor ``(B, channels*T, H, W)``
along the latent T axis, returning ``(B, channels*(t_end - t_start), H, W)``.
Reshape -> slice -> ``.contiguous()`` -> re-collapse. ``reshape`` is
used for the un-collapse so non-contiguous incoming tensors from
cropping or slicing nodes are accepted. The
``.contiguous()`` is mandatory: T-axis slicing of a 5D tensor produces a
non-contiguous view, and the subsequent re-collapse requires contiguous
storage.
"""
B, CT, H, W = tensor_4d.shape
if CT % channels != 0:
raise ValueError(
f"_slice_collapsed_4d_along_t: collapsed channel dim {CT} is not "
f"divisible by channels={channels}; tensor shape {tuple(tensor_4d.shape)}."
)
T = CT // channels
if not (0 <= t_start < t_end <= T):
raise ValueError(
f"_slice_collapsed_4d_along_t: slice [{t_start}:{t_end}] out of "
f"range for T={T}."
)
new_T = t_end - t_start
sliced = tensor_4d.reshape(B, channels, T, H, W)[:, :, t_start:t_end, :, :].contiguous()
return sliced.reshape(B, channels * new_T, H, W)
def _slice_seedvr2_cond_along_t(cond_list, t_start: int, t_end: int):
"""Build a new SeedVR2 conditioning list with the per-frame ``condition``
tensor sliced along the latent T axis.
SeedVR2 conditioning entries have the shape
``[text_cond_tensor, options_dict]`` where ``options_dict["condition"]``
is a 4D collapsed ``(B, 17*T, H, W)`` tensor; the text tensor itself has
no temporal axis and is passed through unchanged. Other keys in the
options dict (controlnets, etc.) are also passed through unchanged. If
an entry has no ``"condition"`` key, the entry is forwarded verbatim.
A new list of ``[text_cond, new_options_dict]`` pairs is returned; the
original ``cond_list`` and its options dicts are not mutated.
"""
new_list = []
for entry in cond_list:
text_cond, options = entry[0], entry[1]
if "condition" not in options:
new_list.append(entry)
continue
new_options = options.copy()
new_options["condition"] = _slice_collapsed_4d_along_t(
new_options["condition"], t_start, t_end,
_SEEDVR2_CONDITION_CHANNELS,
)
new_list.append([text_cond, new_options])
return new_list
def _slice_seedvr2_noise_mask_along_t(noise_mask: torch.Tensor,
samples_4d: torch.Tensor,
t_start: int,
t_end: int):
"""Slice collapsed SeedVR2 masks and preserve standard masks.
``SetLatentNoiseMask`` produces ``(B, 1, H, W)`` masks that KSampler
expands to the latent shape. Only masks already expanded to the full
collapsed ``(B, 16*T, H, W)`` shape need temporal slicing here.
"""
if noise_mask.ndim == samples_4d.ndim and noise_mask.shape[1] == samples_4d.shape[1]:
return _slice_collapsed_4d_along_t(
noise_mask, t_start, t_end, _SEEDVR2_LATENT_CHANNELS,
)
return noise_mask
def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor:
"""Concatenate a list of SeedVR2-style collapsed 4D tensors
``(B, channels*T_i, H, W)`` along the latent T axis. Each chunk is
un-collapsed to 5D, concatenated on ``dim=2``, then re-collapsed to 4D.
"""
if len(chunks_4d) == 0:
raise ValueError("_concat_chunks_along_t: empty chunk list.")
fives = []
for ch in chunks_4d:
B, CT, H, W = ch.shape
if CT % channels != 0:
raise ValueError(
f"_concat_chunks_along_t: chunk shape {tuple(ch.shape)} "
f"channel dim {CT} not divisible by channels={channels}."
)
T = CT // channels
fives.append(ch.reshape(B, channels, T, H, W))
cat = torch.cat(fives, dim=2).contiguous()
B, C, T_total, H, W = cat.shape
return cat.reshape(B, C * T_total, H, W)
def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor:
"""Build a 1D crossfade weight tensor of length ``overlap`` for the
*previous* chunk's contribution; the current chunk's weight is
``1 - w_prev``.
Mirrors the numz ``blend_overlapping_frames`` shape
(AInVFX/numz fork ``src/core/generation_utils.py``,
``blend_overlapping_frames``): a Hann window with a ``[1/3, 2/3]``
dead-band when ``overlap >= 3``, and a plain linear ramp when
``overlap < 3`` (the dead-band would collapse the transition for
very small overlap counts). The numz reference operates on
pixel-space tensors ``[overlap, H, W, C]``; this 1D form is
reshaped by the caller to broadcast across the latent's
``(B, C, T_overlap, H, W)`` axes.
"""
if overlap < 1:
raise ValueError(
f"_hann_blend_weights_1d: overlap must be >= 1; got {overlap}."
)
if overlap >= 3:
t = torch.linspace(0.0, 1.0, steps=overlap, device=device, dtype=dtype)
blend_start = 1.0 / 3.0
blend_end = 2.0 / 3.0
u = ((t - blend_start) / (blend_end - blend_start)).clamp(0.0, 1.0)
return 0.5 + 0.5 * torch.cos(torch.pi * u)
return torch.linspace(1.0, 0.0, steps=overlap, device=device, dtype=dtype)
def _blend_overlap_region(prev_tail_5d: torch.Tensor,
cur_head_5d: torch.Tensor) -> torch.Tensor:
"""Blend two 5D ``(B, C, T_overlap, H, W)`` tensors of equal shape
using a 1D Hann/linear ramp along the T axis. ``prev_tail_5d``
receives the descending weight; ``cur_head_5d`` receives
``1 - w_prev``.
The caller is responsible for ensuring both inputs have identical
shape and dtype/device.
"""
if prev_tail_5d.shape != cur_head_5d.shape:
raise ValueError(
f"_blend_overlap_region: shape mismatch "
f"prev {tuple(prev_tail_5d.shape)} vs "
f"cur {tuple(cur_head_5d.shape)}."
)
overlap = int(prev_tail_5d.shape[2])
w_prev_1d = _hann_blend_weights_1d(
overlap, prev_tail_5d.device, prev_tail_5d.dtype,
)
# Reshape to (1, 1, overlap, 1, 1) for broadcast across B, C, H, W.
w_prev = w_prev_1d.view(1, 1, overlap, 1, 1)
w_cur = 1.0 - w_prev
return prev_tail_5d * w_prev + cur_head_5d * w_cur
def _concat_chunks_with_overlap_blend(chunk_specs, channels: int,
overlap_latent: int) -> torch.Tensor:
"""Concatenate temporally-overlapping chunks back into a single
collapsed 4D tensor, blending overlap regions with a Hann/linear
crossfade.
``chunk_specs`` is a list of ``(t_start, t_end, chunk_4d)`` tuples
in source-latent T coordinates. ``overlap_latent == 0`` is a fast
path that delegates to plain concatenation (and produces output
bit-identical to ``_concat_chunks_along_t`` of the same chunks).
The blend at each pair of adjacent chunks acts on the actual
overlap region width ``min(prev_end - cur_start, current chunk
length)``, which may be smaller than ``overlap_latent`` when the
final chunk is a runt shorter than the configured overlap.
"""
if len(chunk_specs) == 0:
raise ValueError("_concat_chunks_with_overlap_blend: empty chunk list.")
if overlap_latent < 0:
raise ValueError(
f"_concat_chunks_with_overlap_blend: overlap_latent must be "
f">= 0; got {overlap_latent}."
)
# Validate channel divisibility once and capture per-chunk T.
chunk_5d = []
for t_start, t_end, ch in chunk_specs:
B, CT, H, W = ch.shape
if CT % channels != 0:
raise ValueError(
f"_concat_chunks_with_overlap_blend: chunk shape "
f"{tuple(ch.shape)} channel dim {CT} not divisible "
f"by channels={channels}."
)
T = CT // channels
if t_end - t_start != T:
raise ValueError(
f"_concat_chunks_with_overlap_blend: chunk T={T} mismatches "
f"declared range [{t_start}:{t_end}]."
)
chunk_5d.append((t_start, t_end, ch.reshape(B, channels, T, H, W)))
if overlap_latent == 0:
# Fast path: pure concat in the caller-provided chunk order.
return _concat_chunks_along_t(
[c.reshape(c.shape[0], channels * c.shape[2], c.shape[3], c.shape[4])
for _, _, c in chunk_5d],
channels,
)
T_total = max(t_end for _, t_end, _ in chunk_5d)
first_5d = chunk_5d[0][2]
B = first_5d.shape[0]
H = first_5d.shape[3]
W = first_5d.shape[4]
result = torch.empty(
(B, channels, T_total, H, W),
device=first_5d.device, dtype=first_5d.dtype,
)
filled_until = 0
for i, (cs, ce, ct_5d) in enumerate(chunk_5d):
chunk_T = int(ct_5d.shape[2])
if i == 0:
result[:, :, cs:ce, :, :] = ct_5d
filled_until = ce
continue
# Overlap region width is bounded by both the previous fill
# frontier and the current chunk's actual length (for runt
# final chunks shorter than the configured overlap).
overlap_len = min(filled_until - cs, chunk_T)
if overlap_len > 0:
prev_tail = result[:, :, cs:cs + overlap_len, :, :].contiguous()
cur_head = ct_5d[:, :, :overlap_len, :, :].contiguous()
blended = _blend_overlap_region(prev_tail, cur_head)
result[:, :, cs:cs + overlap_len, :, :] = blended
tail_start = cs + overlap_len
tail_end = ce
if tail_end > tail_start:
result[:, :, tail_start:tail_end, :, :] = (
ct_5d[:, :, overlap_len:, :, :]
)
else:
# Disjoint chunks (overlap_latent set but this pair did not
# actually overlap, e.g. step_latent equal to chunk_latent
# in a degenerate config). Treat as concat.
result[:, :, cs:ce, :, :] = ct_5d
filled_until = ce
return result.contiguous().reshape(B, channels * T_total, H, W)
def _run_standard_sample(model, seed: int, steps: int, cfg: float,
sampler_name: str, scheduler: str,
positive, negative, latent_image: dict,
denoise: float) -> dict:
"""Single-shot delegation that mirrors the standard ``common_ksampler``
flow (``nodes.py:common_ksampler``): generate noise from seed, run
``comfy.sample.sample``, return a latent dict. Used by the
ProgressiveSampler short-circuit when the full sequence fits in one
chunk so chunking introduces no overhead for small videos.
"""
samples_in = latent_image["samples"]
samples_in = comfy.sample.fix_empty_latent_channels(
model, samples_in, latent_image.get("downscale_ratio_spacial", None),
)
batch_inds = latent_image.get("batch_index", None)
noise = comfy.sample.prepare_noise(samples_in, seed, batch_inds)
noise_mask = latent_image.get("noise_mask", None)
samples = comfy.sample.sample(
model, noise, steps, cfg, sampler_name, scheduler,
positive, negative, samples_in,
denoise=denoise, noise_mask=noise_mask, seed=seed,
)
out = latent_image.copy()
out.pop("downscale_ratio_spacial", None)
out["samples"] = samples
return out
class SeedVR2ProgressiveSampler(io.ComfyNode):
"""Sequential temporal chunking sampler for SeedVR2 native.
Drop-in replacement for ``KSampler`` in SeedVR2 native workflows that
OOM on long sequences. The latent enters the sampler in SeedVR2's
collapsed form ``(B, 16*T, H, W)`` (collapsed by ``SeedVR2Conditioning``
at ``rearrange(b c t h w -> b (c t) h w)``); this node slices that
tensor along the temporal axis, runs the configured inner sampler
sequentially per chunk against the standard ``comfy.sample.sample``
entry point, and concatenates per-chunk outputs back into a single
``(B, 16*T_total, H, W)`` latent.
``frames_per_chunk`` is expressed in pixel-frame units to match the
SeedVR2 4n+1 constraint enforced upstream by ``cut_videos`` and the
VAE's ``temporal_downsample_factor=4``. A pixel chunk size ``F``
maps to ``(F - 1) // 4 + 1`` latent-frame chunks.
Determinism contract: a single noise tensor is generated once from
the user seed and sliced per chunk (rather than re-seeding each
chunk), so a workflow that fits in a single chunk produces output
identical to a workflow that fits in N chunks at the same seed,
modulo the inherent T-axis chunk-boundary independence of the model.
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SeedVR2ProgressiveSampler",
category="sampling",
inputs=[
io.Model.Input("model"),
io.Int.Input("seed", default=0, min=0,
max=0xffffffffffffffff,
control_after_generate=True),
io.Int.Input("steps", default=20, min=1, max=10000),
io.Float.Input("cfg", default=1.0, min=0.0, max=100.0,
step=0.1, round=0.01),
io.Combo.Input("sampler_name",
options=comfy.samplers.SAMPLER_NAMES),
io.Combo.Input("scheduler",
options=comfy.samplers.SCHEDULER_NAMES),
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Latent.Input("latent_image"),
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0,
step=0.01),
io.Int.Input("frames_per_chunk", default=21, min=1,
max=16384, step=4),
io.Int.Input("temporal_overlap", default=0, min=0,
max=16384,
tooltip="Latent-frame overlap between "
"adjacent chunks; blended with a "
"Hann window (linear for overlap "
"< 3). 0 = no blend, pure concat. "
"Values >= the chunk's latent-frame "
"length use the maximum valid "
"overlap; 1 latent frame corresponds "
"to ~4 pixel frames."),
io.Combo.Input("chunking_mode",
options=["manual", "auto"],
default="manual",
tooltip="manual = use frames_per_chunk "
"exactly; auto = retry only real OOM "
"failures with progressively smaller "
"temporal chunks."),
],
outputs=[io.Latent.Output()],
)
@classmethod
def execute(cls, model, seed, steps, cfg, sampler_name, scheduler,
positive, negative, latent_image, denoise,
frames_per_chunk, temporal_overlap,
chunking_mode="manual") -> io.NodeOutput:
# 4n+1 validation in pixel-frame domain. The SeedVR2 native pipeline
# requires pixel-frame counts of the form 4n+1 (1, 5, 9, 13, ...),
# imposed at ``cut_videos`` upstream and propagated through the VAE's
# temporal_downsample_factor=4. Reject violations explicitly before
# any model invocation; a silent rounding would mis-align chunk
# boundaries with the 4n+1 lattice.
if frames_per_chunk < 1 or (frames_per_chunk - 1) % 4 != 0:
raise ValueError(
f"SeedVR2ProgressiveSampler: frames_per_chunk must be a "
f"4n+1 pixel-frame count (1, 5, 9, 13, 17, 21, ...); "
f"got {frames_per_chunk}."
)
samples_4d = latent_image["samples"]
samples_4d = comfy.sample.fix_empty_latent_channels(
model, samples_4d,
latent_image.get("downscale_ratio_spacial", None),
)
if samples_4d.ndim != 4:
raise ValueError(
f"SeedVR2ProgressiveSampler: expected 4D collapsed latent "
f"(B, 16*T, H, W); got shape {tuple(samples_4d.shape)}."
)
B, CT, H, W = samples_4d.shape
if CT % _SEEDVR2_LATENT_CHANNELS != 0:
raise ValueError(
f"SeedVR2ProgressiveSampler: collapsed channel dim {CT} is "
f"not divisible by SeedVR2 latent channels "
f"{_SEEDVR2_LATENT_CHANNELS}; latent does not appear to be "
f"SeedVR2-shaped."
)
T_latent = CT // _SEEDVR2_LATENT_CHANNELS
T_pixel = 4 * (T_latent - 1) + 1
if chunking_mode not in ("manual", "auto"):
raise ValueError(
f"SeedVR2ProgressiveSampler: chunking_mode must be "
f"'manual' or 'auto'; got {chunking_mode!r}."
)
if chunking_mode == "auto":
attempts = _seedvr2_auto_chunk_attempts(
T_latent, T_pixel, frames_per_chunk,
)
for i, attempt_frames_per_chunk in enumerate(attempts):
retry = False
try:
return cls.execute(
model=model, seed=seed, steps=steps, cfg=cfg,
sampler_name=sampler_name, scheduler=scheduler,
positive=positive, negative=negative,
latent_image=latent_image, denoise=denoise,
frames_per_chunk=attempt_frames_per_chunk,
temporal_overlap=temporal_overlap,
chunking_mode="manual",
)
except Exception as e:
comfy.model_management.raise_non_oom(e)
if i == len(attempts) - 1:
raise RuntimeError(
"SeedVR2ProgressiveSampler: exhausted auto "
"chunking attempts after OOM. Tried "
f"frames_per_chunk values {attempts}."
) from e
retry = True
if retry:
logging.warning(
"SeedVR2ProgressiveSampler auto chunking OOM at "
"frames_per_chunk=%s; retrying with "
"frames_per_chunk=%s.",
attempt_frames_per_chunk, attempts[i + 1],
)
comfy.model_management.soft_empty_cache()
# Short-circuit: total fits in one chunk -> standard path with no
# chunking overhead. Output of this branch is byte-identical to the
# built-in KSampler given the same (model, seed, steps, cfg,
# sampler_name, scheduler, positive, negative, latent_image,
# denoise) tuple.
if T_pixel <= frames_per_chunk:
return io.NodeOutput(_run_standard_sample(
model, seed, steps, cfg, sampler_name, scheduler,
positive, negative, latent_image, denoise,
))
# Map pixel chunk -> latent chunk. Each chunk's latent length is
# at most ``chunk_latent``; the final chunk may be a runt that
# is automatically 4n+1-aligned in the pixel domain by the
# T_pixel = 4*(T_latent-1) + 1 mapping (every positive integer
# T_latent corresponds to a valid 4n+1 pixel count).
chunk_latent = (frames_per_chunk - 1) // 4 + 1
# ``temporal_overlap`` is exposed in latent-frame units, but users
# do not know the derived latent chunk length. Treat oversized
# values as "maximum valid overlap" while preserving a strictly
# positive chunk-loop stride.
if temporal_overlap < 0:
raise ValueError(
f"SeedVR2ProgressiveSampler: temporal_overlap must be >= 0; "
f"got {temporal_overlap}."
)
temporal_overlap = min(temporal_overlap, chunk_latent - 1)
step_latent = chunk_latent - temporal_overlap
# Generate full noise once from the user seed, then slice along T
# per chunk. Using one global noise tensor (rather than re-seeding
# per chunk) preserves seed-determinism across chunk-count
# variations: the same (seed, total T_latent) always produces the
# same noise samples regardless of how the work is partitioned.
batch_inds = latent_image.get("batch_index", None)
noise_full = comfy.sample.prepare_noise(samples_4d, seed, batch_inds)
noise_mask = latent_image.get("noise_mask", None)
# Build the flat list of chunk ranges first so the chunking
# geometry is fully known before any sample call.
chunk_ranges = []
for chunk_start in range(0, T_latent, step_latent):
chunk_end = min(chunk_start + chunk_latent, T_latent)
if chunk_start >= chunk_end:
# The final iteration of a stride that lands exactly on
# T_latent produces a zero-length chunk; skip it.
break
chunk_ranges.append((chunk_start, chunk_end))
if chunk_end >= T_latent:
break
def _sample_one_chunk(chunk_start, chunk_end):
samples_chunk = _slice_collapsed_4d_along_t(
samples_4d, chunk_start, chunk_end,
_SEEDVR2_LATENT_CHANNELS,
)
noise_chunk = _slice_collapsed_4d_along_t(
noise_full, chunk_start, chunk_end,
_SEEDVR2_LATENT_CHANNELS,
)
positive_chunk = _slice_seedvr2_cond_along_t(
positive, chunk_start, chunk_end,
)
negative_chunk = _slice_seedvr2_cond_along_t(
negative, chunk_start, chunk_end,
)
# Per-chunk noise_mask handling: standard masks are passed
# through for KSampler expansion; pre-expanded collapsed
# masks are sliced.
chunk_noise_mask = None
if noise_mask is not None:
chunk_noise_mask = _slice_seedvr2_noise_mask_along_t(
noise_mask, samples_4d, chunk_start, chunk_end,
)
return comfy.sample.sample(
model, noise_chunk, steps, cfg, sampler_name, scheduler,
positive_chunk, negative_chunk, samples_chunk,
denoise=denoise, noise_mask=chunk_noise_mask, seed=seed,
)
chunk_specs = []
for chunk_start, chunk_end in chunk_ranges:
chunk_samples = _sample_one_chunk(chunk_start, chunk_end)
chunk_specs.append((chunk_start, chunk_end, chunk_samples))
final = _concat_chunks_with_overlap_blend(
chunk_specs, _SEEDVR2_LATENT_CHANNELS, temporal_overlap,
)
out = latent_image.copy()
out.pop("downscale_ratio_spacial", None)
out["samples"] = final
return io.NodeOutput(out)
class SeedVRExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SeedVR2Conditioning,
SeedVR2Resize,
SeedVR2ResizeAdvanced,
SeedVR2PostProcessing,
SeedVR2ProgressiveSampler,
]
async def comfy_entrypoint() -> SeedVRExtension:
return SeedVRExtension()