Refine SeedVR2 alpha channel handling and node UX

This commit is contained in:
John Pollock 2026-06-02 21:12:34 -05:00
parent 7431bef672
commit 22078c799b
5 changed files with 598 additions and 535 deletions

View File

@ -0,0 +1,340 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from comfy.ldm.seedvr.model import safe_pad_operation
from comfy.ldm.seedvr.vae import safe_interpolate_operation
from comfy.ldm.seedvr.constants import (
CIELAB_DELTA,
CIELAB_KAPPA,
D65_WHITE_X,
D65_WHITE_Z,
WAVELET_DECOMP_LEVELS,
)
def wavelet_blur(image: Tensor, radius):
max_safe_radius = max(1, min(image.shape[-2:]) // 8)
if radius > max_safe_radius:
radius = max_safe_radius
num_channels = image.shape[1]
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate')
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS):
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq.add_(image).sub_(low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
if content_feat.shape != style_feat.shape:
# Resize style to match content spatial dimensions
if len(content_feat.shape) >= 3:
# safe_interpolate_operation handles FP16 conversion automatically
style_feat = safe_interpolate_operation(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False
)
# Decompose both features into frequency components
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq # Free memory immediately
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq # Free memory immediately
if content_high_freq.shape != style_low_freq.shape:
style_low_freq = safe_interpolate_operation(
style_low_freq,
size=content_high_freq.shape[-2:],
mode='bilinear',
align_corners=False
)
content_high_freq.add_(style_low_freq)
return content_high_freq.clamp_(-1.0, 1.0)
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
original_shape = source.shape
# Flatten
source_flat = source.flatten()
reference_flat = reference.flatten()
# Sort both arrays
source_sorted, source_indices = torch.sort(source_flat)
reference_sorted, _ = torch.sort(reference_flat)
del reference_flat
# Quantile mapping
n_source = len(source_sorted)
n_reference = len(reference_sorted)
if n_source == n_reference:
matched_sorted = reference_sorted
else:
# Interpolate reference to match source quantiles
source_quantiles = torch.linspace(0, 1, n_source, device=device)
ref_indices = (source_quantiles * (n_reference - 1)).long()
ref_indices.clamp_(0, n_reference - 1)
matched_sorted = reference_sorted[ref_indices]
del source_quantiles, ref_indices, reference_sorted
del source_sorted, source_flat
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
inverse_indices = torch.argsort(source_indices)
del source_indices
matched_flat = matched_sorted[inverse_indices]
del matched_sorted, inverse_indices
return matched_flat.reshape(original_shape)
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
"""Convert batch of CIELAB images to RGB color space."""
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
# LAB to XYZ
fy = (L + 16.0) / 116.0
fx = a.div(500.0).add_(fy)
fz = fy - b / 200.0
del L, a, b
# XYZ transformation
x = torch.where(
fx > epsilon,
torch.pow(fx, 3.0),
fx.mul(116.0).sub_(16.0).div_(kappa)
)
y = torch.where(
fy > epsilon,
torch.pow(fy, 3.0),
fy.mul(116.0).sub_(16.0).div_(kappa)
)
z = torch.where(
fz > epsilon,
torch.pow(fz, 3.0),
fz.mul(116.0).sub_(16.0).div_(kappa)
)
del fx, fy, fz
# Apply D65 white point (in-place)
x.mul_(D65_WHITE_X)
# y *= 1.00000 # (no-op, skip)
z.mul_(D65_WHITE_Z)
xyz = torch.stack([x, y, z], dim=1)
del x, y, z
# Matrix multiplication: XYZ -> RGB
B, C, H, W = xyz.shape
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
del xyz
# Ensure dtype consistency for matrix multiplication
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
del xyz_flat
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
del rgb_linear_flat
# Apply inverse gamma correction (delinearize)
mask = rgb_linear > 0.0031308
rgb = torch.where(
mask,
torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055),
rgb_linear * 12.92
)
del mask, rgb_linear
return torch.clamp(rgb, 0.0, 1.0)
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
"""Convert batch of RGB images to CIELAB color space using D65 illuminant."""
# Apply sRGB gamma correction (linearize)
mask = rgb > 0.04045
rgb_linear = torch.where(
mask,
torch.pow((rgb + 0.055) / 1.055, 2.4),
rgb / 12.92
)
del mask
# Matrix multiplication: RGB -> XYZ
B, C, H, W = rgb_linear.shape
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
del rgb_linear
# Ensure dtype consistency for matrix multiplication
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
xyz_flat = torch.matmul(rgb_flat, matrix.T)
del rgb_flat
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
del xyz_flat
# Normalize by D65 white point (in-place)
xyz[:, 0].div_(D65_WHITE_X) # X
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
xyz[:, 2].div_(D65_WHITE_Z) # Z
# XYZ to LAB transformation
epsilon_cubed = epsilon ** 3
mask = xyz > epsilon_cubed
f_xyz = torch.where(
mask,
torch.pow(xyz, 1.0 / 3.0),
xyz.mul(kappa).add_(16.0).div_(116.0)
)
del xyz, mask
# Extract channels and compute LAB
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
del f_xyz
return torch.stack([L, a, b], dim=1)
def lab_color_transfer(
content_feat: Tensor,
style_feat: Tensor,
luminance_weight: float = 0.8
) -> Tensor:
content_feat = wavelet_reconstruction(content_feat, style_feat)
if content_feat.shape != style_feat.shape:
style_feat = safe_interpolate_operation(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False
)
device = content_feat.device
def ensure_float32_precision(c):
orig_dtype = c.dtype
c = c.float()
return c, orig_dtype
content_feat, original_dtype = ensure_float32_precision(content_feat)
style_feat, _ = ensure_float32_precision(style_feat)
rgb_to_xyz_matrix = torch.tensor([
[0.4124564, 0.3575761, 0.1804375],
[0.2126729, 0.7151522, 0.0721750],
[0.0193339, 0.1191920, 0.9503041]
], dtype=torch.float32, device=device)
xyz_to_rgb_matrix = torch.tensor([
[ 3.2404542, -1.5371385, -0.4985314],
[-0.9692660, 1.8760108, 0.0415560],
[ 0.0556434, -0.2040259, 1.0572252]
], dtype=torch.float32, device=device)
epsilon = CIELAB_DELTA
kappa = CIELAB_KAPPA
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
# Convert to LAB color space
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
del content_feat
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
del style_feat, rgb_to_xyz_matrix
# Match chrominance channels (a*, b*) for accurate color transfer
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
# Handle luminance with weighted blending
if luminance_weight < 1.0:
# Partially match luminance for better overall color accuracy
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device)
# Blend: preserve some content L* for detail, adopt some style L* for color
result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight))
del matched_L
else:
# Fully preserve content luminance
result_L = content_lab[:, 0]
del content_lab, style_lab
# Reconstruct LAB with corrected channels
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
del result_L, matched_a, matched_b
# Convert back to RGB
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
del result_lab, xyz_to_rgb_matrix
# Convert back to [-1, 1] range (in-place)
result = result_rgb.mul_(2.0).sub_(1.0)
del result_rgb
result = result.to(original_dtype)
return result
def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor:
return wavelet_reconstruction(content_feat, style_feat)
def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor:
if content_feat.shape != style_feat.shape:
style_feat = safe_interpolate_operation(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False,
)
original_dtype = content_feat.dtype
content_feat = content_feat.float()
style_feat = style_feat.float()
b, c = content_feat.shape[:2]
content_flat = content_feat.reshape(b, c, -1)
style_flat = style_feat.reshape(b, c, -1)
content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1)
content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1)
style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
del content_flat, style_flat
normalized = (content_feat - content_mean) / content_std
del content_mean, content_std
result = normalized * style_std + style_mean
del normalized, style_mean, style_std
result = result.clamp_(-1.0, 1.0)
if result.dtype != original_dtype:
result = result.to(original_dtype)
return result

View File

@ -0,0 +1,82 @@
"""Named constants for the SeedVR2 integration, grouped by provenance.
Provenance prefixes:
- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline.
- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites
the upstream config/source path it was lifted from.
- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature /
ISO / CIE values; cite the standard.
The numz/AInVFX custom node is used only as a behavioral-parity benchmark; it is the
origin of none of these values and appears here nowhere.
"""
# --------------------------------------------------------------------------------------
# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment)
# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN)
# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070
# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT).
# --------------------------------------------------------------------------------------
SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB)
SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB
# --------------------------------------------------------------------------------------
# B. Fork heuristics (SEEDVR2 - this integration)
# --------------------------------------------------------------------------------------
SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim.
# (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.)
SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry.
SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case).
SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM.
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk.
SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels).
SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16).
SEEDVR2_DEFAULT_TEMPORAL_SIZE = 16 # default VAE temporal tile when unset.
# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing)
SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk.
SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path.
SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path.
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path.
# --------------------------------------------------------------------------------------
# C. ByteDance config / source (BYTEDANCE - cite myseedvr2/SeedVR)
# --------------------------------------------------------------------------------------
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm.
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift.
BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem).
BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem).
BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28.
BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28.
BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16).
BYTEDANCE_GN_CHUNKS_FP32 = 2 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp32).
BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD = 64 # attn_video_vae.py:308 (force .contiguous() above this b*t).
BYTEDANCE_BLOCK_OUT_CHANNELS = (128, 256, 512, 512) # s8_c16_t4_inflation_sd3.yaml:7-11.
BYTEDANCE_SLICING_SAMPLE_MIN = 4 # s8_c16_t4_inflation_sd3.yaml:22 (slicing_sample_min_size).
BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE = 4 # infer.py:230 (temporal_downsample_factor); the 4n+1 factor.
BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE = 8 # infer.py:231 (spatial_downsample_factor).
BYTEDANCE_SCHEDULE_T = 1000.0 # configs_3b/main.yaml:65 (schedule.T); timestep range.
BYTEDANCE_SPATIAL_DIVISOR = 16 # inference_seedvr2_3b.py:241 (DivisibleCrop((16,16))).
BYTEDANCE_720P_REF_AREA = 45 * 80 # dit_v2/window.py:32 (720p reference area for window scaling).
BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal window frames).
BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency).
BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim).
# Resolution-dependent timestep-shift linear fits: (x1, y1, x2, y2) for get_lin_function.
BYTEDANCE_IMG_SHIFT_FIT = (256 * 256, 1.0, 1024 * 1024, 3.2) # infer.py:242.
BYTEDANCE_VID_SHIFT_FIT = (256 * 256 * 37, 1.0, 1280 * 720 * 145, 5.0) # infer.py:243.
# --------------------------------------------------------------------------------------
# D. Published standards (cite the literature)
# --------------------------------------------------------------------------------------
ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864.
# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65).
CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta).
CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa).
D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1).
D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn.
WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR).
# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and
# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the
# exact existing coefficients move verbatim rather than being retyped here.

View File

@ -12,6 +12,16 @@ from torch.nn.modules.utils import _triple
from torch import nn
import math
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.seedvr.constants import (
BYTEDANCE_720P_REF_AREA,
BYTEDANCE_MAX_TEMPORAL_WINDOW,
BYTEDANCE_ROPE_MAX_FREQ,
BYTEDANCE_SINUSOIDAL_DIM,
ROPE_THETA,
SEEDVR2_7B_MLP_CHUNK,
SEEDVR2_7B_VID_DIM,
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS,
)
import comfy.model_management
import numbers
@ -203,10 +213,10 @@ def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int,
t, h, w = size
resized_nt, resized_nh, resized_nw = num_windows
#cal windows under 720p
scale = math.sqrt((45 * 80) / (h * w))
scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w))
resized_h, resized_w = round(h * scale), round(w * scale)
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
wt = ceil(min(t, 30) / resized_nt) # window size.
wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size.
nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size.
return [
(
@ -226,10 +236,10 @@ def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tup
t, h, w = size
resized_nt, resized_nh, resized_nw = num_windows
#cal windows under 720p
scale = math.sqrt((45 * 80) / (h * w))
scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w))
resized_h, resized_w = round(h * scale), round(w * scale)
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
wt = ceil(min(t, 30) / resized_nt) # window size.
wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size.
st, sh, sw = ( # shift size.
0.5 if wt < t else 0,
@ -412,7 +422,7 @@ class RotaryEmbeddingBase(nn.Module):
self.rope = RotaryEmbedding(
dim=dim // rope_dim,
freqs_for="pixel",
max_freq=256,
max_freq=BYTEDANCE_ROPE_MAX_FREQ,
)
freqs = self.rope.freqs
del self.rope.freqs
@ -486,7 +496,7 @@ class MMRotaryEmbeddingBase(RotaryEmbeddingBase):
self.rope = RotaryEmbedding(
dim=dim // rope_dim,
freqs_for="lang",
theta=10000,
theta=ROPE_THETA,
cache_if_possible=False,
)
freqs = self.rope.freqs
@ -547,14 +557,7 @@ def apply_rotary_emb(
return out.type(dtype)
def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor:
"""Convert lucidrains-interleaved freqs `[..., d]` (`[θ0, θ0, θ1, θ1, ...]`
from `RotaryEmbedding.forward`'s `repeat(freqs, '... n -> ... (n r)', r=2)`)
into flux-canonical `freqs_cis` of shape `[..., d/2, 2, 2]` with the
`cos/-sin/sin/cos` rotation matrix baked in. Output dtype is fp32 to
match `comfy/ldm/flux/math.py:rope` precision; `apply_rope1` consumes
the matrix layout via `freqs_cis[..., 0]` (column 0) and
`freqs_cis[..., 1]` (column 1) of the 2x2 rotation matrix.
"""
"""Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`."""
angles = freqs_interleaved[..., ::2].float()
cos = torch.cos(angles)
sin = torch.sin(angles)
@ -562,27 +565,18 @@ def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor:
return rearrange(out, "... d (i j) -> ... d i j", i=2, j=2)
_ROPE1_PARTIAL_CHUNK_TOKENS = 4096
SEEDVR2_7B_MLP_CHUNK = 8192
def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""Apply ``apply_rope1`` to the leading ``rot_d = 2 * freqs_cis.shape[-3]``
components of ``t``'s last dim, passing through the remaining dims
untouched in-place for inference tensors. Training tensors are cloned
before slice assignment to preserve autograd correctness. Mirrors the partial-rope contract of the legacy
``apply_rotary_emb`` wrapper at line 470 (``t_left``/``t_middle``/``t_right``
split). For SeedVR2-3B this matters because ``rope_dim=128`` integer-
divides into 3 axes as ``128 // 3 = 42`` per-axis, total ``42 * 3 = 126``;
head_dim is 128, so the trailing 2 dims are unrotated. The fast path
triggers when ``rot_d == t.shape[-1]`` (e.g. test rigs where dim is
chosen divisible by 6) and avoids the cat entirely.
"""Rotate the leading ``rot_d = 2 * freqs_cis.shape[-3]`` dims of ``t`` and pass the rest
through; in-place for inference, cloned for training (autograd). Mirrors the legacy
``apply_rotary_emb`` ``t_left``/``t_middle``/``t_right`` split: 3B ``rope_dim=128`` gives
``42*3 = 126`` rotated of head_dim 128 (trailing 2 unrotated). Fast path skips the cat when
``rot_d == t.shape[-1]``.
"""
out = t.clone() if t.requires_grad or comfy.model_management.in_training else t
rot_d = 2 * freqs_cis.shape[-3]
seq_len = out.shape[-2]
for start in range(0, seq_len, _ROPE1_PARTIAL_CHUNK_TOKENS):
end = min(start + _ROPE1_PARTIAL_CHUNK_TOKENS, seq_len)
for start in range(0, seq_len, SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS):
end = min(start + SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, seq_len)
freqs_chunk = freqs_cis[start:end]
if rot_d == out.shape[-1]:
out[..., start:end, :] = apply_rope1(out[..., start:end, :], freqs_chunk).to(out.dtype)
@ -1385,7 +1379,7 @@ class NaDiT(nn.Module):
operations = None,
**kwargs,
):
self._7b_version = vid_dim == 3072
self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM
self.dtype = dtype
factory_kwargs = {"device": device, "dtype": dtype}
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
@ -1427,7 +1421,7 @@ class NaDiT(nn.Module):
else nn.Identity()
)
self.emb_in = TimeEmbedding(
sinusoidal_dim=256,
sinusoidal_dim=BYTEDANCE_SINUSOIDAL_DIM,
hidden_dim=max(vid_dim, txt_dim),
output_dim=emb_dim,
device=device, dtype=dtype, operations=operations

View File

@ -10,6 +10,27 @@ from contextlib import contextmanager
from comfy.utils import ProgressBar
from comfy.ldm.seedvr.model import safe_pad_operation
from comfy.ldm.seedvr.constants import (
BYTEDANCE_BLOCK_OUT_CHANNELS,
BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD,
BYTEDANCE_GN_CHUNKS_FP16,
BYTEDANCE_GN_CHUNKS_FP32,
BYTEDANCE_LOGVAR_CLAMP_MAX,
BYTEDANCE_LOGVAR_CLAMP_MIN,
BYTEDANCE_SLICING_SAMPLE_MIN,
BYTEDANCE_VAE_CONV_MEM_GIB,
BYTEDANCE_VAE_NORM_MEM_GIB,
BYTEDANCE_VAE_SCALING_FACTOR,
BYTEDANCE_VAE_SHIFTING_FACTOR,
BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE,
BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE,
CIELAB_DELTA,
CIELAB_KAPPA,
D65_WHITE_X,
D65_WHITE_Z,
SEEDVR2_LATENT_CHANNELS,
WAVELET_DECOMP_LEVELS,
)
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.diffusionmodules.model import vae_attention
@ -70,8 +91,8 @@ def tiled_vae(
_, _, d, h, w = x.shape
sf_s = getattr(vae_model, "spatial_downsample_factor", 8)
sf_t = getattr(vae_model, "temporal_downsample_factor", 4)
sf_s = getattr(vae_model, "spatial_downsample_factor", BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE)
sf_t = getattr(vae_model, "temporal_downsample_factor", BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE)
if encode:
slicing_attr = "slicing_sample_min_size"
slicing_min_size = _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap)
@ -278,7 +299,7 @@ class DiagonalGaussianDistribution(object):
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.logvar = torch.clamp(self.logvar, BYTEDANCE_LOGVAR_CLAMP_MIN, BYTEDANCE_LOGVAR_CLAMP_MAX)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
@ -569,7 +590,7 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
x = rearrange(x, "b c t h w -> (b t) c h w")
memory_occupy = x.numel() * x.element_size() / 1024**3
if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit():
num_chunks = min(4 if x.element_size() == 2 else 2, norm_layer.num_groups)
num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups)
assert norm_layer.num_groups % num_chunks == 0
num_groups_per_chunk = norm_layer.num_groups // num_chunks
@ -1189,7 +1210,7 @@ class ResnetBlock3D(nn.Module):
hidden_states = self.nonlinearity(hidden_states)
if self.upsample is not None:
if hidden_states.shape[0] >= 64:
if hidden_states.shape[0] >= BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor, memory_state=memory_state)
@ -1780,333 +1801,6 @@ class Decoder3D(nn.Module):
return sample
def wavelet_blur(image: Tensor, radius):
max_safe_radius = max(1, min(image.shape[-2:]) // 8)
if radius > max_safe_radius:
radius = max_safe_radius
num_channels = image.shape[1]
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate')
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels: int = 5):
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq.add_(image).sub_(low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
if content_feat.shape != style_feat.shape:
# Resize style to match content spatial dimensions
if len(content_feat.shape) >= 3:
# safe_interpolate_operation handles FP16 conversion automatically
style_feat = safe_interpolate_operation(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False
)
# Decompose both features into frequency components
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq # Free memory immediately
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq # Free memory immediately
if content_high_freq.shape != style_low_freq.shape:
style_low_freq = safe_interpolate_operation(
style_low_freq,
size=content_high_freq.shape[-2:],
mode='bilinear',
align_corners=False
)
content_high_freq.add_(style_low_freq)
return content_high_freq.clamp_(-1.0, 1.0)
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
original_shape = source.shape
# Flatten
source_flat = source.flatten()
reference_flat = reference.flatten()
# Sort both arrays
source_sorted, source_indices = torch.sort(source_flat)
reference_sorted, _ = torch.sort(reference_flat)
del reference_flat
# Quantile mapping
n_source = len(source_sorted)
n_reference = len(reference_sorted)
if n_source == n_reference:
matched_sorted = reference_sorted
else:
# Interpolate reference to match source quantiles
source_quantiles = torch.linspace(0, 1, n_source, device=device)
ref_indices = (source_quantiles * (n_reference - 1)).long()
ref_indices.clamp_(0, n_reference - 1)
matched_sorted = reference_sorted[ref_indices]
del source_quantiles, ref_indices, reference_sorted
del source_sorted, source_flat
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
inverse_indices = torch.argsort(source_indices)
del source_indices
matched_flat = matched_sorted[inverse_indices]
del matched_sorted, inverse_indices
return matched_flat.reshape(original_shape)
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
"""Convert batch of CIELAB images to RGB color space."""
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
# LAB to XYZ
fy = (L + 16.0) / 116.0
fx = a.div(500.0).add_(fy)
fz = fy - b / 200.0
del L, a, b
# XYZ transformation
x = torch.where(
fx > epsilon,
torch.pow(fx, 3.0),
fx.mul(116.0).sub_(16.0).div_(kappa)
)
y = torch.where(
fy > epsilon,
torch.pow(fy, 3.0),
fy.mul(116.0).sub_(16.0).div_(kappa)
)
z = torch.where(
fz > epsilon,
torch.pow(fz, 3.0),
fz.mul(116.0).sub_(16.0).div_(kappa)
)
del fx, fy, fz
# Apply D65 white point (in-place)
x.mul_(0.95047)
# y *= 1.00000 # (no-op, skip)
z.mul_(1.08883)
xyz = torch.stack([x, y, z], dim=1)
del x, y, z
# Matrix multiplication: XYZ -> RGB
B, C, H, W = xyz.shape
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
del xyz
# Ensure dtype consistency for matrix multiplication
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
del xyz_flat
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
del rgb_linear_flat
# Apply inverse gamma correction (delinearize)
mask = rgb_linear > 0.0031308
rgb = torch.where(
mask,
torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055),
rgb_linear * 12.92
)
del mask, rgb_linear
return torch.clamp(rgb, 0.0, 1.0)
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
"""Convert batch of RGB images to CIELAB color space using D65 illuminant."""
# Apply sRGB gamma correction (linearize)
mask = rgb > 0.04045
rgb_linear = torch.where(
mask,
torch.pow((rgb + 0.055) / 1.055, 2.4),
rgb / 12.92
)
del mask
# Matrix multiplication: RGB -> XYZ
B, C, H, W = rgb_linear.shape
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
del rgb_linear
# Ensure dtype consistency for matrix multiplication
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
xyz_flat = torch.matmul(rgb_flat, matrix.T)
del rgb_flat
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
del xyz_flat
# Normalize by D65 white point (in-place)
xyz[:, 0].div_(0.95047) # X
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
xyz[:, 2].div_(1.08883) # Z
# XYZ to LAB transformation
epsilon_cubed = epsilon ** 3
mask = xyz > epsilon_cubed
f_xyz = torch.where(
mask,
torch.pow(xyz, 1.0 / 3.0),
xyz.mul(kappa).add_(16.0).div_(116.0)
)
del xyz, mask
# Extract channels and compute LAB
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
del f_xyz
return torch.stack([L, a, b], dim=1)
def lab_color_transfer(
content_feat: Tensor,
style_feat: Tensor,
luminance_weight: float = 0.8
) -> Tensor:
content_feat = wavelet_reconstruction(content_feat, style_feat)
if content_feat.shape != style_feat.shape:
style_feat = safe_interpolate_operation(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False
)
device = content_feat.device
def ensure_float32_precision(c):
orig_dtype = c.dtype
c = c.float()
return c, orig_dtype
content_feat, original_dtype = ensure_float32_precision(content_feat)
style_feat, _ = ensure_float32_precision(style_feat)
rgb_to_xyz_matrix = torch.tensor([
[0.4124564, 0.3575761, 0.1804375],
[0.2126729, 0.7151522, 0.0721750],
[0.0193339, 0.1191920, 0.9503041]
], dtype=torch.float32, device=device)
xyz_to_rgb_matrix = torch.tensor([
[ 3.2404542, -1.5371385, -0.4985314],
[-0.9692660, 1.8760108, 0.0415560],
[ 0.0556434, -0.2040259, 1.0572252]
], dtype=torch.float32, device=device)
epsilon = 6.0 / 29.0
kappa = (29.0 / 3.0) ** 3
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
# Convert to LAB color space
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
del content_feat
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
del style_feat, rgb_to_xyz_matrix
# Match chrominance channels (a*, b*) for accurate color transfer
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
# Handle luminance with weighted blending
if luminance_weight < 1.0:
# Partially match luminance for better overall color accuracy
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device)
# Blend: preserve some content L* for detail, adopt some style L* for color
result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight))
del matched_L
else:
# Fully preserve content luminance
result_L = content_lab[:, 0]
del content_lab, style_lab
# Reconstruct LAB with corrected channels
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
del result_L, matched_a, matched_b
# Convert back to RGB
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
del result_lab, xyz_to_rgb_matrix
# Convert back to [-1, 1] range (in-place)
result = result_rgb.mul_(2.0).sub_(1.0)
del result_rgb
result = result.to(original_dtype)
return result
def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor:
return wavelet_reconstruction(content_feat, style_feat)
def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor:
if content_feat.shape != style_feat.shape:
style_feat = safe_interpolate_operation(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False,
)
original_dtype = content_feat.dtype
content_feat = content_feat.float()
style_feat = style_feat.float()
b, c = content_feat.shape[:2]
content_flat = content_feat.reshape(b, c, -1)
style_flat = style_feat.reshape(b, c, -1)
content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1)
content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1)
style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
del content_flat, style_flat
normalized = (content_feat - content_mean) / content_std
del content_mean, content_std
result = normalized * style_std + style_mean
del normalized, style_mean, style_std
result = result.clamp_(-1.0, 1.0)
if result.dtype != original_dtype:
result = result.to(original_dtype)
return result
class VideoAutoencoderKL(nn.Module):
def __init__(
self,
@ -2114,7 +1808,7 @@ class VideoAutoencoderKL(nn.Module):
out_channels: int = 3,
layers_per_block: int = 2,
act_fn: str = "silu",
latent_channels: int = 16,
latent_channels: int = SEEDVR2_LATENT_CHANNELS,
norm_num_groups: int = 32,
attention: bool = True,
temporal_scale_num: int = 2,
@ -2124,14 +1818,14 @@ class VideoAutoencoderKL(nn.Module):
time_receptive_field: _receptive_field_t = "full",
use_quant_conv: bool = False,
use_post_quant_conv: bool = False,
slicing_sample_min_size = 4,
slicing_sample_min_size = BYTEDANCE_SLICING_SAMPLE_MIN,
*args,
**kwargs,
):
self.slicing_sample_min_size = slicing_sample_min_size
self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num)
extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None
block_out_channels = (128, 256, 512, 512)
block_out_channels = BYTEDANCE_BLOCK_OUT_CHANNELS
down_block_types = ("DownEncoderBlock3D",) * 4
up_block_types = ("UpDecoderBlock3D",) * 4
super().__init__()
@ -2329,7 +2023,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
self.freeze_encoder = freeze_encoder
self.enable_tiling = False
super().__init__(*args, **kwargs)
self.set_memory_limit(0.5, 0.5)
self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB)
def forward(self, x: torch.FloatTensor):
with torch.no_grad() if self.freeze_encoder else nullcontext():
@ -2377,8 +2071,8 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
"4-D collapsed (B, 16*T, H, W) or 5-D (B, 16, T, H, W); "
f"got shape {tuple(z.shape)}."
)
scale = 0.9152
shift = 0
scale = BYTEDANCE_VAE_SCALING_FACTOR
shift = BYTEDANCE_VAE_SHIFTING_FACTOR
latent = latent / scale + shift
self.device = latent.device

View File

@ -9,11 +9,24 @@ import gc
import comfy.model_management
import comfy.sample
import comfy.samplers
from comfy.ldm.seedvr.vae import (
from comfy.ldm.seedvr.color_fix import (
adain_color_transfer,
lab_color_transfer,
wavelet_color_transfer,
)
from comfy.ldm.seedvr.constants import (
BYTEDANCE_IMG_SHIFT_FIT,
BYTEDANCE_SCHEDULE_T,
BYTEDANCE_VID_SHIFT_FIT,
SEEDVR2_ADAIN_SCALE_MULTIPLIER,
SEEDVR2_COLOR_MEM_HEADROOM,
SEEDVR2_COND_CHANNELS,
SEEDVR2_DTYPE_BYTES_FLOOR,
SEEDVR2_LAB_SCALE_MULTIPLIER,
SEEDVR2_LATENT_CHANNELS,
SEEDVR2_OOM_BACKOFF_DIVISOR,
SEEDVR2_WAVELET_SCALE_MULTIPLIER,
)
from torchvision.transforms import functional as TVF
from torchvision.transforms import Lambda
@ -23,10 +36,6 @@ from torchvision.transforms.functional import InterpolationMode
_SEEDVR2_INVALID_MODEL_MSG_PREFIX = (
"SeedVR2Conditioning: model object does not match expected SeedVR2 structure"
)
LAB_SCALE_MULTIPLIER = 13
WAVELET_SCALE_MULTIPLIER = 10
ADAIN_SCALE_MULTIPLIER = 6
COLOR_CORRECTION_MEMORY_HEADROOM = 0.75
# Private sentinel for getattr default: distinguishes "attribute missing"
# from "attribute present but None" so the failure message is accurate.
@ -57,17 +66,7 @@ def _seedvr2_auto_chunk_attempts(t_latent, t_pixel, frames_per_chunk):
def _resolve_seedvr2_diffusion_model(model):
"""Resolve the inner SeedVR2 diffusion-model module from a ComfyUI model
patcher object. Fails loud with a ``RuntimeError`` whose message begins
with ``_SEEDVR2_INVALID_MODEL_MSG_PREFIX`` when the expected wrapper
shape (``model.model.diffusion_model``) is absent.
Distinguishes four failure modes via the ``_ATTR_MISSING`` sentinel:
``model.model`` missing, ``model.model is None``,
``model.model.diffusion_model`` missing, ``model.model.diffusion_model
is None``. Each mode produces an accurate error message rather than
conflating "attribute missing" with "attribute is None".
"""
"""Resolve ``model.model.diffusion_model``, failing loud via the ``_ATTR_MISSING`` sentinel so each of the four modes (model/diffusion_model missing vs None) gives an accurate message."""
inner = getattr(model, "model", _ATTR_MISSING)
if inner is _ATTR_MISSING:
raise RuntimeError(
@ -94,15 +93,7 @@ def _resolve_seedvr2_diffusion_model(model):
def _apply_rope_freqs_float32_cast(diffusion_model):
"""Cast every nested module's ``rope.freqs`` parameter data to ``float32``
when it is not already in float32. Idempotency is per-tensor by dtype
check, NOT a per-instance sentinel attribute a sentinel would survive
Comfy's dynamic model unload/reload cycle while ``rope.freqs`` itself
is restored from the archived dtype, leaving RoPE running in fp16/bf16
on subsequent calls. The dtype check makes the cast self-correcting
against weight-restore lifecycle events. Iteration cost is one walk of
the diffusion-model module tree per ``execute()`` call (microseconds).
"""
"""Cast every module's ``rope.freqs`` to float32; the per-tensor dtype check (not a sentinel attr) self-corrects across Comfy's unload/reload, which would otherwise restore the archived fp16/bf16 dtype."""
for module in diffusion_model.modules():
if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'):
if module.rope.freqs.data.dtype != torch.float32:
@ -140,8 +131,8 @@ def timestep_transform(timesteps, latents_shapes):
b = y1 - m * x1
return lambda x: m * x + b
img_shift_fn = get_lin_function(x1=256 * 256, y1=1.0, x2=1024 * 1024, y2=3.2)
vid_shift_fn = get_lin_function(x1=256 * 256 * 37, y1=1.0, x2=1280 * 720 * 145, y2=5.0)
img_shift_fn = get_lin_function(*BYTEDANCE_IMG_SHIFT_FIT)
vid_shift_fn = get_lin_function(*BYTEDANCE_VID_SHIFT_FIT)
shift = torch.where(
frames > 1,
vid_shift_fn(heights * widths * frames),
@ -149,7 +140,7 @@ def timestep_transform(timesteps, latents_shapes):
).to(timesteps.device)
# Shift timesteps.
T = 1000.0
T = BYTEDANCE_SCHEDULE_T
timesteps = timesteps / T
timesteps = shift * timesteps / (1 + (shift - 1) * timesteps)
timesteps = timesteps * T
@ -157,7 +148,7 @@ def timestep_transform(timesteps, latents_shapes):
def inter(x_0, x_T, t):
t = expand_dims(t, x_0.ndim)
T = 1000.0
T = BYTEDANCE_SCHEDULE_T
B = lambda t: t / T
A = lambda t: 1 - (t / T)
return A(t) * x_0 + B(t) * x_T
@ -235,6 +226,8 @@ def _seedvr2_resize_and_pad(images, upscaled_shorter_edge, node_name):
f"got {upscaled_shorter_edge}."
)
original_image = images
if images.shape[-1] > 3:
images = images[..., :3]
if images.dim() == 4:
# Comfy video components arrive as a 4-D IMAGE frame sequence:
# (frames, H, W, C). SeedVR2 consumes that as one video.
@ -268,10 +261,12 @@ class SeedVR2Resize(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SeedVR2Resize",
category="image/video",
display_name="Resize Image for SeedVR2",
category="image/upscaling",
description="Resize an image to a SeedVR2-compatible size by a multiplier.",
inputs=[
io.Image.Input("images"),
io.Float.Input("multiplier", default=4.0, min=0.01),
io.Image.Input("images", tooltip="The image(s) to resize."),
io.Float.Input("multiplier", default=4.0, min=0.01, tooltip="Upscale factor applied to the shorter edge."),
],
outputs=[
io.Image.Output("input_pixels"),
@ -304,10 +299,12 @@ class SeedVR2ResizeAdvanced(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SeedVR2ResizeAdvanced",
category="image/video",
display_name="Resize Image for SeedVR2 (Advanced)",
category="image/upscaling",
description="Resize an image to an exact shorter-edge size for SeedVR2.",
inputs=[
io.Image.Input("images"),
io.Int.Input("shorter_edge", default=1280, min=2),
io.Image.Input("images", tooltip="The image(s) to resize."),
io.Int.Input("shorter_edge", default=1280, min=2, tooltip="Target length of the shorter edge, in pixels."),
],
outputs=[
io.Image.Output("input_pixels"),
@ -323,17 +320,30 @@ class SeedVR2ResizeAdvanced(io.ComfyNode):
)
def _edge_guided_alpha_upscale(alpha, out_h, out_w):
a = alpha.float()
extreme_fraction = ((a < 0.1) | (a > 0.9)).float().mean()
if extreme_fraction > 0.9:
up = torch.nn.functional.interpolate(a, size=(out_h, out_w), mode="bilinear", align_corners=False, antialias=True)
up = torch.clamp((up - 0.5) * 4.0 + 0.5, 0.0, 1.0)
else:
up = torch.nn.functional.interpolate(a, size=(out_h, out_w), mode="bicubic", align_corners=False, antialias=True).clamp(0.0, 1.0)
return up
class SeedVR2PostProcessing(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SeedVR2PostProcessing",
category="image/video",
display_name="Post-Process SeedVR2 Output",
category="image/upscaling",
description="Align the upscaled output to the original's geometry and optionally color-correct it against the original.",
inputs=[
io.Image.Input("decoded"),
io.Image.Input("original_image"),
io.Int.Input("upscaled_shorter_edge", min=2, force_input=True),
io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab"),
io.Image.Input("decoded", tooltip="The decoded upscaled image to color-correct."),
io.Image.Input("original_image", tooltip="The original image used as the color reference."),
io.Int.Input("upscaled_shorter_edge", min=2, force_input=True, tooltip="Shorter-edge size from the resize node."),
io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab", tooltip="How to match the output's color to the original. lab: transfer color in CIELAB space, preserving detail (most faithful). wavelet: transfer low-frequency color, keeping upscaled high-frequency detail. adain: match per-channel mean/std (fastest, global tint). none: skip color transfer (geometry alignment only)."),
],
outputs=[io.Image.Output()],
)
@ -341,6 +351,10 @@ class SeedVR2PostProcessing(io.ComfyNode):
@classmethod
def execute(cls, decoded, original_image, upscaled_shorter_edge, color_correction_method):
cls._validate_upscaled_shorter_edge(upscaled_shorter_edge)
alpha_input = None
if original_image.shape[-1] == 4:
alpha_input = original_image[..., 3:4]
original_image = original_image[..., :3]
decoded_5d, decoded_was_4d = cls._as_bthwc(decoded)
original_5d, _ = cls._as_bthwc(original_image)
decoded_5d = cls._restore_reference_batch_time(decoded_5d, original_5d)
@ -374,6 +388,13 @@ class SeedVR2PostProcessing(io.ComfyNode):
else:
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
if alpha_input is not None:
ab, at = output.shape[0], output.shape[1]
alpha_5d, _ = cls._as_bthwc(alpha_input)
alpha_flat = rearrange(alpha_5d[:ab, :at], "b t h w c -> (b t) c h w")
alpha_up = _edge_guided_alpha_upscale(alpha_flat, output.shape[2], output.shape[3])
alpha_up = rearrange(alpha_up, "(b t) c h w -> b t h w c", b=ab, t=at)
output = torch.cat([output, alpha_up.to(dtype=output.dtype, device=output.device)], dim=-1)
h2 = output.shape[-3] - (output.shape[-3] % 2)
w2 = output.shape[-2] - (output.shape[-2] % 2)
output = output[:, :, :h2, :w2, :]
@ -472,7 +493,7 @@ class SeedVR2PostProcessing(io.ComfyNode):
"SeedVR2PostProcessing: color correction OOM at one frame; "
f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}."
) from e
next_chunk_size = max(1, chunk_size // 2)
next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR)
comfy.model_management.soft_empty_cache()
chunk_size = next_chunk_size
@ -510,23 +531,23 @@ class SeedVR2PostProcessing(io.ComfyNode):
multiplier = cls._color_correction_memory_multiplier(color_correction_method)
frames = decoded_flat.shape[0]
_, channels, height, width = decoded_flat.shape
dtype_bytes = max(decoded_flat.element_size(), 4)
dtype_bytes = max(decoded_flat.element_size(), SEEDVR2_DTYPE_BYTES_FLOOR)
bytes_per_frame = height * width * channels * dtype_bytes * multiplier
if bytes_per_frame <= 0:
return frames
color_device = comfy.model_management.vae_device()
free_memory = comfy.model_management.get_free_memory(color_device)
chunk_size = int((free_memory * COLOR_CORRECTION_MEMORY_HEADROOM) // bytes_per_frame)
chunk_size = int((free_memory * SEEDVR2_COLOR_MEM_HEADROOM) // bytes_per_frame)
return max(1, min(frames, chunk_size))
@staticmethod
def _color_correction_memory_multiplier(color_correction_method):
if color_correction_method == "lab":
return LAB_SCALE_MULTIPLIER
return SEEDVR2_LAB_SCALE_MULTIPLIER
if color_correction_method == "wavelet":
return WAVELET_SCALE_MULTIPLIER
return SEEDVR2_WAVELET_SCALE_MULTIPLIER
if color_correction_method == "adain":
return ADAIN_SCALE_MULTIPLIER
return SEEDVR2_ADAIN_SCALE_MULTIPLIER
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
@staticmethod
@ -549,10 +570,12 @@ class SeedVR2Conditioning(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SeedVR2Conditioning",
category="image/video",
display_name="Apply SeedVR2 Conditioning",
category="conditioning",
description="Build SeedVR2 positive/negative conditioning from a VAE latent.",
inputs=[
io.Model.Input("model"),
io.Latent.Input("vae_conditioning", display_name="LATENT"),
io.Model.Input("model", tooltip="The SeedVR2 model."),
io.Latent.Input("vae_conditioning", tooltip="The VAE-encoded latent to condition on."),
],
outputs=[
io.Model.Output(display_name = "model"),
@ -571,10 +594,10 @@ class SeedVR2Conditioning(io.ComfyNode):
"SeedVR2Conditioning expects a 5-D VAE latent in Comfy "
f"channel-first layout; got shape {tuple(vae_conditioning.shape)}."
)
if vae_conditioning.shape[-1] == _SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != _SEEDVR2_LATENT_CHANNELS:
if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS:
raise ValueError(
"SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy "
f"channel-first layout (B, {_SEEDVR2_LATENT_CHANNELS}, T, H, W); "
f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); "
f"got channel-last shape {tuple(vae_conditioning.shape)}."
)
vae_conditioning = vae_conditioning.movedim(1, -1).contiguous()
@ -622,27 +645,9 @@ class SeedVR2Conditioning(io.ComfyNode):
return io.NodeOutput(model_patcher, positive, negative, {"samples": latent})
# SeedVR2 latent / conditioning channel constants. The SeedVR2 conditioning
# stage collapses ``(B, C, T, H, W) -> (B, C*T, H, W)`` for both the latent
# (C=16) and the per-frame condition tensor (C=17 = 16 latent + 1 mask), as
# required by ``NaDiT.forward`` which un-collapses via
# ``view(B, 16, -1, H, W)`` and ``view(B, 17, -1, H, W)`` respectively.
_SEEDVR2_LATENT_CHANNELS = 16
_SEEDVR2_CONDITION_CHANNELS = 17
def _slice_collapsed_4d_along_t(tensor_4d: torch.Tensor, t_start: int,
t_end: int, channels: int) -> torch.Tensor:
"""Slice a SeedVR2-style collapsed 4D tensor ``(B, channels*T, H, W)``
along the latent T axis, returning ``(B, channels*(t_end - t_start), H, W)``.
Reshape -> slice -> ``.contiguous()`` -> re-collapse. ``reshape`` is
used for the un-collapse so non-contiguous incoming tensors from
cropping or slicing nodes are accepted. The
``.contiguous()`` is mandatory: T-axis slicing of a 5D tensor produces a
non-contiguous view, and the subsequent re-collapse requires contiguous
storage.
"""
"""Slice collapsed ``(B, channels*T, H, W)`` along latent T: reshape (accepts non-contiguous inputs), slice, ``.contiguous()`` (T-slice of 5D is a non-contiguous view; re-collapse needs contiguous), re-collapse."""
B, CT, H, W = tensor_4d.shape
if CT % channels != 0:
raise ValueError(
@ -661,19 +666,7 @@ def _slice_collapsed_4d_along_t(tensor_4d: torch.Tensor, t_start: int,
def _slice_seedvr2_cond_along_t(cond_list, t_start: int, t_end: int):
"""Build a new SeedVR2 conditioning list with the per-frame ``condition``
tensor sliced along the latent T axis.
SeedVR2 conditioning entries have the shape
``[text_cond_tensor, options_dict]`` where ``options_dict["condition"]``
is a 4D collapsed ``(B, 17*T, H, W)`` tensor; the text tensor itself has
no temporal axis and is passed through unchanged. Other keys in the
options dict (controlnets, etc.) are also passed through unchanged. If
an entry has no ``"condition"`` key, the entry is forwarded verbatim.
A new list of ``[text_cond, new_options_dict]`` pairs is returned; the
original ``cond_list`` and its options dicts are not mutated.
"""
"""Return a new conditioning list with each entry's ``options["condition"]`` (collapsed ``(B, 17*T, H, W)``) sliced along latent T; text tensors, other option keys, and condition-less entries pass through unchanged and inputs are not mutated."""
new_list = []
for entry in cond_list:
text_cond, options = entry[0], entry[1]
@ -683,7 +676,7 @@ def _slice_seedvr2_cond_along_t(cond_list, t_start: int, t_end: int):
new_options = options.copy()
new_options["condition"] = _slice_collapsed_4d_along_t(
new_options["condition"], t_start, t_end,
_SEEDVR2_CONDITION_CHANNELS,
SEEDVR2_COND_CHANNELS,
)
new_list.append([text_cond, new_options])
return new_list
@ -693,24 +686,16 @@ def _slice_seedvr2_noise_mask_along_t(noise_mask: torch.Tensor,
samples_4d: torch.Tensor,
t_start: int,
t_end: int):
"""Slice collapsed SeedVR2 masks and preserve standard masks.
``SetLatentNoiseMask`` produces ``(B, 1, H, W)`` masks that KSampler
expands to the latent shape. Only masks already expanded to the full
collapsed ``(B, 16*T, H, W)`` shape need temporal slicing here.
"""
"""Slice only masks already expanded to collapsed ``(B, 16*T, H, W)``; pass standard ``(B, 1, H, W)`` ``SetLatentNoiseMask`` outputs through for KSampler to expand."""
if noise_mask.ndim == samples_4d.ndim and noise_mask.shape[1] == samples_4d.shape[1]:
return _slice_collapsed_4d_along_t(
noise_mask, t_start, t_end, _SEEDVR2_LATENT_CHANNELS,
noise_mask, t_start, t_end, SEEDVR2_LATENT_CHANNELS,
)
return noise_mask
def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor:
"""Concatenate a list of SeedVR2-style collapsed 4D tensors
``(B, channels*T_i, H, W)`` along the latent T axis. Each chunk is
un-collapsed to 5D, concatenated on ``dim=2``, then re-collapsed to 4D.
"""
"""Concatenate collapsed ``(B, channels*T_i, H, W)`` chunks along latent T: un-collapse to 5D, cat on ``dim=2``, re-collapse to 4D."""
if len(chunks_4d) == 0:
raise ValueError("_concat_chunks_along_t: empty chunk list.")
fives = []
@ -729,19 +714,10 @@ def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor:
def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor:
"""Build a 1D crossfade weight tensor of length ``overlap`` for the
*previous* chunk's contribution; the current chunk's weight is
``1 - w_prev``.
Mirrors the numz ``blend_overlapping_frames`` shape
(AInVFX/numz fork ``src/core/generation_utils.py``,
``blend_overlapping_frames``): a Hann window with a ``[1/3, 2/3]``
dead-band when ``overlap >= 3``, and a plain linear ramp when
``overlap < 3`` (the dead-band would collapse the transition for
very small overlap counts). The numz reference operates on
pixel-space tensors ``[overlap, H, W, C]``; this 1D form is
reshaped by the caller to broadcast across the latent's
``(B, C, T_overlap, H, W)`` axes.
"""1D length-``overlap`` crossfade weights for the previous chunk (current = ``1 - w_prev``):
Hann window with a ``[1/3, 2/3]`` dead-band for ``overlap >= 3``, linear ramp for ``overlap < 3``
(dead-band would collapse a tiny transition). Window shape matched to numz ``blend_overlapping_frames``
for parity (reference, not source); caller broadcasts across ``(B, C, T_overlap, H, W)``.
"""
if overlap < 1:
raise ValueError(
@ -758,14 +734,7 @@ def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor:
def _blend_overlap_region(prev_tail_5d: torch.Tensor,
cur_head_5d: torch.Tensor) -> torch.Tensor:
"""Blend two 5D ``(B, C, T_overlap, H, W)`` tensors of equal shape
using a 1D Hann/linear ramp along the T axis. ``prev_tail_5d``
receives the descending weight; ``cur_head_5d`` receives
``1 - w_prev``.
The caller is responsible for ensuring both inputs have identical
shape and dtype/device.
"""
"""Blend two equal-shape 5D ``(B, C, T_overlap, H, W)`` tensors with a 1D Hann/linear T-ramp: ``prev_tail_5d`` takes the descending weight, ``cur_head_5d`` takes ``1 - w_prev`` (caller ensures matching shape/dtype/device)."""
if prev_tail_5d.shape != cur_head_5d.shape:
raise ValueError(
f"_blend_overlap_region: shape mismatch "
@ -784,20 +753,7 @@ def _blend_overlap_region(prev_tail_5d: torch.Tensor,
def _concat_chunks_with_overlap_blend(chunk_specs, channels: int,
overlap_latent: int) -> torch.Tensor:
"""Concatenate temporally-overlapping chunks back into a single
collapsed 4D tensor, blending overlap regions with a Hann/linear
crossfade.
``chunk_specs`` is a list of ``(t_start, t_end, chunk_4d)`` tuples
in source-latent T coordinates. ``overlap_latent == 0`` is a fast
path that delegates to plain concatenation (and produces output
bit-identical to ``_concat_chunks_along_t`` of the same chunks).
The blend at each pair of adjacent chunks acts on the actual
overlap region width ``min(prev_end - cur_start, current chunk
length)``, which may be smaller than ``overlap_latent`` when the
final chunk is a runt shorter than the configured overlap.
"""
"""Concatenate overlapping ``(t_start, t_end, chunk_4d)`` specs (source-latent T coords) into one collapsed 4D tensor, Hann/linear-blending overlaps; ``overlap_latent == 0`` fast-paths to plain concat (bit-identical to ``_concat_chunks_along_t``). Each blend uses the actual width ``min(prev_end - cur_start, chunk length)``, smaller than ``overlap_latent`` for a runt final chunk."""
if len(chunk_specs) == 0:
raise ValueError("_concat_chunks_with_overlap_blend: empty chunk list.")
if overlap_latent < 0:
@ -877,12 +833,7 @@ def _run_standard_sample(model, seed: int, steps: int, cfg: float,
sampler_name: str, scheduler: str,
positive, negative, latent_image: dict,
denoise: float) -> dict:
"""Single-shot delegation that mirrors the standard ``common_ksampler``
flow (``nodes.py:common_ksampler``): generate noise from seed, run
``comfy.sample.sample``, return a latent dict. Used by the
ProgressiveSampler short-circuit when the full sequence fits in one
chunk so chunking introduces no overhead for small videos.
"""
"""Single-shot mirror of ``nodes.py:common_ksampler`` (seed -> noise, ``comfy.sample.sample``, latent dict); used by the ProgressiveSampler short-circuit when the whole sequence fits one chunk."""
samples_in = latent_image["samples"]
samples_in = comfy.sample.fix_empty_latent_channels(
model, samples_in, latent_image.get("downscale_ratio_spacial", None),
@ -929,43 +880,45 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
def define_schema(cls):
return io.Schema(
node_id="SeedVR2ProgressiveSampler",
display_name="Sample SeedVR2 (Progressive)",
category="sampling",
description="Sample a SeedVR2 latent in sequential temporal chunks to allow longer videos to fit into VRAM via frame blending the resulting upscaled latents.",
inputs=[
io.Model.Input("model"),
io.Model.Input("model", tooltip="The model used for denoising the input latent."),
io.Int.Input("seed", default=0, min=0,
max=0xffffffffffffffff,
control_after_generate=True),
io.Int.Input("steps", default=20, min=1, max=10000),
control_after_generate=True,
tooltip="The random seed used for creating the noise."),
io.Int.Input("steps", default=20, min=1, max=10000,
tooltip="The number of steps used in the denoising process."),
io.Float.Input("cfg", default=1.0, min=0.0, max=100.0,
step=0.1, round=0.01),
step=0.1, round=0.01,
tooltip="The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."),
io.Combo.Input("sampler_name",
options=comfy.samplers.SAMPLER_NAMES),
options=comfy.samplers.SAMPLER_NAMES,
tooltip="The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."),
io.Combo.Input("scheduler",
options=comfy.samplers.SCHEDULER_NAMES),
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Latent.Input("latent_image"),
options=comfy.samplers.SCHEDULER_NAMES,
tooltip="The scheduler controls how noise is gradually removed to form the image."),
io.Conditioning.Input("positive",
tooltip="The conditioning describing the attributes you want to include in the image."),
io.Conditioning.Input("negative",
tooltip="The conditioning describing the attributes you want to exclude from the image."),
io.Latent.Input("latent_image",
tooltip="The latent image to denoise."),
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0,
step=0.01),
step=0.01,
tooltip="The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."),
io.Int.Input("frames_per_chunk", default=21, min=1,
max=16384, step=4),
max=16384, step=4,
tooltip="Pixel frames per temporal chunk (4n+1: 1, 5, 9, 13, ...)."),
io.Int.Input("temporal_overlap", default=0, min=0,
max=16384,
tooltip="Latent-frame overlap between "
"adjacent chunks; blended with a "
"Hann window (linear for overlap "
"< 3). 0 = no blend, pure concat. "
"Values >= the chunk's latent-frame "
"length use the maximum valid "
"overlap; 1 latent frame corresponds "
"to ~4 pixel frames."),
tooltip="Latent frames blended between adjacent chunks to hide the seam; 0 = no blend."),
io.Combo.Input("chunking_mode",
options=["manual", "auto"],
default="manual",
tooltip="manual = use frames_per_chunk "
"exactly; auto = retry only real OOM "
"failures with progressively smaller "
"temporal chunks."),
tooltip="manual = use frames_per_chunk exactly; auto = shrink the chunk until it fits in VRAM."),
],
outputs=[io.Latent.Output()],
)
@ -999,14 +952,14 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
f"(B, 16*T, H, W); got shape {tuple(samples_4d.shape)}."
)
B, CT, H, W = samples_4d.shape
if CT % _SEEDVR2_LATENT_CHANNELS != 0:
if CT % SEEDVR2_LATENT_CHANNELS != 0:
raise ValueError(
f"SeedVR2ProgressiveSampler: collapsed channel dim {CT} is "
f"not divisible by SeedVR2 latent channels "
f"{_SEEDVR2_LATENT_CHANNELS}; latent does not appear to be "
f"{SEEDVR2_LATENT_CHANNELS}; latent does not appear to be "
f"SeedVR2-shaped."
)
T_latent = CT // _SEEDVR2_LATENT_CHANNELS
T_latent = CT // SEEDVR2_LATENT_CHANNELS
T_pixel = 4 * (T_latent - 1) + 1
if chunking_mode not in ("manual", "auto"):
@ -1106,11 +1059,11 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
def _sample_one_chunk(chunk_start, chunk_end):
samples_chunk = _slice_collapsed_4d_along_t(
samples_4d, chunk_start, chunk_end,
_SEEDVR2_LATENT_CHANNELS,
SEEDVR2_LATENT_CHANNELS,
)
noise_chunk = _slice_collapsed_4d_along_t(
noise_full, chunk_start, chunk_end,
_SEEDVR2_LATENT_CHANNELS,
SEEDVR2_LATENT_CHANNELS,
)
positive_chunk = _slice_seedvr2_cond_along_t(
positive, chunk_start, chunk_end,
@ -1140,7 +1093,7 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
chunk_specs.append((chunk_start, chunk_end, chunk_samples))
final = _concat_chunks_with_overlap_blend(
chunk_specs, _SEEDVR2_LATENT_CHANNELS, temporal_overlap,
chunk_specs, SEEDVR2_LATENT_CHANNELS, temporal_overlap,
)
out = latent_image.copy()