diff --git a/comfy/ldm/seedvr/color_fix.py b/comfy/ldm/seedvr/color_fix.py new file mode 100644 index 000000000..7ddfc03af --- /dev/null +++ b/comfy/ldm/seedvr/color_fix.py @@ -0,0 +1,340 @@ +import torch +import torch.nn.functional as F +from torch import Tensor + +from comfy.ldm.seedvr.model import safe_pad_operation +from comfy.ldm.seedvr.vae import safe_interpolate_operation +from comfy.ldm.seedvr.constants import ( + CIELAB_DELTA, + CIELAB_KAPPA, + D65_WHITE_X, + D65_WHITE_Z, + WAVELET_DECOMP_LEVELS, +) + + +def wavelet_blur(image: Tensor, radius): + max_safe_radius = max(1, min(image.shape[-2:]) // 8) + if radius > max_safe_radius: + radius = max_safe_radius + + num_channels = image.shape[1] + + kernel_vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125], + [0.0625, 0.125, 0.0625], + ] + kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) + kernel = kernel[None, None].repeat(num_channels, 1, 1, 1) + + image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate') + output = F.conv2d(image, kernel, groups=num_channels, dilation=radius) + + return output + +def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS): + high_freq = torch.zeros_like(image) + + for i in range(levels): + radius = 2 ** i + low_freq = wavelet_blur(image, radius) + high_freq.add_(image).sub_(low_freq) + image = low_freq + + return high_freq, low_freq + +def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: + + if content_feat.shape != style_feat.shape: + # Resize style to match content spatial dimensions + if len(content_feat.shape) >= 3: + # safe_interpolate_operation handles FP16 conversion automatically + style_feat = safe_interpolate_operation( + style_feat, + size=content_feat.shape[-2:], + mode='bilinear', + align_corners=False + ) + + # Decompose both features into frequency components + content_high_freq, content_low_freq = wavelet_decomposition(content_feat) + del content_low_freq # Free memory immediately + + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + del style_high_freq # Free memory immediately + + if content_high_freq.shape != style_low_freq.shape: + style_low_freq = safe_interpolate_operation( + style_low_freq, + size=content_high_freq.shape[-2:], + mode='bilinear', + align_corners=False + ) + + content_high_freq.add_(style_low_freq) + + return content_high_freq.clamp_(-1.0, 1.0) + +def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor: + original_shape = source.shape + + # Flatten + source_flat = source.flatten() + reference_flat = reference.flatten() + + # Sort both arrays + source_sorted, source_indices = torch.sort(source_flat) + reference_sorted, _ = torch.sort(reference_flat) + del reference_flat + + # Quantile mapping + n_source = len(source_sorted) + n_reference = len(reference_sorted) + + if n_source == n_reference: + matched_sorted = reference_sorted + else: + # Interpolate reference to match source quantiles + source_quantiles = torch.linspace(0, 1, n_source, device=device) + ref_indices = (source_quantiles * (n_reference - 1)).long() + ref_indices.clamp_(0, n_reference - 1) + matched_sorted = reference_sorted[ref_indices] + del source_quantiles, ref_indices, reference_sorted + + del source_sorted, source_flat + + # Reconstruct using argsort (portable across CUDA/ROCm/MPS) + inverse_indices = torch.argsort(source_indices) + del source_indices + matched_flat = matched_sorted[inverse_indices] + del matched_sorted, inverse_indices + + return matched_flat.reshape(original_shape) + +def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor: + """Convert batch of CIELAB images to RGB color space.""" + L, a, b = lab[:, 0], lab[:, 1], lab[:, 2] + + # LAB to XYZ + fy = (L + 16.0) / 116.0 + fx = a.div(500.0).add_(fy) + fz = fy - b / 200.0 + del L, a, b + + # XYZ transformation + x = torch.where( + fx > epsilon, + torch.pow(fx, 3.0), + fx.mul(116.0).sub_(16.0).div_(kappa) + ) + y = torch.where( + fy > epsilon, + torch.pow(fy, 3.0), + fy.mul(116.0).sub_(16.0).div_(kappa) + ) + z = torch.where( + fz > epsilon, + torch.pow(fz, 3.0), + fz.mul(116.0).sub_(16.0).div_(kappa) + ) + del fx, fy, fz + + # Apply D65 white point (in-place) + x.mul_(D65_WHITE_X) + # y *= 1.00000 # (no-op, skip) + z.mul_(D65_WHITE_Z) + + xyz = torch.stack([x, y, z], dim=1) + del x, y, z + + # Matrix multiplication: XYZ -> RGB + B, C, H, W = xyz.shape + xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3) + del xyz + + # Ensure dtype consistency for matrix multiplication + xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype) + rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T) + del xyz_flat + + rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) + del rgb_linear_flat + + # Apply inverse gamma correction (delinearize) + mask = rgb_linear > 0.0031308 + rgb = torch.where( + mask, + torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055), + rgb_linear * 12.92 + ) + del mask, rgb_linear + + return torch.clamp(rgb, 0.0, 1.0) + +def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor: + """Convert batch of RGB images to CIELAB color space using D65 illuminant.""" + # Apply sRGB gamma correction (linearize) + mask = rgb > 0.04045 + rgb_linear = torch.where( + mask, + torch.pow((rgb + 0.055) / 1.055, 2.4), + rgb / 12.92 + ) + del mask + + # Matrix multiplication: RGB -> XYZ + B, C, H, W = rgb_linear.shape + rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3) + del rgb_linear + + # Ensure dtype consistency for matrix multiplication + rgb_flat = rgb_flat.to(dtype=matrix.dtype) + xyz_flat = torch.matmul(rgb_flat, matrix.T) + del rgb_flat + + xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) + del xyz_flat + + # Normalize by D65 white point (in-place) + xyz[:, 0].div_(D65_WHITE_X) # X + # xyz[:, 1] /= 1.00000 # Y (no-op, skip) + xyz[:, 2].div_(D65_WHITE_Z) # Z + + # XYZ to LAB transformation + epsilon_cubed = epsilon ** 3 + mask = xyz > epsilon_cubed + f_xyz = torch.where( + mask, + torch.pow(xyz, 1.0 / 3.0), + xyz.mul(kappa).add_(16.0).div_(116.0) + ) + del xyz, mask + + # Extract channels and compute LAB + L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100] + a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127] + b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127] + del f_xyz + + return torch.stack([L, a, b], dim=1) + +def lab_color_transfer( + content_feat: Tensor, + style_feat: Tensor, + luminance_weight: float = 0.8 +) -> Tensor: + content_feat = wavelet_reconstruction(content_feat, style_feat) + + if content_feat.shape != style_feat.shape: + style_feat = safe_interpolate_operation( + style_feat, + size=content_feat.shape[-2:], + mode='bilinear', + align_corners=False + ) + + device = content_feat.device + + def ensure_float32_precision(c): + orig_dtype = c.dtype + c = c.float() + return c, orig_dtype + content_feat, original_dtype = ensure_float32_precision(content_feat) + style_feat, _ = ensure_float32_precision(style_feat) + + rgb_to_xyz_matrix = torch.tensor([ + [0.4124564, 0.3575761, 0.1804375], + [0.2126729, 0.7151522, 0.0721750], + [0.0193339, 0.1191920, 0.9503041] + ], dtype=torch.float32, device=device) + + xyz_to_rgb_matrix = torch.tensor([ + [ 3.2404542, -1.5371385, -0.4985314], + [-0.9692660, 1.8760108, 0.0415560], + [ 0.0556434, -0.2040259, 1.0572252] + ], dtype=torch.float32, device=device) + + epsilon = CIELAB_DELTA + kappa = CIELAB_KAPPA + + content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) + style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) + + # Convert to LAB color space + content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa) + del content_feat + + style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa) + del style_feat, rgb_to_xyz_matrix + + # Match chrominance channels (a*, b*) for accurate color transfer + matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device) + matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device) + + # Handle luminance with weighted blending + if luminance_weight < 1.0: + # Partially match luminance for better overall color accuracy + matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device) + # Blend: preserve some content L* for detail, adopt some style L* for color + result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight)) + del matched_L + else: + # Fully preserve content luminance + result_L = content_lab[:, 0] + + del content_lab, style_lab + + # Reconstruct LAB with corrected channels + result_lab = torch.stack([result_L, matched_a, matched_b], dim=1) + del result_L, matched_a, matched_b + + # Convert back to RGB + result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa) + del result_lab, xyz_to_rgb_matrix + + # Convert back to [-1, 1] range (in-place) + result = result_rgb.mul_(2.0).sub_(1.0) + del result_rgb + + result = result.to(original_dtype) + + return result + + +def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor: + return wavelet_reconstruction(content_feat, style_feat) + + +def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor: + if content_feat.shape != style_feat.shape: + style_feat = safe_interpolate_operation( + style_feat, + size=content_feat.shape[-2:], + mode='bilinear', + align_corners=False, + ) + + original_dtype = content_feat.dtype + content_feat = content_feat.float() + style_feat = style_feat.float() + + b, c = content_feat.shape[:2] + content_flat = content_feat.reshape(b, c, -1) + style_flat = style_feat.reshape(b, c, -1) + + content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1) + content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1) + style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1) + style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1) + del content_flat, style_flat + + normalized = (content_feat - content_mean) / content_std + del content_mean, content_std + result = normalized * style_std + style_mean + del normalized, style_mean, style_std + + result = result.clamp_(-1.0, 1.0) + if result.dtype != original_dtype: + result = result.to(original_dtype) + return result diff --git a/comfy/ldm/seedvr/constants.py b/comfy/ldm/seedvr/constants.py new file mode 100644 index 000000000..bfd72f1a2 --- /dev/null +++ b/comfy/ldm/seedvr/constants.py @@ -0,0 +1,82 @@ +"""Named constants for the SeedVR2 integration, grouped by provenance. + +Provenance prefixes: +- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline. +- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites + the upstream config/source path it was lifted from. +- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature / + ISO / CIE values; cite the standard. + +The numz/AInVFX custom node is used only as a behavioral-parity benchmark; it is the +origin of none of these values and appears here nowhere. +""" + +# -------------------------------------------------------------------------------------- +# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment) +# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN) +# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070 +# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT). +# -------------------------------------------------------------------------------------- +SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB) +SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB + +# -------------------------------------------------------------------------------------- +# B. Fork heuristics (SEEDVR2 - this integration) +# -------------------------------------------------------------------------------------- +SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim. + # (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.) +SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry. +SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case). +SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM. +SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk. +SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels). +SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16). +SEEDVR2_DEFAULT_TEMPORAL_SIZE = 16 # default VAE temporal tile when unset. + +# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing) +SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk. +SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path. +SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path. +SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path. + +# -------------------------------------------------------------------------------------- +# C. ByteDance config / source (BYTEDANCE - cite myseedvr2/SeedVR) +# -------------------------------------------------------------------------------------- +BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm. +BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift. +BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem). +BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem). +BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28. +BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28. +BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16). +BYTEDANCE_GN_CHUNKS_FP32 = 2 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp32). +BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD = 64 # attn_video_vae.py:308 (force .contiguous() above this b*t). +BYTEDANCE_BLOCK_OUT_CHANNELS = (128, 256, 512, 512) # s8_c16_t4_inflation_sd3.yaml:7-11. +BYTEDANCE_SLICING_SAMPLE_MIN = 4 # s8_c16_t4_inflation_sd3.yaml:22 (slicing_sample_min_size). +BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE = 4 # infer.py:230 (temporal_downsample_factor); the 4n+1 factor. +BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE = 8 # infer.py:231 (spatial_downsample_factor). +BYTEDANCE_SCHEDULE_T = 1000.0 # configs_3b/main.yaml:65 (schedule.T); timestep range. +BYTEDANCE_SPATIAL_DIVISOR = 16 # inference_seedvr2_3b.py:241 (DivisibleCrop((16,16))). +BYTEDANCE_720P_REF_AREA = 45 * 80 # dit_v2/window.py:32 (720p reference area for window scaling). +BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal window frames). +BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency). +BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim). +# Resolution-dependent timestep-shift linear fits: (x1, y1, x2, y2) for get_lin_function. +BYTEDANCE_IMG_SHIFT_FIT = (256 * 256, 1.0, 1024 * 1024, 3.2) # infer.py:242. +BYTEDANCE_VID_SHIFT_FIT = (256 * 256 * 37, 1.0, 1280 * 720 * 145, 5.0) # infer.py:243. + +# -------------------------------------------------------------------------------------- +# D. Published standards (cite the literature) +# -------------------------------------------------------------------------------------- +ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864. + +# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65). +CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta). +CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa). +D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1). +D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn. +WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR). + +# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and +# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the +# exact existing coefficients move verbatim rather than being retyped here. diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 32a1c2134..8f248a4d2 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -12,6 +12,16 @@ from torch.nn.modules.utils import _triple from torch import nn import math from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.seedvr.constants import ( + BYTEDANCE_720P_REF_AREA, + BYTEDANCE_MAX_TEMPORAL_WINDOW, + BYTEDANCE_ROPE_MAX_FREQ, + BYTEDANCE_SINUSOIDAL_DIM, + ROPE_THETA, + SEEDVR2_7B_MLP_CHUNK, + SEEDVR2_7B_VID_DIM, + SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, +) import comfy.model_management import numbers @@ -203,10 +213,10 @@ def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, t, h, w = size resized_nt, resized_nh, resized_nw = num_windows #cal windows under 720p - scale = math.sqrt((45 * 80) / (h * w)) + scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) resized_h, resized_w = round(h * scale), round(w * scale) wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. - wt = ceil(min(t, 30) / resized_nt) # window size. + wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size. return [ ( @@ -226,10 +236,10 @@ def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tup t, h, w = size resized_nt, resized_nh, resized_nw = num_windows #cal windows under 720p - scale = math.sqrt((45 * 80) / (h * w)) + scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) resized_h, resized_w = round(h * scale), round(w * scale) wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. - wt = ceil(min(t, 30) / resized_nt) # window size. + wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. st, sh, sw = ( # shift size. 0.5 if wt < t else 0, @@ -412,7 +422,7 @@ class RotaryEmbeddingBase(nn.Module): self.rope = RotaryEmbedding( dim=dim // rope_dim, freqs_for="pixel", - max_freq=256, + max_freq=BYTEDANCE_ROPE_MAX_FREQ, ) freqs = self.rope.freqs del self.rope.freqs @@ -486,7 +496,7 @@ class MMRotaryEmbeddingBase(RotaryEmbeddingBase): self.rope = RotaryEmbedding( dim=dim // rope_dim, freqs_for="lang", - theta=10000, + theta=ROPE_THETA, cache_if_possible=False, ) freqs = self.rope.freqs @@ -547,14 +557,7 @@ def apply_rotary_emb( return out.type(dtype) def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor: - """Convert lucidrains-interleaved freqs `[..., d]` (`[θ0, θ0, θ1, θ1, ...]` - from `RotaryEmbedding.forward`'s `repeat(freqs, '... n -> ... (n r)', r=2)`) - into flux-canonical `freqs_cis` of shape `[..., d/2, 2, 2]` with the - `cos/-sin/sin/cos` rotation matrix baked in. Output dtype is fp32 to - match `comfy/ldm/flux/math.py:rope` precision; `apply_rope1` consumes - the matrix layout via `freqs_cis[..., 0]` (column 0) and - `freqs_cis[..., 1]` (column 1) of the 2x2 rotation matrix. - """ + """Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`.""" angles = freqs_interleaved[..., ::2].float() cos = torch.cos(angles) sin = torch.sin(angles) @@ -562,27 +565,18 @@ def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor: return rearrange(out, "... d (i j) -> ... d i j", i=2, j=2) -_ROPE1_PARTIAL_CHUNK_TOKENS = 4096 -SEEDVR2_7B_MLP_CHUNK = 8192 - - def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - """Apply ``apply_rope1`` to the leading ``rot_d = 2 * freqs_cis.shape[-3]`` - components of ``t``'s last dim, passing through the remaining dims - untouched in-place for inference tensors. Training tensors are cloned - before slice assignment to preserve autograd correctness. Mirrors the partial-rope contract of the legacy - ``apply_rotary_emb`` wrapper at line 470 (``t_left``/``t_middle``/``t_right`` - split). For SeedVR2-3B this matters because ``rope_dim=128`` integer- - divides into 3 axes as ``128 // 3 = 42`` per-axis, total ``42 * 3 = 126``; - head_dim is 128, so the trailing 2 dims are unrotated. The fast path - triggers when ``rot_d == t.shape[-1]`` (e.g. test rigs where dim is - chosen divisible by 6) and avoids the cat entirely. + """Rotate the leading ``rot_d = 2 * freqs_cis.shape[-3]`` dims of ``t`` and pass the rest + through; in-place for inference, cloned for training (autograd). Mirrors the legacy + ``apply_rotary_emb`` ``t_left``/``t_middle``/``t_right`` split: 3B ``rope_dim=128`` gives + ``42*3 = 126`` rotated of head_dim 128 (trailing 2 unrotated). Fast path skips the cat when + ``rot_d == t.shape[-1]``. """ out = t.clone() if t.requires_grad or comfy.model_management.in_training else t rot_d = 2 * freqs_cis.shape[-3] seq_len = out.shape[-2] - for start in range(0, seq_len, _ROPE1_PARTIAL_CHUNK_TOKENS): - end = min(start + _ROPE1_PARTIAL_CHUNK_TOKENS, seq_len) + for start in range(0, seq_len, SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS): + end = min(start + SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, seq_len) freqs_chunk = freqs_cis[start:end] if rot_d == out.shape[-1]: out[..., start:end, :] = apply_rope1(out[..., start:end, :], freqs_chunk).to(out.dtype) @@ -1385,7 +1379,7 @@ class NaDiT(nn.Module): operations = None, **kwargs, ): - self._7b_version = vid_dim == 3072 + self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM self.dtype = dtype factory_kwargs = {"device": device, "dtype": dtype} window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"] @@ -1427,7 +1421,7 @@ class NaDiT(nn.Module): else nn.Identity() ) self.emb_in = TimeEmbedding( - sinusoidal_dim=256, + sinusoidal_dim=BYTEDANCE_SINUSOIDAL_DIM, hidden_dim=max(vid_dim, txt_dim), output_dim=emb_dim, device=device, dtype=dtype, operations=operations diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 0593fa547..d6d07fe1c 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -10,6 +10,27 @@ from contextlib import contextmanager from comfy.utils import ProgressBar from comfy.ldm.seedvr.model import safe_pad_operation +from comfy.ldm.seedvr.constants import ( + BYTEDANCE_BLOCK_OUT_CHANNELS, + BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD, + BYTEDANCE_GN_CHUNKS_FP16, + BYTEDANCE_GN_CHUNKS_FP32, + BYTEDANCE_LOGVAR_CLAMP_MAX, + BYTEDANCE_LOGVAR_CLAMP_MIN, + BYTEDANCE_SLICING_SAMPLE_MIN, + BYTEDANCE_VAE_CONV_MEM_GIB, + BYTEDANCE_VAE_NORM_MEM_GIB, + BYTEDANCE_VAE_SCALING_FACTOR, + BYTEDANCE_VAE_SHIFTING_FACTOR, + BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE, + BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE, + CIELAB_DELTA, + CIELAB_KAPPA, + D65_WHITE_X, + D65_WHITE_Z, + SEEDVR2_LATENT_CHANNELS, + WAVELET_DECOMP_LEVELS, +) from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.diffusionmodules.model import vae_attention @@ -70,8 +91,8 @@ def tiled_vae( _, _, d, h, w = x.shape - sf_s = getattr(vae_model, "spatial_downsample_factor", 8) - sf_t = getattr(vae_model, "temporal_downsample_factor", 4) + sf_s = getattr(vae_model, "spatial_downsample_factor", BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE) + sf_t = getattr(vae_model, "temporal_downsample_factor", BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE) if encode: slicing_attr = "slicing_sample_min_size" slicing_min_size = _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap) @@ -278,7 +299,7 @@ class DiagonalGaussianDistribution(object): def __init__(self, parameters: torch.Tensor, deterministic: bool = False): self.parameters = parameters self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.logvar = torch.clamp(self.logvar, BYTEDANCE_LOGVAR_CLAMP_MIN, BYTEDANCE_LOGVAR_CLAMP_MAX) self.deterministic = deterministic self.std = torch.exp(0.5 * self.logvar) self.var = torch.exp(self.logvar) @@ -569,7 +590,7 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: x = rearrange(x, "b c t h w -> (b t) c h w") memory_occupy = x.numel() * x.element_size() / 1024**3 if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit(): - num_chunks = min(4 if x.element_size() == 2 else 2, norm_layer.num_groups) + num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups) assert norm_layer.num_groups % num_chunks == 0 num_groups_per_chunk = norm_layer.num_groups // num_chunks @@ -1189,7 +1210,7 @@ class ResnetBlock3D(nn.Module): hidden_states = self.nonlinearity(hidden_states) if self.upsample is not None: - if hidden_states.shape[0] >= 64: + if hidden_states.shape[0] >= BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD: input_tensor = input_tensor.contiguous() hidden_states = hidden_states.contiguous() input_tensor = self.upsample(input_tensor, memory_state=memory_state) @@ -1780,333 +1801,6 @@ class Decoder3D(nn.Module): return sample -def wavelet_blur(image: Tensor, radius): - max_safe_radius = max(1, min(image.shape[-2:]) // 8) - if radius > max_safe_radius: - radius = max_safe_radius - - num_channels = image.shape[1] - - kernel_vals = [ - [0.0625, 0.125, 0.0625], - [0.125, 0.25, 0.125], - [0.0625, 0.125, 0.0625], - ] - kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) - kernel = kernel[None, None].repeat(num_channels, 1, 1, 1) - - image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate') - output = F.conv2d(image, kernel, groups=num_channels, dilation=radius) - - return output - -def wavelet_decomposition(image: Tensor, levels: int = 5): - high_freq = torch.zeros_like(image) - - for i in range(levels): - radius = 2 ** i - low_freq = wavelet_blur(image, radius) - high_freq.add_(image).sub_(low_freq) - image = low_freq - - return high_freq, low_freq - -def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: - - if content_feat.shape != style_feat.shape: - # Resize style to match content spatial dimensions - if len(content_feat.shape) >= 3: - # safe_interpolate_operation handles FP16 conversion automatically - style_feat = safe_interpolate_operation( - style_feat, - size=content_feat.shape[-2:], - mode='bilinear', - align_corners=False - ) - - # Decompose both features into frequency components - content_high_freq, content_low_freq = wavelet_decomposition(content_feat) - del content_low_freq # Free memory immediately - - style_high_freq, style_low_freq = wavelet_decomposition(style_feat) - del style_high_freq # Free memory immediately - - if content_high_freq.shape != style_low_freq.shape: - style_low_freq = safe_interpolate_operation( - style_low_freq, - size=content_high_freq.shape[-2:], - mode='bilinear', - align_corners=False - ) - - content_high_freq.add_(style_low_freq) - - return content_high_freq.clamp_(-1.0, 1.0) - -def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor: - original_shape = source.shape - - # Flatten - source_flat = source.flatten() - reference_flat = reference.flatten() - - # Sort both arrays - source_sorted, source_indices = torch.sort(source_flat) - reference_sorted, _ = torch.sort(reference_flat) - del reference_flat - - # Quantile mapping - n_source = len(source_sorted) - n_reference = len(reference_sorted) - - if n_source == n_reference: - matched_sorted = reference_sorted - else: - # Interpolate reference to match source quantiles - source_quantiles = torch.linspace(0, 1, n_source, device=device) - ref_indices = (source_quantiles * (n_reference - 1)).long() - ref_indices.clamp_(0, n_reference - 1) - matched_sorted = reference_sorted[ref_indices] - del source_quantiles, ref_indices, reference_sorted - - del source_sorted, source_flat - - # Reconstruct using argsort (portable across CUDA/ROCm/MPS) - inverse_indices = torch.argsort(source_indices) - del source_indices - matched_flat = matched_sorted[inverse_indices] - del matched_sorted, inverse_indices - - return matched_flat.reshape(original_shape) - -def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor: - """Convert batch of CIELAB images to RGB color space.""" - L, a, b = lab[:, 0], lab[:, 1], lab[:, 2] - - # LAB to XYZ - fy = (L + 16.0) / 116.0 - fx = a.div(500.0).add_(fy) - fz = fy - b / 200.0 - del L, a, b - - # XYZ transformation - x = torch.where( - fx > epsilon, - torch.pow(fx, 3.0), - fx.mul(116.0).sub_(16.0).div_(kappa) - ) - y = torch.where( - fy > epsilon, - torch.pow(fy, 3.0), - fy.mul(116.0).sub_(16.0).div_(kappa) - ) - z = torch.where( - fz > epsilon, - torch.pow(fz, 3.0), - fz.mul(116.0).sub_(16.0).div_(kappa) - ) - del fx, fy, fz - - # Apply D65 white point (in-place) - x.mul_(0.95047) - # y *= 1.00000 # (no-op, skip) - z.mul_(1.08883) - - xyz = torch.stack([x, y, z], dim=1) - del x, y, z - - # Matrix multiplication: XYZ -> RGB - B, C, H, W = xyz.shape - xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3) - del xyz - - # Ensure dtype consistency for matrix multiplication - xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype) - rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T) - del xyz_flat - - rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) - del rgb_linear_flat - - # Apply inverse gamma correction (delinearize) - mask = rgb_linear > 0.0031308 - rgb = torch.where( - mask, - torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055), - rgb_linear * 12.92 - ) - del mask, rgb_linear - - return torch.clamp(rgb, 0.0, 1.0) - -def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor: - """Convert batch of RGB images to CIELAB color space using D65 illuminant.""" - # Apply sRGB gamma correction (linearize) - mask = rgb > 0.04045 - rgb_linear = torch.where( - mask, - torch.pow((rgb + 0.055) / 1.055, 2.4), - rgb / 12.92 - ) - del mask - - # Matrix multiplication: RGB -> XYZ - B, C, H, W = rgb_linear.shape - rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3) - del rgb_linear - - # Ensure dtype consistency for matrix multiplication - rgb_flat = rgb_flat.to(dtype=matrix.dtype) - xyz_flat = torch.matmul(rgb_flat, matrix.T) - del rgb_flat - - xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) - del xyz_flat - - # Normalize by D65 white point (in-place) - xyz[:, 0].div_(0.95047) # X - # xyz[:, 1] /= 1.00000 # Y (no-op, skip) - xyz[:, 2].div_(1.08883) # Z - - # XYZ to LAB transformation - epsilon_cubed = epsilon ** 3 - mask = xyz > epsilon_cubed - f_xyz = torch.where( - mask, - torch.pow(xyz, 1.0 / 3.0), - xyz.mul(kappa).add_(16.0).div_(116.0) - ) - del xyz, mask - - # Extract channels and compute LAB - L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100] - a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127] - b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127] - del f_xyz - - return torch.stack([L, a, b], dim=1) - -def lab_color_transfer( - content_feat: Tensor, - style_feat: Tensor, - luminance_weight: float = 0.8 -) -> Tensor: - content_feat = wavelet_reconstruction(content_feat, style_feat) - - if content_feat.shape != style_feat.shape: - style_feat = safe_interpolate_operation( - style_feat, - size=content_feat.shape[-2:], - mode='bilinear', - align_corners=False - ) - - device = content_feat.device - - def ensure_float32_precision(c): - orig_dtype = c.dtype - c = c.float() - return c, orig_dtype - content_feat, original_dtype = ensure_float32_precision(content_feat) - style_feat, _ = ensure_float32_precision(style_feat) - - rgb_to_xyz_matrix = torch.tensor([ - [0.4124564, 0.3575761, 0.1804375], - [0.2126729, 0.7151522, 0.0721750], - [0.0193339, 0.1191920, 0.9503041] - ], dtype=torch.float32, device=device) - - xyz_to_rgb_matrix = torch.tensor([ - [ 3.2404542, -1.5371385, -0.4985314], - [-0.9692660, 1.8760108, 0.0415560], - [ 0.0556434, -0.2040259, 1.0572252] - ], dtype=torch.float32, device=device) - - epsilon = 6.0 / 29.0 - kappa = (29.0 / 3.0) ** 3 - - content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) - style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) - - # Convert to LAB color space - content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa) - del content_feat - - style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa) - del style_feat, rgb_to_xyz_matrix - - # Match chrominance channels (a*, b*) for accurate color transfer - matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device) - matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device) - - # Handle luminance with weighted blending - if luminance_weight < 1.0: - # Partially match luminance for better overall color accuracy - matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device) - # Blend: preserve some content L* for detail, adopt some style L* for color - result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight)) - del matched_L - else: - # Fully preserve content luminance - result_L = content_lab[:, 0] - - del content_lab, style_lab - - # Reconstruct LAB with corrected channels - result_lab = torch.stack([result_L, matched_a, matched_b], dim=1) - del result_L, matched_a, matched_b - - # Convert back to RGB - result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa) - del result_lab, xyz_to_rgb_matrix - - # Convert back to [-1, 1] range (in-place) - result = result_rgb.mul_(2.0).sub_(1.0) - del result_rgb - - result = result.to(original_dtype) - - return result - - -def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor: - return wavelet_reconstruction(content_feat, style_feat) - - -def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor: - if content_feat.shape != style_feat.shape: - style_feat = safe_interpolate_operation( - style_feat, - size=content_feat.shape[-2:], - mode='bilinear', - align_corners=False, - ) - - original_dtype = content_feat.dtype - content_feat = content_feat.float() - style_feat = style_feat.float() - - b, c = content_feat.shape[:2] - content_flat = content_feat.reshape(b, c, -1) - style_flat = style_feat.reshape(b, c, -1) - - content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1) - content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1) - style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1) - style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1) - del content_flat, style_flat - - normalized = (content_feat - content_mean) / content_std - del content_mean, content_std - result = normalized * style_std + style_mean - del normalized, style_mean, style_std - - result = result.clamp_(-1.0, 1.0) - if result.dtype != original_dtype: - result = result.to(original_dtype) - return result - - class VideoAutoencoderKL(nn.Module): def __init__( self, @@ -2114,7 +1808,7 @@ class VideoAutoencoderKL(nn.Module): out_channels: int = 3, layers_per_block: int = 2, act_fn: str = "silu", - latent_channels: int = 16, + latent_channels: int = SEEDVR2_LATENT_CHANNELS, norm_num_groups: int = 32, attention: bool = True, temporal_scale_num: int = 2, @@ -2124,14 +1818,14 @@ class VideoAutoencoderKL(nn.Module): time_receptive_field: _receptive_field_t = "full", use_quant_conv: bool = False, use_post_quant_conv: bool = False, - slicing_sample_min_size = 4, + slicing_sample_min_size = BYTEDANCE_SLICING_SAMPLE_MIN, *args, **kwargs, ): self.slicing_sample_min_size = slicing_sample_min_size self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None - block_out_channels = (128, 256, 512, 512) + block_out_channels = BYTEDANCE_BLOCK_OUT_CHANNELS down_block_types = ("DownEncoderBlock3D",) * 4 up_block_types = ("UpDecoderBlock3D",) * 4 super().__init__() @@ -2329,7 +2023,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): self.freeze_encoder = freeze_encoder self.enable_tiling = False super().__init__(*args, **kwargs) - self.set_memory_limit(0.5, 0.5) + self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB) def forward(self, x: torch.FloatTensor): with torch.no_grad() if self.freeze_encoder else nullcontext(): @@ -2377,8 +2071,8 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): "4-D collapsed (B, 16*T, H, W) or 5-D (B, 16, T, H, W); " f"got shape {tuple(z.shape)}." ) - scale = 0.9152 - shift = 0 + scale = BYTEDANCE_VAE_SCALING_FACTOR + shift = BYTEDANCE_VAE_SHIFTING_FACTOR latent = latent / scale + shift self.device = latent.device diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index 6bc2de17f..e48d9e463 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -9,11 +9,24 @@ import gc import comfy.model_management import comfy.sample import comfy.samplers -from comfy.ldm.seedvr.vae import ( +from comfy.ldm.seedvr.color_fix import ( adain_color_transfer, lab_color_transfer, wavelet_color_transfer, ) +from comfy.ldm.seedvr.constants import ( + BYTEDANCE_IMG_SHIFT_FIT, + BYTEDANCE_SCHEDULE_T, + BYTEDANCE_VID_SHIFT_FIT, + SEEDVR2_ADAIN_SCALE_MULTIPLIER, + SEEDVR2_COLOR_MEM_HEADROOM, + SEEDVR2_COND_CHANNELS, + SEEDVR2_DTYPE_BYTES_FLOOR, + SEEDVR2_LAB_SCALE_MULTIPLIER, + SEEDVR2_LATENT_CHANNELS, + SEEDVR2_OOM_BACKOFF_DIVISOR, + SEEDVR2_WAVELET_SCALE_MULTIPLIER, +) from torchvision.transforms import functional as TVF from torchvision.transforms import Lambda @@ -23,10 +36,6 @@ from torchvision.transforms.functional import InterpolationMode _SEEDVR2_INVALID_MODEL_MSG_PREFIX = ( "SeedVR2Conditioning: model object does not match expected SeedVR2 structure" ) -LAB_SCALE_MULTIPLIER = 13 -WAVELET_SCALE_MULTIPLIER = 10 -ADAIN_SCALE_MULTIPLIER = 6 -COLOR_CORRECTION_MEMORY_HEADROOM = 0.75 # Private sentinel for getattr default: distinguishes "attribute missing" # from "attribute present but None" so the failure message is accurate. @@ -57,17 +66,7 @@ def _seedvr2_auto_chunk_attempts(t_latent, t_pixel, frames_per_chunk): def _resolve_seedvr2_diffusion_model(model): - """Resolve the inner SeedVR2 diffusion-model module from a ComfyUI model - patcher object. Fails loud with a ``RuntimeError`` whose message begins - with ``_SEEDVR2_INVALID_MODEL_MSG_PREFIX`` when the expected wrapper - shape (``model.model.diffusion_model``) is absent. - - Distinguishes four failure modes via the ``_ATTR_MISSING`` sentinel: - ``model.model`` missing, ``model.model is None``, - ``model.model.diffusion_model`` missing, ``model.model.diffusion_model - is None``. Each mode produces an accurate error message rather than - conflating "attribute missing" with "attribute is None". - """ + """Resolve ``model.model.diffusion_model``, failing loud via the ``_ATTR_MISSING`` sentinel so each of the four modes (model/diffusion_model missing vs None) gives an accurate message.""" inner = getattr(model, "model", _ATTR_MISSING) if inner is _ATTR_MISSING: raise RuntimeError( @@ -94,15 +93,7 @@ def _resolve_seedvr2_diffusion_model(model): def _apply_rope_freqs_float32_cast(diffusion_model): - """Cast every nested module's ``rope.freqs`` parameter data to ``float32`` - when it is not already in float32. Idempotency is per-tensor by dtype - check, NOT a per-instance sentinel attribute — a sentinel would survive - Comfy's dynamic model unload/reload cycle while ``rope.freqs`` itself - is restored from the archived dtype, leaving RoPE running in fp16/bf16 - on subsequent calls. The dtype check makes the cast self-correcting - against weight-restore lifecycle events. Iteration cost is one walk of - the diffusion-model module tree per ``execute()`` call (microseconds). - """ + """Cast every module's ``rope.freqs`` to float32; the per-tensor dtype check (not a sentinel attr) self-corrects across Comfy's unload/reload, which would otherwise restore the archived fp16/bf16 dtype.""" for module in diffusion_model.modules(): if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'): if module.rope.freqs.data.dtype != torch.float32: @@ -140,8 +131,8 @@ def timestep_transform(timesteps, latents_shapes): b = y1 - m * x1 return lambda x: m * x + b - img_shift_fn = get_lin_function(x1=256 * 256, y1=1.0, x2=1024 * 1024, y2=3.2) - vid_shift_fn = get_lin_function(x1=256 * 256 * 37, y1=1.0, x2=1280 * 720 * 145, y2=5.0) + img_shift_fn = get_lin_function(*BYTEDANCE_IMG_SHIFT_FIT) + vid_shift_fn = get_lin_function(*BYTEDANCE_VID_SHIFT_FIT) shift = torch.where( frames > 1, vid_shift_fn(heights * widths * frames), @@ -149,7 +140,7 @@ def timestep_transform(timesteps, latents_shapes): ).to(timesteps.device) # Shift timesteps. - T = 1000.0 + T = BYTEDANCE_SCHEDULE_T timesteps = timesteps / T timesteps = shift * timesteps / (1 + (shift - 1) * timesteps) timesteps = timesteps * T @@ -157,7 +148,7 @@ def timestep_transform(timesteps, latents_shapes): def inter(x_0, x_T, t): t = expand_dims(t, x_0.ndim) - T = 1000.0 + T = BYTEDANCE_SCHEDULE_T B = lambda t: t / T A = lambda t: 1 - (t / T) return A(t) * x_0 + B(t) * x_T @@ -235,6 +226,8 @@ def _seedvr2_resize_and_pad(images, upscaled_shorter_edge, node_name): f"got {upscaled_shorter_edge}." ) original_image = images + if images.shape[-1] > 3: + images = images[..., :3] if images.dim() == 4: # Comfy video components arrive as a 4-D IMAGE frame sequence: # (frames, H, W, C). SeedVR2 consumes that as one video. @@ -268,10 +261,12 @@ class SeedVR2Resize(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SeedVR2Resize", - category="image/video", + display_name="Resize Image for SeedVR2", + category="image/upscaling", + description="Resize an image to a SeedVR2-compatible size by a multiplier.", inputs=[ - io.Image.Input("images"), - io.Float.Input("multiplier", default=4.0, min=0.01), + io.Image.Input("images", tooltip="The image(s) to resize."), + io.Float.Input("multiplier", default=4.0, min=0.01, tooltip="Upscale factor applied to the shorter edge."), ], outputs=[ io.Image.Output("input_pixels"), @@ -304,10 +299,12 @@ class SeedVR2ResizeAdvanced(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SeedVR2ResizeAdvanced", - category="image/video", + display_name="Resize Image for SeedVR2 (Advanced)", + category="image/upscaling", + description="Resize an image to an exact shorter-edge size for SeedVR2.", inputs=[ - io.Image.Input("images"), - io.Int.Input("shorter_edge", default=1280, min=2), + io.Image.Input("images", tooltip="The image(s) to resize."), + io.Int.Input("shorter_edge", default=1280, min=2, tooltip="Target length of the shorter edge, in pixels."), ], outputs=[ io.Image.Output("input_pixels"), @@ -323,17 +320,30 @@ class SeedVR2ResizeAdvanced(io.ComfyNode): ) +def _edge_guided_alpha_upscale(alpha, out_h, out_w): + a = alpha.float() + extreme_fraction = ((a < 0.1) | (a > 0.9)).float().mean() + if extreme_fraction > 0.9: + up = torch.nn.functional.interpolate(a, size=(out_h, out_w), mode="bilinear", align_corners=False, antialias=True) + up = torch.clamp((up - 0.5) * 4.0 + 0.5, 0.0, 1.0) + else: + up = torch.nn.functional.interpolate(a, size=(out_h, out_w), mode="bicubic", align_corners=False, antialias=True).clamp(0.0, 1.0) + return up + + class SeedVR2PostProcessing(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="SeedVR2PostProcessing", - category="image/video", + display_name="Post-Process SeedVR2 Output", + category="image/upscaling", + description="Align the upscaled output to the original's geometry and optionally color-correct it against the original.", inputs=[ - io.Image.Input("decoded"), - io.Image.Input("original_image"), - io.Int.Input("upscaled_shorter_edge", min=2, force_input=True), - io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab"), + io.Image.Input("decoded", tooltip="The decoded upscaled image to color-correct."), + io.Image.Input("original_image", tooltip="The original image used as the color reference."), + io.Int.Input("upscaled_shorter_edge", min=2, force_input=True, tooltip="Shorter-edge size from the resize node."), + io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab", tooltip="How to match the output's color to the original. lab: transfer color in CIELAB space, preserving detail (most faithful). wavelet: transfer low-frequency color, keeping upscaled high-frequency detail. adain: match per-channel mean/std (fastest, global tint). none: skip color transfer (geometry alignment only)."), ], outputs=[io.Image.Output()], ) @@ -341,6 +351,10 @@ class SeedVR2PostProcessing(io.ComfyNode): @classmethod def execute(cls, decoded, original_image, upscaled_shorter_edge, color_correction_method): cls._validate_upscaled_shorter_edge(upscaled_shorter_edge) + alpha_input = None + if original_image.shape[-1] == 4: + alpha_input = original_image[..., 3:4] + original_image = original_image[..., :3] decoded_5d, decoded_was_4d = cls._as_bthwc(decoded) original_5d, _ = cls._as_bthwc(original_image) decoded_5d = cls._restore_reference_batch_time(decoded_5d, original_5d) @@ -374,6 +388,13 @@ class SeedVR2PostProcessing(io.ComfyNode): else: raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}") + if alpha_input is not None: + ab, at = output.shape[0], output.shape[1] + alpha_5d, _ = cls._as_bthwc(alpha_input) + alpha_flat = rearrange(alpha_5d[:ab, :at], "b t h w c -> (b t) c h w") + alpha_up = _edge_guided_alpha_upscale(alpha_flat, output.shape[2], output.shape[3]) + alpha_up = rearrange(alpha_up, "(b t) c h w -> b t h w c", b=ab, t=at) + output = torch.cat([output, alpha_up.to(dtype=output.dtype, device=output.device)], dim=-1) h2 = output.shape[-3] - (output.shape[-3] % 2) w2 = output.shape[-2] - (output.shape[-2] % 2) output = output[:, :, :h2, :w2, :] @@ -472,7 +493,7 @@ class SeedVR2PostProcessing(io.ComfyNode): "SeedVR2PostProcessing: color correction OOM at one frame; " f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}." ) from e - next_chunk_size = max(1, chunk_size // 2) + next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR) comfy.model_management.soft_empty_cache() chunk_size = next_chunk_size @@ -510,23 +531,23 @@ class SeedVR2PostProcessing(io.ComfyNode): multiplier = cls._color_correction_memory_multiplier(color_correction_method) frames = decoded_flat.shape[0] _, channels, height, width = decoded_flat.shape - dtype_bytes = max(decoded_flat.element_size(), 4) + dtype_bytes = max(decoded_flat.element_size(), SEEDVR2_DTYPE_BYTES_FLOOR) bytes_per_frame = height * width * channels * dtype_bytes * multiplier if bytes_per_frame <= 0: return frames color_device = comfy.model_management.vae_device() free_memory = comfy.model_management.get_free_memory(color_device) - chunk_size = int((free_memory * COLOR_CORRECTION_MEMORY_HEADROOM) // bytes_per_frame) + chunk_size = int((free_memory * SEEDVR2_COLOR_MEM_HEADROOM) // bytes_per_frame) return max(1, min(frames, chunk_size)) @staticmethod def _color_correction_memory_multiplier(color_correction_method): if color_correction_method == "lab": - return LAB_SCALE_MULTIPLIER + return SEEDVR2_LAB_SCALE_MULTIPLIER if color_correction_method == "wavelet": - return WAVELET_SCALE_MULTIPLIER + return SEEDVR2_WAVELET_SCALE_MULTIPLIER if color_correction_method == "adain": - return ADAIN_SCALE_MULTIPLIER + return SEEDVR2_ADAIN_SCALE_MULTIPLIER raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}") @staticmethod @@ -549,10 +570,12 @@ class SeedVR2Conditioning(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SeedVR2Conditioning", - category="image/video", + display_name="Apply SeedVR2 Conditioning", + category="conditioning", + description="Build SeedVR2 positive/negative conditioning from a VAE latent.", inputs=[ - io.Model.Input("model"), - io.Latent.Input("vae_conditioning", display_name="LATENT"), + io.Model.Input("model", tooltip="The SeedVR2 model."), + io.Latent.Input("vae_conditioning", tooltip="The VAE-encoded latent to condition on."), ], outputs=[ io.Model.Output(display_name = "model"), @@ -571,10 +594,10 @@ class SeedVR2Conditioning(io.ComfyNode): "SeedVR2Conditioning expects a 5-D VAE latent in Comfy " f"channel-first layout; got shape {tuple(vae_conditioning.shape)}." ) - if vae_conditioning.shape[-1] == _SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != _SEEDVR2_LATENT_CHANNELS: + if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS: raise ValueError( "SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy " - f"channel-first layout (B, {_SEEDVR2_LATENT_CHANNELS}, T, H, W); " + f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); " f"got channel-last shape {tuple(vae_conditioning.shape)}." ) vae_conditioning = vae_conditioning.movedim(1, -1).contiguous() @@ -622,27 +645,9 @@ class SeedVR2Conditioning(io.ComfyNode): return io.NodeOutput(model_patcher, positive, negative, {"samples": latent}) -# SeedVR2 latent / conditioning channel constants. The SeedVR2 conditioning -# stage collapses ``(B, C, T, H, W) -> (B, C*T, H, W)`` for both the latent -# (C=16) and the per-frame condition tensor (C=17 = 16 latent + 1 mask), as -# required by ``NaDiT.forward`` which un-collapses via -# ``view(B, 16, -1, H, W)`` and ``view(B, 17, -1, H, W)`` respectively. -_SEEDVR2_LATENT_CHANNELS = 16 -_SEEDVR2_CONDITION_CHANNELS = 17 - - def _slice_collapsed_4d_along_t(tensor_4d: torch.Tensor, t_start: int, t_end: int, channels: int) -> torch.Tensor: - """Slice a SeedVR2-style collapsed 4D tensor ``(B, channels*T, H, W)`` - along the latent T axis, returning ``(B, channels*(t_end - t_start), H, W)``. - - Reshape -> slice -> ``.contiguous()`` -> re-collapse. ``reshape`` is - used for the un-collapse so non-contiguous incoming tensors from - cropping or slicing nodes are accepted. The - ``.contiguous()`` is mandatory: T-axis slicing of a 5D tensor produces a - non-contiguous view, and the subsequent re-collapse requires contiguous - storage. - """ + """Slice collapsed ``(B, channels*T, H, W)`` along latent T: reshape (accepts non-contiguous inputs), slice, ``.contiguous()`` (T-slice of 5D is a non-contiguous view; re-collapse needs contiguous), re-collapse.""" B, CT, H, W = tensor_4d.shape if CT % channels != 0: raise ValueError( @@ -661,19 +666,7 @@ def _slice_collapsed_4d_along_t(tensor_4d: torch.Tensor, t_start: int, def _slice_seedvr2_cond_along_t(cond_list, t_start: int, t_end: int): - """Build a new SeedVR2 conditioning list with the per-frame ``condition`` - tensor sliced along the latent T axis. - - SeedVR2 conditioning entries have the shape - ``[text_cond_tensor, options_dict]`` where ``options_dict["condition"]`` - is a 4D collapsed ``(B, 17*T, H, W)`` tensor; the text tensor itself has - no temporal axis and is passed through unchanged. Other keys in the - options dict (controlnets, etc.) are also passed through unchanged. If - an entry has no ``"condition"`` key, the entry is forwarded verbatim. - - A new list of ``[text_cond, new_options_dict]`` pairs is returned; the - original ``cond_list`` and its options dicts are not mutated. - """ + """Return a new conditioning list with each entry's ``options["condition"]`` (collapsed ``(B, 17*T, H, W)``) sliced along latent T; text tensors, other option keys, and condition-less entries pass through unchanged and inputs are not mutated.""" new_list = [] for entry in cond_list: text_cond, options = entry[0], entry[1] @@ -683,7 +676,7 @@ def _slice_seedvr2_cond_along_t(cond_list, t_start: int, t_end: int): new_options = options.copy() new_options["condition"] = _slice_collapsed_4d_along_t( new_options["condition"], t_start, t_end, - _SEEDVR2_CONDITION_CHANNELS, + SEEDVR2_COND_CHANNELS, ) new_list.append([text_cond, new_options]) return new_list @@ -693,24 +686,16 @@ def _slice_seedvr2_noise_mask_along_t(noise_mask: torch.Tensor, samples_4d: torch.Tensor, t_start: int, t_end: int): - """Slice collapsed SeedVR2 masks and preserve standard masks. - - ``SetLatentNoiseMask`` produces ``(B, 1, H, W)`` masks that KSampler - expands to the latent shape. Only masks already expanded to the full - collapsed ``(B, 16*T, H, W)`` shape need temporal slicing here. - """ + """Slice only masks already expanded to collapsed ``(B, 16*T, H, W)``; pass standard ``(B, 1, H, W)`` ``SetLatentNoiseMask`` outputs through for KSampler to expand.""" if noise_mask.ndim == samples_4d.ndim and noise_mask.shape[1] == samples_4d.shape[1]: return _slice_collapsed_4d_along_t( - noise_mask, t_start, t_end, _SEEDVR2_LATENT_CHANNELS, + noise_mask, t_start, t_end, SEEDVR2_LATENT_CHANNELS, ) return noise_mask def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor: - """Concatenate a list of SeedVR2-style collapsed 4D tensors - ``(B, channels*T_i, H, W)`` along the latent T axis. Each chunk is - un-collapsed to 5D, concatenated on ``dim=2``, then re-collapsed to 4D. - """ + """Concatenate collapsed ``(B, channels*T_i, H, W)`` chunks along latent T: un-collapse to 5D, cat on ``dim=2``, re-collapse to 4D.""" if len(chunks_4d) == 0: raise ValueError("_concat_chunks_along_t: empty chunk list.") fives = [] @@ -729,19 +714,10 @@ def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor: def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor: - """Build a 1D crossfade weight tensor of length ``overlap`` for the - *previous* chunk's contribution; the current chunk's weight is - ``1 - w_prev``. - - Mirrors the numz ``blend_overlapping_frames`` shape - (AInVFX/numz fork ``src/core/generation_utils.py``, - ``blend_overlapping_frames``): a Hann window with a ``[1/3, 2/3]`` - dead-band when ``overlap >= 3``, and a plain linear ramp when - ``overlap < 3`` (the dead-band would collapse the transition for - very small overlap counts). The numz reference operates on - pixel-space tensors ``[overlap, H, W, C]``; this 1D form is - reshaped by the caller to broadcast across the latent's - ``(B, C, T_overlap, H, W)`` axes. + """1D length-``overlap`` crossfade weights for the previous chunk (current = ``1 - w_prev``): + Hann window with a ``[1/3, 2/3]`` dead-band for ``overlap >= 3``, linear ramp for ``overlap < 3`` + (dead-band would collapse a tiny transition). Window shape matched to numz ``blend_overlapping_frames`` + for parity (reference, not source); caller broadcasts across ``(B, C, T_overlap, H, W)``. """ if overlap < 1: raise ValueError( @@ -758,14 +734,7 @@ def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor: def _blend_overlap_region(prev_tail_5d: torch.Tensor, cur_head_5d: torch.Tensor) -> torch.Tensor: - """Blend two 5D ``(B, C, T_overlap, H, W)`` tensors of equal shape - using a 1D Hann/linear ramp along the T axis. ``prev_tail_5d`` - receives the descending weight; ``cur_head_5d`` receives - ``1 - w_prev``. - - The caller is responsible for ensuring both inputs have identical - shape and dtype/device. - """ + """Blend two equal-shape 5D ``(B, C, T_overlap, H, W)`` tensors with a 1D Hann/linear T-ramp: ``prev_tail_5d`` takes the descending weight, ``cur_head_5d`` takes ``1 - w_prev`` (caller ensures matching shape/dtype/device).""" if prev_tail_5d.shape != cur_head_5d.shape: raise ValueError( f"_blend_overlap_region: shape mismatch " @@ -784,20 +753,7 @@ def _blend_overlap_region(prev_tail_5d: torch.Tensor, def _concat_chunks_with_overlap_blend(chunk_specs, channels: int, overlap_latent: int) -> torch.Tensor: - """Concatenate temporally-overlapping chunks back into a single - collapsed 4D tensor, blending overlap regions with a Hann/linear - crossfade. - - ``chunk_specs`` is a list of ``(t_start, t_end, chunk_4d)`` tuples - in source-latent T coordinates. ``overlap_latent == 0`` is a fast - path that delegates to plain concatenation (and produces output - bit-identical to ``_concat_chunks_along_t`` of the same chunks). - - The blend at each pair of adjacent chunks acts on the actual - overlap region width ``min(prev_end - cur_start, current chunk - length)``, which may be smaller than ``overlap_latent`` when the - final chunk is a runt shorter than the configured overlap. - """ + """Concatenate overlapping ``(t_start, t_end, chunk_4d)`` specs (source-latent T coords) into one collapsed 4D tensor, Hann/linear-blending overlaps; ``overlap_latent == 0`` fast-paths to plain concat (bit-identical to ``_concat_chunks_along_t``). Each blend uses the actual width ``min(prev_end - cur_start, chunk length)``, smaller than ``overlap_latent`` for a runt final chunk.""" if len(chunk_specs) == 0: raise ValueError("_concat_chunks_with_overlap_blend: empty chunk list.") if overlap_latent < 0: @@ -877,12 +833,7 @@ def _run_standard_sample(model, seed: int, steps: int, cfg: float, sampler_name: str, scheduler: str, positive, negative, latent_image: dict, denoise: float) -> dict: - """Single-shot delegation that mirrors the standard ``common_ksampler`` - flow (``nodes.py:common_ksampler``): generate noise from seed, run - ``comfy.sample.sample``, return a latent dict. Used by the - ProgressiveSampler short-circuit when the full sequence fits in one - chunk so chunking introduces no overhead for small videos. - """ + """Single-shot mirror of ``nodes.py:common_ksampler`` (seed -> noise, ``comfy.sample.sample``, latent dict); used by the ProgressiveSampler short-circuit when the whole sequence fits one chunk.""" samples_in = latent_image["samples"] samples_in = comfy.sample.fix_empty_latent_channels( model, samples_in, latent_image.get("downscale_ratio_spacial", None), @@ -929,43 +880,45 @@ class SeedVR2ProgressiveSampler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SeedVR2ProgressiveSampler", + display_name="Sample SeedVR2 (Progressive)", category="sampling", + description="Sample a SeedVR2 latent in sequential temporal chunks to allow longer videos to fit into VRAM via frame blending the resulting upscaled latents.", inputs=[ - io.Model.Input("model"), + io.Model.Input("model", tooltip="The model used for denoising the input latent."), io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, - control_after_generate=True), - io.Int.Input("steps", default=20, min=1, max=10000), + control_after_generate=True, + tooltip="The random seed used for creating the noise."), + io.Int.Input("steps", default=20, min=1, max=10000, + tooltip="The number of steps used in the denoising process."), io.Float.Input("cfg", default=1.0, min=0.0, max=100.0, - step=0.1, round=0.01), + step=0.1, round=0.01, + tooltip="The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."), io.Combo.Input("sampler_name", - options=comfy.samplers.SAMPLER_NAMES), + options=comfy.samplers.SAMPLER_NAMES, + tooltip="The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."), io.Combo.Input("scheduler", - options=comfy.samplers.SCHEDULER_NAMES), - io.Conditioning.Input("positive"), - io.Conditioning.Input("negative"), - io.Latent.Input("latent_image"), + options=comfy.samplers.SCHEDULER_NAMES, + tooltip="The scheduler controls how noise is gradually removed to form the image."), + io.Conditioning.Input("positive", + tooltip="The conditioning describing the attributes you want to include in the image."), + io.Conditioning.Input("negative", + tooltip="The conditioning describing the attributes you want to exclude from the image."), + io.Latent.Input("latent_image", + tooltip="The latent image to denoise."), io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, - step=0.01), + step=0.01, + tooltip="The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."), io.Int.Input("frames_per_chunk", default=21, min=1, - max=16384, step=4), + max=16384, step=4, + tooltip="Pixel frames per temporal chunk (4n+1: 1, 5, 9, 13, ...)."), io.Int.Input("temporal_overlap", default=0, min=0, max=16384, - tooltip="Latent-frame overlap between " - "adjacent chunks; blended with a " - "Hann window (linear for overlap " - "< 3). 0 = no blend, pure concat. " - "Values >= the chunk's latent-frame " - "length use the maximum valid " - "overlap; 1 latent frame corresponds " - "to ~4 pixel frames."), + tooltip="Latent frames blended between adjacent chunks to hide the seam; 0 = no blend."), io.Combo.Input("chunking_mode", options=["manual", "auto"], default="manual", - tooltip="manual = use frames_per_chunk " - "exactly; auto = retry only real OOM " - "failures with progressively smaller " - "temporal chunks."), + tooltip="manual = use frames_per_chunk exactly; auto = shrink the chunk until it fits in VRAM."), ], outputs=[io.Latent.Output()], ) @@ -999,14 +952,14 @@ class SeedVR2ProgressiveSampler(io.ComfyNode): f"(B, 16*T, H, W); got shape {tuple(samples_4d.shape)}." ) B, CT, H, W = samples_4d.shape - if CT % _SEEDVR2_LATENT_CHANNELS != 0: + if CT % SEEDVR2_LATENT_CHANNELS != 0: raise ValueError( f"SeedVR2ProgressiveSampler: collapsed channel dim {CT} is " f"not divisible by SeedVR2 latent channels " - f"{_SEEDVR2_LATENT_CHANNELS}; latent does not appear to be " + f"{SEEDVR2_LATENT_CHANNELS}; latent does not appear to be " f"SeedVR2-shaped." ) - T_latent = CT // _SEEDVR2_LATENT_CHANNELS + T_latent = CT // SEEDVR2_LATENT_CHANNELS T_pixel = 4 * (T_latent - 1) + 1 if chunking_mode not in ("manual", "auto"): @@ -1106,11 +1059,11 @@ class SeedVR2ProgressiveSampler(io.ComfyNode): def _sample_one_chunk(chunk_start, chunk_end): samples_chunk = _slice_collapsed_4d_along_t( samples_4d, chunk_start, chunk_end, - _SEEDVR2_LATENT_CHANNELS, + SEEDVR2_LATENT_CHANNELS, ) noise_chunk = _slice_collapsed_4d_along_t( noise_full, chunk_start, chunk_end, - _SEEDVR2_LATENT_CHANNELS, + SEEDVR2_LATENT_CHANNELS, ) positive_chunk = _slice_seedvr2_cond_along_t( positive, chunk_start, chunk_end, @@ -1140,7 +1093,7 @@ class SeedVR2ProgressiveSampler(io.ComfyNode): chunk_specs.append((chunk_start, chunk_end, chunk_samples)) final = _concat_chunks_with_overlap_blend( - chunk_specs, _SEEDVR2_LATENT_CHANNELS, temporal_overlap, + chunk_specs, SEEDVR2_LATENT_CHANNELS, temporal_overlap, ) out = latent_image.copy()