mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-26 17:59:54 +08:00
Refine SeedVR2 alpha channel handling and node UX
This commit is contained in:
parent
7431bef672
commit
22078c799b
340
comfy/ldm/seedvr/color_fix.py
Normal file
340
comfy/ldm/seedvr/color_fix.py
Normal file
@ -0,0 +1,340 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||
from comfy.ldm.seedvr.vae import safe_interpolate_operation
|
||||
from comfy.ldm.seedvr.constants import (
|
||||
CIELAB_DELTA,
|
||||
CIELAB_KAPPA,
|
||||
D65_WHITE_X,
|
||||
D65_WHITE_Z,
|
||||
WAVELET_DECOMP_LEVELS,
|
||||
)
|
||||
|
||||
|
||||
def wavelet_blur(image: Tensor, radius):
|
||||
max_safe_radius = max(1, min(image.shape[-2:]) // 8)
|
||||
if radius > max_safe_radius:
|
||||
radius = max_safe_radius
|
||||
|
||||
num_channels = image.shape[1]
|
||||
|
||||
kernel_vals = [
|
||||
[0.0625, 0.125, 0.0625],
|
||||
[0.125, 0.25, 0.125],
|
||||
[0.0625, 0.125, 0.0625],
|
||||
]
|
||||
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
||||
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
|
||||
|
||||
image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate')
|
||||
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
|
||||
|
||||
return output
|
||||
|
||||
def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS):
|
||||
high_freq = torch.zeros_like(image)
|
||||
|
||||
for i in range(levels):
|
||||
radius = 2 ** i
|
||||
low_freq = wavelet_blur(image, radius)
|
||||
high_freq.add_(image).sub_(low_freq)
|
||||
image = low_freq
|
||||
|
||||
return high_freq, low_freq
|
||||
|
||||
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
|
||||
if content_feat.shape != style_feat.shape:
|
||||
# Resize style to match content spatial dimensions
|
||||
if len(content_feat.shape) >= 3:
|
||||
# safe_interpolate_operation handles FP16 conversion automatically
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Decompose both features into frequency components
|
||||
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
||||
del content_low_freq # Free memory immediately
|
||||
|
||||
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
||||
del style_high_freq # Free memory immediately
|
||||
|
||||
if content_high_freq.shape != style_low_freq.shape:
|
||||
style_low_freq = safe_interpolate_operation(
|
||||
style_low_freq,
|
||||
size=content_high_freq.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
content_high_freq.add_(style_low_freq)
|
||||
|
||||
return content_high_freq.clamp_(-1.0, 1.0)
|
||||
|
||||
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
|
||||
original_shape = source.shape
|
||||
|
||||
# Flatten
|
||||
source_flat = source.flatten()
|
||||
reference_flat = reference.flatten()
|
||||
|
||||
# Sort both arrays
|
||||
source_sorted, source_indices = torch.sort(source_flat)
|
||||
reference_sorted, _ = torch.sort(reference_flat)
|
||||
del reference_flat
|
||||
|
||||
# Quantile mapping
|
||||
n_source = len(source_sorted)
|
||||
n_reference = len(reference_sorted)
|
||||
|
||||
if n_source == n_reference:
|
||||
matched_sorted = reference_sorted
|
||||
else:
|
||||
# Interpolate reference to match source quantiles
|
||||
source_quantiles = torch.linspace(0, 1, n_source, device=device)
|
||||
ref_indices = (source_quantiles * (n_reference - 1)).long()
|
||||
ref_indices.clamp_(0, n_reference - 1)
|
||||
matched_sorted = reference_sorted[ref_indices]
|
||||
del source_quantiles, ref_indices, reference_sorted
|
||||
|
||||
del source_sorted, source_flat
|
||||
|
||||
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
|
||||
inverse_indices = torch.argsort(source_indices)
|
||||
del source_indices
|
||||
matched_flat = matched_sorted[inverse_indices]
|
||||
del matched_sorted, inverse_indices
|
||||
|
||||
return matched_flat.reshape(original_shape)
|
||||
|
||||
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||
"""Convert batch of CIELAB images to RGB color space."""
|
||||
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
|
||||
|
||||
# LAB to XYZ
|
||||
fy = (L + 16.0) / 116.0
|
||||
fx = a.div(500.0).add_(fy)
|
||||
fz = fy - b / 200.0
|
||||
del L, a, b
|
||||
|
||||
# XYZ transformation
|
||||
x = torch.where(
|
||||
fx > epsilon,
|
||||
torch.pow(fx, 3.0),
|
||||
fx.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
y = torch.where(
|
||||
fy > epsilon,
|
||||
torch.pow(fy, 3.0),
|
||||
fy.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
z = torch.where(
|
||||
fz > epsilon,
|
||||
torch.pow(fz, 3.0),
|
||||
fz.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
del fx, fy, fz
|
||||
|
||||
# Apply D65 white point (in-place)
|
||||
x.mul_(D65_WHITE_X)
|
||||
# y *= 1.00000 # (no-op, skip)
|
||||
z.mul_(D65_WHITE_Z)
|
||||
|
||||
xyz = torch.stack([x, y, z], dim=1)
|
||||
del x, y, z
|
||||
|
||||
# Matrix multiplication: XYZ -> RGB
|
||||
B, C, H, W = xyz.shape
|
||||
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||
del xyz
|
||||
|
||||
# Ensure dtype consistency for matrix multiplication
|
||||
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
|
||||
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
|
||||
del xyz_flat
|
||||
|
||||
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||
del rgb_linear_flat
|
||||
|
||||
# Apply inverse gamma correction (delinearize)
|
||||
mask = rgb_linear > 0.0031308
|
||||
rgb = torch.where(
|
||||
mask,
|
||||
torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055),
|
||||
rgb_linear * 12.92
|
||||
)
|
||||
del mask, rgb_linear
|
||||
|
||||
return torch.clamp(rgb, 0.0, 1.0)
|
||||
|
||||
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||
"""Convert batch of RGB images to CIELAB color space using D65 illuminant."""
|
||||
# Apply sRGB gamma correction (linearize)
|
||||
mask = rgb > 0.04045
|
||||
rgb_linear = torch.where(
|
||||
mask,
|
||||
torch.pow((rgb + 0.055) / 1.055, 2.4),
|
||||
rgb / 12.92
|
||||
)
|
||||
del mask
|
||||
|
||||
# Matrix multiplication: RGB -> XYZ
|
||||
B, C, H, W = rgb_linear.shape
|
||||
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||
del rgb_linear
|
||||
|
||||
# Ensure dtype consistency for matrix multiplication
|
||||
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
|
||||
xyz_flat = torch.matmul(rgb_flat, matrix.T)
|
||||
del rgb_flat
|
||||
|
||||
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||
del xyz_flat
|
||||
|
||||
# Normalize by D65 white point (in-place)
|
||||
xyz[:, 0].div_(D65_WHITE_X) # X
|
||||
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
|
||||
xyz[:, 2].div_(D65_WHITE_Z) # Z
|
||||
|
||||
# XYZ to LAB transformation
|
||||
epsilon_cubed = epsilon ** 3
|
||||
mask = xyz > epsilon_cubed
|
||||
f_xyz = torch.where(
|
||||
mask,
|
||||
torch.pow(xyz, 1.0 / 3.0),
|
||||
xyz.mul(kappa).add_(16.0).div_(116.0)
|
||||
)
|
||||
del xyz, mask
|
||||
|
||||
# Extract channels and compute LAB
|
||||
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
|
||||
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
|
||||
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
|
||||
del f_xyz
|
||||
|
||||
return torch.stack([L, a, b], dim=1)
|
||||
|
||||
def lab_color_transfer(
|
||||
content_feat: Tensor,
|
||||
style_feat: Tensor,
|
||||
luminance_weight: float = 0.8
|
||||
) -> Tensor:
|
||||
content_feat = wavelet_reconstruction(content_feat, style_feat)
|
||||
|
||||
if content_feat.shape != style_feat.shape:
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
device = content_feat.device
|
||||
|
||||
def ensure_float32_precision(c):
|
||||
orig_dtype = c.dtype
|
||||
c = c.float()
|
||||
return c, orig_dtype
|
||||
content_feat, original_dtype = ensure_float32_precision(content_feat)
|
||||
style_feat, _ = ensure_float32_precision(style_feat)
|
||||
|
||||
rgb_to_xyz_matrix = torch.tensor([
|
||||
[0.4124564, 0.3575761, 0.1804375],
|
||||
[0.2126729, 0.7151522, 0.0721750],
|
||||
[0.0193339, 0.1191920, 0.9503041]
|
||||
], dtype=torch.float32, device=device)
|
||||
|
||||
xyz_to_rgb_matrix = torch.tensor([
|
||||
[ 3.2404542, -1.5371385, -0.4985314],
|
||||
[-0.9692660, 1.8760108, 0.0415560],
|
||||
[ 0.0556434, -0.2040259, 1.0572252]
|
||||
], dtype=torch.float32, device=device)
|
||||
|
||||
epsilon = CIELAB_DELTA
|
||||
kappa = CIELAB_KAPPA
|
||||
|
||||
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||
|
||||
# Convert to LAB color space
|
||||
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||
del content_feat
|
||||
|
||||
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||
del style_feat, rgb_to_xyz_matrix
|
||||
|
||||
# Match chrominance channels (a*, b*) for accurate color transfer
|
||||
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
|
||||
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
|
||||
|
||||
# Handle luminance with weighted blending
|
||||
if luminance_weight < 1.0:
|
||||
# Partially match luminance for better overall color accuracy
|
||||
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device)
|
||||
# Blend: preserve some content L* for detail, adopt some style L* for color
|
||||
result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight))
|
||||
del matched_L
|
||||
else:
|
||||
# Fully preserve content luminance
|
||||
result_L = content_lab[:, 0]
|
||||
|
||||
del content_lab, style_lab
|
||||
|
||||
# Reconstruct LAB with corrected channels
|
||||
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
|
||||
del result_L, matched_a, matched_b
|
||||
|
||||
# Convert back to RGB
|
||||
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
|
||||
del result_lab, xyz_to_rgb_matrix
|
||||
|
||||
# Convert back to [-1, 1] range (in-place)
|
||||
result = result_rgb.mul_(2.0).sub_(1.0)
|
||||
del result_rgb
|
||||
|
||||
result = result.to(original_dtype)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
return wavelet_reconstruction(content_feat, style_feat)
|
||||
|
||||
|
||||
def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor:
|
||||
if content_feat.shape != style_feat.shape:
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
original_dtype = content_feat.dtype
|
||||
content_feat = content_feat.float()
|
||||
style_feat = style_feat.float()
|
||||
|
||||
b, c = content_feat.shape[:2]
|
||||
content_flat = content_feat.reshape(b, c, -1)
|
||||
style_flat = style_feat.reshape(b, c, -1)
|
||||
|
||||
content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1)
|
||||
content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
|
||||
style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1)
|
||||
style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
|
||||
del content_flat, style_flat
|
||||
|
||||
normalized = (content_feat - content_mean) / content_std
|
||||
del content_mean, content_std
|
||||
result = normalized * style_std + style_mean
|
||||
del normalized, style_mean, style_std
|
||||
|
||||
result = result.clamp_(-1.0, 1.0)
|
||||
if result.dtype != original_dtype:
|
||||
result = result.to(original_dtype)
|
||||
return result
|
||||
82
comfy/ldm/seedvr/constants.py
Normal file
82
comfy/ldm/seedvr/constants.py
Normal file
@ -0,0 +1,82 @@
|
||||
"""Named constants for the SeedVR2 integration, grouped by provenance.
|
||||
|
||||
Provenance prefixes:
|
||||
- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline.
|
||||
- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites
|
||||
the upstream config/source path it was lifted from.
|
||||
- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature /
|
||||
ISO / CIE values; cite the standard.
|
||||
|
||||
The numz/AInVFX custom node is used only as a behavioral-parity benchmark; it is the
|
||||
origin of none of these values and appears here nowhere.
|
||||
"""
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment)
|
||||
# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN)
|
||||
# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070
|
||||
# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT).
|
||||
# --------------------------------------------------------------------------------------
|
||||
SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB)
|
||||
SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# B. Fork heuristics (SEEDVR2 - this integration)
|
||||
# --------------------------------------------------------------------------------------
|
||||
SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim.
|
||||
# (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.)
|
||||
SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry.
|
||||
SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case).
|
||||
SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM.
|
||||
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk.
|
||||
SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels).
|
||||
SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16).
|
||||
SEEDVR2_DEFAULT_TEMPORAL_SIZE = 16 # default VAE temporal tile when unset.
|
||||
|
||||
# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing)
|
||||
SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk.
|
||||
SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path.
|
||||
SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path.
|
||||
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path.
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# C. ByteDance config / source (BYTEDANCE - cite myseedvr2/SeedVR)
|
||||
# --------------------------------------------------------------------------------------
|
||||
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm.
|
||||
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift.
|
||||
BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem).
|
||||
BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem).
|
||||
BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28.
|
||||
BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28.
|
||||
BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16).
|
||||
BYTEDANCE_GN_CHUNKS_FP32 = 2 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp32).
|
||||
BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD = 64 # attn_video_vae.py:308 (force .contiguous() above this b*t).
|
||||
BYTEDANCE_BLOCK_OUT_CHANNELS = (128, 256, 512, 512) # s8_c16_t4_inflation_sd3.yaml:7-11.
|
||||
BYTEDANCE_SLICING_SAMPLE_MIN = 4 # s8_c16_t4_inflation_sd3.yaml:22 (slicing_sample_min_size).
|
||||
BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE = 4 # infer.py:230 (temporal_downsample_factor); the 4n+1 factor.
|
||||
BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE = 8 # infer.py:231 (spatial_downsample_factor).
|
||||
BYTEDANCE_SCHEDULE_T = 1000.0 # configs_3b/main.yaml:65 (schedule.T); timestep range.
|
||||
BYTEDANCE_SPATIAL_DIVISOR = 16 # inference_seedvr2_3b.py:241 (DivisibleCrop((16,16))).
|
||||
BYTEDANCE_720P_REF_AREA = 45 * 80 # dit_v2/window.py:32 (720p reference area for window scaling).
|
||||
BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal window frames).
|
||||
BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency).
|
||||
BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim).
|
||||
# Resolution-dependent timestep-shift linear fits: (x1, y1, x2, y2) for get_lin_function.
|
||||
BYTEDANCE_IMG_SHIFT_FIT = (256 * 256, 1.0, 1024 * 1024, 3.2) # infer.py:242.
|
||||
BYTEDANCE_VID_SHIFT_FIT = (256 * 256 * 37, 1.0, 1280 * 720 * 145, 5.0) # infer.py:243.
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# D. Published standards (cite the literature)
|
||||
# --------------------------------------------------------------------------------------
|
||||
ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864.
|
||||
|
||||
# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65).
|
||||
CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta).
|
||||
CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa).
|
||||
D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1).
|
||||
D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn.
|
||||
WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR).
|
||||
|
||||
# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and
|
||||
# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the
|
||||
# exact existing coefficients move verbatim rather than being retyped here.
|
||||
@ -12,6 +12,16 @@ from torch.nn.modules.utils import _triple
|
||||
from torch import nn
|
||||
import math
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.seedvr.constants import (
|
||||
BYTEDANCE_720P_REF_AREA,
|
||||
BYTEDANCE_MAX_TEMPORAL_WINDOW,
|
||||
BYTEDANCE_ROPE_MAX_FREQ,
|
||||
BYTEDANCE_SINUSOIDAL_DIM,
|
||||
ROPE_THETA,
|
||||
SEEDVR2_7B_MLP_CHUNK,
|
||||
SEEDVR2_7B_VID_DIM,
|
||||
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS,
|
||||
)
|
||||
import comfy.model_management
|
||||
import numbers
|
||||
|
||||
@ -203,10 +213,10 @@ def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int,
|
||||
t, h, w = size
|
||||
resized_nt, resized_nh, resized_nw = num_windows
|
||||
#cal windows under 720p
|
||||
scale = math.sqrt((45 * 80) / (h * w))
|
||||
scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w))
|
||||
resized_h, resized_w = round(h * scale), round(w * scale)
|
||||
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
|
||||
wt = ceil(min(t, 30) / resized_nt) # window size.
|
||||
wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size.
|
||||
nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size.
|
||||
return [
|
||||
(
|
||||
@ -226,10 +236,10 @@ def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tup
|
||||
t, h, w = size
|
||||
resized_nt, resized_nh, resized_nw = num_windows
|
||||
#cal windows under 720p
|
||||
scale = math.sqrt((45 * 80) / (h * w))
|
||||
scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w))
|
||||
resized_h, resized_w = round(h * scale), round(w * scale)
|
||||
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
|
||||
wt = ceil(min(t, 30) / resized_nt) # window size.
|
||||
wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size.
|
||||
|
||||
st, sh, sw = ( # shift size.
|
||||
0.5 if wt < t else 0,
|
||||
@ -412,7 +422,7 @@ class RotaryEmbeddingBase(nn.Module):
|
||||
self.rope = RotaryEmbedding(
|
||||
dim=dim // rope_dim,
|
||||
freqs_for="pixel",
|
||||
max_freq=256,
|
||||
max_freq=BYTEDANCE_ROPE_MAX_FREQ,
|
||||
)
|
||||
freqs = self.rope.freqs
|
||||
del self.rope.freqs
|
||||
@ -486,7 +496,7 @@ class MMRotaryEmbeddingBase(RotaryEmbeddingBase):
|
||||
self.rope = RotaryEmbedding(
|
||||
dim=dim // rope_dim,
|
||||
freqs_for="lang",
|
||||
theta=10000,
|
||||
theta=ROPE_THETA,
|
||||
cache_if_possible=False,
|
||||
)
|
||||
freqs = self.rope.freqs
|
||||
@ -547,14 +557,7 @@ def apply_rotary_emb(
|
||||
return out.type(dtype)
|
||||
|
||||
def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert lucidrains-interleaved freqs `[..., d]` (`[θ0, θ0, θ1, θ1, ...]`
|
||||
from `RotaryEmbedding.forward`'s `repeat(freqs, '... n -> ... (n r)', r=2)`)
|
||||
into flux-canonical `freqs_cis` of shape `[..., d/2, 2, 2]` with the
|
||||
`cos/-sin/sin/cos` rotation matrix baked in. Output dtype is fp32 to
|
||||
match `comfy/ldm/flux/math.py:rope` precision; `apply_rope1` consumes
|
||||
the matrix layout via `freqs_cis[..., 0]` (column 0) and
|
||||
`freqs_cis[..., 1]` (column 1) of the 2x2 rotation matrix.
|
||||
"""
|
||||
"""Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`."""
|
||||
angles = freqs_interleaved[..., ::2].float()
|
||||
cos = torch.cos(angles)
|
||||
sin = torch.sin(angles)
|
||||
@ -562,27 +565,18 @@ def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor:
|
||||
return rearrange(out, "... d (i j) -> ... d i j", i=2, j=2)
|
||||
|
||||
|
||||
_ROPE1_PARTIAL_CHUNK_TOKENS = 4096
|
||||
SEEDVR2_7B_MLP_CHUNK = 8192
|
||||
|
||||
|
||||
def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply ``apply_rope1`` to the leading ``rot_d = 2 * freqs_cis.shape[-3]``
|
||||
components of ``t``'s last dim, passing through the remaining dims
|
||||
untouched in-place for inference tensors. Training tensors are cloned
|
||||
before slice assignment to preserve autograd correctness. Mirrors the partial-rope contract of the legacy
|
||||
``apply_rotary_emb`` wrapper at line 470 (``t_left``/``t_middle``/``t_right``
|
||||
split). For SeedVR2-3B this matters because ``rope_dim=128`` integer-
|
||||
divides into 3 axes as ``128 // 3 = 42`` per-axis, total ``42 * 3 = 126``;
|
||||
head_dim is 128, so the trailing 2 dims are unrotated. The fast path
|
||||
triggers when ``rot_d == t.shape[-1]`` (e.g. test rigs where dim is
|
||||
chosen divisible by 6) and avoids the cat entirely.
|
||||
"""Rotate the leading ``rot_d = 2 * freqs_cis.shape[-3]`` dims of ``t`` and pass the rest
|
||||
through; in-place for inference, cloned for training (autograd). Mirrors the legacy
|
||||
``apply_rotary_emb`` ``t_left``/``t_middle``/``t_right`` split: 3B ``rope_dim=128`` gives
|
||||
``42*3 = 126`` rotated of head_dim 128 (trailing 2 unrotated). Fast path skips the cat when
|
||||
``rot_d == t.shape[-1]``.
|
||||
"""
|
||||
out = t.clone() if t.requires_grad or comfy.model_management.in_training else t
|
||||
rot_d = 2 * freqs_cis.shape[-3]
|
||||
seq_len = out.shape[-2]
|
||||
for start in range(0, seq_len, _ROPE1_PARTIAL_CHUNK_TOKENS):
|
||||
end = min(start + _ROPE1_PARTIAL_CHUNK_TOKENS, seq_len)
|
||||
for start in range(0, seq_len, SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS):
|
||||
end = min(start + SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, seq_len)
|
||||
freqs_chunk = freqs_cis[start:end]
|
||||
if rot_d == out.shape[-1]:
|
||||
out[..., start:end, :] = apply_rope1(out[..., start:end, :], freqs_chunk).to(out.dtype)
|
||||
@ -1385,7 +1379,7 @@ class NaDiT(nn.Module):
|
||||
operations = None,
|
||||
**kwargs,
|
||||
):
|
||||
self._7b_version = vid_dim == 3072
|
||||
self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM
|
||||
self.dtype = dtype
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"]
|
||||
@ -1427,7 +1421,7 @@ class NaDiT(nn.Module):
|
||||
else nn.Identity()
|
||||
)
|
||||
self.emb_in = TimeEmbedding(
|
||||
sinusoidal_dim=256,
|
||||
sinusoidal_dim=BYTEDANCE_SINUSOIDAL_DIM,
|
||||
hidden_dim=max(vid_dim, txt_dim),
|
||||
output_dim=emb_dim,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
|
||||
@ -10,6 +10,27 @@ from contextlib import contextmanager
|
||||
from comfy.utils import ProgressBar
|
||||
|
||||
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||
from comfy.ldm.seedvr.constants import (
|
||||
BYTEDANCE_BLOCK_OUT_CHANNELS,
|
||||
BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD,
|
||||
BYTEDANCE_GN_CHUNKS_FP16,
|
||||
BYTEDANCE_GN_CHUNKS_FP32,
|
||||
BYTEDANCE_LOGVAR_CLAMP_MAX,
|
||||
BYTEDANCE_LOGVAR_CLAMP_MIN,
|
||||
BYTEDANCE_SLICING_SAMPLE_MIN,
|
||||
BYTEDANCE_VAE_CONV_MEM_GIB,
|
||||
BYTEDANCE_VAE_NORM_MEM_GIB,
|
||||
BYTEDANCE_VAE_SCALING_FACTOR,
|
||||
BYTEDANCE_VAE_SHIFTING_FACTOR,
|
||||
BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE,
|
||||
BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE,
|
||||
CIELAB_DELTA,
|
||||
CIELAB_KAPPA,
|
||||
D65_WHITE_X,
|
||||
D65_WHITE_Z,
|
||||
SEEDVR2_LATENT_CHANNELS,
|
||||
WAVELET_DECOMP_LEVELS,
|
||||
)
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.modules.diffusionmodules.model import vae_attention
|
||||
|
||||
@ -70,8 +91,8 @@ def tiled_vae(
|
||||
|
||||
_, _, d, h, w = x.shape
|
||||
|
||||
sf_s = getattr(vae_model, "spatial_downsample_factor", 8)
|
||||
sf_t = getattr(vae_model, "temporal_downsample_factor", 4)
|
||||
sf_s = getattr(vae_model, "spatial_downsample_factor", BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE)
|
||||
sf_t = getattr(vae_model, "temporal_downsample_factor", BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE)
|
||||
if encode:
|
||||
slicing_attr = "slicing_sample_min_size"
|
||||
slicing_min_size = _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap)
|
||||
@ -278,7 +299,7 @@ class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.logvar = torch.clamp(self.logvar, BYTEDANCE_LOGVAR_CLAMP_MIN, BYTEDANCE_LOGVAR_CLAMP_MAX)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
@ -569,7 +590,7 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
memory_occupy = x.numel() * x.element_size() / 1024**3
|
||||
if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit():
|
||||
num_chunks = min(4 if x.element_size() == 2 else 2, norm_layer.num_groups)
|
||||
num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups)
|
||||
assert norm_layer.num_groups % num_chunks == 0
|
||||
num_groups_per_chunk = norm_layer.num_groups // num_chunks
|
||||
|
||||
@ -1189,7 +1210,7 @@ class ResnetBlock3D(nn.Module):
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
if self.upsample is not None:
|
||||
if hidden_states.shape[0] >= 64:
|
||||
if hidden_states.shape[0] >= BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD:
|
||||
input_tensor = input_tensor.contiguous()
|
||||
hidden_states = hidden_states.contiguous()
|
||||
input_tensor = self.upsample(input_tensor, memory_state=memory_state)
|
||||
@ -1780,333 +1801,6 @@ class Decoder3D(nn.Module):
|
||||
|
||||
return sample
|
||||
|
||||
def wavelet_blur(image: Tensor, radius):
|
||||
max_safe_radius = max(1, min(image.shape[-2:]) // 8)
|
||||
if radius > max_safe_radius:
|
||||
radius = max_safe_radius
|
||||
|
||||
num_channels = image.shape[1]
|
||||
|
||||
kernel_vals = [
|
||||
[0.0625, 0.125, 0.0625],
|
||||
[0.125, 0.25, 0.125],
|
||||
[0.0625, 0.125, 0.0625],
|
||||
]
|
||||
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
||||
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
|
||||
|
||||
image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate')
|
||||
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
|
||||
|
||||
return output
|
||||
|
||||
def wavelet_decomposition(image: Tensor, levels: int = 5):
|
||||
high_freq = torch.zeros_like(image)
|
||||
|
||||
for i in range(levels):
|
||||
radius = 2 ** i
|
||||
low_freq = wavelet_blur(image, radius)
|
||||
high_freq.add_(image).sub_(low_freq)
|
||||
image = low_freq
|
||||
|
||||
return high_freq, low_freq
|
||||
|
||||
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
|
||||
if content_feat.shape != style_feat.shape:
|
||||
# Resize style to match content spatial dimensions
|
||||
if len(content_feat.shape) >= 3:
|
||||
# safe_interpolate_operation handles FP16 conversion automatically
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Decompose both features into frequency components
|
||||
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
||||
del content_low_freq # Free memory immediately
|
||||
|
||||
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
||||
del style_high_freq # Free memory immediately
|
||||
|
||||
if content_high_freq.shape != style_low_freq.shape:
|
||||
style_low_freq = safe_interpolate_operation(
|
||||
style_low_freq,
|
||||
size=content_high_freq.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
content_high_freq.add_(style_low_freq)
|
||||
|
||||
return content_high_freq.clamp_(-1.0, 1.0)
|
||||
|
||||
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
|
||||
original_shape = source.shape
|
||||
|
||||
# Flatten
|
||||
source_flat = source.flatten()
|
||||
reference_flat = reference.flatten()
|
||||
|
||||
# Sort both arrays
|
||||
source_sorted, source_indices = torch.sort(source_flat)
|
||||
reference_sorted, _ = torch.sort(reference_flat)
|
||||
del reference_flat
|
||||
|
||||
# Quantile mapping
|
||||
n_source = len(source_sorted)
|
||||
n_reference = len(reference_sorted)
|
||||
|
||||
if n_source == n_reference:
|
||||
matched_sorted = reference_sorted
|
||||
else:
|
||||
# Interpolate reference to match source quantiles
|
||||
source_quantiles = torch.linspace(0, 1, n_source, device=device)
|
||||
ref_indices = (source_quantiles * (n_reference - 1)).long()
|
||||
ref_indices.clamp_(0, n_reference - 1)
|
||||
matched_sorted = reference_sorted[ref_indices]
|
||||
del source_quantiles, ref_indices, reference_sorted
|
||||
|
||||
del source_sorted, source_flat
|
||||
|
||||
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
|
||||
inverse_indices = torch.argsort(source_indices)
|
||||
del source_indices
|
||||
matched_flat = matched_sorted[inverse_indices]
|
||||
del matched_sorted, inverse_indices
|
||||
|
||||
return matched_flat.reshape(original_shape)
|
||||
|
||||
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||
"""Convert batch of CIELAB images to RGB color space."""
|
||||
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
|
||||
|
||||
# LAB to XYZ
|
||||
fy = (L + 16.0) / 116.0
|
||||
fx = a.div(500.0).add_(fy)
|
||||
fz = fy - b / 200.0
|
||||
del L, a, b
|
||||
|
||||
# XYZ transformation
|
||||
x = torch.where(
|
||||
fx > epsilon,
|
||||
torch.pow(fx, 3.0),
|
||||
fx.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
y = torch.where(
|
||||
fy > epsilon,
|
||||
torch.pow(fy, 3.0),
|
||||
fy.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
z = torch.where(
|
||||
fz > epsilon,
|
||||
torch.pow(fz, 3.0),
|
||||
fz.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
del fx, fy, fz
|
||||
|
||||
# Apply D65 white point (in-place)
|
||||
x.mul_(0.95047)
|
||||
# y *= 1.00000 # (no-op, skip)
|
||||
z.mul_(1.08883)
|
||||
|
||||
xyz = torch.stack([x, y, z], dim=1)
|
||||
del x, y, z
|
||||
|
||||
# Matrix multiplication: XYZ -> RGB
|
||||
B, C, H, W = xyz.shape
|
||||
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||
del xyz
|
||||
|
||||
# Ensure dtype consistency for matrix multiplication
|
||||
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
|
||||
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
|
||||
del xyz_flat
|
||||
|
||||
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||
del rgb_linear_flat
|
||||
|
||||
# Apply inverse gamma correction (delinearize)
|
||||
mask = rgb_linear > 0.0031308
|
||||
rgb = torch.where(
|
||||
mask,
|
||||
torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055),
|
||||
rgb_linear * 12.92
|
||||
)
|
||||
del mask, rgb_linear
|
||||
|
||||
return torch.clamp(rgb, 0.0, 1.0)
|
||||
|
||||
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||
"""Convert batch of RGB images to CIELAB color space using D65 illuminant."""
|
||||
# Apply sRGB gamma correction (linearize)
|
||||
mask = rgb > 0.04045
|
||||
rgb_linear = torch.where(
|
||||
mask,
|
||||
torch.pow((rgb + 0.055) / 1.055, 2.4),
|
||||
rgb / 12.92
|
||||
)
|
||||
del mask
|
||||
|
||||
# Matrix multiplication: RGB -> XYZ
|
||||
B, C, H, W = rgb_linear.shape
|
||||
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||
del rgb_linear
|
||||
|
||||
# Ensure dtype consistency for matrix multiplication
|
||||
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
|
||||
xyz_flat = torch.matmul(rgb_flat, matrix.T)
|
||||
del rgb_flat
|
||||
|
||||
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||
del xyz_flat
|
||||
|
||||
# Normalize by D65 white point (in-place)
|
||||
xyz[:, 0].div_(0.95047) # X
|
||||
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
|
||||
xyz[:, 2].div_(1.08883) # Z
|
||||
|
||||
# XYZ to LAB transformation
|
||||
epsilon_cubed = epsilon ** 3
|
||||
mask = xyz > epsilon_cubed
|
||||
f_xyz = torch.where(
|
||||
mask,
|
||||
torch.pow(xyz, 1.0 / 3.0),
|
||||
xyz.mul(kappa).add_(16.0).div_(116.0)
|
||||
)
|
||||
del xyz, mask
|
||||
|
||||
# Extract channels and compute LAB
|
||||
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
|
||||
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
|
||||
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
|
||||
del f_xyz
|
||||
|
||||
return torch.stack([L, a, b], dim=1)
|
||||
|
||||
def lab_color_transfer(
|
||||
content_feat: Tensor,
|
||||
style_feat: Tensor,
|
||||
luminance_weight: float = 0.8
|
||||
) -> Tensor:
|
||||
content_feat = wavelet_reconstruction(content_feat, style_feat)
|
||||
|
||||
if content_feat.shape != style_feat.shape:
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
device = content_feat.device
|
||||
|
||||
def ensure_float32_precision(c):
|
||||
orig_dtype = c.dtype
|
||||
c = c.float()
|
||||
return c, orig_dtype
|
||||
content_feat, original_dtype = ensure_float32_precision(content_feat)
|
||||
style_feat, _ = ensure_float32_precision(style_feat)
|
||||
|
||||
rgb_to_xyz_matrix = torch.tensor([
|
||||
[0.4124564, 0.3575761, 0.1804375],
|
||||
[0.2126729, 0.7151522, 0.0721750],
|
||||
[0.0193339, 0.1191920, 0.9503041]
|
||||
], dtype=torch.float32, device=device)
|
||||
|
||||
xyz_to_rgb_matrix = torch.tensor([
|
||||
[ 3.2404542, -1.5371385, -0.4985314],
|
||||
[-0.9692660, 1.8760108, 0.0415560],
|
||||
[ 0.0556434, -0.2040259, 1.0572252]
|
||||
], dtype=torch.float32, device=device)
|
||||
|
||||
epsilon = 6.0 / 29.0
|
||||
kappa = (29.0 / 3.0) ** 3
|
||||
|
||||
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||
|
||||
# Convert to LAB color space
|
||||
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||
del content_feat
|
||||
|
||||
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||
del style_feat, rgb_to_xyz_matrix
|
||||
|
||||
# Match chrominance channels (a*, b*) for accurate color transfer
|
||||
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
|
||||
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
|
||||
|
||||
# Handle luminance with weighted blending
|
||||
if luminance_weight < 1.0:
|
||||
# Partially match luminance for better overall color accuracy
|
||||
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device)
|
||||
# Blend: preserve some content L* for detail, adopt some style L* for color
|
||||
result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight))
|
||||
del matched_L
|
||||
else:
|
||||
# Fully preserve content luminance
|
||||
result_L = content_lab[:, 0]
|
||||
|
||||
del content_lab, style_lab
|
||||
|
||||
# Reconstruct LAB with corrected channels
|
||||
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
|
||||
del result_L, matched_a, matched_b
|
||||
|
||||
# Convert back to RGB
|
||||
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
|
||||
del result_lab, xyz_to_rgb_matrix
|
||||
|
||||
# Convert back to [-1, 1] range (in-place)
|
||||
result = result_rgb.mul_(2.0).sub_(1.0)
|
||||
del result_rgb
|
||||
|
||||
result = result.to(original_dtype)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
return wavelet_reconstruction(content_feat, style_feat)
|
||||
|
||||
|
||||
def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor:
|
||||
if content_feat.shape != style_feat.shape:
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
original_dtype = content_feat.dtype
|
||||
content_feat = content_feat.float()
|
||||
style_feat = style_feat.float()
|
||||
|
||||
b, c = content_feat.shape[:2]
|
||||
content_flat = content_feat.reshape(b, c, -1)
|
||||
style_flat = style_feat.reshape(b, c, -1)
|
||||
|
||||
content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1)
|
||||
content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
|
||||
style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1)
|
||||
style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
|
||||
del content_flat, style_flat
|
||||
|
||||
normalized = (content_feat - content_mean) / content_std
|
||||
del content_mean, content_std
|
||||
result = normalized * style_std + style_mean
|
||||
del normalized, style_mean, style_std
|
||||
|
||||
result = result.clamp_(-1.0, 1.0)
|
||||
if result.dtype != original_dtype:
|
||||
result = result.to(original_dtype)
|
||||
return result
|
||||
|
||||
|
||||
class VideoAutoencoderKL(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -2114,7 +1808,7 @@ class VideoAutoencoderKL(nn.Module):
|
||||
out_channels: int = 3,
|
||||
layers_per_block: int = 2,
|
||||
act_fn: str = "silu",
|
||||
latent_channels: int = 16,
|
||||
latent_channels: int = SEEDVR2_LATENT_CHANNELS,
|
||||
norm_num_groups: int = 32,
|
||||
attention: bool = True,
|
||||
temporal_scale_num: int = 2,
|
||||
@ -2124,14 +1818,14 @@ class VideoAutoencoderKL(nn.Module):
|
||||
time_receptive_field: _receptive_field_t = "full",
|
||||
use_quant_conv: bool = False,
|
||||
use_post_quant_conv: bool = False,
|
||||
slicing_sample_min_size = 4,
|
||||
slicing_sample_min_size = BYTEDANCE_SLICING_SAMPLE_MIN,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.slicing_sample_min_size = slicing_sample_min_size
|
||||
self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num)
|
||||
extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None
|
||||
block_out_channels = (128, 256, 512, 512)
|
||||
block_out_channels = BYTEDANCE_BLOCK_OUT_CHANNELS
|
||||
down_block_types = ("DownEncoderBlock3D",) * 4
|
||||
up_block_types = ("UpDecoderBlock3D",) * 4
|
||||
super().__init__()
|
||||
@ -2329,7 +2023,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
self.freeze_encoder = freeze_encoder
|
||||
self.enable_tiling = False
|
||||
super().__init__(*args, **kwargs)
|
||||
self.set_memory_limit(0.5, 0.5)
|
||||
self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB)
|
||||
|
||||
def forward(self, x: torch.FloatTensor):
|
||||
with torch.no_grad() if self.freeze_encoder else nullcontext():
|
||||
@ -2377,8 +2071,8 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
"4-D collapsed (B, 16*T, H, W) or 5-D (B, 16, T, H, W); "
|
||||
f"got shape {tuple(z.shape)}."
|
||||
)
|
||||
scale = 0.9152
|
||||
shift = 0
|
||||
scale = BYTEDANCE_VAE_SCALING_FACTOR
|
||||
shift = BYTEDANCE_VAE_SHIFTING_FACTOR
|
||||
latent = latent / scale + shift
|
||||
|
||||
self.device = latent.device
|
||||
|
||||
@ -9,11 +9,24 @@ import gc
|
||||
import comfy.model_management
|
||||
import comfy.sample
|
||||
import comfy.samplers
|
||||
from comfy.ldm.seedvr.vae import (
|
||||
from comfy.ldm.seedvr.color_fix import (
|
||||
adain_color_transfer,
|
||||
lab_color_transfer,
|
||||
wavelet_color_transfer,
|
||||
)
|
||||
from comfy.ldm.seedvr.constants import (
|
||||
BYTEDANCE_IMG_SHIFT_FIT,
|
||||
BYTEDANCE_SCHEDULE_T,
|
||||
BYTEDANCE_VID_SHIFT_FIT,
|
||||
SEEDVR2_ADAIN_SCALE_MULTIPLIER,
|
||||
SEEDVR2_COLOR_MEM_HEADROOM,
|
||||
SEEDVR2_COND_CHANNELS,
|
||||
SEEDVR2_DTYPE_BYTES_FLOOR,
|
||||
SEEDVR2_LAB_SCALE_MULTIPLIER,
|
||||
SEEDVR2_LATENT_CHANNELS,
|
||||
SEEDVR2_OOM_BACKOFF_DIVISOR,
|
||||
SEEDVR2_WAVELET_SCALE_MULTIPLIER,
|
||||
)
|
||||
|
||||
from torchvision.transforms import functional as TVF
|
||||
from torchvision.transforms import Lambda
|
||||
@ -23,10 +36,6 @@ from torchvision.transforms.functional import InterpolationMode
|
||||
_SEEDVR2_INVALID_MODEL_MSG_PREFIX = (
|
||||
"SeedVR2Conditioning: model object does not match expected SeedVR2 structure"
|
||||
)
|
||||
LAB_SCALE_MULTIPLIER = 13
|
||||
WAVELET_SCALE_MULTIPLIER = 10
|
||||
ADAIN_SCALE_MULTIPLIER = 6
|
||||
COLOR_CORRECTION_MEMORY_HEADROOM = 0.75
|
||||
|
||||
# Private sentinel for getattr default: distinguishes "attribute missing"
|
||||
# from "attribute present but None" so the failure message is accurate.
|
||||
@ -57,17 +66,7 @@ def _seedvr2_auto_chunk_attempts(t_latent, t_pixel, frames_per_chunk):
|
||||
|
||||
|
||||
def _resolve_seedvr2_diffusion_model(model):
|
||||
"""Resolve the inner SeedVR2 diffusion-model module from a ComfyUI model
|
||||
patcher object. Fails loud with a ``RuntimeError`` whose message begins
|
||||
with ``_SEEDVR2_INVALID_MODEL_MSG_PREFIX`` when the expected wrapper
|
||||
shape (``model.model.diffusion_model``) is absent.
|
||||
|
||||
Distinguishes four failure modes via the ``_ATTR_MISSING`` sentinel:
|
||||
``model.model`` missing, ``model.model is None``,
|
||||
``model.model.diffusion_model`` missing, ``model.model.diffusion_model
|
||||
is None``. Each mode produces an accurate error message rather than
|
||||
conflating "attribute missing" with "attribute is None".
|
||||
"""
|
||||
"""Resolve ``model.model.diffusion_model``, failing loud via the ``_ATTR_MISSING`` sentinel so each of the four modes (model/diffusion_model missing vs None) gives an accurate message."""
|
||||
inner = getattr(model, "model", _ATTR_MISSING)
|
||||
if inner is _ATTR_MISSING:
|
||||
raise RuntimeError(
|
||||
@ -94,15 +93,7 @@ def _resolve_seedvr2_diffusion_model(model):
|
||||
|
||||
|
||||
def _apply_rope_freqs_float32_cast(diffusion_model):
|
||||
"""Cast every nested module's ``rope.freqs`` parameter data to ``float32``
|
||||
when it is not already in float32. Idempotency is per-tensor by dtype
|
||||
check, NOT a per-instance sentinel attribute — a sentinel would survive
|
||||
Comfy's dynamic model unload/reload cycle while ``rope.freqs`` itself
|
||||
is restored from the archived dtype, leaving RoPE running in fp16/bf16
|
||||
on subsequent calls. The dtype check makes the cast self-correcting
|
||||
against weight-restore lifecycle events. Iteration cost is one walk of
|
||||
the diffusion-model module tree per ``execute()`` call (microseconds).
|
||||
"""
|
||||
"""Cast every module's ``rope.freqs`` to float32; the per-tensor dtype check (not a sentinel attr) self-corrects across Comfy's unload/reload, which would otherwise restore the archived fp16/bf16 dtype."""
|
||||
for module in diffusion_model.modules():
|
||||
if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'):
|
||||
if module.rope.freqs.data.dtype != torch.float32:
|
||||
@ -140,8 +131,8 @@ def timestep_transform(timesteps, latents_shapes):
|
||||
b = y1 - m * x1
|
||||
return lambda x: m * x + b
|
||||
|
||||
img_shift_fn = get_lin_function(x1=256 * 256, y1=1.0, x2=1024 * 1024, y2=3.2)
|
||||
vid_shift_fn = get_lin_function(x1=256 * 256 * 37, y1=1.0, x2=1280 * 720 * 145, y2=5.0)
|
||||
img_shift_fn = get_lin_function(*BYTEDANCE_IMG_SHIFT_FIT)
|
||||
vid_shift_fn = get_lin_function(*BYTEDANCE_VID_SHIFT_FIT)
|
||||
shift = torch.where(
|
||||
frames > 1,
|
||||
vid_shift_fn(heights * widths * frames),
|
||||
@ -149,7 +140,7 @@ def timestep_transform(timesteps, latents_shapes):
|
||||
).to(timesteps.device)
|
||||
|
||||
# Shift timesteps.
|
||||
T = 1000.0
|
||||
T = BYTEDANCE_SCHEDULE_T
|
||||
timesteps = timesteps / T
|
||||
timesteps = shift * timesteps / (1 + (shift - 1) * timesteps)
|
||||
timesteps = timesteps * T
|
||||
@ -157,7 +148,7 @@ def timestep_transform(timesteps, latents_shapes):
|
||||
|
||||
def inter(x_0, x_T, t):
|
||||
t = expand_dims(t, x_0.ndim)
|
||||
T = 1000.0
|
||||
T = BYTEDANCE_SCHEDULE_T
|
||||
B = lambda t: t / T
|
||||
A = lambda t: 1 - (t / T)
|
||||
return A(t) * x_0 + B(t) * x_T
|
||||
@ -235,6 +226,8 @@ def _seedvr2_resize_and_pad(images, upscaled_shorter_edge, node_name):
|
||||
f"got {upscaled_shorter_edge}."
|
||||
)
|
||||
original_image = images
|
||||
if images.shape[-1] > 3:
|
||||
images = images[..., :3]
|
||||
if images.dim() == 4:
|
||||
# Comfy video components arrive as a 4-D IMAGE frame sequence:
|
||||
# (frames, H, W, C). SeedVR2 consumes that as one video.
|
||||
@ -268,10 +261,12 @@ class SeedVR2Resize(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SeedVR2Resize",
|
||||
category="image/video",
|
||||
display_name="Resize Image for SeedVR2",
|
||||
category="image/upscaling",
|
||||
description="Resize an image to a SeedVR2-compatible size by a multiplier.",
|
||||
inputs=[
|
||||
io.Image.Input("images"),
|
||||
io.Float.Input("multiplier", default=4.0, min=0.01),
|
||||
io.Image.Input("images", tooltip="The image(s) to resize."),
|
||||
io.Float.Input("multiplier", default=4.0, min=0.01, tooltip="Upscale factor applied to the shorter edge."),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output("input_pixels"),
|
||||
@ -304,10 +299,12 @@ class SeedVR2ResizeAdvanced(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SeedVR2ResizeAdvanced",
|
||||
category="image/video",
|
||||
display_name="Resize Image for SeedVR2 (Advanced)",
|
||||
category="image/upscaling",
|
||||
description="Resize an image to an exact shorter-edge size for SeedVR2.",
|
||||
inputs=[
|
||||
io.Image.Input("images"),
|
||||
io.Int.Input("shorter_edge", default=1280, min=2),
|
||||
io.Image.Input("images", tooltip="The image(s) to resize."),
|
||||
io.Int.Input("shorter_edge", default=1280, min=2, tooltip="Target length of the shorter edge, in pixels."),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output("input_pixels"),
|
||||
@ -323,17 +320,30 @@ class SeedVR2ResizeAdvanced(io.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
def _edge_guided_alpha_upscale(alpha, out_h, out_w):
|
||||
a = alpha.float()
|
||||
extreme_fraction = ((a < 0.1) | (a > 0.9)).float().mean()
|
||||
if extreme_fraction > 0.9:
|
||||
up = torch.nn.functional.interpolate(a, size=(out_h, out_w), mode="bilinear", align_corners=False, antialias=True)
|
||||
up = torch.clamp((up - 0.5) * 4.0 + 0.5, 0.0, 1.0)
|
||||
else:
|
||||
up = torch.nn.functional.interpolate(a, size=(out_h, out_w), mode="bicubic", align_corners=False, antialias=True).clamp(0.0, 1.0)
|
||||
return up
|
||||
|
||||
|
||||
class SeedVR2PostProcessing(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SeedVR2PostProcessing",
|
||||
category="image/video",
|
||||
display_name="Post-Process SeedVR2 Output",
|
||||
category="image/upscaling",
|
||||
description="Align the upscaled output to the original's geometry and optionally color-correct it against the original.",
|
||||
inputs=[
|
||||
io.Image.Input("decoded"),
|
||||
io.Image.Input("original_image"),
|
||||
io.Int.Input("upscaled_shorter_edge", min=2, force_input=True),
|
||||
io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab"),
|
||||
io.Image.Input("decoded", tooltip="The decoded upscaled image to color-correct."),
|
||||
io.Image.Input("original_image", tooltip="The original image used as the color reference."),
|
||||
io.Int.Input("upscaled_shorter_edge", min=2, force_input=True, tooltip="Shorter-edge size from the resize node."),
|
||||
io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab", tooltip="How to match the output's color to the original. lab: transfer color in CIELAB space, preserving detail (most faithful). wavelet: transfer low-frequency color, keeping upscaled high-frequency detail. adain: match per-channel mean/std (fastest, global tint). none: skip color transfer (geometry alignment only)."),
|
||||
],
|
||||
outputs=[io.Image.Output()],
|
||||
)
|
||||
@ -341,6 +351,10 @@ class SeedVR2PostProcessing(io.ComfyNode):
|
||||
@classmethod
|
||||
def execute(cls, decoded, original_image, upscaled_shorter_edge, color_correction_method):
|
||||
cls._validate_upscaled_shorter_edge(upscaled_shorter_edge)
|
||||
alpha_input = None
|
||||
if original_image.shape[-1] == 4:
|
||||
alpha_input = original_image[..., 3:4]
|
||||
original_image = original_image[..., :3]
|
||||
decoded_5d, decoded_was_4d = cls._as_bthwc(decoded)
|
||||
original_5d, _ = cls._as_bthwc(original_image)
|
||||
decoded_5d = cls._restore_reference_batch_time(decoded_5d, original_5d)
|
||||
@ -374,6 +388,13 @@ class SeedVR2PostProcessing(io.ComfyNode):
|
||||
else:
|
||||
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
|
||||
|
||||
if alpha_input is not None:
|
||||
ab, at = output.shape[0], output.shape[1]
|
||||
alpha_5d, _ = cls._as_bthwc(alpha_input)
|
||||
alpha_flat = rearrange(alpha_5d[:ab, :at], "b t h w c -> (b t) c h w")
|
||||
alpha_up = _edge_guided_alpha_upscale(alpha_flat, output.shape[2], output.shape[3])
|
||||
alpha_up = rearrange(alpha_up, "(b t) c h w -> b t h w c", b=ab, t=at)
|
||||
output = torch.cat([output, alpha_up.to(dtype=output.dtype, device=output.device)], dim=-1)
|
||||
h2 = output.shape[-3] - (output.shape[-3] % 2)
|
||||
w2 = output.shape[-2] - (output.shape[-2] % 2)
|
||||
output = output[:, :, :h2, :w2, :]
|
||||
@ -472,7 +493,7 @@ class SeedVR2PostProcessing(io.ComfyNode):
|
||||
"SeedVR2PostProcessing: color correction OOM at one frame; "
|
||||
f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}."
|
||||
) from e
|
||||
next_chunk_size = max(1, chunk_size // 2)
|
||||
next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR)
|
||||
|
||||
comfy.model_management.soft_empty_cache()
|
||||
chunk_size = next_chunk_size
|
||||
@ -510,23 +531,23 @@ class SeedVR2PostProcessing(io.ComfyNode):
|
||||
multiplier = cls._color_correction_memory_multiplier(color_correction_method)
|
||||
frames = decoded_flat.shape[0]
|
||||
_, channels, height, width = decoded_flat.shape
|
||||
dtype_bytes = max(decoded_flat.element_size(), 4)
|
||||
dtype_bytes = max(decoded_flat.element_size(), SEEDVR2_DTYPE_BYTES_FLOOR)
|
||||
bytes_per_frame = height * width * channels * dtype_bytes * multiplier
|
||||
if bytes_per_frame <= 0:
|
||||
return frames
|
||||
color_device = comfy.model_management.vae_device()
|
||||
free_memory = comfy.model_management.get_free_memory(color_device)
|
||||
chunk_size = int((free_memory * COLOR_CORRECTION_MEMORY_HEADROOM) // bytes_per_frame)
|
||||
chunk_size = int((free_memory * SEEDVR2_COLOR_MEM_HEADROOM) // bytes_per_frame)
|
||||
return max(1, min(frames, chunk_size))
|
||||
|
||||
@staticmethod
|
||||
def _color_correction_memory_multiplier(color_correction_method):
|
||||
if color_correction_method == "lab":
|
||||
return LAB_SCALE_MULTIPLIER
|
||||
return SEEDVR2_LAB_SCALE_MULTIPLIER
|
||||
if color_correction_method == "wavelet":
|
||||
return WAVELET_SCALE_MULTIPLIER
|
||||
return SEEDVR2_WAVELET_SCALE_MULTIPLIER
|
||||
if color_correction_method == "adain":
|
||||
return ADAIN_SCALE_MULTIPLIER
|
||||
return SEEDVR2_ADAIN_SCALE_MULTIPLIER
|
||||
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
|
||||
|
||||
@staticmethod
|
||||
@ -549,10 +570,12 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SeedVR2Conditioning",
|
||||
category="image/video",
|
||||
display_name="Apply SeedVR2 Conditioning",
|
||||
category="conditioning",
|
||||
description="Build SeedVR2 positive/negative conditioning from a VAE latent.",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Latent.Input("vae_conditioning", display_name="LATENT"),
|
||||
io.Model.Input("model", tooltip="The SeedVR2 model."),
|
||||
io.Latent.Input("vae_conditioning", tooltip="The VAE-encoded latent to condition on."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(display_name = "model"),
|
||||
@ -571,10 +594,10 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
"SeedVR2Conditioning expects a 5-D VAE latent in Comfy "
|
||||
f"channel-first layout; got shape {tuple(vae_conditioning.shape)}."
|
||||
)
|
||||
if vae_conditioning.shape[-1] == _SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != _SEEDVR2_LATENT_CHANNELS:
|
||||
if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS:
|
||||
raise ValueError(
|
||||
"SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy "
|
||||
f"channel-first layout (B, {_SEEDVR2_LATENT_CHANNELS}, T, H, W); "
|
||||
f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); "
|
||||
f"got channel-last shape {tuple(vae_conditioning.shape)}."
|
||||
)
|
||||
vae_conditioning = vae_conditioning.movedim(1, -1).contiguous()
|
||||
@ -622,27 +645,9 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
|
||||
return io.NodeOutput(model_patcher, positive, negative, {"samples": latent})
|
||||
|
||||
# SeedVR2 latent / conditioning channel constants. The SeedVR2 conditioning
|
||||
# stage collapses ``(B, C, T, H, W) -> (B, C*T, H, W)`` for both the latent
|
||||
# (C=16) and the per-frame condition tensor (C=17 = 16 latent + 1 mask), as
|
||||
# required by ``NaDiT.forward`` which un-collapses via
|
||||
# ``view(B, 16, -1, H, W)`` and ``view(B, 17, -1, H, W)`` respectively.
|
||||
_SEEDVR2_LATENT_CHANNELS = 16
|
||||
_SEEDVR2_CONDITION_CHANNELS = 17
|
||||
|
||||
|
||||
def _slice_collapsed_4d_along_t(tensor_4d: torch.Tensor, t_start: int,
|
||||
t_end: int, channels: int) -> torch.Tensor:
|
||||
"""Slice a SeedVR2-style collapsed 4D tensor ``(B, channels*T, H, W)``
|
||||
along the latent T axis, returning ``(B, channels*(t_end - t_start), H, W)``.
|
||||
|
||||
Reshape -> slice -> ``.contiguous()`` -> re-collapse. ``reshape`` is
|
||||
used for the un-collapse so non-contiguous incoming tensors from
|
||||
cropping or slicing nodes are accepted. The
|
||||
``.contiguous()`` is mandatory: T-axis slicing of a 5D tensor produces a
|
||||
non-contiguous view, and the subsequent re-collapse requires contiguous
|
||||
storage.
|
||||
"""
|
||||
"""Slice collapsed ``(B, channels*T, H, W)`` along latent T: reshape (accepts non-contiguous inputs), slice, ``.contiguous()`` (T-slice of 5D is a non-contiguous view; re-collapse needs contiguous), re-collapse."""
|
||||
B, CT, H, W = tensor_4d.shape
|
||||
if CT % channels != 0:
|
||||
raise ValueError(
|
||||
@ -661,19 +666,7 @@ def _slice_collapsed_4d_along_t(tensor_4d: torch.Tensor, t_start: int,
|
||||
|
||||
|
||||
def _slice_seedvr2_cond_along_t(cond_list, t_start: int, t_end: int):
|
||||
"""Build a new SeedVR2 conditioning list with the per-frame ``condition``
|
||||
tensor sliced along the latent T axis.
|
||||
|
||||
SeedVR2 conditioning entries have the shape
|
||||
``[text_cond_tensor, options_dict]`` where ``options_dict["condition"]``
|
||||
is a 4D collapsed ``(B, 17*T, H, W)`` tensor; the text tensor itself has
|
||||
no temporal axis and is passed through unchanged. Other keys in the
|
||||
options dict (controlnets, etc.) are also passed through unchanged. If
|
||||
an entry has no ``"condition"`` key, the entry is forwarded verbatim.
|
||||
|
||||
A new list of ``[text_cond, new_options_dict]`` pairs is returned; the
|
||||
original ``cond_list`` and its options dicts are not mutated.
|
||||
"""
|
||||
"""Return a new conditioning list with each entry's ``options["condition"]`` (collapsed ``(B, 17*T, H, W)``) sliced along latent T; text tensors, other option keys, and condition-less entries pass through unchanged and inputs are not mutated."""
|
||||
new_list = []
|
||||
for entry in cond_list:
|
||||
text_cond, options = entry[0], entry[1]
|
||||
@ -683,7 +676,7 @@ def _slice_seedvr2_cond_along_t(cond_list, t_start: int, t_end: int):
|
||||
new_options = options.copy()
|
||||
new_options["condition"] = _slice_collapsed_4d_along_t(
|
||||
new_options["condition"], t_start, t_end,
|
||||
_SEEDVR2_CONDITION_CHANNELS,
|
||||
SEEDVR2_COND_CHANNELS,
|
||||
)
|
||||
new_list.append([text_cond, new_options])
|
||||
return new_list
|
||||
@ -693,24 +686,16 @@ def _slice_seedvr2_noise_mask_along_t(noise_mask: torch.Tensor,
|
||||
samples_4d: torch.Tensor,
|
||||
t_start: int,
|
||||
t_end: int):
|
||||
"""Slice collapsed SeedVR2 masks and preserve standard masks.
|
||||
|
||||
``SetLatentNoiseMask`` produces ``(B, 1, H, W)`` masks that KSampler
|
||||
expands to the latent shape. Only masks already expanded to the full
|
||||
collapsed ``(B, 16*T, H, W)`` shape need temporal slicing here.
|
||||
"""
|
||||
"""Slice only masks already expanded to collapsed ``(B, 16*T, H, W)``; pass standard ``(B, 1, H, W)`` ``SetLatentNoiseMask`` outputs through for KSampler to expand."""
|
||||
if noise_mask.ndim == samples_4d.ndim and noise_mask.shape[1] == samples_4d.shape[1]:
|
||||
return _slice_collapsed_4d_along_t(
|
||||
noise_mask, t_start, t_end, _SEEDVR2_LATENT_CHANNELS,
|
||||
noise_mask, t_start, t_end, SEEDVR2_LATENT_CHANNELS,
|
||||
)
|
||||
return noise_mask
|
||||
|
||||
|
||||
def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor:
|
||||
"""Concatenate a list of SeedVR2-style collapsed 4D tensors
|
||||
``(B, channels*T_i, H, W)`` along the latent T axis. Each chunk is
|
||||
un-collapsed to 5D, concatenated on ``dim=2``, then re-collapsed to 4D.
|
||||
"""
|
||||
"""Concatenate collapsed ``(B, channels*T_i, H, W)`` chunks along latent T: un-collapse to 5D, cat on ``dim=2``, re-collapse to 4D."""
|
||||
if len(chunks_4d) == 0:
|
||||
raise ValueError("_concat_chunks_along_t: empty chunk list.")
|
||||
fives = []
|
||||
@ -729,19 +714,10 @@ def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor:
|
||||
|
||||
|
||||
def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor:
|
||||
"""Build a 1D crossfade weight tensor of length ``overlap`` for the
|
||||
*previous* chunk's contribution; the current chunk's weight is
|
||||
``1 - w_prev``.
|
||||
|
||||
Mirrors the numz ``blend_overlapping_frames`` shape
|
||||
(AInVFX/numz fork ``src/core/generation_utils.py``,
|
||||
``blend_overlapping_frames``): a Hann window with a ``[1/3, 2/3]``
|
||||
dead-band when ``overlap >= 3``, and a plain linear ramp when
|
||||
``overlap < 3`` (the dead-band would collapse the transition for
|
||||
very small overlap counts). The numz reference operates on
|
||||
pixel-space tensors ``[overlap, H, W, C]``; this 1D form is
|
||||
reshaped by the caller to broadcast across the latent's
|
||||
``(B, C, T_overlap, H, W)`` axes.
|
||||
"""1D length-``overlap`` crossfade weights for the previous chunk (current = ``1 - w_prev``):
|
||||
Hann window with a ``[1/3, 2/3]`` dead-band for ``overlap >= 3``, linear ramp for ``overlap < 3``
|
||||
(dead-band would collapse a tiny transition). Window shape matched to numz ``blend_overlapping_frames``
|
||||
for parity (reference, not source); caller broadcasts across ``(B, C, T_overlap, H, W)``.
|
||||
"""
|
||||
if overlap < 1:
|
||||
raise ValueError(
|
||||
@ -758,14 +734,7 @@ def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor:
|
||||
|
||||
def _blend_overlap_region(prev_tail_5d: torch.Tensor,
|
||||
cur_head_5d: torch.Tensor) -> torch.Tensor:
|
||||
"""Blend two 5D ``(B, C, T_overlap, H, W)`` tensors of equal shape
|
||||
using a 1D Hann/linear ramp along the T axis. ``prev_tail_5d``
|
||||
receives the descending weight; ``cur_head_5d`` receives
|
||||
``1 - w_prev``.
|
||||
|
||||
The caller is responsible for ensuring both inputs have identical
|
||||
shape and dtype/device.
|
||||
"""
|
||||
"""Blend two equal-shape 5D ``(B, C, T_overlap, H, W)`` tensors with a 1D Hann/linear T-ramp: ``prev_tail_5d`` takes the descending weight, ``cur_head_5d`` takes ``1 - w_prev`` (caller ensures matching shape/dtype/device)."""
|
||||
if prev_tail_5d.shape != cur_head_5d.shape:
|
||||
raise ValueError(
|
||||
f"_blend_overlap_region: shape mismatch "
|
||||
@ -784,20 +753,7 @@ def _blend_overlap_region(prev_tail_5d: torch.Tensor,
|
||||
|
||||
def _concat_chunks_with_overlap_blend(chunk_specs, channels: int,
|
||||
overlap_latent: int) -> torch.Tensor:
|
||||
"""Concatenate temporally-overlapping chunks back into a single
|
||||
collapsed 4D tensor, blending overlap regions with a Hann/linear
|
||||
crossfade.
|
||||
|
||||
``chunk_specs`` is a list of ``(t_start, t_end, chunk_4d)`` tuples
|
||||
in source-latent T coordinates. ``overlap_latent == 0`` is a fast
|
||||
path that delegates to plain concatenation (and produces output
|
||||
bit-identical to ``_concat_chunks_along_t`` of the same chunks).
|
||||
|
||||
The blend at each pair of adjacent chunks acts on the actual
|
||||
overlap region width ``min(prev_end - cur_start, current chunk
|
||||
length)``, which may be smaller than ``overlap_latent`` when the
|
||||
final chunk is a runt shorter than the configured overlap.
|
||||
"""
|
||||
"""Concatenate overlapping ``(t_start, t_end, chunk_4d)`` specs (source-latent T coords) into one collapsed 4D tensor, Hann/linear-blending overlaps; ``overlap_latent == 0`` fast-paths to plain concat (bit-identical to ``_concat_chunks_along_t``). Each blend uses the actual width ``min(prev_end - cur_start, chunk length)``, smaller than ``overlap_latent`` for a runt final chunk."""
|
||||
if len(chunk_specs) == 0:
|
||||
raise ValueError("_concat_chunks_with_overlap_blend: empty chunk list.")
|
||||
if overlap_latent < 0:
|
||||
@ -877,12 +833,7 @@ def _run_standard_sample(model, seed: int, steps: int, cfg: float,
|
||||
sampler_name: str, scheduler: str,
|
||||
positive, negative, latent_image: dict,
|
||||
denoise: float) -> dict:
|
||||
"""Single-shot delegation that mirrors the standard ``common_ksampler``
|
||||
flow (``nodes.py:common_ksampler``): generate noise from seed, run
|
||||
``comfy.sample.sample``, return a latent dict. Used by the
|
||||
ProgressiveSampler short-circuit when the full sequence fits in one
|
||||
chunk so chunking introduces no overhead for small videos.
|
||||
"""
|
||||
"""Single-shot mirror of ``nodes.py:common_ksampler`` (seed -> noise, ``comfy.sample.sample``, latent dict); used by the ProgressiveSampler short-circuit when the whole sequence fits one chunk."""
|
||||
samples_in = latent_image["samples"]
|
||||
samples_in = comfy.sample.fix_empty_latent_channels(
|
||||
model, samples_in, latent_image.get("downscale_ratio_spacial", None),
|
||||
@ -929,43 +880,45 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SeedVR2ProgressiveSampler",
|
||||
display_name="Sample SeedVR2 (Progressive)",
|
||||
category="sampling",
|
||||
description="Sample a SeedVR2 latent in sequential temporal chunks to allow longer videos to fit into VRAM via frame blending the resulting upscaled latents.",
|
||||
inputs=[
|
||||
io.Model.Input("model"),
|
||||
io.Model.Input("model", tooltip="The model used for denoising the input latent."),
|
||||
io.Int.Input("seed", default=0, min=0,
|
||||
max=0xffffffffffffffff,
|
||||
control_after_generate=True),
|
||||
io.Int.Input("steps", default=20, min=1, max=10000),
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise."),
|
||||
io.Int.Input("steps", default=20, min=1, max=10000,
|
||||
tooltip="The number of steps used in the denoising process."),
|
||||
io.Float.Input("cfg", default=1.0, min=0.0, max=100.0,
|
||||
step=0.1, round=0.01),
|
||||
step=0.1, round=0.01,
|
||||
tooltip="The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."),
|
||||
io.Combo.Input("sampler_name",
|
||||
options=comfy.samplers.SAMPLER_NAMES),
|
||||
options=comfy.samplers.SAMPLER_NAMES,
|
||||
tooltip="The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."),
|
||||
io.Combo.Input("scheduler",
|
||||
options=comfy.samplers.SCHEDULER_NAMES),
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Latent.Input("latent_image"),
|
||||
options=comfy.samplers.SCHEDULER_NAMES,
|
||||
tooltip="The scheduler controls how noise is gradually removed to form the image."),
|
||||
io.Conditioning.Input("positive",
|
||||
tooltip="The conditioning describing the attributes you want to include in the image."),
|
||||
io.Conditioning.Input("negative",
|
||||
tooltip="The conditioning describing the attributes you want to exclude from the image."),
|
||||
io.Latent.Input("latent_image",
|
||||
tooltip="The latent image to denoise."),
|
||||
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0,
|
||||
step=0.01),
|
||||
step=0.01,
|
||||
tooltip="The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."),
|
||||
io.Int.Input("frames_per_chunk", default=21, min=1,
|
||||
max=16384, step=4),
|
||||
max=16384, step=4,
|
||||
tooltip="Pixel frames per temporal chunk (4n+1: 1, 5, 9, 13, ...)."),
|
||||
io.Int.Input("temporal_overlap", default=0, min=0,
|
||||
max=16384,
|
||||
tooltip="Latent-frame overlap between "
|
||||
"adjacent chunks; blended with a "
|
||||
"Hann window (linear for overlap "
|
||||
"< 3). 0 = no blend, pure concat. "
|
||||
"Values >= the chunk's latent-frame "
|
||||
"length use the maximum valid "
|
||||
"overlap; 1 latent frame corresponds "
|
||||
"to ~4 pixel frames."),
|
||||
tooltip="Latent frames blended between adjacent chunks to hide the seam; 0 = no blend."),
|
||||
io.Combo.Input("chunking_mode",
|
||||
options=["manual", "auto"],
|
||||
default="manual",
|
||||
tooltip="manual = use frames_per_chunk "
|
||||
"exactly; auto = retry only real OOM "
|
||||
"failures with progressively smaller "
|
||||
"temporal chunks."),
|
||||
tooltip="manual = use frames_per_chunk exactly; auto = shrink the chunk until it fits in VRAM."),
|
||||
],
|
||||
outputs=[io.Latent.Output()],
|
||||
)
|
||||
@ -999,14 +952,14 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
|
||||
f"(B, 16*T, H, W); got shape {tuple(samples_4d.shape)}."
|
||||
)
|
||||
B, CT, H, W = samples_4d.shape
|
||||
if CT % _SEEDVR2_LATENT_CHANNELS != 0:
|
||||
if CT % SEEDVR2_LATENT_CHANNELS != 0:
|
||||
raise ValueError(
|
||||
f"SeedVR2ProgressiveSampler: collapsed channel dim {CT} is "
|
||||
f"not divisible by SeedVR2 latent channels "
|
||||
f"{_SEEDVR2_LATENT_CHANNELS}; latent does not appear to be "
|
||||
f"{SEEDVR2_LATENT_CHANNELS}; latent does not appear to be "
|
||||
f"SeedVR2-shaped."
|
||||
)
|
||||
T_latent = CT // _SEEDVR2_LATENT_CHANNELS
|
||||
T_latent = CT // SEEDVR2_LATENT_CHANNELS
|
||||
T_pixel = 4 * (T_latent - 1) + 1
|
||||
|
||||
if chunking_mode not in ("manual", "auto"):
|
||||
@ -1106,11 +1059,11 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
|
||||
def _sample_one_chunk(chunk_start, chunk_end):
|
||||
samples_chunk = _slice_collapsed_4d_along_t(
|
||||
samples_4d, chunk_start, chunk_end,
|
||||
_SEEDVR2_LATENT_CHANNELS,
|
||||
SEEDVR2_LATENT_CHANNELS,
|
||||
)
|
||||
noise_chunk = _slice_collapsed_4d_along_t(
|
||||
noise_full, chunk_start, chunk_end,
|
||||
_SEEDVR2_LATENT_CHANNELS,
|
||||
SEEDVR2_LATENT_CHANNELS,
|
||||
)
|
||||
positive_chunk = _slice_seedvr2_cond_along_t(
|
||||
positive, chunk_start, chunk_end,
|
||||
@ -1140,7 +1093,7 @@ class SeedVR2ProgressiveSampler(io.ComfyNode):
|
||||
chunk_specs.append((chunk_start, chunk_end, chunk_samples))
|
||||
|
||||
final = _concat_chunks_with_overlap_blend(
|
||||
chunk_specs, _SEEDVR2_LATENT_CHANNELS, temporal_overlap,
|
||||
chunk_specs, SEEDVR2_LATENT_CHANNELS, temporal_overlap,
|
||||
)
|
||||
|
||||
out = latent_image.copy()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user