Refactors and cleanups.

This commit is contained in:
comfyanonymous 2026-07-02 22:59:38 -04:00
parent 77d42ed7e9
commit c7b2c3b569
19 changed files with 436 additions and 729 deletions

View File

@ -781,6 +781,7 @@ class ACEAudio(LatentFormat):
class SeedVR2(LatentFormat): class SeedVR2(LatentFormat):
latent_channels = 16 latent_channels = 16
latent_dimensions = 3
class ACEAudio15(LatentFormat): class ACEAudio15(LatentFormat):
latent_channels = 64 latent_channels = 64

View File

@ -22,33 +22,14 @@ def _var_attention_output(out, heads, head_dim, skip_output_reshape):
return out.reshape(-1, heads * head_dim) return out.reshape(-1, heads * head_dim)
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
if cu_seqlens.dtype not in (torch.int32, torch.int64):
raise ValueError(f"{name} must use an integer dtype")
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
if cu_seqlens[0].item() != 0:
raise ValueError(f"{name} must start at 0")
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
raise ValueError(f"{name} must be strictly increasing")
if cu_seqlens[-1].item() != token_count:
raise ValueError(f"{name} does not match token count")
def _split_indices(cu_seqlens):
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0]) q_split_indices = cu_seqlens_q[1:-1]
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0]) k_split_indices = cu_seqlens_k[1:-1]
if cu_seqlens_k[-1].item() != v.shape[0]: if k.shape[0] != v.shape[0]:
raise ValueError("cu_seqlens_k does not match v token count") raise ValueError("cu_seqlens_k does not match v token count")
q_split_indices = _split_indices(cu_seqlens_q)
k_split_indices = _split_indices(cu_seqlens_k)
q_splits = torch.tensor_split(q, q_split_indices, dim=0) q_splits = torch.tensor_split(q, q_split_indices, dim=0)
k_splits = torch.tensor_split(k, k_split_indices, dim=0) k_splits = torch.tensor_split(k, k_split_indices, dim=0)
v_splits = torch.tensor_split(v, k_split_indices, dim=0) v_splits = torch.tensor_split(v, k_split_indices, dim=0)

View File

@ -45,7 +45,6 @@ def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS):
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
if content_feat.shape != style_feat.shape: if content_feat.shape != style_feat.shape:
# Resize style to match content spatial dimensions
if len(content_feat.shape) >= 3: if len(content_feat.shape) >= 3:
style_feat = F.interpolate( style_feat = F.interpolate(
style_feat, style_feat,
@ -54,12 +53,11 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
align_corners=False align_corners=False
) )
# Decompose both features into frequency components
content_high_freq, content_low_freq = wavelet_decomposition(content_feat) content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq # Free memory immediately del content_low_freq
style_high_freq, style_low_freq = wavelet_decomposition(style_feat) style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq # Free memory immediately del style_high_freq
if content_high_freq.shape != style_low_freq.shape: if content_high_freq.shape != style_low_freq.shape:
style_low_freq = F.interpolate( style_low_freq = F.interpolate(
@ -73,27 +71,23 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
return content_high_freq.clamp_(-1.0, 1.0) return content_high_freq.clamp_(-1.0, 1.0)
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor: def _histogram_matching_channel(source: Tensor, reference: Tensor) -> Tensor:
original_shape = source.shape original_shape = source.shape
# Flatten
source_flat = source.flatten() source_flat = source.flatten()
reference_flat = reference.flatten() reference_flat = reference.flatten()
# Sort both arrays
source_sorted, source_indices = torch.sort(source_flat) source_sorted, source_indices = torch.sort(source_flat)
reference_sorted, _ = torch.sort(reference_flat) reference_sorted, _ = torch.sort(reference_flat)
del reference_flat del reference_flat
# Quantile mapping
n_source = len(source_sorted) n_source = len(source_sorted)
n_reference = len(reference_sorted) n_reference = len(reference_sorted)
if n_source == n_reference: if n_source == n_reference:
matched_sorted = reference_sorted matched_sorted = reference_sorted
else: else:
# Interpolate reference to match source quantiles source_quantiles = torch.linspace(0, 1, n_source, device=source.device)
source_quantiles = torch.linspace(0, 1, n_source, device=device)
ref_indices = (source_quantiles * (n_reference - 1)).long() ref_indices = (source_quantiles * (n_reference - 1)).long()
ref_indices.clamp_(0, n_reference - 1) ref_indices.clamp_(0, n_reference - 1)
matched_sorted = reference_sorted[ref_indices] matched_sorted = reference_sorted[ref_indices]
@ -101,7 +95,6 @@ def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch
del source_sorted, source_flat del source_sorted, source_flat
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
inverse_indices = torch.argsort(source_indices) inverse_indices = torch.argsort(source_indices)
del source_indices del source_indices
matched_flat = matched_sorted[inverse_indices] matched_flat = matched_sorted[inverse_indices]
@ -109,17 +102,14 @@ def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch
return matched_flat.reshape(original_shape) return matched_flat.reshape(original_shape)
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor: def _lab_to_rgb_batch(lab: Tensor, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
"""Convert batch of CIELAB images to RGB color space."""
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2] L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
# LAB to XYZ
fy = (L + 16.0) / 116.0 fy = (L + 16.0) / 116.0
fx = a.div(500.0).add_(fy) fx = a.div(500.0).add_(fy)
fz = fy - b / 200.0 fz = fy - b / 200.0
del L, a, b del L, a, b
# XYZ transformation
x = torch.where( x = torch.where(
fx > epsilon, fx > epsilon,
torch.pow(fx, 3.0), torch.pow(fx, 3.0),
@ -137,20 +127,16 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps
) )
del fx, fy, fz del fx, fy, fz
# Apply D65 white point (in-place)
x.mul_(D65_WHITE_X) x.mul_(D65_WHITE_X)
# y *= 1.00000 # (no-op, skip)
z.mul_(D65_WHITE_Z) z.mul_(D65_WHITE_Z)
xyz = torch.stack([x, y, z], dim=1) xyz = torch.stack([x, y, z], dim=1)
del x, y, z del x, y, z
# Matrix multiplication: XYZ -> RGB B, _, H, W = xyz.shape
B, C, H, W = xyz.shape
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3) xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
del xyz del xyz
# Ensure dtype consistency for matrix multiplication
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype) xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T) rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
del xyz_flat del xyz_flat
@ -158,7 +144,6 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
del rgb_linear_flat del rgb_linear_flat
# Apply inverse gamma correction (delinearize)
mask = rgb_linear > 0.0031308 mask = rgb_linear > 0.0031308
rgb = torch.where( rgb = torch.where(
mask, mask,
@ -169,9 +154,7 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps
return torch.clamp(rgb, 0.0, 1.0) return torch.clamp(rgb, 0.0, 1.0)
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor: def _rgb_to_lab_batch(rgb: Tensor, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
"""Convert batch of RGB images to CIELAB color space using D65 illuminant."""
# Apply sRGB gamma correction (linearize)
mask = rgb > 0.04045 mask = rgb > 0.04045
rgb_linear = torch.where( rgb_linear = torch.where(
mask, mask,
@ -180,12 +163,10 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon
) )
del mask del mask
# Matrix multiplication: RGB -> XYZ B, _, H, W = rgb_linear.shape
B, C, H, W = rgb_linear.shape
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3) rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
del rgb_linear del rgb_linear
# Ensure dtype consistency for matrix multiplication
rgb_flat = rgb_flat.to(dtype=matrix.dtype) rgb_flat = rgb_flat.to(dtype=matrix.dtype)
xyz_flat = torch.matmul(rgb_flat, matrix.T) xyz_flat = torch.matmul(rgb_flat, matrix.T)
del rgb_flat del rgb_flat
@ -193,12 +174,9 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
del xyz_flat del xyz_flat
# Normalize by D65 white point (in-place) xyz[:, 0].div_(D65_WHITE_X)
xyz[:, 0].div_(D65_WHITE_X) # X xyz[:, 2].div_(D65_WHITE_Z)
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
xyz[:, 2].div_(D65_WHITE_Z) # Z
# XYZ to LAB transformation
epsilon_cubed = epsilon ** 3 epsilon_cubed = epsilon ** 3
mask = xyz > epsilon_cubed mask = xyz > epsilon_cubed
f_xyz = torch.where( f_xyz = torch.where(
@ -208,10 +186,9 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon
) )
del xyz, mask del xyz, mask
# Extract channels and compute LAB L = f_xyz[:, 1].mul(116.0).sub_(16.0)
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100] a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0)
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127] b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0)
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
del f_xyz del f_xyz
return torch.stack([L, a, b], dim=1) return torch.stack([L, a, b], dim=1)
@ -232,13 +209,9 @@ def lab_color_transfer(
) )
device = content_feat.device device = content_feat.device
original_dtype = content_feat.dtype
def ensure_float32_precision(c): content_feat = content_feat.float()
orig_dtype = c.dtype style_feat = style_feat.float()
c = c.float()
return c, orig_dtype
content_feat, original_dtype = ensure_float32_precision(content_feat)
style_feat, _ = ensure_float32_precision(style_feat)
rgb_to_xyz_matrix = torch.tensor([ rgb_to_xyz_matrix = torch.tensor([
[0.4124564, 0.3575761, 0.1804375], [0.4124564, 0.3575761, 0.1804375],
@ -258,39 +231,30 @@ def lab_color_transfer(
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
# Convert to LAB color space content_lab = _rgb_to_lab_batch(content_feat, rgb_to_xyz_matrix, epsilon, kappa)
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
del content_feat del content_feat
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa) style_lab = _rgb_to_lab_batch(style_feat, rgb_to_xyz_matrix, epsilon, kappa)
del style_feat, rgb_to_xyz_matrix del style_feat, rgb_to_xyz_matrix
# Match chrominance channels (a*, b*) for accurate color transfer matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1])
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device) matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2])
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
# Handle luminance with weighted blending
if luminance_weight < 1.0: if luminance_weight < 1.0:
# Partially match luminance for better overall color accuracy matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0])
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device)
# Blend: preserve some content L* for detail, adopt some style L* for color
result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight)) result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight))
del matched_L del matched_L
else: else:
# Fully preserve content luminance
result_L = content_lab[:, 0] result_L = content_lab[:, 0]
del content_lab, style_lab del content_lab, style_lab
# Reconstruct LAB with corrected channels
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1) result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
del result_L, matched_a, matched_b del result_L, matched_a, matched_b
# Convert back to RGB result_rgb = _lab_to_rgb_batch(result_lab, xyz_to_rgb_matrix, epsilon, kappa)
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
del result_lab, xyz_to_rgb_matrix del result_lab, xyz_to_rgb_matrix
# Convert back to [-1, 1] range (in-place)
result = result_rgb.mul_(2.0).sub_(1.0) result = result_rgb.mul_(2.0).sub_(1.0)
del result_rgb del result_rgb

View File

@ -1,34 +1,21 @@
"""Named constants for the SeedVR2 integration, grouped by provenance. """SeedVR2 constants."""
Provenance prefixes: SEEDVR2_7B_VID_DIM = 3072
- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline. SEEDVR2_OOM_BACKOFF_DIVISOR = 2
- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites SEEDVR2_DTYPE_BYTES_FLOOR = 4
the upstream config/source path it was lifted from. SEEDVR2_7B_MLP_CHUNK = 8192
- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature /
ISO / CIE values; cite the standard.
"""
SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim.
# (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.)
SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # OOM retry backoff: halve the chunk and retry.
SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case).
SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM.
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk. SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk.
SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels). SEEDVR2_LATENT_CHANNELS = 16
# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing) SEEDVR2_COLOR_MEM_HEADROOM = 0.75
SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk. SEEDVR2_LAB_SCALE_MULTIPLIER = 13
SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path.
SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path. SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path.
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path. SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6
# -------------------------------------------------------------------------------------- BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57.
# ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR) BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0
# -------------------------------------------------------------------------------------- BYTEDANCE_VAE_CONV_MEM_GIB = 0.5
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm. BYTEDANCE_VAE_NORM_MEM_GIB = 0.5
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift.
BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem).
BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem).
BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28. BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28.
BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28. BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28.
BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16). BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16).
@ -42,18 +29,10 @@ BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal wind
BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency). BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency).
BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim). BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim).
# --------------------------------------------------------------------------------------
# Published standards (cite the literature)
# --------------------------------------------------------------------------------------
ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864. ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864.
# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65).
CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta). CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta).
CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa). CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa).
D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1). D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1).
D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn. D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn.
WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR). WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR).
# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and
# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the
# exact existing coefficients move verbatim rather than being retyped here.

View File

@ -3,7 +3,7 @@ from typing import Optional, Tuple, Union, List, Dict, Any, Callable
import torch.nn.functional as F import torch.nn.functional as F
from math import ceil, pi from math import ceil, pi
import torch import torch
from itertools import chain from itertools import accumulate, chain
from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding
from comfy.ldm.seedvr.attention import optimized_var_attention from comfy.ldm.seedvr.attention import optimized_var_attention
from torch.nn.modules.utils import _triple from torch.nn.modules.utils import _triple
@ -18,6 +18,7 @@ from comfy.ldm.seedvr.constants import (
ROPE_THETA, ROPE_THETA,
SEEDVR2_7B_MLP_CHUNK, SEEDVR2_7B_MLP_CHUNK,
SEEDVR2_7B_VID_DIM, SEEDVR2_7B_VID_DIM,
SEEDVR2_LATENT_CHANNELS,
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS,
) )
import comfy.model_management import comfy.model_management
@ -70,7 +71,7 @@ def repeat_concat_idx(
vid_idx = torch.arange(vid_len.sum(), device=device) vid_idx = torch.arange(vid_len.sum(), device=device)
txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device)
txt_repeat_list = txt_repeat.tolist() txt_repeat_list = txt_repeat.tolist()
tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat_list)
src_idx = torch.argsort(tgt_idx) src_idx = torch.argsort(tgt_idx)
txt_idx_len = len(tgt_idx) - len(vid_idx) txt_idx_len = len(tgt_idx) - len(vid_idx)
repeat_txt_len = (txt_len * txt_repeat).tolist() repeat_txt_len = (txt_len * txt_repeat).tolist()
@ -88,6 +89,9 @@ def repeat_concat_idx(
lambda all: unconcat_coalesce(all), lambda all: unconcat_coalesce(all),
) )
def cumulative_lengths(lengths):
return [0, *accumulate(lengths)]
@dataclass @dataclass
class MMArg: class MMArg:
@ -110,16 +114,14 @@ def get_window_op(name: str):
raise ValueError(f"Unknown windowing method: {name}") raise ValueError(f"Unknown windowing method: {name}")
# -------------------------------- Windowing -------------------------------- #
def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]):
t, h, w = size t, h, w = size
resized_nt, resized_nh, resized_nw = num_windows resized_nt, resized_nh, resized_nw = num_windows
#cal windows under 720p
scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w))
resized_h, resized_w = round(h * scale), round(w * scale) resized_h, resized_w = round(h * scale), round(w * scale)
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw)
wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt)
nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size. nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww)
return [ return [
( (
slice(it * wt, min((it + 1) * wt, t)), slice(it * wt, min((it + 1) * wt, t)),
@ -137,19 +139,18 @@ def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int,
def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]):
t, h, w = size t, h, w = size
resized_nt, resized_nh, resized_nw = num_windows resized_nt, resized_nh, resized_nw = num_windows
#cal windows under 720p
scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w))
resized_h, resized_w = round(h * scale), round(w * scale) resized_h, resized_w = round(h * scale), round(w * scale)
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw)
wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt)
st, sh, sw = ( # shift size. st, sh, sw = (
0.5 if wt < t else 0, 0.5 if wt < t else 0,
0.5 if wh < h else 0, 0.5 if wh < h else 0,
0.5 if ww < w else 0, 0.5 if ww < w else 0,
) )
nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size. nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww)
nt, nh, nw = ( # number of window. nt, nh, nw = (
nt + 1 if st > 0 else 1, nt + 1 if st > 0 else 1,
nh + 1 if sh > 0 else 1, nh + 1 if sh > 0 else 1,
nw + 1 if sw > 0 else 1, nw + 1 if sw > 0 else 1,
@ -175,7 +176,6 @@ class RotaryEmbedding(nn.Module):
freqs_for = 'lang', freqs_for = 'lang',
theta = 10000, theta = 10000,
max_freq = 10, max_freq = 10,
learned_freq = False,
): ):
super().__init__() super().__init__()
@ -185,18 +185,14 @@ class RotaryEmbedding(nn.Module):
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
elif freqs_for == 'pixel': elif freqs_for == 'pixel':
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
else:
raise ValueError(f"Unknown rotary frequency type: {freqs_for}")
self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) self.register_buffer("freqs", freqs)
self.learned_freq = learned_freq
# dummy for device
self.register_buffer('dummy', torch.tensor(0), persistent = False)
@property @property
def device(self): def device(self):
return self.dummy.device return self.freqs.device
def get_axial_freqs( def get_axial_freqs(
self, self,
@ -206,10 +202,9 @@ class RotaryEmbedding(nn.Module):
Colon = slice(None) Colon = slice(None)
all_freqs = [] all_freqs = []
# handle offset
if exists(offsets): if exists(offsets):
assert len(offsets) == len(dims) if len(offsets) != len(dims):
raise ValueError(f"SeedVR2 rotary offsets length must match dims length, got {len(offsets)} and {len(dims)}.")
for ind, dim in enumerate(dims): for ind, dim in enumerate(dims):
@ -224,7 +219,7 @@ class RotaryEmbedding(nn.Module):
pos = pos + offset pos = pos + offset
freqs = self.forward(pos, seq_len = dim) freqs = self.forward(pos)
all_axis = [None] * len(dims) all_axis = [None] * len(dims)
all_axis[ind] = Colon all_axis[ind] = Colon
@ -232,16 +227,12 @@ class RotaryEmbedding(nn.Module):
new_axis_slice = (Ellipsis, *all_axis, Colon) new_axis_slice = (Ellipsis, *all_axis, Colon)
all_freqs.append(freqs[new_axis_slice]) all_freqs.append(freqs[new_axis_slice])
# concat all freqs
all_freqs = torch.broadcast_tensors(*all_freqs) all_freqs = torch.broadcast_tensors(*all_freqs)
return torch.cat(all_freqs, dim = -1) return torch.cat(all_freqs, dim = -1)
def forward( def forward(
self, self,
t, t,
seq_len: int | None = None,
offset = 0
): ):
freqs = self.freqs freqs = self.freqs
@ -258,9 +249,6 @@ class RotaryEmbeddingBase(nn.Module):
freqs_for="pixel", freqs_for="pixel",
max_freq=BYTEDANCE_ROPE_MAX_FREQ, max_freq=BYTEDANCE_ROPE_MAX_FREQ,
) )
freqs = self.rope.freqs
del self.rope.freqs
self.rope.register_buffer("freqs", freqs.detach())
def get_axial_freqs(self, *dims): def get_axial_freqs(self, *dims):
return self.rope.get_axial_freqs(*dims) return self.rope.get_axial_freqs(*dims)
@ -306,7 +294,7 @@ class NaRotaryEmbedding3d(RotaryEmbedding3d):
freqs_for="pixel", freqs_for="pixel",
max_freq=BYTEDANCE_ROPE_MAX_FREQ, max_freq=BYTEDANCE_ROPE_MAX_FREQ,
) )
plain_rope = plain_rope.to(self.rope.dummy.device) plain_rope = plain_rope.to(self.rope.device)
freq_list = [] freq_list = []
for f, h, w in shape.tolist(): for f, h, w in shape.tolist():
freqs = plain_rope.get_axial_freqs(f, h, w) freqs = plain_rope.get_axial_freqs(f, h, w)
@ -322,9 +310,6 @@ class MMRotaryEmbeddingBase(RotaryEmbeddingBase):
freqs_for="lang", freqs_for="lang",
theta=ROPE_THETA, theta=ROPE_THETA,
) )
freqs = self.rope.freqs
del self.rope.freqs
self.rope.register_buffer("freqs", freqs.detach())
self.mm = True self.mm = True
def slice_at_dim(t, dim_slice: slice, *, dim): def slice_at_dim(t, dim_slice: slice, *, dim):
@ -333,8 +318,6 @@ def slice_at_dim(t, dim_slice: slice, *, dim):
colons[dim] = dim_slice colons[dim] = dim_slice
return t[tuple(colons)] return t[tuple(colons)]
# rotary embedding helper functions
def rotate_half(x): def rotate_half(x):
x = x.reshape(*x.shape[:-1], x.shape[-1] // 2, 2) x = x.reshape(*x.shape[:-1], x.shape[-1] // 2, 2)
x1, x2 = x.unbind(dim = -1) x1, x2 = x.unbind(dim = -1)
@ -373,7 +356,6 @@ def _apply_seedvr2_rotary_emb(
return torch.cat((t_left, t_middle, t_right), dim=-1).to(dtype) return torch.cat((t_left, t_middle, t_right), dim=-1).to(dtype)
def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor: def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor:
"""Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`."""
angles = freqs_interleaved[..., ::2].float() angles = freqs_interleaved[..., ::2].float()
cos = torch.cos(angles) cos = torch.cos(angles)
sin = torch.sin(angles) sin = torch.sin(angles)
@ -382,12 +364,6 @@ def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor:
def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""Rotate the leading ``rot_d = 2 * freqs_cis.shape[-3]`` dims of ``t`` and pass the rest
through; in-place for inference, cloned for training (autograd). Mirrors the legacy
``apply_rotary_emb`` ``t_left``/``t_middle``/``t_right`` split: 3B ``rope_dim=128`` gives
``42*3 = 126`` rotated of head_dim 128 (trailing 2 unrotated). Fast path skips the cat when
``rot_d == t.shape[-1]``.
"""
out = t.clone() if t.requires_grad or comfy.model_management.in_training else t out = t.clone() if t.requires_grad or comfy.model_management.in_training else t
rot_d = 2 * freqs_cis.shape[-3] rot_d = 2 * freqs_cis.shape[-3]
seq_len = out.shape[-2] seq_len = out.shape[-2]
@ -454,14 +430,13 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
torch.Tensor, torch.Tensor,
]: ]:
# Calculate actual max dimensions needed for this batch
max_temporal = 0 max_temporal = 0
max_height = 0 max_height = 0
max_width = 0 max_width = 0
max_txt_len = 0 max_txt_len = 0
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal max_temporal = max(max_temporal, l + f)
max_height = max(max_height, h) max_height = max(max_height, h)
max_width = max(max_width, w) max_width = max(max_width, w)
max_txt_len = max(max_txt_len, l) max_txt_len = max(max_txt_len, l)
@ -475,7 +450,6 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
).float() ).float()
txt_freqs = self.get_axial_freqs(max_txt_len + 16) txt_freqs = self.get_axial_freqs(max_txt_len + 16)
# Now slice as before
vid_freq_list, txt_freq_list = [], [] vid_freq_list, txt_freq_list = [], []
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1))
@ -485,13 +459,6 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
vid_freqs_interleaved = torch.cat(vid_freq_list, dim=0) vid_freqs_interleaved = torch.cat(vid_freq_list, dim=0)
txt_freqs_interleaved = torch.cat(txt_freq_list, dim=0) txt_freqs_interleaved = torch.cat(txt_freq_list, dim=0)
# Convert from lucidrains-interleaved layout `[θ0, θ0, θ1, θ1, ...]`
# (produced by `repeat(freqs, '... n -> ... (n r)', r=2)` in the
# upstream `RotaryEmbedding.forward`) to flux-canonical `freqs_cis`
# in shape `[..., d/2, 2, 2]` with `cos/-sin/sin/cos` baked in.
# Mirrors `comfy/ldm/flux/math.py:rope` (line 27) so the trailing
# 2x2 is the per-frequency rotation matrix that
# `comfy.ldm.flux.math.apply_rope1` expects.
return _to_flux_freqs_cis(vid_freqs_interleaved), _to_flux_freqs_cis(txt_freqs_interleaved) return _to_flux_freqs_cis(vid_freqs_interleaved), _to_flux_freqs_cis(txt_freqs_interleaved)
class MMModule(nn.Module): class MMModule(nn.Module):
@ -507,8 +474,10 @@ class MMModule(nn.Module):
self.shared_weights = shared_weights self.shared_weights = shared_weights
self.vid_only = vid_only self.vid_only = vid_only
if self.shared_weights: if self.shared_weights:
assert get_args("vid", args) == get_args("txt", args) if get_args("vid", args) != get_args("txt", args):
assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) raise ValueError("SeedVR2 shared MMModule requires matching vid/txt args.")
if get_kwargs("vid", kwargs) != get_kwargs("txt", kwargs):
raise ValueError("SeedVR2 shared MMModule requires matching vid/txt kwargs.")
self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs))
else: else:
self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs))
@ -543,6 +512,7 @@ def get_na_rope(rope_type: Optional[str], dim: int):
return NaRotaryEmbedding3d(dim=dim) return NaRotaryEmbedding3d(dim=dim)
if rope_type == "mmrope3d": if rope_type == "mmrope3d":
return NaMMRotaryEmbedding3d(dim=dim) return NaMMRotaryEmbedding3d(dim=dim)
raise ValueError(f"Unknown SeedVR2 rope type: {rope_type}")
class NaMMAttention(nn.Module): class NaMMAttention(nn.Module):
def __init__( def __init__(
@ -558,7 +528,6 @@ class NaMMAttention(nn.Module):
rope_dim: int, rope_dim: int,
shared_weights: bool, shared_weights: bool,
device, dtype, operations, device, dtype, operations,
**kwargs,
): ):
super().__init__() super().__init__()
dim = MMArg(vid_dim, txt_dim) dim = MMArg(vid_dim, txt_dim)
@ -597,16 +566,19 @@ def window(
): ):
hid = unflatten(hid, hid_shape) hid = unflatten(hid, hid_shape)
hid = list(map(window_fn, hid)) hid = list(map(window_fn, hid))
hid_windows = torch.as_tensor([len(x) for x in hid], device=hid_shape.device) hid_windows_list = [len(x) for x in hid]
hid, hid_shape = flatten(list(chain(*hid))) hid_windows = torch.as_tensor(hid_windows_list, device=hid_shape.device)
return hid, hid_shape, hid_windows hid = list(chain(*hid))
hid_len_list = [math.prod(x.shape[:-1]) for x in hid]
hid, hid_shape = flatten(hid)
return hid, hid_shape, hid_windows, hid_len_list, hid_windows_list
def window_idx( def window_idx(
hid_shape: torch.LongTensor, # (b n) hid_shape: torch.LongTensor, # (b n)
window_fn: Callable[[torch.Tensor], List[torch.Tensor]], window_fn: Callable[[torch.Tensor], List[torch.Tensor]],
): ):
hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1)
tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) tgt_idx, tgt_shape, tgt_windows, tgt_len_list, tgt_windows_list = window(hid_idx, hid_shape, window_fn)
tgt_idx = tgt_idx.squeeze(-1) tgt_idx = tgt_idx.squeeze(-1)
src_idx = torch.argsort(tgt_idx) src_idx = torch.argsort(tgt_idx)
return ( return (
@ -614,6 +586,8 @@ def window_idx(
lambda hid: torch.index_select(hid, 0, src_idx), lambda hid: torch.index_select(hid, 0, src_idx),
tgt_shape, tgt_shape,
tgt_windows, tgt_windows,
tgt_len_list,
tgt_windows_list,
) )
class NaSwinAttention(NaMMAttention): class NaSwinAttention(NaMMAttention):
@ -622,13 +596,15 @@ class NaSwinAttention(NaMMAttention):
*args, *args,
window: Union[int, Tuple[int, int, int]], window: Union[int, Tuple[int, int, int]],
window_method: str, window_method: str,
version: bool = False,
**kwargs, **kwargs,
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.version_7b = kwargs.get("version", False) self.version_7b = version
self.window = _triple(window) self.window = _triple(window)
self.window_method = window_method self.window_method = window_method
assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) if not all(isinstance(v, int) and v >= 0 for v in self.window):
raise ValueError(f"SeedVR2 window must contain non-negative integers, got {self.window}.")
self.window_op = get_window_op(window_method) self.window_op = get_window_op(window_method)
@ -646,7 +622,6 @@ class NaSwinAttention(NaMMAttention):
vid_qkv, txt_qkv = self.proj_qkv(vid, txt) vid_qkv, txt_qkv = self.proj_qkv(vid, txt)
# re-org the input seq for window attn
cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3") cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3")
def make_window(x: torch.Tensor): def make_window(x: torch.Tensor):
@ -654,7 +629,7 @@ class NaSwinAttention(NaMMAttention):
window_slices = self.window_op((t, h, w), self.window) window_slices = self.window_op((t, h, w), self.window)
return [x[st, sh, sw] for (st, sh, sw) in window_slices] return [x[st, sh, sw] for (st, sh, sw) in window_slices]
window_partition, window_reverse, window_shape, window_count = cache_win( window_partition, window_reverse, window_shape, window_count, vid_len_win_list, window_count_list = cache_win(
"win_transform", "win_transform",
lambda: window_idx(vid_shape, make_window), lambda: window_idx(vid_shape, make_window),
) )
@ -674,23 +649,21 @@ class NaSwinAttention(NaMMAttention):
vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1))
txt_len = txt_len.to(window_count.device) txt_len = txt_len.to(window_count.device)
# window rope
if self.rope: if self.rope:
if self.version_7b: if self.version_7b:
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
elif self.rope.mm: elif self.rope.mm:
# repeat text q and k for window mmrope
_, num_h, _ = txt_q.shape _, num_h, _ = txt_q.shape
txt_q_repeat = txt_q.flatten(1, 2) txt_q_repeat = txt_q.flatten(1, 2)
txt_q_repeat = unflatten(txt_q_repeat, txt_shape) txt_q_repeat = unflatten(txt_q_repeat, txt_shape)
txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)] txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count_list)]
txt_q_repeat = list(chain(*txt_q_repeat)) txt_q_repeat = list(chain(*txt_q_repeat))
txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat) txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat)
txt_q_repeat = txt_q_repeat.reshape(txt_q_repeat.shape[0], num_h, self.head_dim) txt_q_repeat = txt_q_repeat.reshape(txt_q_repeat.shape[0], num_h, self.head_dim)
txt_k_repeat = txt_k.flatten(1, 2) txt_k_repeat = txt_k.flatten(1, 2)
txt_k_repeat = unflatten(txt_k_repeat, txt_shape) txt_k_repeat = unflatten(txt_k_repeat, txt_shape)
txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)] txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count_list)]
txt_k_repeat = list(chain(*txt_k_repeat)) txt_k_repeat = list(chain(*txt_k_repeat))
txt_k_repeat, _ = flatten(txt_k_repeat) txt_k_repeat, _ = flatten(txt_k_repeat)
txt_k_repeat = txt_k_repeat.reshape(txt_k_repeat.shape[0], num_h, self.head_dim) txt_k_repeat = txt_k_repeat.reshape(txt_k_repeat.shape[0], num_h, self.head_dim)
@ -702,7 +675,11 @@ class NaSwinAttention(NaMMAttention):
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count))
all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) txt_len_win_list = cache_win(
"txt_len_list",
lambda: [txt_len for txt_len, window_count in zip(txt_len.tolist(), window_count_list) for _ in range(window_count)],
)
all_len_win = cache_win("all_len", lambda: [vid_len + txt_len for vid_len, txt_len in zip(vid_len_win_list, txt_len_win_list)])
concat_win, unconcat_win = cache_win( concat_win, unconcat_win = cache_win(
"mm_pnp", lambda: repeat_concat_idx(vid_len_win, txt_len, window_count) "mm_pnp", lambda: repeat_concat_idx(vid_len_win, txt_len, window_count)
) )
@ -711,12 +688,8 @@ class NaSwinAttention(NaMMAttention):
k=concat_win(vid_k, txt_k), k=concat_win(vid_k, txt_k),
v=concat_win(vid_v, txt_v), v=concat_win(vid_v, txt_v),
heads=self.heads, skip_reshape=True, skip_output_reshape=True, heads=self.heads, skip_reshape=True, skip_output_reshape=True,
cu_seqlens_q=cache_win( cu_seqlens_q=cache_win("vid_seqlens_q", lambda: cumulative_lengths(all_len_win)),
"vid_seqlens_q", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() cu_seqlens_k=cache_win("vid_seqlens_k", lambda: cumulative_lengths(all_len_win)),
),
cu_seqlens_k=cache_win(
"vid_seqlens_k", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int()
),
) )
vid_out, txt_out = unconcat_win(out) vid_out, txt_out = unconcat_win(out)
@ -766,11 +739,11 @@ class SwiGLUMLP(nn.Module):
return self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) return self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x))
def get_mlp(mlp_type: Optional[str] = "normal"): def get_mlp(mlp_type: Optional[str] = "normal"):
# 3b and 7b uses different mlp types
if mlp_type == "normal": if mlp_type == "normal":
return MLP return MLP
elif mlp_type == "swiglu": if mlp_type == "swiglu":
return SwiGLUMLP return SwiGLUMLP
raise ValueError(f"Unknown SeedVR2 MLP type: {mlp_type}")
class NaMMSRTransformerBlock(nn.Module): class NaMMSRTransformerBlock(nn.Module):
def __init__( def __init__(
@ -792,11 +765,12 @@ class NaMMSRTransformerBlock(nn.Module):
rope_type: str, rope_type: str,
rope_dim: int, rope_dim: int,
is_last_layer: bool, is_last_layer: bool,
window: Union[int, Tuple[int, int, int]],
window_method: str,
version: bool,
device, dtype, operations, device, dtype, operations,
**kwargs,
): ):
super().__init__() super().__init__()
version = kwargs.get("version", False)
dim = MMArg(vid_dim, txt_dim) dim = MMArg(vid_dim, txt_dim)
self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype) self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype)
@ -811,8 +785,8 @@ class NaMMSRTransformerBlock(nn.Module):
rope_type=rope_type, rope_type=rope_type,
rope_dim=rope_dim, rope_dim=rope_dim,
shared_weights=shared_weights, shared_weights=shared_weights,
window=kwargs.pop("window", None), window=window,
window_method=kwargs.pop("window_method", None), window_method=window_method,
version=version, version=version,
device=device, dtype=dtype, operations=operations device=device, dtype=dtype, operations=operations
) )
@ -930,12 +904,14 @@ class NaPatchOut(PatchOut):
self, self,
vid: torch.FloatTensor, # l c vid: torch.FloatTensor, # l c
vid_shape: torch.LongTensor, vid_shape: torch.LongTensor,
cache: Cache = Cache(disable=True), cache: Optional[Cache] = None,
vid_shape_before_patchify = None vid_shape_before_patchify = None
) -> Tuple[ ) -> Tuple[
torch.FloatTensor, torch.FloatTensor,
torch.LongTensor, torch.LongTensor,
]: ]:
if cache is None:
cache = Cache(disable=True)
t, h, w = self.patch_size t, h, w = self.patch_size
vid = self.proj(vid) vid = self.proj(vid)
@ -971,7 +947,10 @@ class PatchIn(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
t, h, w = self.patch_size t, h, w = self.patch_size
if t > 1: if t > 1:
assert vid.size(2) % t == 1 if vid.size(2) % t != 1:
raise ValueError(
f"SeedVR2 patch input temporal size must satisfy T % {t} == 1, got {vid.size(2)}."
)
vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2)
b, c, Tt, Hh, Ww = vid.shape b, c, Tt, Hh, Ww = vid.shape
vid = vid.view(b, c, Tt // t, t, Hh // h, h, Ww // w, w).permute(0, 2, 4, 6, 3, 5, 7, 1).reshape(b, Tt // t, Hh // h, Ww // w, t * h * w * c) vid = vid.view(b, c, Tt // t, t, Hh // h, h, Ww // w, w).permute(0, 2, 4, 6, 3, 5, 7, 1).reshape(b, Tt // t, Hh // h, Ww // w, t * h * w * c)
@ -983,8 +962,10 @@ class NaPatchIn(PatchIn):
self, self,
vid: torch.Tensor, # l c vid: torch.Tensor, # l c
vid_shape: torch.LongTensor, vid_shape: torch.LongTensor,
cache: Cache = Cache(disable=True), cache: Optional[Cache] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if cache is None:
cache = Cache(disable=True)
cache = cache.namespace("patch") cache = cache.namespace("patch")
vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape)
t, h, w = self.patch_size t, h, w = self.patch_size
@ -1012,10 +993,11 @@ class AdaSingle(nn.Module):
dim: int, dim: int,
emb_dim: int, emb_dim: int,
layers: List[str], layers: List[str],
modes: List[str] = ["in", "out"], modes: Tuple[str, ...] = ("in", "out"),
device = None, dtype = None, device = None, dtype = None,
): ):
assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" if emb_dim != 6 * dim:
raise ValueError(f"SeedVR2 AdaSingle requires emb_dim == 6 * dim, got emb_dim={emb_dim}, dim={dim}.")
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.emb_dim = emb_dim self.emb_dim = emb_dim
@ -1036,22 +1018,20 @@ class AdaSingle(nn.Module):
emb: torch.FloatTensor, # b d emb: torch.FloatTensor, # b d
layer: str, layer: str,
mode: str, mode: str,
cache: Cache = Cache(disable=True), cache: Optional[Cache] = None,
branch_tag: str = "", branch_tag: str = "",
hid_len: Optional[torch.LongTensor] = None, # b hid_len: Optional[torch.LongTensor] = None, # b
) -> torch.FloatTensor: ) -> torch.FloatTensor:
if cache is None:
cache = Cache(disable=True)
idx = self.layers.index(layer) idx = self.layers.index(layer)
emb = emb.reshape(emb.shape[0], -1, len(self.layers), 3)[:, :, idx, :] emb = emb.reshape(emb.shape[0], -1, len(self.layers), 3)[:, :, idx, :]
emb = expand_dims(emb, 1, hid.ndim + 1) emb = expand_dims(emb, 1, hid.ndim + 1)
if hid_len is not None: if hid_len is not None:
slice_inputs = lambda x, dim: x
emb = cache( emb = cache(
f"emb_repeat_{idx}_{branch_tag}", f"emb_repeat_{idx}_{branch_tag}",
lambda: slice_inputs( lambda: torch.repeat_interleave(emb, hid_len, dim=0),
torch.repeat_interleave(emb, hid_len, dim=0),
dim=0,
),
) )
shiftA, scaleA, gateA = emb.unbind(-1) shiftA, scaleA, gateA = emb.unbind(-1)
@ -1069,7 +1049,7 @@ class AdaSingle(nn.Module):
else: else:
return hid.mul_(gateA) return hid.mul_(gateA)
raise NotImplementedError raise ValueError(f"Unknown AdaSingle mode: {mode}")
class TimeEmbedding(nn.Module): class TimeEmbedding(nn.Module):
@ -1117,7 +1097,8 @@ def flatten(
torch.FloatTensor, # (L c) torch.FloatTensor, # (L c)
torch.LongTensor, # (b n) torch.LongTensor, # (b n)
]: ]:
assert len(hid) > 0 if len(hid) == 0:
raise ValueError("SeedVR2 flatten requires at least one tensor.")
shape = torch.as_tensor([x.shape[:-1] for x in hid], device=hid[0].device) shape = torch.as_tensor([x.shape[:-1] for x in hid], device=hid[0].device)
hid = torch.cat([x.flatten(0, -2) for x in hid]) hid = torch.cat([x.flatten(0, -2) for x in hid])
return hid, shape return hid, shape
@ -1140,7 +1121,7 @@ class NaDiT(nn.Module):
num_layers, num_layers,
mlp_type, mlp_type,
vid_in_channels = 33, vid_in_channels = 33,
vid_out_channels = 16, vid_out_channels = SEEDVR2_LATENT_CHANNELS,
vid_dim = 2560, vid_dim = 2560,
txt_in_dim = 5120, txt_in_dim = 5120,
heads = 20, heads = 20,
@ -1148,15 +1129,17 @@ class NaDiT(nn.Module):
mm_layers = 10, mm_layers = 10,
expand_ratio = 4, expand_ratio = 4,
qk_bias = False, qk_bias = False,
patch_size = [ 1,2,2 ], patch_size = (1, 2, 2),
rope_dim = 128, rope_dim = 128,
rope_type = "mmrope3d", rope_type = "mmrope3d",
vid_out_norm: Optional[str] = None, vid_out_norm: Optional[str] = None,
image_model = None,
device = None, device = None,
dtype = None, dtype = None,
operations = None, operations = None,
**kwargs,
): ):
if image_model not in (None, "seedvr2"):
raise ValueError(f"SeedVR2 NaDiT expected image_model='seedvr2', got {image_model!r}.")
self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM
if self._7b_version: if self._7b_version:
rope_type = "rope3d" rope_type = "rope3d"
@ -1212,14 +1195,13 @@ class NaDiT(nn.Module):
rope_dim = rope_dim, rope_dim = rope_dim,
window=window[i], window=window[i],
window_method=window_method[i], window_method=window_method[i],
version = self._7b_version,
is_last_layer=(i == num_layers - 1) and not self._7b_version, is_last_layer=(i == num_layers - 1) and not self._7b_version,
rope_type = rope_type, rope_type = rope_type,
shared_weights=not ( shared_weights=not (
(i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] (i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i]
), ),
version = self._7b_version,
operations = operations, operations = operations,
**kwargs,
**factory_kwargs **factory_kwargs
) )
for i in range(num_layers) for i in range(num_layers)
@ -1272,13 +1254,17 @@ class NaDiT(nn.Module):
first = cond_or_uncond[0] first = cond_or_uncond[0]
return all(entry == first for entry in cond_or_uncond) return all(entry == first for entry in cond_or_uncond)
@staticmethod
def _check_seedvr2_video_latent(x, channels, name):
if x.ndim != 5:
raise ValueError(f"SeedVR2 expected {name} to be 5-D native latent, got shape {tuple(x.shape)}.")
if x.shape[1] != channels:
raise ValueError(f"SeedVR2 expected {name} channels to be {channels}, got shape {tuple(x.shape)}.")
return x
def _swap_pos_neg_halves(self, out, cond_or_uncond=None): def _swap_pos_neg_halves(self, out, cond_or_uncond=None):
if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond): if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond):
return out return out
# ``dim=0`` is explicit on both calls. The contract is "split
# the batch axis into two halves and swap them"; making the
# axis load-bearing in source guards against silent drift if a
# future refactor reorders tensor axes.
pos, neg = out.chunk(2, dim=0) pos, neg = out.chunk(2, dim=0)
return torch.cat([neg, pos], dim=0) return torch.cat([neg, pos], dim=0)
@ -1294,9 +1280,15 @@ class NaDiT(nn.Module):
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
conditions = kwargs.get("condition") conditions = kwargs.get("condition")
b, tc, h, w = x.shape if conditions is None:
x = x.view(b, 16, -1, h, w) raise ValueError("SeedVR2 requires conditioning latents from the SeedVR2Conditioning node.")
conditions = conditions.view(b, 17, -1, h, w) x = self._check_seedvr2_video_latent(x, SEEDVR2_LATENT_CHANNELS, "latent")
conditions = self._check_seedvr2_video_latent(conditions, SEEDVR2_LATENT_CHANNELS + 1, "conditioning")
b, _, t, h, w = x.shape
if conditions.shape[0] != b or conditions.shape[2:] != (t, h, w):
raise ValueError(
f"SeedVR2 conditioning shape must match latent batch/temporal/spatial dimensions; got latent {tuple(x.shape)} and conditioning {tuple(conditions.shape)}."
)
x = x.movedim(1, -1) x = x.movedim(1, -1)
conditions = conditions.movedim(1, -1) conditions = conditions.movedim(1, -1)
cache = Cache(disable=disable_cache) cache = Cache(disable=disable_cache)
@ -1361,7 +1353,6 @@ class NaDiT(nn.Module):
vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify) vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify)
vid = unflatten(vid, vid_shape) vid = unflatten(vid, vid_shape)
out = torch.stack(vid) out = torch.stack(vid)
out = out.movedim(-1, 1) out = out.movedim(-1, 1)
out = out.reshape(out.shape[0], out.shape[1] * out.shape[2], out.shape[3], out.shape[4])
return self._swap_pos_neg_halves(out, transformer_options.get("cond_or_uncond")) return self._swap_pos_neg_halves(out, transformer_options.get("cond_or_uncond"))

View File

@ -62,7 +62,6 @@ def tiled_vae(
temporal_size=16, temporal_size=16,
temporal_overlap=0, temporal_overlap=0,
encode=True, encode=True,
**kwargs,
): ):
if x.ndim != 5: if x.ndim != 5:
x = x.unsqueeze(2) x = x.unsqueeze(2)
@ -166,8 +165,8 @@ def tiled_vae(
if single_spatial_tile: if single_spatial_tile:
result = tile_out[:, :, :target_d, :target_h, :target_w] result = tile_out[:, :, :target_d, :target_h, :target_w]
if result.device != x.device: if result.device != x.device or result.dtype != x.dtype:
result = result.to(x.device).to(x.dtype) result = result.to(device=x.device, dtype=x.dtype)
if x.shape[2] == 1 and sf_t == 1: if x.shape[2] == 1 and sf_t == 1:
result = result.squeeze(2) result = result.squeeze(2)
bar.update(1) bar.update(1)
@ -221,8 +220,8 @@ def tiled_vae(
result.div_(count.clamp(min=1e-6)) result.div_(count.clamp(min=1e-6))
if result.device != x.device: if result.device != x.device or result.dtype != x.dtype:
result = result.to(x.device).to(x.dtype) result = result.to(device=x.device, dtype=x.dtype)
if x.shape[2] == 1 and sf_t == 1: if x.shape[2] == 1 and sf_t == 1:
result = result.squeeze(2) result = result.squeeze(2)
@ -256,15 +255,18 @@ class MemoryState(Enum):
UNSET = 3 UNSET = 3
def get_cache_size(conv_module, input_len, pad_len, dim=0): def get_cache_size(conv_module, input_len, pad_len, dim=0):
dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 dilated_kernel_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1
output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 output_len = (input_len + pad_len - dilated_kernel_size) // conv_module.stride[dim] + 1
remain_len = ( remain_len = (
input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size) input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernel_size)
) )
overlap_len = dilated_kernerl_size - conv_module.stride[dim] overlap_len = dilated_kernel_size - conv_module.stride[dim]
cache_len = overlap_len + remain_len # >= 0 cache_len = overlap_len + remain_len
assert output_len > 0 if output_len <= 0:
raise ValueError(
f"SeedVR2 VAE cache input is too short for convolution: input_len={input_len}, pad_len={pad_len}."
)
return cache_len return cache_len
class DiagonalGaussianDistribution(object): class DiagonalGaussianDistribution(object):
@ -294,52 +296,27 @@ class SpatialNorm(nn.Module):
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f return new_f
# partial implementation of diffusers's Attention for comfyui
class Attention(nn.Module): class Attention(nn.Module):
def __init__( def __init__(
self, self,
query_dim: int, query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8, heads: int = 8,
kv_heads: Optional[int] = None,
dim_head: int = 64, dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False, bias: bool = False,
upcast_softmax: bool = False,
norm_num_groups: Optional[int] = None, norm_num_groups: Optional[int] = None,
spatial_norm_dim: Optional[int] = None, spatial_norm_dim: Optional[int] = None,
out_bias: bool = True, out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 1e-5, eps: float = 1e-5,
rescale_output_factor: float = 1.0, rescale_output_factor: float = 1.0,
residual_connection: bool = False, residual_connection: bool = False,
out_dim: int = None,
pre_only=False,
): ):
super().__init__() super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads self.inner_dim = dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
self.use_bias = bias
self.is_cross_attention = cross_attention_dim is not None
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection self.residual_connection = residual_connection
self.dropout = dropout self.out_dim = query_dim
self.fused_projections = False self.heads = heads
self.out_dim = out_dim if out_dim is not None else query_dim
self.pre_only = pre_only
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
self.sliceable_head_dim = heads
self.only_cross_attention = only_cross_attention
if norm_num_groups is not None: if norm_num_groups is not None:
self.group_norm = ops.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) self.group_norm = ops.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
@ -351,37 +328,19 @@ class Attention(nn.Module):
else: else:
self.spatial_norm = None self.spatial_norm = None
self.norm_q = None
self.norm_k = None
self.norm_cross = None
self.to_q = ops.Linear(query_dim, self.inner_dim, bias=bias) self.to_q = ops.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = ops.Linear(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention: self.to_v = ops.Linear(query_dim, self.inner_dim, bias=bias)
# only relevant for the `AddedKVProcessor` classes self.to_out = nn.ModuleList([])
self.to_k = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_v = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) self.to_out.append(nn.Identity())
else:
self.to_k = None
self.to_v = None
if not self.pre_only:
self.to_out = nn.ModuleList([])
self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
else:
self.to_out = None
self.optimized_vae_attention = vae_attention() self.optimized_vae_attention = vae_attention()
def __call__( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
@ -394,20 +353,14 @@ class Attention(nn.Module):
batch_size, channel, height, width = hidden_states.shape batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = ( batch_size = hidden_states.shape[0]
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if self.group_norm is not None: if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states) query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
if encoder_hidden_states is None: value = self.to_v(hidden_states)
encoder_hidden_states = hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
inner_dim = key.shape[-1] inner_dim = key.shape[-1]
head_dim = inner_dim // self.heads head_dim = inner_dim // self.heads
@ -417,25 +370,18 @@ class Attention(nn.Module):
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
if self.norm_q is not None: if input_ndim == 4 and self.heads == 1:
query = self.norm_q(query)
if self.norm_k is not None:
key = self.norm_k(key)
if input_ndim == 4 and encoder_hidden_states is hidden_states and attention_mask is None and self.heads == 1:
query = query.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) query = query.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width)
key = key.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) key = key.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width)
value = value.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) value = value.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width)
hidden_states = self.optimized_vae_attention(query, key, value).reshape(batch_size, self.heads, head_dim, height * width).transpose(2, 3) hidden_states = self.optimized_vae_attention(query, key, value).reshape(batch_size, self.heads, head_dim, height * width).transpose(2, 3)
else: else:
hidden_states = optimized_attention(query, key, value, heads = self.heads, mask = attention_mask, skip_reshape=True, skip_output_reshape=True) hidden_states = optimized_attention(query, key, value, heads = self.heads, skip_reshape=True, skip_output_reshape=True)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = self.to_out[0](hidden_states) hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states) hidden_states = self.to_out[1](hidden_states)
if input_ndim == 4: if input_ndim == 4:
@ -471,7 +417,10 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
memory_occupy = x.numel() * x.element_size() / 1024**3 memory_occupy = x.numel() * x.element_size() / 1024**3
if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit(): if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit():
num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups) num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups)
assert norm_layer.num_groups % num_chunks == 0 if norm_layer.num_groups % num_chunks != 0:
raise ValueError(
f"SeedVR2 VAE GroupNorm groups must divide chunks: groups={norm_layer.num_groups}, chunks={num_chunks}."
)
num_groups_per_chunk = norm_layer.num_groups // num_chunks num_groups_per_chunk = norm_layer.num_groups // num_chunks
x = list(x.chunk(num_chunks, dim=1)) x = list(x.chunk(num_chunks, dim=1))
@ -485,14 +434,15 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
x = norm_layer(x) x = norm_layer(x)
x = x.reshape((b, t, x.size(1), x.size(2), x.size(3))).transpose(1, 2) x = x.reshape((b, t, x.size(1), x.size(2), x.size(3))).transpose(1, 2)
return x.to(input_dtype) return x.to(input_dtype)
raise NotImplementedError raise TypeError(f"SeedVR2 VAE unsupported norm layer type: {type(norm_layer).__name__}")
_receptive_field_t = Literal["half", "full"] _receptive_field_t = Literal["half", "full"]
def extend_head(tensor, times: int = 2, memory = None): def extend_head(tensor, times: int = 2, memory = None):
if memory is not None: if memory is not None:
return torch.cat((memory.to(tensor), tensor), dim=2) return torch.cat((memory.to(tensor), tensor), dim=2)
assert times >= 0, "Invalid input for function 'extend_head'!" if times < 0:
raise ValueError(f"SeedVR2 VAE extend_head expected times >= 0, got {times}.")
if times == 0: if times == 0:
return tensor return tensor
else: else:
@ -547,13 +497,11 @@ class InflatedCausalConv3d(ops.Conv3d):
padding=(0, 0, 0, 0, 0, 0), padding=(0, 0, 0, 0, 0, 0),
prev_cache=None, prev_cache=None,
): ):
# Compatible with no limit.
if math.isinf(self.memory_limit): if math.isinf(self.memory_limit):
if prev_cache is not None: if prev_cache is not None:
x = torch.cat([prev_cache, x], dim=split_dim - 1) x = torch.cat([prev_cache, x], dim=split_dim - 1)
return super().forward(x) return super().forward(x)
# Compute tensor shape after concat & padding.
shape = list(x.size()) shape = list(x.size())
if prev_cache is not None: if prev_cache is not None:
shape[split_dim - 1] += prev_cache.size(split_dim - 1) shape[split_dim - 1] += prev_cache.size(split_dim - 1)
@ -597,16 +545,19 @@ class InflatedCausalConv3d(ops.Conv3d):
next_cache = None next_cache = None
cache_len = cache.size(split_dim) if cache is not None else 0 cache_len = cache.size(split_dim) if cache is not None else 0
next_catch_size = get_cache_size( next_cache_size = get_cache_size(
conv_module=self, conv_module=self,
input_len=x[idx].size(split_dim) + cache_len, input_len=x[idx].size(split_dim) + cache_len,
pad_len=pad_len, pad_len=pad_len,
dim=split_dim - 2, dim=split_dim - 2,
) )
if next_catch_size != 0: if next_cache_size != 0:
assert next_catch_size <= x[idx].size(split_dim) if next_cache_size > x[idx].size(split_dim):
raise ValueError(
f"SeedVR2 VAE cache size {next_cache_size} exceeds split size {x[idx].size(split_dim)}."
)
next_cache = ( next_cache = (
x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim) x[idx].transpose(0, split_dim)[-next_cache_size:].transpose(0, split_dim)
) )
x[idx] = self.memory_limit_conv( x[idx] = self.memory_limit_conv(
@ -627,7 +578,8 @@ class InflatedCausalConv3d(ops.Conv3d):
memory_state: MemoryState = MemoryState.UNSET, memory_state: MemoryState = MemoryState.UNSET,
memory_cache = None, memory_cache = None,
) -> Tensor: ) -> Tensor:
assert memory_state != MemoryState.UNSET if memory_state == MemoryState.UNSET:
raise ValueError("SeedVR2 VAE convolution requires an explicit MemoryState.")
if memory_cache is None: if memory_cache is None:
memory_cache = {} memory_cache = {}
if memory_state != MemoryState.ACTIVE: if memory_state != MemoryState.ACTIVE:
@ -677,9 +629,8 @@ class InflatedCausalConv3d(ops.Conv3d):
input, cache_size=cache_size, memory=memory, times=self.temporal_padding * 2 input, cache_size=cache_size, memory=memory, times=self.temporal_padding * 2
) )
# Single GPU inference - simplified memory management
if ( if (
memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE]
and cache_size != 0 and cache_size != 0
): ):
if cache_size > input[-1].size(2) and cache is not None and len(input) == 1: if cache_size > input[-1].size(2) and cache is not None and len(input) == 1:
@ -690,7 +641,6 @@ class InflatedCausalConv3d(ops.Conv3d):
padding = tuple(x for x in reversed(self.padding) for _ in range(2)) padding = tuple(x for x in reversed(self.padding) for _ in range(2))
for i in range(len(input)): for i in range(len(input)):
# Prepare cache for next input slice.
next_cache = None next_cache = None
cache_size = 0 cache_size = 0
if i < len(input) - 1: if i < len(input) - 1:
@ -700,17 +650,16 @@ class InflatedCausalConv3d(ops.Conv3d):
if cache_size > input[i].size(2) and cache is not None: if cache_size > input[i].size(2) and cache is not None:
input[i] = torch.cat([cache, input[i]], dim=2) input[i] = torch.cat([cache, input[i]], dim=2)
cache = None cache = None
assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}" if cache_size > input[i].size(2):
raise ValueError(f"SeedVR2 VAE cache size {cache_size} exceeds input length {input[i].size(2)}.")
next_cache = input[i][:, :, -cache_size:] next_cache = input[i][:, :, -cache_size:]
# Conv forward for this input slice.
input[i] = self.memory_limit_conv( input[i] = self.memory_limit_conv(
input[i], input[i],
padding=padding, padding=padding,
prev_cache=cache prev_cache=cache
) )
# Update cache.
cache = next_cache cache = next_cache
return input[0] if squeeze_out else input return input[0] if squeeze_out else input
@ -729,7 +678,6 @@ class Upsample3D(nn.Module):
inflation_mode = "tail", inflation_mode = "tail",
temporal_up: bool = False, temporal_up: bool = False,
spatial_up: bool = True, spatial_up: bool = True,
**kwargs,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@ -760,9 +708,9 @@ class Upsample3D(nn.Module):
hidden_states: torch.FloatTensor, hidden_states: torch.FloatTensor,
memory_state=None, memory_state=None,
memory_cache=None, memory_cache=None,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels if hidden_states.shape[1] != self.channels:
raise ValueError(f"SeedVR2 upsample expected {self.channels} channels, got {hidden_states.shape[1]}.")
hidden_states = self.upscale_conv(hidden_states) hidden_states = self.upscale_conv(hidden_states)
b, channels, f, h, w = hidden_states.shape b, channels, f, h, w = hidden_states.shape
@ -785,8 +733,6 @@ class Upsample3D(nn.Module):
class Downsample3D(nn.Module): class Downsample3D(nn.Module):
"""A 3D downsampling layer with an optional convolution."""
def __init__( def __init__(
self, self,
channels, channels,
@ -794,7 +740,6 @@ class Downsample3D(nn.Module):
inflation_mode = "tail", inflation_mode = "tail",
spatial_down: bool = False, spatial_down: bool = False,
temporal_down: bool = False, temporal_down: bool = False,
**kwargs,
): ):
super().__init__() super().__init__()
self.channels = channels self.channels = channels
@ -823,20 +768,17 @@ class Downsample3D(nn.Module):
hidden_states: torch.FloatTensor, hidden_states: torch.FloatTensor,
memory_state = None, memory_state = None,
memory_cache = None, memory_cache = None,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels if hidden_states.shape[1] != self.channels:
raise ValueError(f"SeedVR2 downsample expected {self.channels} channels, got {hidden_states.shape[1]}.")
if hasattr(self, "norm") and self.norm is not None:
# [Overridden] change to causal norm.
hidden_states = causal_norm_wrapper(self.norm, hidden_states)
if self.spatial_down: if self.spatial_down:
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert hidden_states.shape[1] == self.channels if hidden_states.shape[1] != self.channels:
raise ValueError(f"SeedVR2 downsample expected {self.channels} channels after padding, got {hidden_states.shape[1]}.")
hidden_states = self.conv(hidden_states, memory_state=memory_state, memory_cache=memory_cache) hidden_states = self.conv(hidden_states, memory_state=memory_state, memory_cache=memory_cache)
@ -848,7 +790,6 @@ class ResnetBlock3D(nn.Module):
self, self,
in_channels: int, in_channels: int,
out_channels: Optional[int] = None, out_channels: Optional[int] = None,
dropout: float = 0.0,
temb_channels: int = 512, temb_channels: int = 512,
groups: int = 32, groups: int = 32,
groups_out: Optional[int] = None, groups_out: Optional[int] = None,
@ -857,7 +798,6 @@ class ResnetBlock3D(nn.Module):
skip_time_act: bool = False, skip_time_act: bool = False,
inflation_mode = "tail", inflation_mode = "tail",
time_receptive_field: _receptive_field_t = "half", time_receptive_field: _receptive_field_t = "half",
**kwargs,
): ):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
@ -866,15 +806,14 @@ class ResnetBlock3D(nn.Module):
self.skip_time_act = skip_time_act self.skip_time_act = skip_time_act
self.nonlinearity = nn.SiLU() self.nonlinearity = nn.SiLU()
if temb_channels is not None: if temb_channels is not None:
self.time_emb_proj = ops.Linear(temb_channels, out_channels) self.time_emb_proj = ops.Linear(temb_channels, self.out_channels)
else: else:
self.time_emb_proj = None self.time_emb_proj = None
self.norm1 = ops.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.norm1 = ops.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
if groups_out is None: if groups_out is None:
groups_out = groups groups_out = groups
self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=self.out_channels, eps=eps, affine=True)
self.use_in_shortcut = self.in_channels != out_channels self.use_in_shortcut = self.in_channels != self.out_channels
self.dropout = torch.nn.Dropout(dropout)
self.conv1 = InflatedCausalConv3d( self.conv1 = InflatedCausalConv3d(
self.in_channels, self.in_channels,
self.out_channels, self.out_channels,
@ -886,7 +825,7 @@ class ResnetBlock3D(nn.Module):
self.conv2 = InflatedCausalConv3d( self.conv2 = InflatedCausalConv3d(
self.out_channels, self.out_channels,
out_channels, self.out_channels,
kernel_size=3, kernel_size=3,
stride=1, stride=1,
padding=1, padding=1,
@ -897,7 +836,7 @@ class ResnetBlock3D(nn.Module):
if self.use_in_shortcut: if self.use_in_shortcut:
self.conv_shortcut = InflatedCausalConv3d( self.conv_shortcut = InflatedCausalConv3d(
self.in_channels, self.in_channels,
out_channels, self.out_channels,
kernel_size=1, kernel_size=1,
stride=1, stride=1,
padding=0, padding=0,
@ -905,9 +844,7 @@ class ResnetBlock3D(nn.Module):
inflation_mode=inflation_mode, inflation_mode=inflation_mode,
) )
def forward( def forward(self, input_tensor, temb, memory_state = None, memory_cache = None):
self, input_tensor, temb, memory_state = None, memory_cache = None, **kwargs
):
hidden_states = input_tensor hidden_states = input_tensor
hidden_states = causal_norm_wrapper(self.norm1, hidden_states) hidden_states = causal_norm_wrapper(self.norm1, hidden_states)
@ -928,7 +865,6 @@ class ResnetBlock3D(nn.Module):
hidden_states = self.nonlinearity(hidden_states) hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, memory_state=memory_state, memory_cache=memory_cache) hidden_states = self.conv2(hidden_states, memory_state=memory_state, memory_cache=memory_cache)
if self.conv_shortcut is not None: if self.conv_shortcut is not None:
@ -944,7 +880,6 @@ class DownEncoderBlock3D(nn.Module):
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_groups: int = 32, resnet_groups: int = 32,
@ -957,28 +892,23 @@ class DownEncoderBlock3D(nn.Module):
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
temporal_modules = []
for i in range(num_layers): for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels in_channels = in_channels if i == 0 else out_channels
resnets.append( resnets.append(
# [Override] Replace module.
ResnetBlock3D( ResnetBlock3D(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=None, temb_channels=None,
eps=resnet_eps, eps=resnet_eps,
groups=resnet_groups, groups=resnet_groups,
dropout=dropout,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
inflation_mode=inflation_mode, inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field, time_receptive_field=time_receptive_field,
) )
) )
temporal_modules.append(nn.Identity())
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
self.temporal_modules = nn.ModuleList(temporal_modules)
if add_downsample: if add_downsample:
self.downsamplers = nn.ModuleList( self.downsamplers = nn.ModuleList(
@ -1000,11 +930,9 @@ class DownEncoderBlock3D(nn.Module):
hidden_states: torch.FloatTensor, hidden_states: torch.FloatTensor,
memory_state = None, memory_state = None,
memory_cache = None, memory_cache = None,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
for resnet, temporal in zip(self.resnets, self.temporal_modules): for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache) hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache)
hidden_states = temporal(hidden_states)
if self.downsamplers is not None: if self.downsamplers is not None:
for downsampler in self.downsamplers: for downsampler in self.downsamplers:
@ -1018,7 +946,6 @@ class UpDecoderBlock3D(nn.Module):
self, self,
in_channels: int, in_channels: int,
out_channels: int, out_channels: int,
dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_groups: int = 32, resnet_groups: int = 32,
@ -1032,33 +959,26 @@ class UpDecoderBlock3D(nn.Module):
): ):
super().__init__() super().__init__()
resnets = [] resnets = []
temporal_modules = []
for i in range(num_layers): for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels input_channels = in_channels if i == 0 else out_channels
resnets.append( resnets.append(
# [Override] Replace module.
ResnetBlock3D( ResnetBlock3D(
in_channels=input_channels, in_channels=input_channels,
out_channels=out_channels, out_channels=out_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
eps=resnet_eps, eps=resnet_eps,
groups=resnet_groups, groups=resnet_groups,
dropout=dropout,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
inflation_mode=inflation_mode, inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field, time_receptive_field=time_receptive_field,
) )
) )
temporal_modules.append(nn.Identity())
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
self.temporal_modules = nn.ModuleList(temporal_modules)
if add_upsample: if add_upsample:
# [Override] Replace module & use learnable upsample
self.upsamplers = nn.ModuleList( self.upsamplers = nn.ModuleList(
[ [
Upsample3D( Upsample3D(
@ -1080,9 +1000,8 @@ class UpDecoderBlock3D(nn.Module):
memory_state=None, memory_state=None,
memory_cache=None, memory_cache=None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
for resnet, temporal in zip(self.resnets, self.temporal_modules): for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache) hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache)
hidden_states = temporal(hidden_states)
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
@ -1096,7 +1015,6 @@ class UNetMidBlock3D(nn.Module):
self, self,
in_channels: int, in_channels: int,
temb_channels: int, temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1, num_layers: int = 1,
resnet_eps: float = 1e-6, resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial resnet_time_scale_shift: str = "default", # default, spatial
@ -1111,16 +1029,13 @@ class UNetMidBlock3D(nn.Module):
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention self.add_attention = add_attention
# there is always at least one resnet
resnets = [ resnets = [
# [Override] Replace module.
ResnetBlock3D( ResnetBlock3D(
in_channels=in_channels, in_channels=in_channels,
out_channels=in_channels, out_channels=in_channels,
temb_channels=temb_channels, temb_channels=temb_channels,
eps=resnet_eps, eps=resnet_eps,
groups=resnet_groups, groups=resnet_groups,
dropout=dropout,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
inflation_mode=inflation_mode, inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field, time_receptive_field=time_receptive_field,
@ -1148,7 +1063,6 @@ class UNetMidBlock3D(nn.Module):
), ),
residual_connection=True, residual_connection=True,
bias=True, bias=True,
upcast_softmax=True,
) )
) )
else: else:
@ -1161,7 +1075,6 @@ class UNetMidBlock3D(nn.Module):
temb_channels=temb_channels, temb_channels=temb_channels,
eps=resnet_eps, eps=resnet_eps,
groups=resnet_groups, groups=resnet_groups,
dropout=dropout,
output_scale_factor=output_scale_factor, output_scale_factor=output_scale_factor,
inflation_mode=inflation_mode, inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field, time_receptive_field=time_receptive_field,
@ -1172,7 +1085,7 @@ class UNetMidBlock3D(nn.Module):
self.resnets = nn.ModuleList(resnets) self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, memory_state=None, memory_cache=None): def forward(self, hidden_states, temb=None, memory_state=None, memory_cache=None):
video_length, frame_height, frame_width = hidden_states.size()[-3:] video_length = hidden_states.size(2)
hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state, memory_cache=memory_cache) hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state, memory_cache=memory_cache)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None: if attn is not None:
@ -1195,7 +1108,6 @@ class Encoder3D(nn.Module):
layers_per_block: int = 2, layers_per_block: int = 2,
norm_num_groups: int = 32, norm_num_groups: int = 32,
mid_block_add_attention=True, mid_block_add_attention=True,
# [Override] add temporal down num
temporal_down_num: int = 2, temporal_down_num: int = 2,
inflation_mode = "tail", inflation_mode = "tail",
time_receptive_field: _receptive_field_t = "half", time_receptive_field: _receptive_field_t = "half",
@ -1216,17 +1128,15 @@ class Encoder3D(nn.Module):
self.mid_block = None self.mid_block = None
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0] output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types): for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel input_channel = output_channel
output_channel = block_out_channels[i] output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1 is_final_block = i == len(block_out_channels) - 1
# [Override] to support temporal down block design
is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1
# Note: take the last ones
assert down_block_type == "DownEncoderBlock3D" if down_block_type != "DownEncoderBlock3D":
raise ValueError(f"SeedVR2 encoder only supports DownEncoderBlock3D, got {down_block_type}.")
down_block = DownEncoderBlock3D( down_block = DownEncoderBlock3D(
num_layers=self.layers_per_block, num_layers=self.layers_per_block,
@ -1242,7 +1152,6 @@ class Encoder3D(nn.Module):
) )
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock3D( self.mid_block = UNetMidBlock3D(
in_channels=block_out_channels[-1], in_channels=block_out_channels[-1],
resnet_eps=1e-6, resnet_eps=1e-6,
@ -1256,7 +1165,6 @@ class Encoder3D(nn.Module):
time_receptive_field=time_receptive_field, time_receptive_field=time_receptive_field,
) )
# out
self.conv_norm_out = ops.GroupNorm( self.conv_norm_out = ops.GroupNorm(
num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6
) )
@ -1274,17 +1182,13 @@ class Encoder3D(nn.Module):
memory_state = None, memory_state = None,
memory_cache = None, memory_cache = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = sample.to(next(self.parameters()).device) sample = sample.to(next(self.parameters()).device)
sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache) sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache)
# down
for down_block in self.down_blocks: for down_block in self.down_blocks:
sample = down_block(sample, memory_state=memory_state, memory_cache=memory_cache) sample = down_block(sample, memory_state=memory_state, memory_cache=memory_cache)
# middle
sample = self.mid_block(sample, memory_state=memory_state, memory_cache=memory_cache) sample = self.mid_block(sample, memory_state=memory_state, memory_cache=memory_cache)
# post-process
sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = causal_norm_wrapper(self.conv_norm_out, sample)
sample = self.conv_act(sample) sample = self.conv_act(sample)
sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache) sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache)
@ -1303,7 +1207,6 @@ class Decoder3D(nn.Module):
layers_per_block: int = 2, layers_per_block: int = 2,
norm_num_groups: int = 32, norm_num_groups: int = 32,
mid_block_add_attention=True, mid_block_add_attention=True,
# [Override] add temporal up block
inflation_mode = "tail", inflation_mode = "tail",
time_receptive_field: _receptive_field_t = "half", time_receptive_field: _receptive_field_t = "half",
temporal_up_num: int = 2, temporal_up_num: int = 2,
@ -1326,7 +1229,6 @@ class Decoder3D(nn.Module):
temb_channels = None temb_channels = None
# mid
self.mid_block = UNetMidBlock3D( self.mid_block = UNetMidBlock3D(
in_channels=block_out_channels[-1], in_channels=block_out_channels[-1],
resnet_eps=1e-6, resnet_eps=1e-6,
@ -1340,7 +1242,6 @@ class Decoder3D(nn.Module):
time_receptive_field=time_receptive_field, time_receptive_field=time_receptive_field,
) )
# up
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types): for i, up_block_type in enumerate(up_block_types):
@ -1349,7 +1250,8 @@ class Decoder3D(nn.Module):
is_final_block = i == len(block_out_channels) - 1 is_final_block = i == len(block_out_channels) - 1
is_temporal_up_block = i < self.temporal_up_num is_temporal_up_block = i < self.temporal_up_num
assert up_block_type == "UpDecoderBlock3D" if up_block_type != "UpDecoderBlock3D":
raise ValueError(f"SeedVR2 decoder only supports UpDecoderBlock3D, got {up_block_type}.")
up_block = UpDecoderBlock3D( up_block = UpDecoderBlock3D(
num_layers=self.layers_per_block + 1, num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel, in_channels=prev_output_channel,
@ -1365,7 +1267,6 @@ class Decoder3D(nn.Module):
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
# out
self.conv_norm_out = ops.GroupNorm( self.conv_norm_out = ops.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
) )
@ -1375,7 +1276,6 @@ class Decoder3D(nn.Module):
) )
# Note: Just copy from Decoder.
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
@ -1388,15 +1288,12 @@ class Decoder3D(nn.Module):
sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache) sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
# middle
sample = self.mid_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache) sample = self.mid_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache)
sample = sample.to(upscale_dtype) sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks: for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache) sample = up_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache)
# post-process
sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = causal_norm_wrapper(self.conv_norm_out, sample)
sample = self.conv_act(sample) sample = self.conv_act(sample)
sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache) sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache)
@ -1415,8 +1312,6 @@ class VideoAutoencoderKL(nn.Module):
inflation_mode = "pad", inflation_mode = "pad",
time_receptive_field: _receptive_field_t = "full", time_receptive_field: _receptive_field_t = "full",
slicing_sample_min_size = BYTEDANCE_SLICING_SAMPLE_MIN, slicing_sample_min_size = BYTEDANCE_SLICING_SAMPLE_MIN,
*args,
**kwargs,
): ):
self.slicing_sample_min_size = slicing_sample_min_size self.slicing_sample_min_size = slicing_sample_min_size
self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num)
@ -1425,7 +1320,6 @@ class VideoAutoencoderKL(nn.Module):
up_block_types = ("UpDecoderBlock3D",) * 4 up_block_types = ("UpDecoderBlock3D",) * 4
super().__init__() super().__init__()
# pass init params to Encoder
self.encoder = Encoder3D( self.encoder = Encoder3D(
in_channels=in_channels, in_channels=in_channels,
out_channels=latent_channels, out_channels=latent_channels,
@ -1433,13 +1327,11 @@ class VideoAutoencoderKL(nn.Module):
block_out_channels=block_out_channels, block_out_channels=block_out_channels,
layers_per_block=layers_per_block, layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups, norm_num_groups=norm_num_groups,
# [Override] add temporal_down_num parameter
temporal_down_num=temporal_scale_num, temporal_down_num=temporal_scale_num,
inflation_mode=inflation_mode, inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field, time_receptive_field=time_receptive_field,
) )
# pass init params to Decoder
self.decoder = Decoder3D( self.decoder = Decoder3D(
in_channels=latent_channels, in_channels=latent_channels,
out_channels=out_channels, out_channels=out_channels,
@ -1447,7 +1339,6 @@ class VideoAutoencoderKL(nn.Module):
block_out_channels=block_out_channels, block_out_channels=block_out_channels,
layers_per_block=layers_per_block, layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups, norm_num_groups=norm_num_groups,
# [Override] add temporal_up_num parameter
temporal_up_num=temporal_scale_num, temporal_up_num=temporal_scale_num,
inflation_mode=inflation_mode, inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field, time_receptive_field=time_receptive_field,
@ -1489,11 +1380,10 @@ class VideoAutoencoderKL(nn.Module):
return output.to(z.device) return output.to(z.device)
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
sp_size =1 if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size:
if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size:
memory_cache = {} memory_cache = {}
split_size = max( split_size = max(
self.slicing_sample_min_size * sp_size, self.slicing_sample_min_size,
getattr(self, "temporal_downsample_factor", 1), getattr(self, "temporal_downsample_factor", 1),
) )
x_slices = list(x[:, :, 1:].split(split_size=split_size, dim=2)) x_slices = list(x[:, :, 1:].split(split_size=split_size, dim=2))
@ -1518,10 +1408,9 @@ class VideoAutoencoderKL(nn.Module):
return self._encode(x) return self._encode(x)
def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: def slicing_decode(self, z: torch.Tensor) -> torch.Tensor:
sp_size = 1 if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size:
if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size:
memory_cache = {} memory_cache = {}
z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size, dim=2)
decoded_slices = [ decoded_slices = [
self._decode( self._decode(
torch.cat((z[:, :, :1], z_slices[0]), dim=2), torch.cat((z[:, :, :1], z_slices[0]), dim=2),
@ -1538,33 +1427,28 @@ class VideoAutoencoderKL(nn.Module):
else: else:
return self._decode(z) return self._decode(z)
def forward( def forward(self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all"):
self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs
):
# x: [b c t h w]
def _unwrap(value): def _unwrap(value):
return value[0] if isinstance(value, tuple) else value return value[0] if isinstance(value, tuple) else value
if mode == "encode": if mode == "encode":
return _unwrap(self.encode(x)) return _unwrap(self.encode(x))
elif mode == "decode": if mode == "decode":
return _unwrap(self.decode_(x)) return _unwrap(self.decode_(x))
else: if mode == "all":
latent = _unwrap(self.encode(x)) latent = _unwrap(self.encode(x))
return _unwrap(self.decode_(latent)) return _unwrap(self.decode_(latent))
raise ValueError(f"Unknown SeedVR2 VAE forward mode: {mode}")
class VideoAutoencoderKLWrapper(VideoAutoencoderKL): class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
def __init__( def __init__(
self, self,
*args,
spatial_downsample_factor = 8, spatial_downsample_factor = 8,
temporal_downsample_factor = 4, temporal_downsample_factor = 4,
**kwargs,
): ):
self.spatial_downsample_factor = spatial_downsample_factor self.spatial_downsample_factor = spatial_downsample_factor
self.temporal_downsample_factor = temporal_downsample_factor self.temporal_downsample_factor = temporal_downsample_factor
self.enable_tiling = False super().__init__()
super().__init__(*args, **kwargs)
self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB) self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB)
def forward(self, x: torch.FloatTensor): def forward(self, x: torch.FloatTensor):
@ -1581,7 +1465,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
z = p.squeeze(2) z = p.squeeze(2)
return z, p return z, p
def encode(self, x, orig_dims=None): def encode(self, x):
z, _ = self._encode_with_raw_latent(x) z, _ = self._encode_with_raw_latent(x)
return z return z
@ -1594,26 +1478,27 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
) )
if z.ndim == 5: if z.ndim == 5:
b, c, t_latent, h, w = z.shape _, c, _, _, _ = z.shape
if c != 16: if c != SEEDVR2_LATENT_CHANNELS:
raise RuntimeError( raise RuntimeError(
"SeedVR2 VideoAutoencoderKLWrapper.decode: 5-D latent input must " "SeedVR2 VideoAutoencoderKLWrapper.decode: 5-D latent input must "
f"have 16 channels; got shape {tuple(z.shape)}." f"have {SEEDVR2_LATENT_CHANNELS} channels; got shape {tuple(z.shape)}."
) )
latent = z latent = z
elif z.ndim == 4: elif z.ndim == 4:
b, tc, h, w = z.shape b, tc, h, w = z.shape
if tc % 16 != 0: if tc % SEEDVR2_LATENT_CHANNELS != 0:
raise RuntimeError( raise RuntimeError(
"SeedVR2 VideoAutoencoderKLWrapper.decode: 4-D latent input must " "SeedVR2 VideoAutoencoderKLWrapper.decode: 4-D latent input must "
"use collapsed channel layout (B, 16*T, H, W); " f"use collapsed channel layout (B, {SEEDVR2_LATENT_CHANNELS}*T, H, W); "
f"got shape {tuple(z.shape)}." f"got shape {tuple(z.shape)}."
) )
latent = z.reshape(b, 16, -1, h, w) latent = z.reshape(b, SEEDVR2_LATENT_CHANNELS, -1, h, w)
else: else:
raise RuntimeError( raise RuntimeError(
"SeedVR2 VideoAutoencoderKLWrapper.decode: latent input must be " "SeedVR2 VideoAutoencoderKLWrapper.decode: latent input must be "
"4-D collapsed (B, 16*T, H, W) or 5-D (B, 16, T, H, W); " f"4-D collapsed (B, {SEEDVR2_LATENT_CHANNELS}*T, H, W) or "
f"5-D (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); "
f"got shape {tuple(z.shape)}." f"got shape {tuple(z.shape)}."
) )
scale = BYTEDANCE_VAE_SCALING_FACTOR scale = BYTEDANCE_VAE_SCALING_FACTOR
@ -1621,10 +1506,11 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
latent = latent / scale + shift latent = latent / scale + shift
self.device = latent.device self.device = latent.device
self.enable_tiling = seedvr2_tiling.get("enable_tiling", False) enable_tiling = seedvr2_tiling.get("enable_tiling", False)
if self.enable_tiling: if enable_tiling:
decode_seedvr2_args = dict(seedvr2_tiling) decode_seedvr2_args = dict(seedvr2_tiling)
decode_seedvr2_args.pop("enable_tiling", None)
tile_h, tile_w = decode_seedvr2_args.get("tile_size", (512, 512)) tile_h, tile_w = decode_seedvr2_args.get("tile_size", (512, 512))
ov_h, ov_w = decode_seedvr2_args.get("tile_overlap", (64, 64)) ov_h, ov_w = decode_seedvr2_args.get("tile_overlap", (64, 64))
decode_seedvr2_args["tile_overlap"] = ( decode_seedvr2_args["tile_overlap"] = (
@ -1641,7 +1527,6 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
else: else:
x = super().decode_(latent) x = super().decode_(latent)
# ensure even dims for save video
h, w = x.shape[-2:] h, w = x.shape[-2:]
w2 = w - (w % 2) w2 = w - (w % 2)
h2 = h - (h % 2) h2 = h - (h % 2)
@ -1693,7 +1578,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
if samples.ndim == 4: if samples.ndim == 4:
samples = samples.unsqueeze(2) samples = samples.unsqueeze(2)
samples = samples.contiguous() samples = samples.contiguous()
samples = samples * 0.9152 samples = samples * BYTEDANCE_VAE_SCALING_FACTOR
return samples return samples
def comfy_memory_used_decode(self, shape): def comfy_memory_used_decode(self, shape):
@ -1707,15 +1592,15 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
# plus int64 sort indices dominate peak memory, not the VAE weight dtype. # plus int64 sort indices dominate peak memory, not the VAE weight dtype.
if len(shape) == 5: if len(shape) == 5:
candidates = [] candidates = []
if shape[1] == 16: if shape[1] == SEEDVR2_LATENT_CHANNELS:
candidates.append((shape[2], shape[3], shape[4])) candidates.append((shape[2], shape[3], shape[4]))
if shape[-1] == 16: if shape[-1] == SEEDVR2_LATENT_CHANNELS:
candidates.append((shape[1], shape[2], shape[3])) candidates.append((shape[1], shape[2], shape[3]))
if len(candidates) == 0: if len(candidates) == 0:
candidates.append((shape[2], shape[3], shape[4])) candidates.append((shape[2], shape[3], shape[4]))
pixels = max(output_pixels(*candidate) for candidate in candidates) pixels = max(output_pixels(*candidate) for candidate in candidates)
elif len(shape) == 4: elif len(shape) == 4:
latent_t = max(1, (shape[1] + 15) // 16) latent_t = max(1, (shape[1] + SEEDVR2_LATENT_CHANNELS - 1) // SEEDVR2_LATENT_CHANNELS)
pixels = output_pixels(latent_t, shape[2], shape[3]) pixels = output_pixels(latent_t, shape[2], shape[3])
else: else:
pixels = output_pixels(1, shape[-2], shape[-1]) pixels = output_pixels(1, shape[-2], shape[-1])

View File

@ -933,7 +933,8 @@ class HunyuanDiT(BaseModel):
class SeedVR2(BaseModel): class SeedVR2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.seedvr.model.NaDiT)
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs) out = super().extra_conds(**kwargs)
condition = kwargs.get("condition", None) condition = kwargs.get("condition", None)

View File

@ -598,43 +598,34 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config return dit_config
if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b seedvr2_7b_separate_key = "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)
if seedvr2_7b_separate_key in state_dict_keys and state_dict[seedvr2_7b_separate_key].shape[1] == 3072: # seedvr2 7b
dit_config = {} dit_config = {}
dit_config["image_model"] = "seedvr2" dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 3072 dit_config["vid_dim"] = 3072
dit_config["heads"] = 24 dit_config["heads"] = 24
dit_config["num_layers"] = 36 dit_config["num_layers"] = 36
# 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.`` # This checkpoint uses separate vid/txt MMModule keys in every block.
# submodules) at EVERY block — verified by inspecting the 7B
# state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means
# ``MMModule.shared_weights=False``). Native NaDiT computes
# per-block ``shared_weights = not (i < mm_layers)``, so to keep
# every block non-shared we set ``mm_layers = num_layers``.
# Without this, blocks at index >= mm_layers (default 10) try to
# load ``blocks.N.*.all.*`` keys that don't exist in the file,
# silently miss-load → all-black output.
dit_config["mm_layers"] = 36 dit_config["mm_layers"] = 36
dit_config["norm_eps"] = 1e-5 dit_config["norm_eps"] = 1e-5
dit_config["rope_type"] = "rope3d" dit_config["rope_type"] = "rope3d"
dit_config["rope_dim"] = 64 dit_config["rope_dim"] = 64
dit_config["mlp_type"] = "normal" dit_config["mlp_type"] = "normal"
return dit_config return dit_config
elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b if "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
dit_config = {} dit_config = {}
dit_config["image_model"] = "seedvr2" dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 3072 dit_config["vid_dim"] = 3072
dit_config["heads"] = 24 dit_config["heads"] = 24
dit_config["num_layers"] = 36 dit_config["num_layers"] = 36
# This checkpoint layout carries shared ``all.`` MMModule keys. # This checkpoint uses shared all.* MMModule keys after the initial blocks.
# Preserve the historical split: the initial blocks use separate
# vid/txt modules, later blocks use shared modules.
dit_config["mm_layers"] = 10 dit_config["mm_layers"] = 10
dit_config["norm_eps"] = 1e-5 dit_config["norm_eps"] = 1e-5
dit_config["rope_type"] = "rope3d" dit_config["rope_type"] = "rope3d"
dit_config["rope_dim"] = 64 dit_config["rope_dim"] = 64
dit_config["mlp_type"] = "swiglu" dit_config["mlp_type"] = "swiglu"
return dit_config return dit_config
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b if "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
dit_config = {} dit_config = {}
dit_config["image_model"] = "seedvr2" dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 2560 dit_config["vid_dim"] = 2560
@ -1150,8 +1141,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
unet_config["heatmap_head"] = True unet_config["heatmap_head"] = True
return unet_config return unet_config
def normalize_seedvr2_unet_config(unet_config):
if unet_config.get("image_model") != "seedvr2" or "num_heads" not in unet_config:
return unet_config
unet_config = dict(unet_config)
num_heads = unet_config.pop("num_heads")
if "heads" in unet_config and unet_config["heads"] != num_heads:
raise ValueError(
f"SeedVR2 config has conflicting heads={unet_config['heads']} and num_heads={num_heads}."
)
unet_config["heads"] = num_heads
return unet_config
def model_config_from_unet_config(unet_config, state_dict=None, unet_key_prefix=""): def model_config_from_unet_config(unet_config, state_dict=None, unet_key_prefix=""):
unet_config = normalize_seedvr2_unet_config(unet_config)
for model_config in comfy.supported_models.models: for model_config in comfy.supported_models.models:
if model_config.matches(unet_config, state_dict, unet_key_prefix=unet_key_prefix): if model_config.matches(unet_config, state_dict, unet_key_prefix=unet_key_prefix):
return model_config(unet_config) return model_config(unet_config)

View File

@ -472,8 +472,7 @@ class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd
if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
if metadata is None or metadata.get("keep_diffusers_format") != "true": sd = diffusers_convert.convert_vae_state_dict(sd)
sd = diffusers_convert.convert_vae_state_dict(sd)
if model_management.is_amd(): if model_management.is_amd():
VAE_KL_MEM_RATIO = 2.73 VAE_KL_MEM_RATIO = 2.73
@ -549,7 +548,7 @@ class VAE:
self.latent_channels = 16 self.latent_channels = 16
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
self.latent_channels = 16 self.latent_channels = comfy.ldm.seedvr.vae.SEEDVR2_LATENT_CHANNELS
self.latent_dim = 3 self.latent_dim = 3
self.disable_offload = True self.disable_offload = True
self.memory_used_decode = lambda shape, dtype: self.first_stage_model.comfy_memory_used_decode(shape) self.memory_used_decode = lambda shape, dtype: self.first_stage_model.comfy_memory_used_decode(shape)
@ -1074,6 +1073,20 @@ class VAE:
out = self.first_stage_model.encode_tiled(x, **kwargs) out = self.first_stage_model.encode_tiled(x, **kwargs)
return out.to(device=self.output_device, dtype=self.vae_output_dtype()) return out.to(device=self.output_device, dtype=self.vae_output_dtype())
def _owned_tiled_args(self, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
args = {}
if tile_x is not None:
args["tile_x"] = tile_x
if tile_y is not None:
args["tile_y"] = tile_y
if overlap is not None:
args["overlap"] = overlap
if tile_t is not None:
args["tile_t"] = tile_t
if overlap_t is not None:
args["overlap_t"] = overlap_t
return args
def decode(self, samples_in, vae_options={}): def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid() self.throw_exception_if_invalid()
pixel_samples = None pixel_samples = None
@ -1153,18 +1166,7 @@ class VAE:
with model_management.cuda_device_context(self.device): with model_management.cuda_device_context(self.device):
if self.handles_tiling and dims in (2, 3): if self.handles_tiling and dims in (2, 3):
tiled_args = {} output = self._decode_tiled_owned(samples, **self._owned_tiled_args(tile_x, tile_y, overlap, tile_t, overlap_t))
if tile_x is not None:
tiled_args["tile_x"] = tile_x
if tile_y is not None:
tiled_args["tile_y"] = tile_y
if overlap is not None:
tiled_args["overlap"] = overlap
if tile_t is not None:
tiled_args["tile_t"] = tile_t
if overlap_t is not None:
tiled_args["overlap_t"] = overlap_t
output = self._decode_tiled_owned(samples, **tiled_args)
elif dims == 1 or self.extra_1d_channel is not None: elif dims == 1 or self.extra_1d_channel is not None:
args.pop("tile_y") args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args) output = self.decode_tiled_1d(samples, **args)
@ -1269,18 +1271,7 @@ class VAE:
samples = self.encode_tiled_(pixel_samples, **args) samples = self.encode_tiled_(pixel_samples, **args)
elif dims == 3: elif dims == 3:
if self.handles_tiling: if self.handles_tiling:
tiled_args = {} samples = self._encode_tiled_owned(pixel_samples, **self._owned_tiled_args(tile_x, tile_y, overlap, tile_t, overlap_t))
if tile_x is not None:
tiled_args["tile_x"] = tile_x
if tile_y is not None:
tiled_args["tile_y"] = tile_y
if overlap is not None:
tiled_args["overlap"] = overlap
if tile_t is not None:
tiled_args["tile_t"] = tile_t
if overlap_t is not None:
tiled_args["overlap_t"] = overlap_t
samples = self._encode_tiled_owned(pixel_samples, **tiled_args)
else: else:
if tile_t is not None: if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
@ -1850,7 +1841,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae) return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)

View File

@ -1688,6 +1688,7 @@ class SeedVR2(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "seedvr2" "image_model": "seedvr2"
} }
unet_extra_config = {}
required_keys = { required_keys = {
"{}positive_conditioning", "{}positive_conditioning",
"{}negative_conditioning", "{}negative_conditioning",

View File

@ -19,21 +19,14 @@ from comfy.ldm.seedvr.constants import (
) )
from torchvision.transforms import functional as TVF from torchvision.transforms import functional as TVF
from torchvision.transforms import Lambda
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
_SEEDVR2_INVALID_MODEL_MSG_PREFIX = ( _SEEDVR2_INVALID_MODEL_MSG_PREFIX = "SeedVR2Conditioning: model object does not match expected SeedVR2 structure"
"SeedVR2Conditioning: model object does not match expected SeedVR2 structure"
)
# Private sentinel for getattr default: distinguishes "attribute missing"
# from "attribute present but None" so the failure message is accurate.
_ATTR_MISSING = object() _ATTR_MISSING = object()
def _resolve_seedvr2_diffusion_model(model): def _resolve_seedvr2_diffusion_model(model):
"""Resolve ``model.model.diffusion_model``, failing loud via the ``_ATTR_MISSING`` sentinel so each of the four modes (model/diffusion_model missing vs None) gives an accurate message."""
inner = getattr(model, "model", _ATTR_MISSING) inner = getattr(model, "model", _ATTR_MISSING)
if inner is _ATTR_MISSING: if inner is _ATTR_MISSING:
raise RuntimeError( raise RuntimeError(
@ -59,15 +52,7 @@ def _resolve_seedvr2_diffusion_model(model):
return diffusion_model return diffusion_model
def get_conditions(latent, latent_blur):
t, h, w, c = latent.shape
cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
cond[:, ..., :-1] = latent_blur[:]
cond[:, ..., -1:] = 1.0
return cond
def div_pad(image, factor): def div_pad(image, factor):
height_factor, width_factor = factor height_factor, width_factor = factor
height, width = image.shape[-2:] height, width = image.shape[-2:]
@ -77,31 +62,25 @@ def div_pad(image, factor):
if pad_height == 0 and pad_width == 0: if pad_height == 0 and pad_width == 0:
return image return image
if isinstance(image, torch.Tensor): padding = (0, pad_width, 0, pad_height)
padding = (0, pad_width, 0, pad_height) return torch.nn.functional.pad(image, padding, mode='constant', value=0.0)
image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0)
return image
def cut_videos(videos): def cut_videos(videos):
t = videos.size(1) t = videos.size(1)
if t < 1:
raise ValueError("SeedVR2Preprocess expected at least one frame.")
if t == 1: if t == 1:
return videos return videos
if t <= 4 : if t <= 4:
padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1) padding = videos[:, -1:].repeat(1, 4 - t + 1, 1, 1, 1)
padding = torch.cat(padding, dim=1) return torch.cat([videos, padding], dim=1)
videos = torch.cat([videos, padding], dim=1) if (t - 1) % 4 == 0:
return videos
if (t - 1) % (4) == 0:
return videos
else:
padding = [videos[:, -1].unsqueeze(1)] * (
4 - ((t - 1) % (4))
)
padding = torch.cat(padding, dim=1)
videos = torch.cat([videos, padding], dim=1)
assert (videos.size(1) - 1) % (4) == 0
return videos return videos
padding = videos[:, -1:].repeat(1, 4 - ((t - 1) % 4), 1, 1, 1)
videos = torch.cat([videos, padding], dim=1)
if (videos.size(1) - 1) % 4 != 0:
raise ValueError(f"SeedVR2Preprocess failed to pad video length to 4n+1; got {videos.size(1)} frames.")
return videos
def _seedvr2_input_shorter_edge(images, node_name): def _seedvr2_input_shorter_edge(images, node_name):
if images.dim() == 4: if images.dim() == 4:
@ -136,8 +115,7 @@ def _seedvr2_pad(images, upscaled_shorter_edge, node_name):
b, t, c, h, w = images.shape b, t, c, h, w = images.shape
images = images.reshape(b * t, c, h, w) images = images.reshape(b * t, c, h, w)
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) images = torch.clamp(images, 0.0, 1.0)
images = clip(images)
images = div_pad(images, (16, 16)) images = div_pad(images, (16, 16))
_, _, new_h, new_w = images.shape _, _, new_h, new_w = images.shape
@ -295,7 +273,6 @@ class SeedVR2PostProcessing(io.ComfyNode):
def _color_transfer_chunked(cls, decoded_flat, reference_flat, output_device, color_correction_method): def _color_transfer_chunked(cls, decoded_flat, reference_flat, output_device, color_correction_method):
chunk_size = cls._estimate_color_correction_chunk_size(decoded_flat, color_correction_method) chunk_size = cls._estimate_color_correction_chunk_size(decoded_flat, color_correction_method)
while True: while True:
next_chunk_size = None
try: try:
return cls._run_color_transfer_chunks( return cls._run_color_transfer_chunks(
decoded_flat, reference_flat, output_device, color_correction_method, chunk_size, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size,
@ -307,9 +284,7 @@ class SeedVR2PostProcessing(io.ComfyNode):
"SeedVR2PostProcessing: color correction OOM at one frame; " "SeedVR2PostProcessing: color correction OOM at one frame; "
f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}." f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}."
) from e ) from e
next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR) chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR)
chunk_size = next_chunk_size
@classmethod @classmethod
def _run_color_transfer_chunks(cls, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size): def _run_color_transfer_chunks(cls, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size):
@ -392,10 +367,8 @@ class SeedVR2Conditioning(io.ComfyNode):
io.Latent.Input("vae_conditioning", display_name="latent"), io.Latent.Input("vae_conditioning", display_name="latent"),
], ],
outputs=[ outputs=[
io.Model.Output(display_name="model", tooltip="The SeedVR2 model, passed through."),
io.Conditioning.Output(display_name="positive", tooltip="The positive conditioning for sampling."), io.Conditioning.Output(display_name="positive", tooltip="The positive conditioning for sampling."),
io.Conditioning.Output(display_name="negative", tooltip="The negative conditioning for sampling."), io.Conditioning.Output(display_name="negative", tooltip="The negative conditioning for sampling."),
io.Latent.Output(display_name="latent", tooltip="The latent to denoise."),
], ],
) )
@ -408,29 +381,30 @@ class SeedVR2Conditioning(io.ComfyNode):
"SeedVR2Conditioning expects a 5-D VAE latent in Comfy " "SeedVR2Conditioning expects a 5-D VAE latent in Comfy "
f"channel-first layout; got shape {tuple(vae_conditioning.shape)}." f"channel-first layout; got shape {tuple(vae_conditioning.shape)}."
) )
if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS: if vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS:
if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS:
raise ValueError(
"SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy "
f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); "
f"got channel-last shape {tuple(vae_conditioning.shape)}."
)
raise ValueError( raise ValueError(
"SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy " "SeedVR2Conditioning expects SeedVR2 VAE latents with "
f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); " f"{SEEDVR2_LATENT_CHANNELS} channels; got shape {tuple(vae_conditioning.shape)}."
f"got channel-last shape {tuple(vae_conditioning.shape)}."
) )
vae_conditioning = vae_conditioning.movedim(1, -1).contiguous() vae_conditioning = vae_conditioning.movedim(1, -1).contiguous()
model_patcher = model model = _resolve_seedvr2_diffusion_model(model)
model = _resolve_seedvr2_diffusion_model(model_patcher)
pos_cond = model.positive_conditioning pos_cond = model.positive_conditioning
neg_cond = model.negative_conditioning neg_cond = model.negative_conditioning
condition = torch.stack([get_conditions(c, c) for c in vae_conditioning]) mask = vae_conditioning.new_ones(vae_conditioning.shape[:-1] + (1,))
condition = torch.cat((vae_conditioning, mask), dim=-1)
condition = condition.movedim(-1, 1) condition = condition.movedim(-1, 1)
latent = vae_conditioning.movedim(-1, 1)
latent = latent.reshape(latent.shape[0], latent.shape[1] * latent.shape[2], latent.shape[3], latent.shape[4])
condition = condition.reshape(condition.shape[0], condition.shape[1] * condition.shape[2], condition.shape[3], condition.shape[4])
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]] negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]] positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
return io.NodeOutput(model_patcher, positive, negative, {"samples": latent}) return io.NodeOutput(positive, negative)
class SeedVRExtension(ComfyExtension): class SeedVRExtension(ComfyExtension):
@override @override

View File

@ -1,20 +1,15 @@
"""Consolidated SeedVR2 conditioning and refactor regression tests. """SeedVR2 conditioning node regression tests."""
Merges the prior test_seedvr2_refactor_nodes.py and
test_seedvr_conditioning_hardening.py modules. Refactor tests use the
top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests
use _import_nodes_seedvr_isolated() for sys.modules isolation when
mocking comfy.model_management.
"""
import importlib import importlib
import sys import sys
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from comfy.cli_args import args as cli_args from comfy.cli_args import args as cli_args
from comfy.ldm.seedvr.constants import SEEDVR2_LATENT_CHANNELS
if not torch.cuda.is_available(): if not torch.cuda.is_available():
cli_args.cpu = True cli_args.cpu = True
@ -79,21 +74,18 @@ def _import_nodes_seedvr_isolated():
class _Rope(nn.Module): class _Rope(nn.Module):
"""Minimal RoPE stub exposing a `freqs` parameter."""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.freqs = nn.Parameter(torch.zeros(4)) self.freqs = nn.Parameter(torch.zeros(4))
class _Block(nn.Module): class _Block(nn.Module):
"""Minimal transformer block stub holding a `_Rope`."""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.rope = _Rope() self.rope = _Rope()
class _DiffusionModel(nn.Module): class _DiffusionModel(nn.Module):
"""Stub diffusion model with N blocks and pos/neg conditioning buffers."""
def __init__(self, n_blocks=3, conditioning_dtype=torch.float32): def __init__(self, n_blocks=3, conditioning_dtype=torch.float32):
super().__init__() super().__init__()
self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)]) self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)])
@ -102,18 +94,16 @@ class _DiffusionModel(nn.Module):
class _ModelInner: class _ModelInner:
"""Inner model wrapper exposing `.diffusion_model`."""
def __init__(self, diffusion_model): def __init__(self, diffusion_model):
self.diffusion_model = diffusion_model self.diffusion_model = diffusion_model
class _ModelPatcher: class _ModelPatcher:
"""ModelPatcher stub exposing `.model._ModelInner`."""
def __init__(self, diffusion_model): def __init__(self, diffusion_model):
self.model = _ModelInner(diffusion_model) self.model = _ModelInner(diffusion_model)
def test_seedvr2_conditioning_schema_exposes_model_passthrough_output(): def test_seedvr2_conditioning_schema_exposes_conditioning_outputs():
nodes_seedvr, restore = _import_nodes_seedvr_isolated() nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try: try:
schema = nodes_seedvr.SeedVR2Conditioning.define_schema() schema = nodes_seedvr.SeedVR2Conditioning.define_schema()
@ -123,37 +113,50 @@ def test_seedvr2_conditioning_schema_exposes_model_passthrough_output():
] ]
assert schema.inputs[1].display_name == "latent" assert schema.inputs[1].display_name == "latent"
assert [output.display_name for output in schema.outputs] == [ assert [output.display_name for output in schema.outputs] == [
"model",
"positive", "positive",
"negative", "negative",
"latent",
] ]
finally: finally:
restore() restore()
def test_seedvr2_conditioning_returns_packed_input_latent_deterministically(): def test_seedvr2_conditioning_rejects_wrong_latent_channels():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
patcher = _ModelPatcher(_DiffusionModel())
vae_conditioning = {"samples": torch.zeros(1, 8, 2, 2, 2)}
with pytest.raises(ValueError, match=f"{SEEDVR2_LATENT_CHANNELS} channels"):
nodes_seedvr.SeedVR2Conditioning.execute(patcher, vae_conditioning)
finally:
restore()
def test_seedvr2_conditioning_returns_conditioning_deterministically():
nodes_seedvr, restore = _import_nodes_seedvr_isolated() nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try: try:
diffusion_model = _DiffusionModel() diffusion_model = _DiffusionModel()
patcher = _ModelPatcher(diffusion_model) patcher = _ModelPatcher(diffusion_model)
samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2) samples = torch.arange(
1,
1 + SEEDVR2_LATENT_CHANNELS * 3 * 2 * 2,
dtype=torch.float32,
).reshape(1, SEEDVR2_LATENT_CHANNELS, 3, 2, 2)
vae_conditioning = {"samples": samples} vae_conditioning = {"samples": samples}
_, first_positive, first_negative, first_latent = ( first_positive, first_negative = (
nodes_seedvr.SeedVR2Conditioning.execute( nodes_seedvr.SeedVR2Conditioning.execute(
patcher, patcher,
vae_conditioning, vae_conditioning,
) )
) )
_, second_positive, second_negative, second_latent = ( second_positive, second_negative = (
nodes_seedvr.SeedVR2Conditioning.execute( nodes_seedvr.SeedVR2Conditioning.execute(
patcher, patcher,
vae_conditioning, vae_conditioning,
) )
) )
expected_latent = samples.reshape(1, 6, 2, 2)
channel_last = samples.movedim(1, -1).contiguous() channel_last = samples.movedim(1, -1).contiguous()
expected_condition = torch.cat( expected_condition = torch.cat(
[ [
@ -161,10 +164,8 @@ def test_seedvr2_conditioning_returns_packed_input_latent_deterministically():
torch.ones((*channel_last.shape[:-1], 1)), torch.ones((*channel_last.shape[:-1], 1)),
], ],
dim=-1, dim=-1,
).movedim(-1, 1).reshape(1, 9, 2, 2) ).movedim(-1, 1)
assert torch.equal(first_latent["samples"], expected_latent)
assert torch.equal(second_latent["samples"], expected_latent)
assert torch.equal( assert torch.equal(
first_positive[0][1]["condition"], first_positive[0][1]["condition"],
expected_condition, expected_condition,

View File

@ -201,6 +201,17 @@ class TestModelDetection:
del sd["positive_conditioning"] del sd["positive_conditioning"]
assert model_config_from_unet_config(unet_config, sd) is None assert model_config_from_unet_config(unet_config, sd) is None
def test_seedvr2_model_match_normalizes_num_heads(self):
sd = _make_seedvr2_7b_shared_mm_sd()
unet_config = detect_unet_config(sd, "")
unet_config["num_heads"] = unet_config.pop("heads")
model_config = model_config_from_unet_config(unet_config, sd)
assert type(model_config).__name__ == "SeedVR2"
assert model_config.unet_config["heads"] == 24
assert "num_heads" not in model_config.unet_config
def test_seedvr2_model_match_accepts_full_checkpoint_prefix(self): def test_seedvr2_model_match_accepts_full_checkpoint_prefix(self):
sd = _add_model_diffusion_prefix(_make_seedvr2_7b_shared_mm_sd()) sd = _add_model_diffusion_prefix(_make_seedvr2_7b_shared_mm_sd())

View File

@ -1,22 +1,6 @@
"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must """Regression tests for the SeedVR2 VAE forward return contract."""
honor the actual tensor/tuple return contract of ``encode()`` and
``decode_()`` and must NOT dereference diffusers-style ``.latent_dist``
or ``.sample`` attributes on those returns.
The pre-fix body raised ``AttributeError: 'Tensor' object has no
attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and
``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'``
for ``mode == "decode"`` (the class only defines ``decode_`` with a
trailing underscore). The post-fix body unwraps the optional one-element
tuple shape that ``return_dict=False`` produces and returns the tensor
directly.
Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses
the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and
overrides ``encode``/``decode_`` with known tensors so the contract can
be probed without loading any real VAE weights.
"""
import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -25,13 +9,13 @@ from comfy.cli_args import args as cli_args
if not torch.cuda.is_available(): if not torch.cuda.is_available():
cli_args.cpu = True cli_args.cpu = True
from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402 from comfy.ldm.seedvr.vae import SEEDVR2_LATENT_CHANNELS, VideoAutoencoderKL # noqa: E402
_LATENT_SHAPE = (1, 16, 2, 2, 2) _LATENT_SHAPE = (1, SEEDVR2_LATENT_CHANNELS, 2, 2, 2)
_DECODED_SHAPE = (1, 3, 5, 16, 16) _DECODED_SHAPE = (1, 3, 5, 16, 16)
_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16) _INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16)
_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2) _INPUT_DECODE_SHAPE = _LATENT_SHAPE
class _StubVAE(VideoAutoencoderKL): class _StubVAE(VideoAutoencoderKL):
@ -64,8 +48,6 @@ def test_forward_decode_returns_tensor():
class _TupleReturningStubVAE(VideoAutoencoderKL): class _TupleReturningStubVAE(VideoAutoencoderKL):
"""Stub whose ``encode``/``decode_`` return the ``(tensor,)`` tuple of ``return_dict=False``, exercising the unwrap branch of ``VideoAutoencoderKL.forward``."""
def __init__(self): def __init__(self):
nn.Module.__init__(self) nn.Module.__init__(self)
self._encode_tensor = torch.zeros(*_LATENT_SHAPE) self._encode_tensor = torch.zeros(*_LATENT_SHAPE)
@ -84,3 +66,9 @@ def test_forward_all_unwraps_one_tuple_at_each_step():
result = vae.forward(x, mode="all") result = vae.forward(x, mode="all")
assert type(result) is torch.Tensor assert type(result) is torch.Tensor
assert result.shape == torch.Size(_DECODED_SHAPE) assert result.shape == torch.Size(_DECODED_SHAPE)
def test_forward_rejects_unknown_mode():
vae = _StubVAE()
with pytest.raises(ValueError, match="Unknown SeedVR2 VAE forward mode"):
vae.forward(torch.zeros(*_INPUT_ENCODE_SHAPE), mode="bogus")

View File

@ -41,8 +41,9 @@ def test_seedvr2_text_conditioning_accepts_cfg1_single_branch():
def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer(): def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer():
wrapper = seedvr_vae.VideoAutoencoderKLWrapper.__new__(seedvr_vae.VideoAutoencoderKLWrapper) wrapper = seedvr_vae.VideoAutoencoderKLWrapper.__new__(seedvr_vae.VideoAutoencoderKLWrapper)
estimate = wrapper.comfy_memory_used_decode((1, 16, 26, 120, 160)) latent_channels = seedvr_vae.SEEDVR2_LATENT_CHANNELS
old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2 estimate = wrapper.comfy_memory_used_decode((1, latent_channels, 26, 120, 160))
old_estimate = latent_channels * 120 * 160 * (4 * 8 * 8) * 2
assert estimate == 101 * 960 * 1280 * 160 assert estimate == 101 * 960 * 1280 * 160
assert estimate > 15 * 1024 ** 3 assert estimate > 15 * 1024 ** 3

View File

@ -1,16 +1,4 @@
"""Consolidated SeedVR2 internals regression tests. """SeedVR2 internals regression tests."""
Sources (all merged verbatim, helper names disambiguated where colliding):
* GroupNorm limit gate causal_norm_wrapper at vae.py:509 must compare
memory_occupy against get_norm_limit(), not float('inf').
* SeedVR2 variable-length attention split-loop contract.
Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and
comfy.ldm.modules.attention transitively pull in comfy.model_management,
which probes torch.cuda.current_device() at import time unless args.cpu is
set first.
"""
from __future__ import annotations from __future__ import annotations
@ -35,10 +23,6 @@ from comfy.ldm.seedvr.vae import ( # noqa: E402
from comfy.ldm.seedvr.attention import var_attention_optimized_split # noqa: E402 from comfy.ldm.seedvr.attention import var_attention_optimized_split # noqa: E402
# ---------------------------------------------------------------------------
# GroupNorm limit tests (test_seedvr_groupnorm_limit.py)
# ---------------------------------------------------------------------------
_NUM_CHANNELS = 8 _NUM_CHANNELS = 8
_NUM_GROUPS = 4 _NUM_GROUPS = 4
_TENSOR_SHAPE = (1, 8, 2, 4, 4) _TENSOR_SHAPE = (1, 8, 2, 4, 4)
@ -89,10 +73,6 @@ def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls):
set_norm_limit(None) set_norm_limit(None)
# ---------------------------------------------------------------------------
# SeedVR2 var_attention split-loop tests
# ---------------------------------------------------------------------------
def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch): def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch):
dim = 8 dim = 8
heads = 2 heads = 2
@ -140,18 +120,8 @@ def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypa
assert call["heads"] == heads assert call["heads"] == heads
assert call["skip_reshape"] is True assert call["skip_reshape"] is True
assert call["skip_output_reshape"] is True assert call["skip_output_reshape"] is True
torch.testing.assert_close( assert call["cu_seqlens_q"] == [0, 7, 14]
call["cu_seqlens_q"], assert call["cu_seqlens_k"] == [0, 7, 14]
torch.tensor([0, 7, 14], dtype=torch.int32),
rtol=0,
atol=0,
)
torch.testing.assert_close(
call["cu_seqlens_k"],
torch.tensor([0, 7, 14], dtype=torch.int32),
rtol=0,
atol=0,
)
def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch): def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch):
@ -160,7 +130,7 @@ def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatc
q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim) q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim)
k = q + 100 k = q + 100
v = q + 200 v = q + 200
cu = torch.tensor([0, 2, 5], dtype=torch.int32) cu = [0, 2, 5]
calls = [] calls = []
def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs): def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs):
@ -197,20 +167,3 @@ def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatc
assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls) assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls)
torch.testing.assert_close(out, q + v, rtol=0, atol=0) torch.testing.assert_close(out, q + v, rtol=0, atol=0)
def test_var_attention_optimized_split_rejects_bad_offsets():
q = torch.randn(5, 2, 3)
cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32)
cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32)
with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"):
var_attention_optimized_split(
q,
q,
q,
2,
cu_bad,
cu_ok,
skip_reshape=True,
skip_output_reshape=True,
)

View File

@ -1,17 +1,10 @@
"""Consolidated SeedVR2 model/graph/forward regression tests. """SeedVR2 model, latent-format, and VAE graph regression tests."""
Merged from:
- seedvr_model_test.py
- test_seedvr_7b_final_block_text_path.py
- test_seedvr_forward_no_device_cast.py
- test_seedvr_latent_format.py
- test_seedvr2_vae_graph_boundaries.py
"""
from __future__ import annotations from __future__ import annotations
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest
import torch import torch
from torch import nn from torch import nn
@ -22,7 +15,6 @@ if not torch.cuda.is_available():
import comfy # noqa: E402 import comfy # noqa: E402
import comfy.latent_formats # noqa: E402 import comfy.latent_formats # noqa: E402
import comfy.ldm.seedvr.model # noqa: E402
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
import comfy.model_management # noqa: E402 import comfy.model_management # noqa: E402
@ -33,9 +25,7 @@ import nodes as nodes_mod # noqa: E402
from comfy.ldm.seedvr.model import NaDiT # noqa: E402 from comfy.ldm.seedvr.model import NaDiT # noqa: E402
# --------------------------------------------------------------------------- _LATENT_CHANNELS = seedvr_vae_mod.SEEDVR2_LATENT_CHANNELS
# Helpers from seedvr_model_test.py
# ---------------------------------------------------------------------------
def _make_standin(positive_conditioning): def _make_standin(positive_conditioning):
@ -51,11 +41,6 @@ def _make_standin(positive_conditioning):
return _StandIn() return _StandIn()
# ---------------------------------------------------------------------------
# Helpers from test_seedvr_7b_final_block_text_path.py
# ---------------------------------------------------------------------------
class _StubModule(nn.Module): class _StubModule(nn.Module):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__() super().__init__()
@ -88,11 +73,6 @@ def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> lis
return flags return flags
# ---------------------------------------------------------------------------
# Helpers from test_seedvr_latent_format.py
# ---------------------------------------------------------------------------
class _Model: class _Model:
def __init__(self, latent_format): def __init__(self, latent_format):
self._latent_format = latent_format self._latent_format = latent_format
@ -102,11 +82,6 @@ class _Model:
return self._latent_format return self._latent_format
# ---------------------------------------------------------------------------
# Helpers from test_seedvr2_vae_graph_boundaries.py
# ---------------------------------------------------------------------------
class _Patcher: class _Patcher:
def get_free_memory(self, device): def get_free_memory(self, device):
return 1024 * 1024 * 1024 return 1024 * 1024 * 1024
@ -136,14 +111,14 @@ class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling}) self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling})
if z.ndim == 4: if z.ndim == 4:
b, tc, h, w = z.shape b, tc, h, w = z.shape
t = tc // 16 t = tc // _LATENT_CHANNELS
else: else:
b, _, t, h, w = z.shape b, _, t, h, w = z.shape
return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device) return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
def test_seedvr2_wrapper_public_encode_returns_tensor(monkeypatch): def test_seedvr2_wrapper_public_encode_returns_tensor(monkeypatch):
raw_latent = torch.full((1, 16, 1, 4, 5), 2.0) raw_latent = torch.full((1, _LATENT_CHANNELS, 1, 4, 5), 2.0)
seen_shapes = [] seen_shapes = []
def base_encode(self, x): def base_encode(self, x):
@ -159,12 +134,12 @@ def test_seedvr2_wrapper_public_encode_returns_tensor(monkeypatch):
latent = vae.encode(torch.zeros(1, 3, 32, 40)) latent = vae.encode(torch.zeros(1, 3, 32, 40))
assert type(latent) is torch.Tensor assert type(latent) is torch.Tensor
assert tuple(latent.shape) == (1, 16, 4, 5) assert tuple(latent.shape) == (1, _LATENT_CHANNELS, 4, 5)
assert seen_shapes == [(1, 3, 1, 32, 40)] assert seen_shapes == [(1, 3, 1, 32, 40)]
def test_seedvr2_wrapper_private_encode_helper_keeps_raw_latent(monkeypatch): def test_seedvr2_wrapper_private_encode_helper_keeps_raw_latent(monkeypatch):
raw_latent = torch.full((1, 16, 1, 4, 5), 3.0) raw_latent = torch.full((1, _LATENT_CHANNELS, 1, 4, 5), 3.0)
def base_encode(self, x): def base_encode(self, x):
return raw_latent.to(device=x.device, dtype=x.dtype) return raw_latent.to(device=x.device, dtype=x.dtype)
@ -177,8 +152,8 @@ def test_seedvr2_wrapper_private_encode_helper_keeps_raw_latent(monkeypatch):
latent, raw = vae._encode_with_raw_latent(torch.zeros(1, 3, 32, 40)) latent, raw = vae._encode_with_raw_latent(torch.zeros(1, 3, 32, 40))
assert tuple(latent.shape) == (1, 16, 4, 5) assert tuple(latent.shape) == (1, _LATENT_CHANNELS, 4, 5)
assert tuple(raw.shape) == (1, 16, 1, 4, 5) assert tuple(raw.shape) == (1, _LATENT_CHANNELS, 1, 4, 5)
assert torch.equal(raw, raw_latent) assert torch.equal(raw, raw_latent)
@ -188,7 +163,7 @@ def _make_vae(wrapper):
vae.device = torch.device("cpu") vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu") vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32 vae.vae_dtype = torch.float32
vae.latent_channels = 16 vae.latent_channels = _LATENT_CHANNELS
vae.latent_dim = 3 vae.latent_dim = 3
vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8) vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8)
vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
@ -212,13 +187,7 @@ def _make_vae(wrapper):
return vae return vae
# ---------------------------------------------------------------------------
# Tests from seedvr_model_test.py
# ---------------------------------------------------------------------------
def test_missing_context_falls_back_to_positive_buffer(): def test_missing_context_falls_back_to_positive_buffer():
"""``context is None`` falls back to the registered ``positive_conditioning`` buffer and runs to completion."""
pos_buffer = torch.full((58, 5120), 7.0) pos_buffer = torch.full((58, 5120), 7.0)
standin = _make_standin(pos_buffer) standin = _make_standin(pos_buffer)
txt, txt_shape = standin._resolve_text_conditioning(None) txt, txt_shape = standin._resolve_text_conditioning(None)
@ -231,11 +200,6 @@ def test_missing_context_falls_back_to_positive_buffer():
assert txt_shape[0, 0].item() == 58 assert txt_shape[0, 0].item() == 58
# ---------------------------------------------------------------------------
# Tests from test_seedvr_7b_final_block_text_path.py
# ---------------------------------------------------------------------------
def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch): def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch):
assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [ assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [
False, False,
@ -268,43 +232,49 @@ def test_seedvr2_7b_rope3d_matches_wrapper_oracle():
torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0) torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0)
# --------------------------------------------------------------------------- def test_seedvr2_forward_requires_conditioning_latents():
# Tests from test_seedvr_latent_format.py model = NaDiT.__new__(NaDiT)
# --------------------------------------------------------------------------- x = torch.zeros(1, _LATENT_CHANNELS, 1, 4, 5)
with pytest.raises(ValueError, match="requires conditioning latents"):
NaDiT.forward(model, x, timestep=torch.tensor([1.0]), context=None)
def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion(): def test_seedvr2_latent_format_uses_native_video_latent_shape():
latent_format = comfy.latent_formats.SeedVR2() latent_format = comfy.latent_formats.SeedVR2()
latent_image = torch.zeros(1, 1, 4, 5) latent_image = torch.zeros(1, 1, 4, 5)
fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image) fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image)
assert latent_format.latent_channels == 16 assert latent_format.latent_channels == _LATENT_CHANNELS
assert latent_format.latent_dimensions == 2 assert latent_format.latent_dimensions == 3
assert fixed.shape == (1, 16, 4, 5) assert fixed.shape == (1, _LATENT_CHANNELS, 1, 4, 5)
# --------------------------------------------------------------------------- def test_seedvr2_model_requires_native_5d_latent():
# Tests from test_seedvr2_vae_graph_boundaries.py latent = torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5)
# --------------------------------------------------------------------------- assert NaDiT._check_seedvr2_video_latent(latent, _LATENT_CHANNELS, "latent") is latent
with pytest.raises(ValueError, match="5-D native latent"):
NaDiT._check_seedvr2_video_latent(torch.zeros(1, _LATENT_CHANNELS * 2, 4, 5), _LATENT_CHANNELS, "latent")
def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch): def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch):
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
encoded = torch.full((1, 16, 2, 4, 5), 2.0) encoded = torch.full((1, _LATENT_CHANNELS, 2, 4, 5), 2.0)
vae = _make_vae(_EncodeWrapper(encoded)) vae = _make_vae(_EncodeWrapper(encoded))
pixels = torch.zeros(1, 5, 32, 40, 3) pixels = torch.zeros(1, 5, 32, 40, 3)
node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0] node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0]
node_latent = node_output["samples"] node_latent = node_output["samples"]
assert set(node_output) == {"samples"} assert set(node_output) == {"samples"}
assert tuple(node_latent.shape) == (1, 16, 2, 4, 5) assert tuple(node_latent.shape) == (1, _LATENT_CHANNELS, 2, 4, 5)
assert node_latent.dtype == torch.float32 assert node_latent.dtype == torch.float32
assert node_latent.stride()[-1] == 1 assert node_latent.stride()[-1] == 1
assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152)) assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * seedvr_vae_mod.BYTEDANCE_VAE_SCALING_FACTOR))
tiled = torch.full((1, 16, 2, 4, 5), 3.0) tiled = torch.full((1, _LATENT_CHANNELS, 2, 4, 5), 3.0)
monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled)) monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled))
tiled_output = nodes_mod.VAEEncodeTiled().encode( tiled_output = nodes_mod.VAEEncodeTiled().encode(
vae, vae,
@ -316,9 +286,9 @@ def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeyp
)[0] )[0]
tiled_latent = tiled_output["samples"] tiled_latent = tiled_output["samples"]
assert set(tiled_output) == {"samples"} assert set(tiled_output) == {"samples"}
assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5) assert tuple(tiled_latent.shape) == (1, _LATENT_CHANNELS, 2, 4, 5)
assert tiled_latent.dtype == torch.float32 assert tiled_latent.dtype == torch.float32
assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152)) assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * seedvr_vae_mod.BYTEDANCE_VAE_SCALING_FACTOR))
def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch): def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch):
@ -327,7 +297,7 @@ def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch):
nodes_mod.VAEDecodeTiled().decode( nodes_mod.VAEDecodeTiled().decode(
vae, vae,
{"samples": torch.zeros(1, 16, 2, 4, 5)}, {"samples": torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5)},
tile_size=512, tile_size=512,
overlap=64, overlap=64,
temporal_size=16, temporal_size=16,
@ -339,7 +309,7 @@ def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch):
# knobs are no-ops at the wrapper. # knobs are no-ops at the wrapper.
assert vae.first_stage_model.calls == [ assert vae.first_stage_model.calls == [
{ {
"shape": (1, 16, 2, 4, 5), "shape": (1, _LATENT_CHANNELS, 2, 4, 5),
"seedvr2_tiling": { "seedvr2_tiling": {
"enable_tiling": True, "enable_tiling": True,
"tile_size": (512, 512), "tile_size": (512, 512),

View File

@ -13,6 +13,9 @@ import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
from comfy_extras import nodes_seedvr # noqa: E402 from comfy_extras import nodes_seedvr # noqa: E402
_LATENT_CHANNELS = vae_mod.SEEDVR2_LATENT_CHANNELS
def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper: def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper:
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
vae_mod.VideoAutoencoderKLWrapper vae_mod.VideoAutoencoderKLWrapper
@ -40,7 +43,7 @@ def _decode_with_patches(wrapper, z):
def test_decode_b2_t3_multi_frame_batch_unchanged(): def test_decode_b2_t3_multi_frame_batch_unchanged():
wrapper = _make_wrapper() wrapper = _make_wrapper()
out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2)) out = _decode_with_patches(wrapper, torch.zeros(2, _LATENT_CHANNELS * 3, 2, 2))
assert tuple(out.shape) == (2, 3, 3, 16, 16) assert tuple(out.shape) == (2, 3, 3, 16, 16)
@ -62,17 +65,17 @@ def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preproc
wrapper = _Wrapper() wrapper = _Wrapper()
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub): with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5)) out = wrapper.decode(torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5))
assert tuple(out.shape) == (1, 3, 2, 32, 40) assert tuple(out.shape) == (1, 3, 2, 32, 40)
assert wrapper.calls == [(1, 16, 2, 4, 5)] assert wrapper.calls == [(1, _LATENT_CHANNELS, 2, 4, 5)]
def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents(): def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents():
wrapper = _Wrapper() wrapper = _Wrapper()
with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"): with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"):
wrapper.decode(torch.zeros(1, 16, 4)) wrapper.decode(torch.zeros(1, _LATENT_CHANNELS, 4))
def _t_padded(t_in: int) -> int: def _t_padded(t_in: int) -> int:

View File

@ -16,9 +16,7 @@ import comfy.sd as sd_mod # noqa: E402
from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402 from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402
# --------------------------------------------------------------------------- _LATENT_CHANNELS = seedvr_vae_mod.SEEDVR2_LATENT_CHANNELS
# From test_seedvr_vae_tiled_decode_latent_min_size_override.py
# ---------------------------------------------------------------------------
def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
@ -44,7 +42,7 @@ def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype) return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype)
vae = StubVAEModel() vae = StubVAEModel()
z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32) z = torch.zeros((1, _LATENT_CHANNELS, 5, 8, 8), dtype=torch.float32)
tiled_vae( tiled_vae(
z, z,
@ -61,11 +59,6 @@ def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
assert vae.slicing_latent_min_size == 2 assert vae.slicing_latent_min_size == 2
# ---------------------------------------------------------------------------
# From test_seedvr_vae_tiled_encode_runt_slice_override.py
# ---------------------------------------------------------------------------
def test_zero_temporal_size_preserves_min_size_when_encode_raises(): def test_zero_temporal_size_preserves_min_size_when_encode_raises():
class RaisingVAEModel(torch.nn.Module): class RaisingVAEModel(torch.nn.Module):
def __init__(self): def __init__(self):
@ -110,7 +103,7 @@ def test_tiled_vae_encode_uses_tensor_return_without_indexing():
def encode(self, t_chunk): def encode(self, t_chunk):
self.calls.append(tuple(t_chunk.shape)) self.calls.append(tuple(t_chunk.shape))
b, _, _, h, w = t_chunk.shape b, _, _, h, w = t_chunk.shape
return torch.ones((b, 16, 1, h // 8, w // 8), dtype=t_chunk.dtype) return torch.ones((b, _LATENT_CHANNELS, 1, h // 8, w // 8), dtype=t_chunk.dtype)
vae = TensorEncodeVAEModel() vae = TensorEncodeVAEModel()
x = torch.zeros((2, 3, 1, 64, 64), dtype=torch.float32) x = torch.zeros((2, 3, 1, 64, 64), dtype=torch.float32)
@ -126,12 +119,34 @@ def test_tiled_vae_encode_uses_tensor_return_without_indexing():
) )
assert vae.calls == [(2, 3, 1, 64, 64)] assert vae.calls == [(2, 3, 1, 64, 64)]
assert tuple(out.shape) == (2, 16, 1, 8, 8) assert tuple(out.shape) == (2, _LATENT_CHANNELS, 1, 8, 8)
# --------------------------------------------------------------------------- def test_tiled_vae_preserves_input_dtype_on_single_tile():
# From test_seedvr_vae_tiled_temporal_slicing.py class FloatOutputVAEModel(torch.nn.Module):
# --------------------------------------------------------------------------- def __init__(self):
super().__init__()
self.slicing_sample_min_size = 4
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
def encode(self, t_chunk):
b, _, _, h, w = t_chunk.shape
return torch.ones((b, _LATENT_CHANNELS, 1, h // 8, w // 8), dtype=torch.float32)
out = tiled_vae(
torch.zeros((1, 3, 1, 64, 64), dtype=torch.float16),
FloatOutputVAEModel(),
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=0,
temporal_overlap=0,
encode=True,
)
assert out.dtype == torch.float16
class _SlicingDecodeVAE(nn.Module): class _SlicingDecodeVAE(nn.Module):
@ -164,7 +179,10 @@ class _SlicingDecodeVAE(nn.Module):
def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size(): def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size():
vae = _SlicingDecodeVAE(slicing_latent_min_size=2) vae = _SlicingDecodeVAE(slicing_latent_min_size=2)
z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8) z = torch.arange(
_LATENT_CHANNELS * 5 * 8 * 8,
dtype=torch.float32,
).reshape(1, _LATENT_CHANNELS, 5, 8, 8)
tiled_vae( tiled_vae(
z, z,
@ -199,16 +217,11 @@ def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size():
return torch.zeros(1, 3, 1, 16, 16) return torch.zeros(1, 3, 1, 16, 16)
with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae): with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae):
wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling) wrapper.decode(torch.zeros(1, _LATENT_CHANNELS, 2, 2), seedvr2_tiling=seedvr2_tiling)
assert captured["temporal_overlap"] == 7 assert captured["temporal_overlap"] == 7
# ---------------------------------------------------------------------------
# From test_vae_decode_tiled_dispatcher_seedvr2_4d.py
# ---------------------------------------------------------------------------
def _force_oom(*a, **k): def _force_oom(*a, **k):
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
@ -256,10 +269,10 @@ def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode):
def test_4d_seedvr2_latent_routes_to_owned_decode_tiled(): def test_4d_seedvr2_latent_routes_to_owned_decode_tiled():
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
seedvr_vae_mod.VideoAutoencoderKLWrapper) seedvr_vae_mod.VideoAutoencoderKLWrapper)
vae = _make_vae(wrapper, latent_channels=16, latent_dim=3) vae = _make_vae(wrapper, latent_channels=_LATENT_CHANNELS, latent_dim=3)
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
_dispatch(vae, torch.zeros(1, 16 * 3, 8, 8), seedvr2_call, generic_call, True) _dispatch(vae, torch.zeros(1, _LATENT_CHANNELS * 3, 8, 8), seedvr2_call, generic_call, True)
assert seedvr2_call.call_count == 1 assert seedvr2_call.call_count == 1
assert generic_call.call_count == 0 assert generic_call.call_count == 0
@ -275,11 +288,6 @@ def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled():
assert seedvr2_call.call_count == 0 assert seedvr2_call.call_count == 0
# ---------------------------------------------------------------------------
# From test_vae_encode_tiled_fallback_dispatcher_seedvr2.py
# ---------------------------------------------------------------------------
def _populate_common_vae_attrs_fallback(vae): def _populate_common_vae_attrs_fallback(vae):
vae.patcher = MagicMock() vae.patcher = MagicMock()
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
@ -291,7 +299,7 @@ def _populate_common_vae_attrs_fallback(vae):
vae.upscale_ratio = 8 vae.upscale_ratio = 8
vae.upscale_index_formula = None vae.upscale_index_formula = None
vae.output_channels = 3 vae.output_channels = 3
vae.latent_channels = 16 vae.latent_channels = _LATENT_CHANNELS
vae.latent_dim = 3 vae.latent_dim = 3
vae.downscale_ratio = 8 vae.downscale_ratio = 8
vae.downscale_index_formula = None vae.downscale_index_formula = None
@ -334,8 +342,8 @@ def test_seedvr2_3d_routes_to_owned_encode_tiled_on_oom():
vae = _make_seedvr2_vae_fallback() vae = _make_seedvr2_vae_fallback()
pixel_samples = torch.zeros((1, 8, 64, 64, 3)) pixel_samples = torch.zeros((1, 8, 64, 64, 3))
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) seedvr2_call = MagicMock(return_value=torch.zeros(1, _LATENT_CHANNELS, 2, 8, 8))
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) generic_call = MagicMock(return_value=torch.zeros(1, _LATENT_CHANNELS, 2, 8, 8))
with patch.object(sd_mod.model_management, "raise_non_oom", with patch.object(sd_mod.model_management, "raise_non_oom",
lambda e: None), \ lambda e: None), \
@ -363,7 +371,7 @@ def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete():
vae = _make_non_seedvr2_vae_fallback() vae = _make_non_seedvr2_vae_fallback()
vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8) vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8)
vae.upscale_ratio = (lambda a: a * 4, 8, 8) vae.upscale_ratio = (lambda a: a * 4, 8, 8)
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) generic_call = MagicMock(return_value=torch.zeros(1, _LATENT_CHANNELS, 2, 8, 8))
pixel_samples = torch.zeros((1, 8, 64, 64, 3)) pixel_samples = torch.zeros((1, 8, 64, 64, 3))
with patch.object(sd_mod.model_management, "load_models_gpu", with patch.object(sd_mod.model_management, "load_models_gpu",