Refactors and cleanups.

This commit is contained in:
comfyanonymous 2026-07-02 22:59:38 -04:00
parent 77d42ed7e9
commit c7b2c3b569
19 changed files with 436 additions and 729 deletions

View File

@ -781,6 +781,7 @@ class ACEAudio(LatentFormat):
class SeedVR2(LatentFormat):
latent_channels = 16
latent_dimensions = 3
class ACEAudio15(LatentFormat):
latent_channels = 64

View File

@ -22,33 +22,14 @@ def _var_attention_output(out, heads, head_dim, skip_output_reshape):
return out.reshape(-1, heads * head_dim)
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
if cu_seqlens.dtype not in (torch.int32, torch.int64):
raise ValueError(f"{name} must use an integer dtype")
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
if cu_seqlens[0].item() != 0:
raise ValueError(f"{name} must start at 0")
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
raise ValueError(f"{name} must be strictly increasing")
if cu_seqlens[-1].item() != token_count:
raise ValueError(f"{name} does not match token count")
def _split_indices(cu_seqlens):
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0])
if cu_seqlens_k[-1].item() != v.shape[0]:
q_split_indices = cu_seqlens_q[1:-1]
k_split_indices = cu_seqlens_k[1:-1]
if k.shape[0] != v.shape[0]:
raise ValueError("cu_seqlens_k does not match v token count")
q_split_indices = _split_indices(cu_seqlens_q)
k_split_indices = _split_indices(cu_seqlens_k)
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
v_splits = torch.tensor_split(v, k_split_indices, dim=0)

View File

@ -45,7 +45,6 @@ def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS):
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
if content_feat.shape != style_feat.shape:
# Resize style to match content spatial dimensions
if len(content_feat.shape) >= 3:
style_feat = F.interpolate(
style_feat,
@ -54,12 +53,11 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
align_corners=False
)
# Decompose both features into frequency components
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq # Free memory immediately
del content_low_freq
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq # Free memory immediately
del style_high_freq
if content_high_freq.shape != style_low_freq.shape:
style_low_freq = F.interpolate(
@ -73,27 +71,23 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
return content_high_freq.clamp_(-1.0, 1.0)
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
def _histogram_matching_channel(source: Tensor, reference: Tensor) -> Tensor:
original_shape = source.shape
# Flatten
source_flat = source.flatten()
reference_flat = reference.flatten()
# Sort both arrays
source_sorted, source_indices = torch.sort(source_flat)
reference_sorted, _ = torch.sort(reference_flat)
del reference_flat
# Quantile mapping
n_source = len(source_sorted)
n_reference = len(reference_sorted)
if n_source == n_reference:
matched_sorted = reference_sorted
else:
# Interpolate reference to match source quantiles
source_quantiles = torch.linspace(0, 1, n_source, device=device)
source_quantiles = torch.linspace(0, 1, n_source, device=source.device)
ref_indices = (source_quantiles * (n_reference - 1)).long()
ref_indices.clamp_(0, n_reference - 1)
matched_sorted = reference_sorted[ref_indices]
@ -101,7 +95,6 @@ def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch
del source_sorted, source_flat
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
inverse_indices = torch.argsort(source_indices)
del source_indices
matched_flat = matched_sorted[inverse_indices]
@ -109,17 +102,14 @@ def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch
return matched_flat.reshape(original_shape)
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
"""Convert batch of CIELAB images to RGB color space."""
def _lab_to_rgb_batch(lab: Tensor, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
# LAB to XYZ
fy = (L + 16.0) / 116.0
fx = a.div(500.0).add_(fy)
fz = fy - b / 200.0
del L, a, b
# XYZ transformation
x = torch.where(
fx > epsilon,
torch.pow(fx, 3.0),
@ -137,20 +127,16 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps
)
del fx, fy, fz
# Apply D65 white point (in-place)
x.mul_(D65_WHITE_X)
# y *= 1.00000 # (no-op, skip)
z.mul_(D65_WHITE_Z)
xyz = torch.stack([x, y, z], dim=1)
del x, y, z
# Matrix multiplication: XYZ -> RGB
B, C, H, W = xyz.shape
B, _, H, W = xyz.shape
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
del xyz
# Ensure dtype consistency for matrix multiplication
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
del xyz_flat
@ -158,7 +144,6 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
del rgb_linear_flat
# Apply inverse gamma correction (delinearize)
mask = rgb_linear > 0.0031308
rgb = torch.where(
mask,
@ -169,9 +154,7 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps
return torch.clamp(rgb, 0.0, 1.0)
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
"""Convert batch of RGB images to CIELAB color space using D65 illuminant."""
# Apply sRGB gamma correction (linearize)
def _rgb_to_lab_batch(rgb: Tensor, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
mask = rgb > 0.04045
rgb_linear = torch.where(
mask,
@ -180,12 +163,10 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon
)
del mask
# Matrix multiplication: RGB -> XYZ
B, C, H, W = rgb_linear.shape
B, _, H, W = rgb_linear.shape
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
del rgb_linear
# Ensure dtype consistency for matrix multiplication
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
xyz_flat = torch.matmul(rgb_flat, matrix.T)
del rgb_flat
@ -193,12 +174,9 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
del xyz_flat
# Normalize by D65 white point (in-place)
xyz[:, 0].div_(D65_WHITE_X) # X
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
xyz[:, 2].div_(D65_WHITE_Z) # Z
xyz[:, 0].div_(D65_WHITE_X)
xyz[:, 2].div_(D65_WHITE_Z)
# XYZ to LAB transformation
epsilon_cubed = epsilon ** 3
mask = xyz > epsilon_cubed
f_xyz = torch.where(
@ -208,10 +186,9 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon
)
del xyz, mask
# Extract channels and compute LAB
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
L = f_xyz[:, 1].mul(116.0).sub_(16.0)
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0)
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0)
del f_xyz
return torch.stack([L, a, b], dim=1)
@ -232,13 +209,9 @@ def lab_color_transfer(
)
device = content_feat.device
def ensure_float32_precision(c):
orig_dtype = c.dtype
c = c.float()
return c, orig_dtype
content_feat, original_dtype = ensure_float32_precision(content_feat)
style_feat, _ = ensure_float32_precision(style_feat)
original_dtype = content_feat.dtype
content_feat = content_feat.float()
style_feat = style_feat.float()
rgb_to_xyz_matrix = torch.tensor([
[0.4124564, 0.3575761, 0.1804375],
@ -258,39 +231,30 @@ def lab_color_transfer(
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
# Convert to LAB color space
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
content_lab = _rgb_to_lab_batch(content_feat, rgb_to_xyz_matrix, epsilon, kappa)
del content_feat
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
style_lab = _rgb_to_lab_batch(style_feat, rgb_to_xyz_matrix, epsilon, kappa)
del style_feat, rgb_to_xyz_matrix
# Match chrominance channels (a*, b*) for accurate color transfer
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1])
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2])
# Handle luminance with weighted blending
if luminance_weight < 1.0:
# Partially match luminance for better overall color accuracy
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device)
# Blend: preserve some content L* for detail, adopt some style L* for color
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0])
result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight))
del matched_L
else:
# Fully preserve content luminance
result_L = content_lab[:, 0]
del content_lab, style_lab
# Reconstruct LAB with corrected channels
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
del result_L, matched_a, matched_b
# Convert back to RGB
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
result_rgb = _lab_to_rgb_batch(result_lab, xyz_to_rgb_matrix, epsilon, kappa)
del result_lab, xyz_to_rgb_matrix
# Convert back to [-1, 1] range (in-place)
result = result_rgb.mul_(2.0).sub_(1.0)
del result_rgb

View File

@ -1,34 +1,21 @@
"""Named constants for the SeedVR2 integration, grouped by provenance.
"""SeedVR2 constants."""
Provenance prefixes:
- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline.
- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites
the upstream config/source path it was lifted from.
- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature /
ISO / CIE values; cite the standard.
"""
SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim.
# (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.)
SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # OOM retry backoff: halve the chunk and retry.
SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case).
SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM.
SEEDVR2_7B_VID_DIM = 3072
SEEDVR2_OOM_BACKOFF_DIVISOR = 2
SEEDVR2_DTYPE_BYTES_FLOOR = 4
SEEDVR2_7B_MLP_CHUNK = 8192
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk.
SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels).
SEEDVR2_LATENT_CHANNELS = 16
# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing)
SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk.
SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path.
SEEDVR2_COLOR_MEM_HEADROOM = 0.75
SEEDVR2_LAB_SCALE_MULTIPLIER = 13
SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path.
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path.
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6
# --------------------------------------------------------------------------------------
# ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR)
# --------------------------------------------------------------------------------------
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm.
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift.
BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem).
BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem).
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57.
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0
BYTEDANCE_VAE_CONV_MEM_GIB = 0.5
BYTEDANCE_VAE_NORM_MEM_GIB = 0.5
BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28.
BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28.
BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16).
@ -42,18 +29,10 @@ BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal wind
BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency).
BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim).
# --------------------------------------------------------------------------------------
# Published standards (cite the literature)
# --------------------------------------------------------------------------------------
ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864.
# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65).
CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta).
CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa).
D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1).
D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn.
WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR).
# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and
# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the
# exact existing coefficients move verbatim rather than being retyped here.

View File

@ -3,7 +3,7 @@ from typing import Optional, Tuple, Union, List, Dict, Any, Callable
import torch.nn.functional as F
from math import ceil, pi
import torch
from itertools import chain
from itertools import accumulate, chain
from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding
from comfy.ldm.seedvr.attention import optimized_var_attention
from torch.nn.modules.utils import _triple
@ -18,6 +18,7 @@ from comfy.ldm.seedvr.constants import (
ROPE_THETA,
SEEDVR2_7B_MLP_CHUNK,
SEEDVR2_7B_VID_DIM,
SEEDVR2_LATENT_CHANNELS,
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS,
)
import comfy.model_management
@ -70,7 +71,7 @@ def repeat_concat_idx(
vid_idx = torch.arange(vid_len.sum(), device=device)
txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device)
txt_repeat_list = txt_repeat.tolist()
tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat)
tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat_list)
src_idx = torch.argsort(tgt_idx)
txt_idx_len = len(tgt_idx) - len(vid_idx)
repeat_txt_len = (txt_len * txt_repeat).tolist()
@ -88,6 +89,9 @@ def repeat_concat_idx(
lambda all: unconcat_coalesce(all),
)
def cumulative_lengths(lengths):
return [0, *accumulate(lengths)]
@dataclass
class MMArg:
@ -110,16 +114,14 @@ def get_window_op(name: str):
raise ValueError(f"Unknown windowing method: {name}")
# -------------------------------- Windowing -------------------------------- #
def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]):
t, h, w = size
resized_nt, resized_nh, resized_nw = num_windows
#cal windows under 720p
scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w))
resized_h, resized_w = round(h * scale), round(w * scale)
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size.
nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size.
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw)
wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt)
nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww)
return [
(
slice(it * wt, min((it + 1) * wt, t)),
@ -137,19 +139,18 @@ def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int,
def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]):
t, h, w = size
resized_nt, resized_nh, resized_nw = num_windows
#cal windows under 720p
scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w))
resized_h, resized_w = round(h * scale), round(w * scale)
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size.
wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size.
wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw)
wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt)
st, sh, sw = ( # shift size.
st, sh, sw = (
0.5 if wt < t else 0,
0.5 if wh < h else 0,
0.5 if ww < w else 0,
)
nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size.
nt, nh, nw = ( # number of window.
nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww)
nt, nh, nw = (
nt + 1 if st > 0 else 1,
nh + 1 if sh > 0 else 1,
nw + 1 if sw > 0 else 1,
@ -175,7 +176,6 @@ class RotaryEmbedding(nn.Module):
freqs_for = 'lang',
theta = 10000,
max_freq = 10,
learned_freq = False,
):
super().__init__()
@ -185,18 +185,14 @@ class RotaryEmbedding(nn.Module):
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
elif freqs_for == 'pixel':
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
else:
raise ValueError(f"Unknown rotary frequency type: {freqs_for}")
self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
self.learned_freq = learned_freq
# dummy for device
self.register_buffer('dummy', torch.tensor(0), persistent = False)
self.register_buffer("freqs", freqs)
@property
def device(self):
return self.dummy.device
return self.freqs.device
def get_axial_freqs(
self,
@ -206,10 +202,9 @@ class RotaryEmbedding(nn.Module):
Colon = slice(None)
all_freqs = []
# handle offset
if exists(offsets):
assert len(offsets) == len(dims)
if len(offsets) != len(dims):
raise ValueError(f"SeedVR2 rotary offsets length must match dims length, got {len(offsets)} and {len(dims)}.")
for ind, dim in enumerate(dims):
@ -224,7 +219,7 @@ class RotaryEmbedding(nn.Module):
pos = pos + offset
freqs = self.forward(pos, seq_len = dim)
freqs = self.forward(pos)
all_axis = [None] * len(dims)
all_axis[ind] = Colon
@ -232,16 +227,12 @@ class RotaryEmbedding(nn.Module):
new_axis_slice = (Ellipsis, *all_axis, Colon)
all_freqs.append(freqs[new_axis_slice])
# concat all freqs
all_freqs = torch.broadcast_tensors(*all_freqs)
return torch.cat(all_freqs, dim = -1)
def forward(
self,
t,
seq_len: int | None = None,
offset = 0
):
freqs = self.freqs
@ -258,9 +249,6 @@ class RotaryEmbeddingBase(nn.Module):
freqs_for="pixel",
max_freq=BYTEDANCE_ROPE_MAX_FREQ,
)
freqs = self.rope.freqs
del self.rope.freqs
self.rope.register_buffer("freqs", freqs.detach())
def get_axial_freqs(self, *dims):
return self.rope.get_axial_freqs(*dims)
@ -306,7 +294,7 @@ class NaRotaryEmbedding3d(RotaryEmbedding3d):
freqs_for="pixel",
max_freq=BYTEDANCE_ROPE_MAX_FREQ,
)
plain_rope = plain_rope.to(self.rope.dummy.device)
plain_rope = plain_rope.to(self.rope.device)
freq_list = []
for f, h, w in shape.tolist():
freqs = plain_rope.get_axial_freqs(f, h, w)
@ -322,9 +310,6 @@ class MMRotaryEmbeddingBase(RotaryEmbeddingBase):
freqs_for="lang",
theta=ROPE_THETA,
)
freqs = self.rope.freqs
del self.rope.freqs
self.rope.register_buffer("freqs", freqs.detach())
self.mm = True
def slice_at_dim(t, dim_slice: slice, *, dim):
@ -333,8 +318,6 @@ def slice_at_dim(t, dim_slice: slice, *, dim):
colons[dim] = dim_slice
return t[tuple(colons)]
# rotary embedding helper functions
def rotate_half(x):
x = x.reshape(*x.shape[:-1], x.shape[-1] // 2, 2)
x1, x2 = x.unbind(dim = -1)
@ -373,7 +356,6 @@ def _apply_seedvr2_rotary_emb(
return torch.cat((t_left, t_middle, t_right), dim=-1).to(dtype)
def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor:
"""Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`."""
angles = freqs_interleaved[..., ::2].float()
cos = torch.cos(angles)
sin = torch.sin(angles)
@ -382,12 +364,6 @@ def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor:
def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""Rotate the leading ``rot_d = 2 * freqs_cis.shape[-3]`` dims of ``t`` and pass the rest
through; in-place for inference, cloned for training (autograd). Mirrors the legacy
``apply_rotary_emb`` ``t_left``/``t_middle``/``t_right`` split: 3B ``rope_dim=128`` gives
``42*3 = 126`` rotated of head_dim 128 (trailing 2 unrotated). Fast path skips the cat when
``rot_d == t.shape[-1]``.
"""
out = t.clone() if t.requires_grad or comfy.model_management.in_training else t
rot_d = 2 * freqs_cis.shape[-3]
seq_len = out.shape[-2]
@ -454,14 +430,13 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
torch.Tensor,
]:
# Calculate actual max dimensions needed for this batch
max_temporal = 0
max_height = 0
max_width = 0
max_txt_len = 0
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal
max_temporal = max(max_temporal, l + f)
max_height = max(max_height, h)
max_width = max(max_width, w)
max_txt_len = max(max_txt_len, l)
@ -475,7 +450,6 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
).float()
txt_freqs = self.get_axial_freqs(max_txt_len + 16)
# Now slice as before
vid_freq_list, txt_freq_list = [], []
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1))
@ -485,13 +459,6 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
vid_freqs_interleaved = torch.cat(vid_freq_list, dim=0)
txt_freqs_interleaved = torch.cat(txt_freq_list, dim=0)
# Convert from lucidrains-interleaved layout `[θ0, θ0, θ1, θ1, ...]`
# (produced by `repeat(freqs, '... n -> ... (n r)', r=2)` in the
# upstream `RotaryEmbedding.forward`) to flux-canonical `freqs_cis`
# in shape `[..., d/2, 2, 2]` with `cos/-sin/sin/cos` baked in.
# Mirrors `comfy/ldm/flux/math.py:rope` (line 27) so the trailing
# 2x2 is the per-frequency rotation matrix that
# `comfy.ldm.flux.math.apply_rope1` expects.
return _to_flux_freqs_cis(vid_freqs_interleaved), _to_flux_freqs_cis(txt_freqs_interleaved)
class MMModule(nn.Module):
@ -507,8 +474,10 @@ class MMModule(nn.Module):
self.shared_weights = shared_weights
self.vid_only = vid_only
if self.shared_weights:
assert get_args("vid", args) == get_args("txt", args)
assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs)
if get_args("vid", args) != get_args("txt", args):
raise ValueError("SeedVR2 shared MMModule requires matching vid/txt args.")
if get_kwargs("vid", kwargs) != get_kwargs("txt", kwargs):
raise ValueError("SeedVR2 shared MMModule requires matching vid/txt kwargs.")
self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs))
else:
self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs))
@ -543,6 +512,7 @@ def get_na_rope(rope_type: Optional[str], dim: int):
return NaRotaryEmbedding3d(dim=dim)
if rope_type == "mmrope3d":
return NaMMRotaryEmbedding3d(dim=dim)
raise ValueError(f"Unknown SeedVR2 rope type: {rope_type}")
class NaMMAttention(nn.Module):
def __init__(
@ -558,7 +528,6 @@ class NaMMAttention(nn.Module):
rope_dim: int,
shared_weights: bool,
device, dtype, operations,
**kwargs,
):
super().__init__()
dim = MMArg(vid_dim, txt_dim)
@ -597,16 +566,19 @@ def window(
):
hid = unflatten(hid, hid_shape)
hid = list(map(window_fn, hid))
hid_windows = torch.as_tensor([len(x) for x in hid], device=hid_shape.device)
hid, hid_shape = flatten(list(chain(*hid)))
return hid, hid_shape, hid_windows
hid_windows_list = [len(x) for x in hid]
hid_windows = torch.as_tensor(hid_windows_list, device=hid_shape.device)
hid = list(chain(*hid))
hid_len_list = [math.prod(x.shape[:-1]) for x in hid]
hid, hid_shape = flatten(hid)
return hid, hid_shape, hid_windows, hid_len_list, hid_windows_list
def window_idx(
hid_shape: torch.LongTensor, # (b n)
window_fn: Callable[[torch.Tensor], List[torch.Tensor]],
):
hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1)
tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn)
tgt_idx, tgt_shape, tgt_windows, tgt_len_list, tgt_windows_list = window(hid_idx, hid_shape, window_fn)
tgt_idx = tgt_idx.squeeze(-1)
src_idx = torch.argsort(tgt_idx)
return (
@ -614,6 +586,8 @@ def window_idx(
lambda hid: torch.index_select(hid, 0, src_idx),
tgt_shape,
tgt_windows,
tgt_len_list,
tgt_windows_list,
)
class NaSwinAttention(NaMMAttention):
@ -622,13 +596,15 @@ class NaSwinAttention(NaMMAttention):
*args,
window: Union[int, Tuple[int, int, int]],
window_method: str,
version: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.version_7b = kwargs.get("version", False)
self.version_7b = version
self.window = _triple(window)
self.window_method = window_method
assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window))
if not all(isinstance(v, int) and v >= 0 for v in self.window):
raise ValueError(f"SeedVR2 window must contain non-negative integers, got {self.window}.")
self.window_op = get_window_op(window_method)
@ -646,7 +622,6 @@ class NaSwinAttention(NaMMAttention):
vid_qkv, txt_qkv = self.proj_qkv(vid, txt)
# re-org the input seq for window attn
cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3")
def make_window(x: torch.Tensor):
@ -654,7 +629,7 @@ class NaSwinAttention(NaMMAttention):
window_slices = self.window_op((t, h, w), self.window)
return [x[st, sh, sw] for (st, sh, sw) in window_slices]
window_partition, window_reverse, window_shape, window_count = cache_win(
window_partition, window_reverse, window_shape, window_count, vid_len_win_list, window_count_list = cache_win(
"win_transform",
lambda: window_idx(vid_shape, make_window),
)
@ -674,23 +649,21 @@ class NaSwinAttention(NaMMAttention):
vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1))
txt_len = txt_len.to(window_count.device)
# window rope
if self.rope:
if self.version_7b:
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
elif self.rope.mm:
# repeat text q and k for window mmrope
_, num_h, _ = txt_q.shape
txt_q_repeat = txt_q.flatten(1, 2)
txt_q_repeat = unflatten(txt_q_repeat, txt_shape)
txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)]
txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count_list)]
txt_q_repeat = list(chain(*txt_q_repeat))
txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat)
txt_q_repeat = txt_q_repeat.reshape(txt_q_repeat.shape[0], num_h, self.head_dim)
txt_k_repeat = txt_k.flatten(1, 2)
txt_k_repeat = unflatten(txt_k_repeat, txt_shape)
txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)]
txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count_list)]
txt_k_repeat = list(chain(*txt_k_repeat))
txt_k_repeat, _ = flatten(txt_k_repeat)
txt_k_repeat = txt_k_repeat.reshape(txt_k_repeat.shape[0], num_h, self.head_dim)
@ -702,7 +675,11 @@ class NaSwinAttention(NaMMAttention):
vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win)
txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count))
all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win)
txt_len_win_list = cache_win(
"txt_len_list",
lambda: [txt_len for txt_len, window_count in zip(txt_len.tolist(), window_count_list) for _ in range(window_count)],
)
all_len_win = cache_win("all_len", lambda: [vid_len + txt_len for vid_len, txt_len in zip(vid_len_win_list, txt_len_win_list)])
concat_win, unconcat_win = cache_win(
"mm_pnp", lambda: repeat_concat_idx(vid_len_win, txt_len, window_count)
)
@ -711,12 +688,8 @@ class NaSwinAttention(NaMMAttention):
k=concat_win(vid_k, txt_k),
v=concat_win(vid_v, txt_v),
heads=self.heads, skip_reshape=True, skip_output_reshape=True,
cu_seqlens_q=cache_win(
"vid_seqlens_q", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int()
),
cu_seqlens_k=cache_win(
"vid_seqlens_k", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int()
),
cu_seqlens_q=cache_win("vid_seqlens_q", lambda: cumulative_lengths(all_len_win)),
cu_seqlens_k=cache_win("vid_seqlens_k", lambda: cumulative_lengths(all_len_win)),
)
vid_out, txt_out = unconcat_win(out)
@ -766,11 +739,11 @@ class SwiGLUMLP(nn.Module):
return self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x))
def get_mlp(mlp_type: Optional[str] = "normal"):
# 3b and 7b uses different mlp types
if mlp_type == "normal":
return MLP
elif mlp_type == "swiglu":
if mlp_type == "swiglu":
return SwiGLUMLP
raise ValueError(f"Unknown SeedVR2 MLP type: {mlp_type}")
class NaMMSRTransformerBlock(nn.Module):
def __init__(
@ -792,11 +765,12 @@ class NaMMSRTransformerBlock(nn.Module):
rope_type: str,
rope_dim: int,
is_last_layer: bool,
window: Union[int, Tuple[int, int, int]],
window_method: str,
version: bool,
device, dtype, operations,
**kwargs,
):
super().__init__()
version = kwargs.get("version", False)
dim = MMArg(vid_dim, txt_dim)
self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype)
@ -811,8 +785,8 @@ class NaMMSRTransformerBlock(nn.Module):
rope_type=rope_type,
rope_dim=rope_dim,
shared_weights=shared_weights,
window=kwargs.pop("window", None),
window_method=kwargs.pop("window_method", None),
window=window,
window_method=window_method,
version=version,
device=device, dtype=dtype, operations=operations
)
@ -930,12 +904,14 @@ class NaPatchOut(PatchOut):
self,
vid: torch.FloatTensor, # l c
vid_shape: torch.LongTensor,
cache: Cache = Cache(disable=True),
cache: Optional[Cache] = None,
vid_shape_before_patchify = None
) -> Tuple[
torch.FloatTensor,
torch.LongTensor,
]:
if cache is None:
cache = Cache(disable=True)
t, h, w = self.patch_size
vid = self.proj(vid)
@ -971,7 +947,10 @@ class PatchIn(nn.Module):
) -> torch.Tensor:
t, h, w = self.patch_size
if t > 1:
assert vid.size(2) % t == 1
if vid.size(2) % t != 1:
raise ValueError(
f"SeedVR2 patch input temporal size must satisfy T % {t} == 1, got {vid.size(2)}."
)
vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2)
b, c, Tt, Hh, Ww = vid.shape
vid = vid.view(b, c, Tt // t, t, Hh // h, h, Ww // w, w).permute(0, 2, 4, 6, 3, 5, 7, 1).reshape(b, Tt // t, Hh // h, Ww // w, t * h * w * c)
@ -983,8 +962,10 @@ class NaPatchIn(PatchIn):
self,
vid: torch.Tensor, # l c
vid_shape: torch.LongTensor,
cache: Cache = Cache(disable=True),
cache: Optional[Cache] = None,
) -> torch.Tensor:
if cache is None:
cache = Cache(disable=True)
cache = cache.namespace("patch")
vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape)
t, h, w = self.patch_size
@ -1012,10 +993,11 @@ class AdaSingle(nn.Module):
dim: int,
emb_dim: int,
layers: List[str],
modes: List[str] = ["in", "out"],
modes: Tuple[str, ...] = ("in", "out"),
device = None, dtype = None,
):
assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim"
if emb_dim != 6 * dim:
raise ValueError(f"SeedVR2 AdaSingle requires emb_dim == 6 * dim, got emb_dim={emb_dim}, dim={dim}.")
super().__init__()
self.dim = dim
self.emb_dim = emb_dim
@ -1036,22 +1018,20 @@ class AdaSingle(nn.Module):
emb: torch.FloatTensor, # b d
layer: str,
mode: str,
cache: Cache = Cache(disable=True),
cache: Optional[Cache] = None,
branch_tag: str = "",
hid_len: Optional[torch.LongTensor] = None, # b
) -> torch.FloatTensor:
if cache is None:
cache = Cache(disable=True)
idx = self.layers.index(layer)
emb = emb.reshape(emb.shape[0], -1, len(self.layers), 3)[:, :, idx, :]
emb = expand_dims(emb, 1, hid.ndim + 1)
if hid_len is not None:
slice_inputs = lambda x, dim: x
emb = cache(
f"emb_repeat_{idx}_{branch_tag}",
lambda: slice_inputs(
torch.repeat_interleave(emb, hid_len, dim=0),
dim=0,
),
lambda: torch.repeat_interleave(emb, hid_len, dim=0),
)
shiftA, scaleA, gateA = emb.unbind(-1)
@ -1069,7 +1049,7 @@ class AdaSingle(nn.Module):
else:
return hid.mul_(gateA)
raise NotImplementedError
raise ValueError(f"Unknown AdaSingle mode: {mode}")
class TimeEmbedding(nn.Module):
@ -1117,7 +1097,8 @@ def flatten(
torch.FloatTensor, # (L c)
torch.LongTensor, # (b n)
]:
assert len(hid) > 0
if len(hid) == 0:
raise ValueError("SeedVR2 flatten requires at least one tensor.")
shape = torch.as_tensor([x.shape[:-1] for x in hid], device=hid[0].device)
hid = torch.cat([x.flatten(0, -2) for x in hid])
return hid, shape
@ -1140,7 +1121,7 @@ class NaDiT(nn.Module):
num_layers,
mlp_type,
vid_in_channels = 33,
vid_out_channels = 16,
vid_out_channels = SEEDVR2_LATENT_CHANNELS,
vid_dim = 2560,
txt_in_dim = 5120,
heads = 20,
@ -1148,15 +1129,17 @@ class NaDiT(nn.Module):
mm_layers = 10,
expand_ratio = 4,
qk_bias = False,
patch_size = [ 1,2,2 ],
patch_size = (1, 2, 2),
rope_dim = 128,
rope_type = "mmrope3d",
vid_out_norm: Optional[str] = None,
image_model = None,
device = None,
dtype = None,
operations = None,
**kwargs,
):
if image_model not in (None, "seedvr2"):
raise ValueError(f"SeedVR2 NaDiT expected image_model='seedvr2', got {image_model!r}.")
self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM
if self._7b_version:
rope_type = "rope3d"
@ -1212,14 +1195,13 @@ class NaDiT(nn.Module):
rope_dim = rope_dim,
window=window[i],
window_method=window_method[i],
version = self._7b_version,
is_last_layer=(i == num_layers - 1) and not self._7b_version,
rope_type = rope_type,
shared_weights=not (
(i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i]
),
version = self._7b_version,
operations = operations,
**kwargs,
**factory_kwargs
)
for i in range(num_layers)
@ -1272,13 +1254,17 @@ class NaDiT(nn.Module):
first = cond_or_uncond[0]
return all(entry == first for entry in cond_or_uncond)
@staticmethod
def _check_seedvr2_video_latent(x, channels, name):
if x.ndim != 5:
raise ValueError(f"SeedVR2 expected {name} to be 5-D native latent, got shape {tuple(x.shape)}.")
if x.shape[1] != channels:
raise ValueError(f"SeedVR2 expected {name} channels to be {channels}, got shape {tuple(x.shape)}.")
return x
def _swap_pos_neg_halves(self, out, cond_or_uncond=None):
if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond):
return out
# ``dim=0`` is explicit on both calls. The contract is "split
# the batch axis into two halves and swap them"; making the
# axis load-bearing in source guards against silent drift if a
# future refactor reorders tensor axes.
pos, neg = out.chunk(2, dim=0)
return torch.cat([neg, pos], dim=0)
@ -1294,9 +1280,15 @@ class NaDiT(nn.Module):
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
conditions = kwargs.get("condition")
b, tc, h, w = x.shape
x = x.view(b, 16, -1, h, w)
conditions = conditions.view(b, 17, -1, h, w)
if conditions is None:
raise ValueError("SeedVR2 requires conditioning latents from the SeedVR2Conditioning node.")
x = self._check_seedvr2_video_latent(x, SEEDVR2_LATENT_CHANNELS, "latent")
conditions = self._check_seedvr2_video_latent(conditions, SEEDVR2_LATENT_CHANNELS + 1, "conditioning")
b, _, t, h, w = x.shape
if conditions.shape[0] != b or conditions.shape[2:] != (t, h, w):
raise ValueError(
f"SeedVR2 conditioning shape must match latent batch/temporal/spatial dimensions; got latent {tuple(x.shape)} and conditioning {tuple(conditions.shape)}."
)
x = x.movedim(1, -1)
conditions = conditions.movedim(1, -1)
cache = Cache(disable=disable_cache)
@ -1361,7 +1353,6 @@ class NaDiT(nn.Module):
vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify)
vid = unflatten(vid, vid_shape)
out = torch.stack(vid)
out = torch.stack(vid)
out = out.movedim(-1, 1)
out = out.reshape(out.shape[0], out.shape[1] * out.shape[2], out.shape[3], out.shape[4])
return self._swap_pos_neg_halves(out, transformer_options.get("cond_or_uncond"))

View File

@ -62,7 +62,6 @@ def tiled_vae(
temporal_size=16,
temporal_overlap=0,
encode=True,
**kwargs,
):
if x.ndim != 5:
x = x.unsqueeze(2)
@ -166,8 +165,8 @@ def tiled_vae(
if single_spatial_tile:
result = tile_out[:, :, :target_d, :target_h, :target_w]
if result.device != x.device:
result = result.to(x.device).to(x.dtype)
if result.device != x.device or result.dtype != x.dtype:
result = result.to(device=x.device, dtype=x.dtype)
if x.shape[2] == 1 and sf_t == 1:
result = result.squeeze(2)
bar.update(1)
@ -221,8 +220,8 @@ def tiled_vae(
result.div_(count.clamp(min=1e-6))
if result.device != x.device:
result = result.to(x.device).to(x.dtype)
if result.device != x.device or result.dtype != x.dtype:
result = result.to(device=x.device, dtype=x.dtype)
if x.shape[2] == 1 and sf_t == 1:
result = result.squeeze(2)
@ -256,15 +255,18 @@ class MemoryState(Enum):
UNSET = 3
def get_cache_size(conv_module, input_len, pad_len, dim=0):
dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1
output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1
dilated_kernel_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1
output_len = (input_len + pad_len - dilated_kernel_size) // conv_module.stride[dim] + 1
remain_len = (
input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size)
input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernel_size)
)
overlap_len = dilated_kernerl_size - conv_module.stride[dim]
cache_len = overlap_len + remain_len # >= 0
overlap_len = dilated_kernel_size - conv_module.stride[dim]
cache_len = overlap_len + remain_len
assert output_len > 0
if output_len <= 0:
raise ValueError(
f"SeedVR2 VAE cache input is too short for convolution: input_len={input_len}, pad_len={pad_len}."
)
return cache_len
class DiagonalGaussianDistribution(object):
@ -294,52 +296,27 @@ class SpatialNorm(nn.Module):
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
# partial implementation of diffusers's Attention for comfyui
class Attention(nn.Module):
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
kv_heads: Optional[int] = None,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
upcast_softmax: bool = False,
norm_num_groups: Optional[int] = None,
spatial_norm_dim: Optional[int] = None,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 1e-5,
rescale_output_factor: float = 1.0,
residual_connection: bool = False,
out_dim: int = None,
pre_only=False,
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
self.query_dim = query_dim
self.use_bias = bias
self.is_cross_attention = cross_attention_dim is not None
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_softmax = upcast_softmax
self.inner_dim = dim_head * heads
self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
self.pre_only = pre_only
self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = out_dim // dim_head if out_dim is not None else heads
self.sliceable_head_dim = heads
self.only_cross_attention = only_cross_attention
self.out_dim = query_dim
self.heads = heads
if norm_num_groups is not None:
self.group_norm = ops.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True)
@ -351,37 +328,19 @@ class Attention(nn.Module):
else:
self.spatial_norm = None
self.norm_q = None
self.norm_k = None
self.norm_cross = None
self.to_q = ops.Linear(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
self.to_v = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
else:
self.to_k = None
self.to_v = None
if not self.pre_only:
self.to_out = nn.ModuleList([])
self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
else:
self.to_out = None
self.to_k = ops.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = ops.Linear(query_dim, self.inner_dim, bias=bias)
self.to_out = nn.ModuleList([])
self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Identity())
self.optimized_vae_attention = vae_attention()
def __call__(
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
*args,
**kwargs,
) -> torch.Tensor:
residual = hidden_states
@ -394,20 +353,14 @@ class Attention(nn.Module):
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
batch_size = hidden_states.shape[0]
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // self.heads
@ -417,25 +370,18 @@ class Attention(nn.Module):
key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2)
if self.norm_q is not None:
query = self.norm_q(query)
if self.norm_k is not None:
key = self.norm_k(key)
if input_ndim == 4 and encoder_hidden_states is hidden_states and attention_mask is None and self.heads == 1:
if input_ndim == 4 and self.heads == 1:
query = query.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width)
key = key.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width)
value = value.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width)
hidden_states = self.optimized_vae_attention(query, key, value).reshape(batch_size, self.heads, head_dim, height * width).transpose(2, 3)
else:
hidden_states = optimized_attention(query, key, value, heads = self.heads, mask = attention_mask, skip_reshape=True, skip_output_reshape=True)
hidden_states = optimized_attention(query, key, value, heads = self.heads, skip_reshape=True, skip_output_reshape=True)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
if input_ndim == 4:
@ -471,7 +417,10 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
memory_occupy = x.numel() * x.element_size() / 1024**3
if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit():
num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups)
assert norm_layer.num_groups % num_chunks == 0
if norm_layer.num_groups % num_chunks != 0:
raise ValueError(
f"SeedVR2 VAE GroupNorm groups must divide chunks: groups={norm_layer.num_groups}, chunks={num_chunks}."
)
num_groups_per_chunk = norm_layer.num_groups // num_chunks
x = list(x.chunk(num_chunks, dim=1))
@ -485,14 +434,15 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor:
x = norm_layer(x)
x = x.reshape((b, t, x.size(1), x.size(2), x.size(3))).transpose(1, 2)
return x.to(input_dtype)
raise NotImplementedError
raise TypeError(f"SeedVR2 VAE unsupported norm layer type: {type(norm_layer).__name__}")
_receptive_field_t = Literal["half", "full"]
def extend_head(tensor, times: int = 2, memory = None):
if memory is not None:
return torch.cat((memory.to(tensor), tensor), dim=2)
assert times >= 0, "Invalid input for function 'extend_head'!"
if times < 0:
raise ValueError(f"SeedVR2 VAE extend_head expected times >= 0, got {times}.")
if times == 0:
return tensor
else:
@ -547,13 +497,11 @@ class InflatedCausalConv3d(ops.Conv3d):
padding=(0, 0, 0, 0, 0, 0),
prev_cache=None,
):
# Compatible with no limit.
if math.isinf(self.memory_limit):
if prev_cache is not None:
x = torch.cat([prev_cache, x], dim=split_dim - 1)
return super().forward(x)
# Compute tensor shape after concat & padding.
shape = list(x.size())
if prev_cache is not None:
shape[split_dim - 1] += prev_cache.size(split_dim - 1)
@ -597,16 +545,19 @@ class InflatedCausalConv3d(ops.Conv3d):
next_cache = None
cache_len = cache.size(split_dim) if cache is not None else 0
next_catch_size = get_cache_size(
next_cache_size = get_cache_size(
conv_module=self,
input_len=x[idx].size(split_dim) + cache_len,
pad_len=pad_len,
dim=split_dim - 2,
)
if next_catch_size != 0:
assert next_catch_size <= x[idx].size(split_dim)
if next_cache_size != 0:
if next_cache_size > x[idx].size(split_dim):
raise ValueError(
f"SeedVR2 VAE cache size {next_cache_size} exceeds split size {x[idx].size(split_dim)}."
)
next_cache = (
x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim)
x[idx].transpose(0, split_dim)[-next_cache_size:].transpose(0, split_dim)
)
x[idx] = self.memory_limit_conv(
@ -627,7 +578,8 @@ class InflatedCausalConv3d(ops.Conv3d):
memory_state: MemoryState = MemoryState.UNSET,
memory_cache = None,
) -> Tensor:
assert memory_state != MemoryState.UNSET
if memory_state == MemoryState.UNSET:
raise ValueError("SeedVR2 VAE convolution requires an explicit MemoryState.")
if memory_cache is None:
memory_cache = {}
if memory_state != MemoryState.ACTIVE:
@ -677,9 +629,8 @@ class InflatedCausalConv3d(ops.Conv3d):
input, cache_size=cache_size, memory=memory, times=self.temporal_padding * 2
)
# Single GPU inference - simplified memory management
if (
memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing
memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE]
and cache_size != 0
):
if cache_size > input[-1].size(2) and cache is not None and len(input) == 1:
@ -690,7 +641,6 @@ class InflatedCausalConv3d(ops.Conv3d):
padding = tuple(x for x in reversed(self.padding) for _ in range(2))
for i in range(len(input)):
# Prepare cache for next input slice.
next_cache = None
cache_size = 0
if i < len(input) - 1:
@ -700,17 +650,16 @@ class InflatedCausalConv3d(ops.Conv3d):
if cache_size > input[i].size(2) and cache is not None:
input[i] = torch.cat([cache, input[i]], dim=2)
cache = None
assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}"
if cache_size > input[i].size(2):
raise ValueError(f"SeedVR2 VAE cache size {cache_size} exceeds input length {input[i].size(2)}.")
next_cache = input[i][:, :, -cache_size:]
# Conv forward for this input slice.
input[i] = self.memory_limit_conv(
input[i],
padding=padding,
prev_cache=cache
)
# Update cache.
cache = next_cache
return input[0] if squeeze_out else input
@ -729,7 +678,6 @@ class Upsample3D(nn.Module):
inflation_mode = "tail",
temporal_up: bool = False,
spatial_up: bool = True,
**kwargs,
):
super().__init__()
self.channels = channels
@ -760,9 +708,9 @@ class Upsample3D(nn.Module):
hidden_states: torch.FloatTensor,
memory_state=None,
memory_cache=None,
**kwargs,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if hidden_states.shape[1] != self.channels:
raise ValueError(f"SeedVR2 upsample expected {self.channels} channels, got {hidden_states.shape[1]}.")
hidden_states = self.upscale_conv(hidden_states)
b, channels, f, h, w = hidden_states.shape
@ -785,8 +733,6 @@ class Upsample3D(nn.Module):
class Downsample3D(nn.Module):
"""A 3D downsampling layer with an optional convolution."""
def __init__(
self,
channels,
@ -794,7 +740,6 @@ class Downsample3D(nn.Module):
inflation_mode = "tail",
spatial_down: bool = False,
temporal_down: bool = False,
**kwargs,
):
super().__init__()
self.channels = channels
@ -823,20 +768,17 @@ class Downsample3D(nn.Module):
hidden_states: torch.FloatTensor,
memory_state = None,
memory_cache = None,
**kwargs,
) -> torch.FloatTensor:
assert hidden_states.shape[1] == self.channels
if hasattr(self, "norm") and self.norm is not None:
# [Overridden] change to causal norm.
hidden_states = causal_norm_wrapper(self.norm, hidden_states)
if hidden_states.shape[1] != self.channels:
raise ValueError(f"SeedVR2 downsample expected {self.channels} channels, got {hidden_states.shape[1]}.")
if self.spatial_down:
pad = (0, 1, 0, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert hidden_states.shape[1] == self.channels
if hidden_states.shape[1] != self.channels:
raise ValueError(f"SeedVR2 downsample expected {self.channels} channels after padding, got {hidden_states.shape[1]}.")
hidden_states = self.conv(hidden_states, memory_state=memory_state, memory_cache=memory_cache)
@ -848,7 +790,6 @@ class ResnetBlock3D(nn.Module):
self,
in_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
temb_channels: int = 512,
groups: int = 32,
groups_out: Optional[int] = None,
@ -857,7 +798,6 @@ class ResnetBlock3D(nn.Module):
skip_time_act: bool = False,
inflation_mode = "tail",
time_receptive_field: _receptive_field_t = "half",
**kwargs,
):
super().__init__()
self.in_channels = in_channels
@ -866,15 +806,14 @@ class ResnetBlock3D(nn.Module):
self.skip_time_act = skip_time_act
self.nonlinearity = nn.SiLU()
if temb_channels is not None:
self.time_emb_proj = ops.Linear(temb_channels, out_channels)
self.time_emb_proj = ops.Linear(temb_channels, self.out_channels)
else:
self.time_emb_proj = None
self.norm1 = ops.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
if groups_out is None:
groups_out = groups
self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.use_in_shortcut = self.in_channels != out_channels
self.dropout = torch.nn.Dropout(dropout)
self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=self.out_channels, eps=eps, affine=True)
self.use_in_shortcut = self.in_channels != self.out_channels
self.conv1 = InflatedCausalConv3d(
self.in_channels,
self.out_channels,
@ -886,7 +825,7 @@ class ResnetBlock3D(nn.Module):
self.conv2 = InflatedCausalConv3d(
self.out_channels,
out_channels,
self.out_channels,
kernel_size=3,
stride=1,
padding=1,
@ -897,7 +836,7 @@ class ResnetBlock3D(nn.Module):
if self.use_in_shortcut:
self.conv_shortcut = InflatedCausalConv3d(
self.in_channels,
out_channels,
self.out_channels,
kernel_size=1,
stride=1,
padding=0,
@ -905,9 +844,7 @@ class ResnetBlock3D(nn.Module):
inflation_mode=inflation_mode,
)
def forward(
self, input_tensor, temb, memory_state = None, memory_cache = None, **kwargs
):
def forward(self, input_tensor, temb, memory_state = None, memory_cache = None):
hidden_states = input_tensor
hidden_states = causal_norm_wrapper(self.norm1, hidden_states)
@ -928,7 +865,6 @@ class ResnetBlock3D(nn.Module):
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, memory_state=memory_state, memory_cache=memory_cache)
if self.conv_shortcut is not None:
@ -944,7 +880,6 @@ class DownEncoderBlock3D(nn.Module):
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
@ -957,28 +892,23 @@ class DownEncoderBlock3D(nn.Module):
):
super().__init__()
resnets = []
temporal_modules = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
# [Override] Replace module.
ResnetBlock3D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=None,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
output_scale_factor=output_scale_factor,
inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field,
)
)
temporal_modules.append(nn.Identity())
self.resnets = nn.ModuleList(resnets)
self.temporal_modules = nn.ModuleList(temporal_modules)
if add_downsample:
self.downsamplers = nn.ModuleList(
@ -1000,11 +930,9 @@ class DownEncoderBlock3D(nn.Module):
hidden_states: torch.FloatTensor,
memory_state = None,
memory_cache = None,
**kwargs,
) -> torch.FloatTensor:
for resnet, temporal in zip(self.resnets, self.temporal_modules):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache)
hidden_states = temporal(hidden_states)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
@ -1018,7 +946,6 @@ class UpDecoderBlock3D(nn.Module):
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
@ -1032,33 +959,26 @@ class UpDecoderBlock3D(nn.Module):
):
super().__init__()
resnets = []
temporal_modules = []
for i in range(num_layers):
input_channels = in_channels if i == 0 else out_channels
resnets.append(
# [Override] Replace module.
ResnetBlock3D(
in_channels=input_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
output_scale_factor=output_scale_factor,
inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field,
)
)
temporal_modules.append(nn.Identity())
self.resnets = nn.ModuleList(resnets)
self.temporal_modules = nn.ModuleList(temporal_modules)
if add_upsample:
# [Override] Replace module & use learnable upsample
self.upsamplers = nn.ModuleList(
[
Upsample3D(
@ -1080,9 +1000,8 @@ class UpDecoderBlock3D(nn.Module):
memory_state=None,
memory_cache=None,
) -> torch.FloatTensor:
for resnet, temporal in zip(self.resnets, self.temporal_modules):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache)
hidden_states = temporal(hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@ -1096,7 +1015,6 @@ class UNetMidBlock3D(nn.Module):
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default", # default, spatial
@ -1111,16 +1029,13 @@ class UNetMidBlock3D(nn.Module):
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention
# there is always at least one resnet
resnets = [
# [Override] Replace module.
ResnetBlock3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
output_scale_factor=output_scale_factor,
inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field,
@ -1148,7 +1063,6 @@ class UNetMidBlock3D(nn.Module):
),
residual_connection=True,
bias=True,
upcast_softmax=True,
)
)
else:
@ -1161,7 +1075,6 @@ class UNetMidBlock3D(nn.Module):
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
output_scale_factor=output_scale_factor,
inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field,
@ -1172,7 +1085,7 @@ class UNetMidBlock3D(nn.Module):
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, memory_state=None, memory_cache=None):
video_length, frame_height, frame_width = hidden_states.size()[-3:]
video_length = hidden_states.size(2)
hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state, memory_cache=memory_cache)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if attn is not None:
@ -1195,7 +1108,6 @@ class Encoder3D(nn.Module):
layers_per_block: int = 2,
norm_num_groups: int = 32,
mid_block_add_attention=True,
# [Override] add temporal down num
temporal_down_num: int = 2,
inflation_mode = "tail",
time_receptive_field: _receptive_field_t = "half",
@ -1216,17 +1128,15 @@ class Encoder3D(nn.Module):
self.mid_block = None
self.down_blocks = nn.ModuleList([])
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
# [Override] to support temporal down block design
is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1
# Note: take the last ones
assert down_block_type == "DownEncoderBlock3D"
if down_block_type != "DownEncoderBlock3D":
raise ValueError(f"SeedVR2 encoder only supports DownEncoderBlock3D, got {down_block_type}.")
down_block = DownEncoderBlock3D(
num_layers=self.layers_per_block,
@ -1242,7 +1152,6 @@ class Encoder3D(nn.Module):
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock3D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
@ -1256,7 +1165,6 @@ class Encoder3D(nn.Module):
time_receptive_field=time_receptive_field,
)
# out
self.conv_norm_out = ops.GroupNorm(
num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6
)
@ -1274,17 +1182,13 @@ class Encoder3D(nn.Module):
memory_state = None,
memory_cache = None,
) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = sample.to(next(self.parameters()).device)
sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache)
# down
for down_block in self.down_blocks:
sample = down_block(sample, memory_state=memory_state, memory_cache=memory_cache)
# middle
sample = self.mid_block(sample, memory_state=memory_state, memory_cache=memory_cache)
# post-process
sample = causal_norm_wrapper(self.conv_norm_out, sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache)
@ -1303,7 +1207,6 @@ class Decoder3D(nn.Module):
layers_per_block: int = 2,
norm_num_groups: int = 32,
mid_block_add_attention=True,
# [Override] add temporal up block
inflation_mode = "tail",
time_receptive_field: _receptive_field_t = "half",
temporal_up_num: int = 2,
@ -1326,7 +1229,6 @@ class Decoder3D(nn.Module):
temb_channels = None
# mid
self.mid_block = UNetMidBlock3D(
in_channels=block_out_channels[-1],
resnet_eps=1e-6,
@ -1340,7 +1242,6 @@ class Decoder3D(nn.Module):
time_receptive_field=time_receptive_field,
)
# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
@ -1349,7 +1250,8 @@ class Decoder3D(nn.Module):
is_final_block = i == len(block_out_channels) - 1
is_temporal_up_block = i < self.temporal_up_num
assert up_block_type == "UpDecoderBlock3D"
if up_block_type != "UpDecoderBlock3D":
raise ValueError(f"SeedVR2 decoder only supports UpDecoderBlock3D, got {up_block_type}.")
up_block = UpDecoderBlock3D(
num_layers=self.layers_per_block + 1,
in_channels=prev_output_channel,
@ -1365,7 +1267,6 @@ class Decoder3D(nn.Module):
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
self.conv_norm_out = ops.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
)
@ -1375,7 +1276,6 @@ class Decoder3D(nn.Module):
)
# Note: Just copy from Decoder.
def forward(
self,
sample: torch.FloatTensor,
@ -1388,15 +1288,12 @@ class Decoder3D(nn.Module):
sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
# middle
sample = self.mid_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache)
sample = sample.to(upscale_dtype)
# up
for up_block in self.up_blocks:
sample = up_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache)
# post-process
sample = causal_norm_wrapper(self.conv_norm_out, sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache)
@ -1415,8 +1312,6 @@ class VideoAutoencoderKL(nn.Module):
inflation_mode = "pad",
time_receptive_field: _receptive_field_t = "full",
slicing_sample_min_size = BYTEDANCE_SLICING_SAMPLE_MIN,
*args,
**kwargs,
):
self.slicing_sample_min_size = slicing_sample_min_size
self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num)
@ -1425,7 +1320,6 @@ class VideoAutoencoderKL(nn.Module):
up_block_types = ("UpDecoderBlock3D",) * 4
super().__init__()
# pass init params to Encoder
self.encoder = Encoder3D(
in_channels=in_channels,
out_channels=latent_channels,
@ -1433,13 +1327,11 @@ class VideoAutoencoderKL(nn.Module):
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
# [Override] add temporal_down_num parameter
temporal_down_num=temporal_scale_num,
inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field,
)
# pass init params to Decoder
self.decoder = Decoder3D(
in_channels=latent_channels,
out_channels=out_channels,
@ -1447,7 +1339,6 @@ class VideoAutoencoderKL(nn.Module):
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
# [Override] add temporal_up_num parameter
temporal_up_num=temporal_scale_num,
inflation_mode=inflation_mode,
time_receptive_field=time_receptive_field,
@ -1489,11 +1380,10 @@ class VideoAutoencoderKL(nn.Module):
return output.to(z.device)
def slicing_encode(self, x: torch.Tensor) -> torch.Tensor:
sp_size =1
if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size:
if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size:
memory_cache = {}
split_size = max(
self.slicing_sample_min_size * sp_size,
self.slicing_sample_min_size,
getattr(self, "temporal_downsample_factor", 1),
)
x_slices = list(x[:, :, 1:].split(split_size=split_size, dim=2))
@ -1518,10 +1408,9 @@ class VideoAutoencoderKL(nn.Module):
return self._encode(x)
def slicing_decode(self, z: torch.Tensor) -> torch.Tensor:
sp_size = 1
if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size:
if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size:
memory_cache = {}
z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2)
z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size, dim=2)
decoded_slices = [
self._decode(
torch.cat((z[:, :, :1], z_slices[0]), dim=2),
@ -1538,33 +1427,28 @@ class VideoAutoencoderKL(nn.Module):
else:
return self._decode(z)
def forward(
self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs
):
# x: [b c t h w]
def forward(self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all"):
def _unwrap(value):
return value[0] if isinstance(value, tuple) else value
if mode == "encode":
return _unwrap(self.encode(x))
elif mode == "decode":
if mode == "decode":
return _unwrap(self.decode_(x))
else:
if mode == "all":
latent = _unwrap(self.encode(x))
return _unwrap(self.decode_(latent))
raise ValueError(f"Unknown SeedVR2 VAE forward mode: {mode}")
class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
def __init__(
self,
*args,
spatial_downsample_factor = 8,
temporal_downsample_factor = 4,
**kwargs,
):
self.spatial_downsample_factor = spatial_downsample_factor
self.temporal_downsample_factor = temporal_downsample_factor
self.enable_tiling = False
super().__init__(*args, **kwargs)
super().__init__()
self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB)
def forward(self, x: torch.FloatTensor):
@ -1581,7 +1465,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
z = p.squeeze(2)
return z, p
def encode(self, x, orig_dims=None):
def encode(self, x):
z, _ = self._encode_with_raw_latent(x)
return z
@ -1594,26 +1478,27 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
)
if z.ndim == 5:
b, c, t_latent, h, w = z.shape
if c != 16:
_, c, _, _, _ = z.shape
if c != SEEDVR2_LATENT_CHANNELS:
raise RuntimeError(
"SeedVR2 VideoAutoencoderKLWrapper.decode: 5-D latent input must "
f"have 16 channels; got shape {tuple(z.shape)}."
f"have {SEEDVR2_LATENT_CHANNELS} channels; got shape {tuple(z.shape)}."
)
latent = z
elif z.ndim == 4:
b, tc, h, w = z.shape
if tc % 16 != 0:
if tc % SEEDVR2_LATENT_CHANNELS != 0:
raise RuntimeError(
"SeedVR2 VideoAutoencoderKLWrapper.decode: 4-D latent input must "
"use collapsed channel layout (B, 16*T, H, W); "
f"use collapsed channel layout (B, {SEEDVR2_LATENT_CHANNELS}*T, H, W); "
f"got shape {tuple(z.shape)}."
)
latent = z.reshape(b, 16, -1, h, w)
latent = z.reshape(b, SEEDVR2_LATENT_CHANNELS, -1, h, w)
else:
raise RuntimeError(
"SeedVR2 VideoAutoencoderKLWrapper.decode: latent input must be "
"4-D collapsed (B, 16*T, H, W) or 5-D (B, 16, T, H, W); "
f"4-D collapsed (B, {SEEDVR2_LATENT_CHANNELS}*T, H, W) or "
f"5-D (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); "
f"got shape {tuple(z.shape)}."
)
scale = BYTEDANCE_VAE_SCALING_FACTOR
@ -1621,10 +1506,11 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
latent = latent / scale + shift
self.device = latent.device
self.enable_tiling = seedvr2_tiling.get("enable_tiling", False)
enable_tiling = seedvr2_tiling.get("enable_tiling", False)
if self.enable_tiling:
if enable_tiling:
decode_seedvr2_args = dict(seedvr2_tiling)
decode_seedvr2_args.pop("enable_tiling", None)
tile_h, tile_w = decode_seedvr2_args.get("tile_size", (512, 512))
ov_h, ov_w = decode_seedvr2_args.get("tile_overlap", (64, 64))
decode_seedvr2_args["tile_overlap"] = (
@ -1641,7 +1527,6 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
else:
x = super().decode_(latent)
# ensure even dims for save video
h, w = x.shape[-2:]
w2 = w - (w % 2)
h2 = h - (h % 2)
@ -1693,7 +1578,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
if samples.ndim == 4:
samples = samples.unsqueeze(2)
samples = samples.contiguous()
samples = samples * 0.9152
samples = samples * BYTEDANCE_VAE_SCALING_FACTOR
return samples
def comfy_memory_used_decode(self, shape):
@ -1707,15 +1592,15 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
# plus int64 sort indices dominate peak memory, not the VAE weight dtype.
if len(shape) == 5:
candidates = []
if shape[1] == 16:
if shape[1] == SEEDVR2_LATENT_CHANNELS:
candidates.append((shape[2], shape[3], shape[4]))
if shape[-1] == 16:
if shape[-1] == SEEDVR2_LATENT_CHANNELS:
candidates.append((shape[1], shape[2], shape[3]))
if len(candidates) == 0:
candidates.append((shape[2], shape[3], shape[4]))
pixels = max(output_pixels(*candidate) for candidate in candidates)
elif len(shape) == 4:
latent_t = max(1, (shape[1] + 15) // 16)
latent_t = max(1, (shape[1] + SEEDVR2_LATENT_CHANNELS - 1) // SEEDVR2_LATENT_CHANNELS)
pixels = output_pixels(latent_t, shape[2], shape[3])
else:
pixels = output_pixels(1, shape[-2], shape[-1])

View File

@ -933,7 +933,8 @@ class HunyuanDiT(BaseModel):
class SeedVR2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.seedvr.model.NaDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
condition = kwargs.get("condition", None)

View File

@ -598,43 +598,34 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config
if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b
seedvr2_7b_separate_key = "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)
if seedvr2_7b_separate_key in state_dict_keys and state_dict[seedvr2_7b_separate_key].shape[1] == 3072: # seedvr2 7b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 3072
dit_config["heads"] = 24
dit_config["num_layers"] = 36
# 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.``
# submodules) at EVERY block — verified by inspecting the 7B
# state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means
# ``MMModule.shared_weights=False``). Native NaDiT computes
# per-block ``shared_weights = not (i < mm_layers)``, so to keep
# every block non-shared we set ``mm_layers = num_layers``.
# Without this, blocks at index >= mm_layers (default 10) try to
# load ``blocks.N.*.all.*`` keys that don't exist in the file,
# silently miss-load → all-black output.
# This checkpoint uses separate vid/txt MMModule keys in every block.
dit_config["mm_layers"] = 36
dit_config["norm_eps"] = 1e-5
dit_config["rope_type"] = "rope3d"
dit_config["rope_dim"] = 64
dit_config["mlp_type"] = "normal"
return dit_config
elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
if "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 3072
dit_config["heads"] = 24
dit_config["num_layers"] = 36
# This checkpoint layout carries shared ``all.`` MMModule keys.
# Preserve the historical split: the initial blocks use separate
# vid/txt modules, later blocks use shared modules.
# This checkpoint uses shared all.* MMModule keys after the initial blocks.
dit_config["mm_layers"] = 10
dit_config["norm_eps"] = 1e-5
dit_config["rope_type"] = "rope3d"
dit_config["rope_dim"] = 64
dit_config["mlp_type"] = "swiglu"
return dit_config
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
if "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 2560
@ -1150,8 +1141,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
unet_config["heatmap_head"] = True
return unet_config
def normalize_seedvr2_unet_config(unet_config):
if unet_config.get("image_model") != "seedvr2" or "num_heads" not in unet_config:
return unet_config
unet_config = dict(unet_config)
num_heads = unet_config.pop("num_heads")
if "heads" in unet_config and unet_config["heads"] != num_heads:
raise ValueError(
f"SeedVR2 config has conflicting heads={unet_config['heads']} and num_heads={num_heads}."
)
unet_config["heads"] = num_heads
return unet_config
def model_config_from_unet_config(unet_config, state_dict=None, unet_key_prefix=""):
unet_config = normalize_seedvr2_unet_config(unet_config)
for model_config in comfy.supported_models.models:
if model_config.matches(unet_config, state_dict, unet_key_prefix=unet_key_prefix):
return model_config(unet_config)

View File

@ -472,8 +472,7 @@ class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd
if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
if metadata is None or metadata.get("keep_diffusers_format") != "true":
sd = diffusers_convert.convert_vae_state_dict(sd)
sd = diffusers_convert.convert_vae_state_dict(sd)
if model_management.is_amd():
VAE_KL_MEM_RATIO = 2.73
@ -549,7 +548,7 @@ class VAE:
self.latent_channels = 16
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
self.latent_channels = 16
self.latent_channels = comfy.ldm.seedvr.vae.SEEDVR2_LATENT_CHANNELS
self.latent_dim = 3
self.disable_offload = True
self.memory_used_decode = lambda shape, dtype: self.first_stage_model.comfy_memory_used_decode(shape)
@ -1074,6 +1073,20 @@ class VAE:
out = self.first_stage_model.encode_tiled(x, **kwargs)
return out.to(device=self.output_device, dtype=self.vae_output_dtype())
def _owned_tiled_args(self, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
args = {}
if tile_x is not None:
args["tile_x"] = tile_x
if tile_y is not None:
args["tile_y"] = tile_y
if overlap is not None:
args["overlap"] = overlap
if tile_t is not None:
args["tile_t"] = tile_t
if overlap_t is not None:
args["overlap_t"] = overlap_t
return args
def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid()
pixel_samples = None
@ -1153,18 +1166,7 @@ class VAE:
with model_management.cuda_device_context(self.device):
if self.handles_tiling and dims in (2, 3):
tiled_args = {}
if tile_x is not None:
tiled_args["tile_x"] = tile_x
if tile_y is not None:
tiled_args["tile_y"] = tile_y
if overlap is not None:
tiled_args["overlap"] = overlap
if tile_t is not None:
tiled_args["tile_t"] = tile_t
if overlap_t is not None:
tiled_args["overlap_t"] = overlap_t
output = self._decode_tiled_owned(samples, **tiled_args)
output = self._decode_tiled_owned(samples, **self._owned_tiled_args(tile_x, tile_y, overlap, tile_t, overlap_t))
elif dims == 1 or self.extra_1d_channel is not None:
args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args)
@ -1269,18 +1271,7 @@ class VAE:
samples = self.encode_tiled_(pixel_samples, **args)
elif dims == 3:
if self.handles_tiling:
tiled_args = {}
if tile_x is not None:
tiled_args["tile_x"] = tile_x
if tile_y is not None:
tiled_args["tile_y"] = tile_y
if overlap is not None:
tiled_args["overlap"] = overlap
if tile_t is not None:
tiled_args["tile_t"] = tile_t
if overlap_t is not None:
tiled_args["overlap_t"] = overlap_t
samples = self._encode_tiled_owned(pixel_samples, **tiled_args)
samples = self._encode_tiled_owned(pixel_samples, **self._owned_tiled_args(tile_x, tile_y, overlap, tile_t, overlap_t))
else:
if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
@ -1850,7 +1841,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)

View File

@ -1688,6 +1688,7 @@ class SeedVR2(supported_models_base.BASE):
unet_config = {
"image_model": "seedvr2"
}
unet_extra_config = {}
required_keys = {
"{}positive_conditioning",
"{}negative_conditioning",

View File

@ -19,21 +19,14 @@ from comfy.ldm.seedvr.constants import (
)
from torchvision.transforms import functional as TVF
from torchvision.transforms import Lambda
from torchvision.transforms.functional import InterpolationMode
_SEEDVR2_INVALID_MODEL_MSG_PREFIX = (
"SeedVR2Conditioning: model object does not match expected SeedVR2 structure"
)
# Private sentinel for getattr default: distinguishes "attribute missing"
# from "attribute present but None" so the failure message is accurate.
_SEEDVR2_INVALID_MODEL_MSG_PREFIX = "SeedVR2Conditioning: model object does not match expected SeedVR2 structure"
_ATTR_MISSING = object()
def _resolve_seedvr2_diffusion_model(model):
"""Resolve ``model.model.diffusion_model``, failing loud via the ``_ATTR_MISSING`` sentinel so each of the four modes (model/diffusion_model missing vs None) gives an accurate message."""
inner = getattr(model, "model", _ATTR_MISSING)
if inner is _ATTR_MISSING:
raise RuntimeError(
@ -59,15 +52,7 @@ def _resolve_seedvr2_diffusion_model(model):
return diffusion_model
def get_conditions(latent, latent_blur):
t, h, w, c = latent.shape
cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
cond[:, ..., :-1] = latent_blur[:]
cond[:, ..., -1:] = 1.0
return cond
def div_pad(image, factor):
height_factor, width_factor = factor
height, width = image.shape[-2:]
@ -77,31 +62,25 @@ def div_pad(image, factor):
if pad_height == 0 and pad_width == 0:
return image
if isinstance(image, torch.Tensor):
padding = (0, pad_width, 0, pad_height)
image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0)
return image
padding = (0, pad_width, 0, pad_height)
return torch.nn.functional.pad(image, padding, mode='constant', value=0.0)
def cut_videos(videos):
t = videos.size(1)
if t < 1:
raise ValueError("SeedVR2Preprocess expected at least one frame.")
if t == 1:
return videos
if t <= 4 :
padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1)
padding = torch.cat(padding, dim=1)
videos = torch.cat([videos, padding], dim=1)
return videos
if (t - 1) % (4) == 0:
return videos
else:
padding = [videos[:, -1].unsqueeze(1)] * (
4 - ((t - 1) % (4))
)
padding = torch.cat(padding, dim=1)
videos = torch.cat([videos, padding], dim=1)
assert (videos.size(1) - 1) % (4) == 0
if t <= 4:
padding = videos[:, -1:].repeat(1, 4 - t + 1, 1, 1, 1)
return torch.cat([videos, padding], dim=1)
if (t - 1) % 4 == 0:
return videos
padding = videos[:, -1:].repeat(1, 4 - ((t - 1) % 4), 1, 1, 1)
videos = torch.cat([videos, padding], dim=1)
if (videos.size(1) - 1) % 4 != 0:
raise ValueError(f"SeedVR2Preprocess failed to pad video length to 4n+1; got {videos.size(1)} frames.")
return videos
def _seedvr2_input_shorter_edge(images, node_name):
if images.dim() == 4:
@ -136,8 +115,7 @@ def _seedvr2_pad(images, upscaled_shorter_edge, node_name):
b, t, c, h, w = images.shape
images = images.reshape(b * t, c, h, w)
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
images = clip(images)
images = torch.clamp(images, 0.0, 1.0)
images = div_pad(images, (16, 16))
_, _, new_h, new_w = images.shape
@ -295,7 +273,6 @@ class SeedVR2PostProcessing(io.ComfyNode):
def _color_transfer_chunked(cls, decoded_flat, reference_flat, output_device, color_correction_method):
chunk_size = cls._estimate_color_correction_chunk_size(decoded_flat, color_correction_method)
while True:
next_chunk_size = None
try:
return cls._run_color_transfer_chunks(
decoded_flat, reference_flat, output_device, color_correction_method, chunk_size,
@ -307,9 +284,7 @@ class SeedVR2PostProcessing(io.ComfyNode):
"SeedVR2PostProcessing: color correction OOM at one frame; "
f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}."
) from e
next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR)
chunk_size = next_chunk_size
chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR)
@classmethod
def _run_color_transfer_chunks(cls, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size):
@ -392,10 +367,8 @@ class SeedVR2Conditioning(io.ComfyNode):
io.Latent.Input("vae_conditioning", display_name="latent"),
],
outputs=[
io.Model.Output(display_name="model", tooltip="The SeedVR2 model, passed through."),
io.Conditioning.Output(display_name="positive", tooltip="The positive conditioning for sampling."),
io.Conditioning.Output(display_name="negative", tooltip="The negative conditioning for sampling."),
io.Latent.Output(display_name="latent", tooltip="The latent to denoise."),
],
)
@ -408,29 +381,30 @@ class SeedVR2Conditioning(io.ComfyNode):
"SeedVR2Conditioning expects a 5-D VAE latent in Comfy "
f"channel-first layout; got shape {tuple(vae_conditioning.shape)}."
)
if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS:
if vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS:
if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS:
raise ValueError(
"SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy "
f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); "
f"got channel-last shape {tuple(vae_conditioning.shape)}."
)
raise ValueError(
"SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy "
f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); "
f"got channel-last shape {tuple(vae_conditioning.shape)}."
"SeedVR2Conditioning expects SeedVR2 VAE latents with "
f"{SEEDVR2_LATENT_CHANNELS} channels; got shape {tuple(vae_conditioning.shape)}."
)
vae_conditioning = vae_conditioning.movedim(1, -1).contiguous()
model_patcher = model
model = _resolve_seedvr2_diffusion_model(model_patcher)
model = _resolve_seedvr2_diffusion_model(model)
pos_cond = model.positive_conditioning
neg_cond = model.negative_conditioning
condition = torch.stack([get_conditions(c, c) for c in vae_conditioning])
mask = vae_conditioning.new_ones(vae_conditioning.shape[:-1] + (1,))
condition = torch.cat((vae_conditioning, mask), dim=-1)
condition = condition.movedim(-1, 1)
latent = vae_conditioning.movedim(-1, 1)
latent = latent.reshape(latent.shape[0], latent.shape[1] * latent.shape[2], latent.shape[3], latent.shape[4])
condition = condition.reshape(condition.shape[0], condition.shape[1] * condition.shape[2], condition.shape[3], condition.shape[4])
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
return io.NodeOutput(model_patcher, positive, negative, {"samples": latent})
return io.NodeOutput(positive, negative)
class SeedVRExtension(ComfyExtension):
@override

View File

@ -1,20 +1,15 @@
"""Consolidated SeedVR2 conditioning and refactor regression tests.
Merges the prior test_seedvr2_refactor_nodes.py and
test_seedvr_conditioning_hardening.py modules. Refactor tests use the
top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests
use _import_nodes_seedvr_isolated() for sys.modules isolation when
mocking comfy.model_management.
"""
"""SeedVR2 conditioning node regression tests."""
import importlib
import sys
from unittest.mock import MagicMock
import pytest
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
from comfy.ldm.seedvr.constants import SEEDVR2_LATENT_CHANNELS
if not torch.cuda.is_available():
cli_args.cpu = True
@ -79,21 +74,18 @@ def _import_nodes_seedvr_isolated():
class _Rope(nn.Module):
"""Minimal RoPE stub exposing a `freqs` parameter."""
def __init__(self):
super().__init__()
self.freqs = nn.Parameter(torch.zeros(4))
class _Block(nn.Module):
"""Minimal transformer block stub holding a `_Rope`."""
def __init__(self):
super().__init__()
self.rope = _Rope()
class _DiffusionModel(nn.Module):
"""Stub diffusion model with N blocks and pos/neg conditioning buffers."""
def __init__(self, n_blocks=3, conditioning_dtype=torch.float32):
super().__init__()
self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)])
@ -102,18 +94,16 @@ class _DiffusionModel(nn.Module):
class _ModelInner:
"""Inner model wrapper exposing `.diffusion_model`."""
def __init__(self, diffusion_model):
self.diffusion_model = diffusion_model
class _ModelPatcher:
"""ModelPatcher stub exposing `.model._ModelInner`."""
def __init__(self, diffusion_model):
self.model = _ModelInner(diffusion_model)
def test_seedvr2_conditioning_schema_exposes_model_passthrough_output():
def test_seedvr2_conditioning_schema_exposes_conditioning_outputs():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
schema = nodes_seedvr.SeedVR2Conditioning.define_schema()
@ -123,37 +113,50 @@ def test_seedvr2_conditioning_schema_exposes_model_passthrough_output():
]
assert schema.inputs[1].display_name == "latent"
assert [output.display_name for output in schema.outputs] == [
"model",
"positive",
"negative",
"latent",
]
finally:
restore()
def test_seedvr2_conditioning_returns_packed_input_latent_deterministically():
def test_seedvr2_conditioning_rejects_wrong_latent_channels():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
patcher = _ModelPatcher(_DiffusionModel())
vae_conditioning = {"samples": torch.zeros(1, 8, 2, 2, 2)}
with pytest.raises(ValueError, match=f"{SEEDVR2_LATENT_CHANNELS} channels"):
nodes_seedvr.SeedVR2Conditioning.execute(patcher, vae_conditioning)
finally:
restore()
def test_seedvr2_conditioning_returns_conditioning_deterministically():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
diffusion_model = _DiffusionModel()
patcher = _ModelPatcher(diffusion_model)
samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2)
samples = torch.arange(
1,
1 + SEEDVR2_LATENT_CHANNELS * 3 * 2 * 2,
dtype=torch.float32,
).reshape(1, SEEDVR2_LATENT_CHANNELS, 3, 2, 2)
vae_conditioning = {"samples": samples}
_, first_positive, first_negative, first_latent = (
first_positive, first_negative = (
nodes_seedvr.SeedVR2Conditioning.execute(
patcher,
vae_conditioning,
)
)
_, second_positive, second_negative, second_latent = (
second_positive, second_negative = (
nodes_seedvr.SeedVR2Conditioning.execute(
patcher,
vae_conditioning,
)
)
expected_latent = samples.reshape(1, 6, 2, 2)
channel_last = samples.movedim(1, -1).contiguous()
expected_condition = torch.cat(
[
@ -161,10 +164,8 @@ def test_seedvr2_conditioning_returns_packed_input_latent_deterministically():
torch.ones((*channel_last.shape[:-1], 1)),
],
dim=-1,
).movedim(-1, 1).reshape(1, 9, 2, 2)
).movedim(-1, 1)
assert torch.equal(first_latent["samples"], expected_latent)
assert torch.equal(second_latent["samples"], expected_latent)
assert torch.equal(
first_positive[0][1]["condition"],
expected_condition,

View File

@ -201,6 +201,17 @@ class TestModelDetection:
del sd["positive_conditioning"]
assert model_config_from_unet_config(unet_config, sd) is None
def test_seedvr2_model_match_normalizes_num_heads(self):
sd = _make_seedvr2_7b_shared_mm_sd()
unet_config = detect_unet_config(sd, "")
unet_config["num_heads"] = unet_config.pop("heads")
model_config = model_config_from_unet_config(unet_config, sd)
assert type(model_config).__name__ == "SeedVR2"
assert model_config.unet_config["heads"] == 24
assert "num_heads" not in model_config.unet_config
def test_seedvr2_model_match_accepts_full_checkpoint_prefix(self):
sd = _add_model_diffusion_prefix(_make_seedvr2_7b_shared_mm_sd())

View File

@ -1,22 +1,6 @@
"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must
honor the actual tensor/tuple return contract of ``encode()`` and
``decode_()`` and must NOT dereference diffusers-style ``.latent_dist``
or ``.sample`` attributes on those returns.
The pre-fix body raised ``AttributeError: 'Tensor' object has no
attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and
``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'``
for ``mode == "decode"`` (the class only defines ``decode_`` with a
trailing underscore). The post-fix body unwraps the optional one-element
tuple shape that ``return_dict=False`` produces and returns the tensor
directly.
Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses
the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and
overrides ``encode``/``decode_`` with known tensors so the contract can
be probed without loading any real VAE weights.
"""
"""Regression tests for the SeedVR2 VAE forward return contract."""
import pytest
import torch
import torch.nn as nn
@ -25,13 +9,13 @@ from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402
from comfy.ldm.seedvr.vae import SEEDVR2_LATENT_CHANNELS, VideoAutoencoderKL # noqa: E402
_LATENT_SHAPE = (1, 16, 2, 2, 2)
_LATENT_SHAPE = (1, SEEDVR2_LATENT_CHANNELS, 2, 2, 2)
_DECODED_SHAPE = (1, 3, 5, 16, 16)
_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16)
_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2)
_INPUT_DECODE_SHAPE = _LATENT_SHAPE
class _StubVAE(VideoAutoencoderKL):
@ -64,8 +48,6 @@ def test_forward_decode_returns_tensor():
class _TupleReturningStubVAE(VideoAutoencoderKL):
"""Stub whose ``encode``/``decode_`` return the ``(tensor,)`` tuple of ``return_dict=False``, exercising the unwrap branch of ``VideoAutoencoderKL.forward``."""
def __init__(self):
nn.Module.__init__(self)
self._encode_tensor = torch.zeros(*_LATENT_SHAPE)
@ -84,3 +66,9 @@ def test_forward_all_unwraps_one_tuple_at_each_step():
result = vae.forward(x, mode="all")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_DECODED_SHAPE)
def test_forward_rejects_unknown_mode():
vae = _StubVAE()
with pytest.raises(ValueError, match="Unknown SeedVR2 VAE forward mode"):
vae.forward(torch.zeros(*_INPUT_ENCODE_SHAPE), mode="bogus")

View File

@ -41,8 +41,9 @@ def test_seedvr2_text_conditioning_accepts_cfg1_single_branch():
def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer():
wrapper = seedvr_vae.VideoAutoencoderKLWrapper.__new__(seedvr_vae.VideoAutoencoderKLWrapper)
estimate = wrapper.comfy_memory_used_decode((1, 16, 26, 120, 160))
old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2
latent_channels = seedvr_vae.SEEDVR2_LATENT_CHANNELS
estimate = wrapper.comfy_memory_used_decode((1, latent_channels, 26, 120, 160))
old_estimate = latent_channels * 120 * 160 * (4 * 8 * 8) * 2
assert estimate == 101 * 960 * 1280 * 160
assert estimate > 15 * 1024 ** 3

View File

@ -1,16 +1,4 @@
"""Consolidated SeedVR2 internals regression tests.
Sources (all merged verbatim, helper names disambiguated where colliding):
* GroupNorm limit gate causal_norm_wrapper at vae.py:509 must compare
memory_occupy against get_norm_limit(), not float('inf').
* SeedVR2 variable-length attention split-loop contract.
Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and
comfy.ldm.modules.attention transitively pull in comfy.model_management,
which probes torch.cuda.current_device() at import time unless args.cpu is
set first.
"""
"""SeedVR2 internals regression tests."""
from __future__ import annotations
@ -35,10 +23,6 @@ from comfy.ldm.seedvr.vae import ( # noqa: E402
from comfy.ldm.seedvr.attention import var_attention_optimized_split # noqa: E402
# ---------------------------------------------------------------------------
# GroupNorm limit tests (test_seedvr_groupnorm_limit.py)
# ---------------------------------------------------------------------------
_NUM_CHANNELS = 8
_NUM_GROUPS = 4
_TENSOR_SHAPE = (1, 8, 2, 4, 4)
@ -89,10 +73,6 @@ def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls):
set_norm_limit(None)
# ---------------------------------------------------------------------------
# SeedVR2 var_attention split-loop tests
# ---------------------------------------------------------------------------
def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch):
dim = 8
heads = 2
@ -140,18 +120,8 @@ def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypa
assert call["heads"] == heads
assert call["skip_reshape"] is True
assert call["skip_output_reshape"] is True
torch.testing.assert_close(
call["cu_seqlens_q"],
torch.tensor([0, 7, 14], dtype=torch.int32),
rtol=0,
atol=0,
)
torch.testing.assert_close(
call["cu_seqlens_k"],
torch.tensor([0, 7, 14], dtype=torch.int32),
rtol=0,
atol=0,
)
assert call["cu_seqlens_q"] == [0, 7, 14]
assert call["cu_seqlens_k"] == [0, 7, 14]
def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch):
@ -160,7 +130,7 @@ def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatc
q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim)
k = q + 100
v = q + 200
cu = torch.tensor([0, 2, 5], dtype=torch.int32)
cu = [0, 2, 5]
calls = []
def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs):
@ -197,20 +167,3 @@ def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatc
assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls)
torch.testing.assert_close(out, q + v, rtol=0, atol=0)
def test_var_attention_optimized_split_rejects_bad_offsets():
q = torch.randn(5, 2, 3)
cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32)
cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32)
with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"):
var_attention_optimized_split(
q,
q,
q,
2,
cu_bad,
cu_ok,
skip_reshape=True,
skip_output_reshape=True,
)

View File

@ -1,17 +1,10 @@
"""Consolidated SeedVR2 model/graph/forward regression tests.
Merged from:
- seedvr_model_test.py
- test_seedvr_7b_final_block_text_path.py
- test_seedvr_forward_no_device_cast.py
- test_seedvr_latent_format.py
- test_seedvr2_vae_graph_boundaries.py
"""
"""SeedVR2 model, latent-format, and VAE graph regression tests."""
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
import torch
from torch import nn
@ -22,7 +15,6 @@ if not torch.cuda.is_available():
import comfy # noqa: E402
import comfy.latent_formats # noqa: E402
import comfy.ldm.seedvr.model # noqa: E402
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
import comfy.model_management # noqa: E402
@ -33,9 +25,7 @@ import nodes as nodes_mod # noqa: E402
from comfy.ldm.seedvr.model import NaDiT # noqa: E402
# ---------------------------------------------------------------------------
# Helpers from seedvr_model_test.py
# ---------------------------------------------------------------------------
_LATENT_CHANNELS = seedvr_vae_mod.SEEDVR2_LATENT_CHANNELS
def _make_standin(positive_conditioning):
@ -51,11 +41,6 @@ def _make_standin(positive_conditioning):
return _StandIn()
# ---------------------------------------------------------------------------
# Helpers from test_seedvr_7b_final_block_text_path.py
# ---------------------------------------------------------------------------
class _StubModule(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
@ -88,11 +73,6 @@ def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> lis
return flags
# ---------------------------------------------------------------------------
# Helpers from test_seedvr_latent_format.py
# ---------------------------------------------------------------------------
class _Model:
def __init__(self, latent_format):
self._latent_format = latent_format
@ -102,11 +82,6 @@ class _Model:
return self._latent_format
# ---------------------------------------------------------------------------
# Helpers from test_seedvr2_vae_graph_boundaries.py
# ---------------------------------------------------------------------------
class _Patcher:
def get_free_memory(self, device):
return 1024 * 1024 * 1024
@ -136,14 +111,14 @@ class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling})
if z.ndim == 4:
b, tc, h, w = z.shape
t = tc // 16
t = tc // _LATENT_CHANNELS
else:
b, _, t, h, w = z.shape
return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
def test_seedvr2_wrapper_public_encode_returns_tensor(monkeypatch):
raw_latent = torch.full((1, 16, 1, 4, 5), 2.0)
raw_latent = torch.full((1, _LATENT_CHANNELS, 1, 4, 5), 2.0)
seen_shapes = []
def base_encode(self, x):
@ -159,12 +134,12 @@ def test_seedvr2_wrapper_public_encode_returns_tensor(monkeypatch):
latent = vae.encode(torch.zeros(1, 3, 32, 40))
assert type(latent) is torch.Tensor
assert tuple(latent.shape) == (1, 16, 4, 5)
assert tuple(latent.shape) == (1, _LATENT_CHANNELS, 4, 5)
assert seen_shapes == [(1, 3, 1, 32, 40)]
def test_seedvr2_wrapper_private_encode_helper_keeps_raw_latent(monkeypatch):
raw_latent = torch.full((1, 16, 1, 4, 5), 3.0)
raw_latent = torch.full((1, _LATENT_CHANNELS, 1, 4, 5), 3.0)
def base_encode(self, x):
return raw_latent.to(device=x.device, dtype=x.dtype)
@ -177,8 +152,8 @@ def test_seedvr2_wrapper_private_encode_helper_keeps_raw_latent(monkeypatch):
latent, raw = vae._encode_with_raw_latent(torch.zeros(1, 3, 32, 40))
assert tuple(latent.shape) == (1, 16, 4, 5)
assert tuple(raw.shape) == (1, 16, 1, 4, 5)
assert tuple(latent.shape) == (1, _LATENT_CHANNELS, 4, 5)
assert tuple(raw.shape) == (1, _LATENT_CHANNELS, 1, 4, 5)
assert torch.equal(raw, raw_latent)
@ -188,7 +163,7 @@ def _make_vae(wrapper):
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.latent_channels = 16
vae.latent_channels = _LATENT_CHANNELS
vae.latent_dim = 3
vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8)
vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
@ -212,13 +187,7 @@ def _make_vae(wrapper):
return vae
# ---------------------------------------------------------------------------
# Tests from seedvr_model_test.py
# ---------------------------------------------------------------------------
def test_missing_context_falls_back_to_positive_buffer():
"""``context is None`` falls back to the registered ``positive_conditioning`` buffer and runs to completion."""
pos_buffer = torch.full((58, 5120), 7.0)
standin = _make_standin(pos_buffer)
txt, txt_shape = standin._resolve_text_conditioning(None)
@ -231,11 +200,6 @@ def test_missing_context_falls_back_to_positive_buffer():
assert txt_shape[0, 0].item() == 58
# ---------------------------------------------------------------------------
# Tests from test_seedvr_7b_final_block_text_path.py
# ---------------------------------------------------------------------------
def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch):
assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [
False,
@ -268,43 +232,49 @@ def test_seedvr2_7b_rope3d_matches_wrapper_oracle():
torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0)
# ---------------------------------------------------------------------------
# Tests from test_seedvr_latent_format.py
# ---------------------------------------------------------------------------
def test_seedvr2_forward_requires_conditioning_latents():
model = NaDiT.__new__(NaDiT)
x = torch.zeros(1, _LATENT_CHANNELS, 1, 4, 5)
with pytest.raises(ValueError, match="requires conditioning latents"):
NaDiT.forward(model, x, timestep=torch.tensor([1.0]), context=None)
def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion():
def test_seedvr2_latent_format_uses_native_video_latent_shape():
latent_format = comfy.latent_formats.SeedVR2()
latent_image = torch.zeros(1, 1, 4, 5)
fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image)
assert latent_format.latent_channels == 16
assert latent_format.latent_dimensions == 2
assert fixed.shape == (1, 16, 4, 5)
assert latent_format.latent_channels == _LATENT_CHANNELS
assert latent_format.latent_dimensions == 3
assert fixed.shape == (1, _LATENT_CHANNELS, 1, 4, 5)
# ---------------------------------------------------------------------------
# Tests from test_seedvr2_vae_graph_boundaries.py
# ---------------------------------------------------------------------------
def test_seedvr2_model_requires_native_5d_latent():
latent = torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5)
assert NaDiT._check_seedvr2_video_latent(latent, _LATENT_CHANNELS, "latent") is latent
with pytest.raises(ValueError, match="5-D native latent"):
NaDiT._check_seedvr2_video_latent(torch.zeros(1, _LATENT_CHANNELS * 2, 4, 5), _LATENT_CHANNELS, "latent")
def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch):
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
encoded = torch.full((1, 16, 2, 4, 5), 2.0)
encoded = torch.full((1, _LATENT_CHANNELS, 2, 4, 5), 2.0)
vae = _make_vae(_EncodeWrapper(encoded))
pixels = torch.zeros(1, 5, 32, 40, 3)
node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0]
node_latent = node_output["samples"]
assert set(node_output) == {"samples"}
assert tuple(node_latent.shape) == (1, 16, 2, 4, 5)
assert tuple(node_latent.shape) == (1, _LATENT_CHANNELS, 2, 4, 5)
assert node_latent.dtype == torch.float32
assert node_latent.stride()[-1] == 1
assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152))
assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * seedvr_vae_mod.BYTEDANCE_VAE_SCALING_FACTOR))
tiled = torch.full((1, 16, 2, 4, 5), 3.0)
tiled = torch.full((1, _LATENT_CHANNELS, 2, 4, 5), 3.0)
monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled))
tiled_output = nodes_mod.VAEEncodeTiled().encode(
vae,
@ -316,9 +286,9 @@ def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeyp
)[0]
tiled_latent = tiled_output["samples"]
assert set(tiled_output) == {"samples"}
assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5)
assert tuple(tiled_latent.shape) == (1, _LATENT_CHANNELS, 2, 4, 5)
assert tiled_latent.dtype == torch.float32
assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152))
assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * seedvr_vae_mod.BYTEDANCE_VAE_SCALING_FACTOR))
def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch):
@ -327,7 +297,7 @@ def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch):
nodes_mod.VAEDecodeTiled().decode(
vae,
{"samples": torch.zeros(1, 16, 2, 4, 5)},
{"samples": torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5)},
tile_size=512,
overlap=64,
temporal_size=16,
@ -339,7 +309,7 @@ def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch):
# knobs are no-ops at the wrapper.
assert vae.first_stage_model.calls == [
{
"shape": (1, 16, 2, 4, 5),
"shape": (1, _LATENT_CHANNELS, 2, 4, 5),
"seedvr2_tiling": {
"enable_tiling": True,
"tile_size": (512, 512),

View File

@ -13,6 +13,9 @@ import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
from comfy_extras import nodes_seedvr # noqa: E402
_LATENT_CHANNELS = vae_mod.SEEDVR2_LATENT_CHANNELS
def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper:
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
vae_mod.VideoAutoencoderKLWrapper
@ -40,7 +43,7 @@ def _decode_with_patches(wrapper, z):
def test_decode_b2_t3_multi_frame_batch_unchanged():
wrapper = _make_wrapper()
out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2))
out = _decode_with_patches(wrapper, torch.zeros(2, _LATENT_CHANNELS * 3, 2, 2))
assert tuple(out.shape) == (2, 3, 3, 16, 16)
@ -62,17 +65,17 @@ def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preproc
wrapper = _Wrapper()
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5))
out = wrapper.decode(torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5))
assert tuple(out.shape) == (1, 3, 2, 32, 40)
assert wrapper.calls == [(1, 16, 2, 4, 5)]
assert wrapper.calls == [(1, _LATENT_CHANNELS, 2, 4, 5)]
def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents():
wrapper = _Wrapper()
with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"):
wrapper.decode(torch.zeros(1, 16, 4))
wrapper.decode(torch.zeros(1, _LATENT_CHANNELS, 4))
def _t_padded(t_in: int) -> int:

View File

@ -16,9 +16,7 @@ import comfy.sd as sd_mod # noqa: E402
from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402
# ---------------------------------------------------------------------------
# From test_seedvr_vae_tiled_decode_latent_min_size_override.py
# ---------------------------------------------------------------------------
_LATENT_CHANNELS = seedvr_vae_mod.SEEDVR2_LATENT_CHANNELS
def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
@ -44,7 +42,7 @@ def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype)
vae = StubVAEModel()
z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32)
z = torch.zeros((1, _LATENT_CHANNELS, 5, 8, 8), dtype=torch.float32)
tiled_vae(
z,
@ -61,11 +59,6 @@ def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
assert vae.slicing_latent_min_size == 2
# ---------------------------------------------------------------------------
# From test_seedvr_vae_tiled_encode_runt_slice_override.py
# ---------------------------------------------------------------------------
def test_zero_temporal_size_preserves_min_size_when_encode_raises():
class RaisingVAEModel(torch.nn.Module):
def __init__(self):
@ -110,7 +103,7 @@ def test_tiled_vae_encode_uses_tensor_return_without_indexing():
def encode(self, t_chunk):
self.calls.append(tuple(t_chunk.shape))
b, _, _, h, w = t_chunk.shape
return torch.ones((b, 16, 1, h // 8, w // 8), dtype=t_chunk.dtype)
return torch.ones((b, _LATENT_CHANNELS, 1, h // 8, w // 8), dtype=t_chunk.dtype)
vae = TensorEncodeVAEModel()
x = torch.zeros((2, 3, 1, 64, 64), dtype=torch.float32)
@ -126,12 +119,34 @@ def test_tiled_vae_encode_uses_tensor_return_without_indexing():
)
assert vae.calls == [(2, 3, 1, 64, 64)]
assert tuple(out.shape) == (2, 16, 1, 8, 8)
assert tuple(out.shape) == (2, _LATENT_CHANNELS, 1, 8, 8)
# ---------------------------------------------------------------------------
# From test_seedvr_vae_tiled_temporal_slicing.py
# ---------------------------------------------------------------------------
def test_tiled_vae_preserves_input_dtype_on_single_tile():
class FloatOutputVAEModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.slicing_sample_min_size = 4
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
def encode(self, t_chunk):
b, _, _, h, w = t_chunk.shape
return torch.ones((b, _LATENT_CHANNELS, 1, h // 8, w // 8), dtype=torch.float32)
out = tiled_vae(
torch.zeros((1, 3, 1, 64, 64), dtype=torch.float16),
FloatOutputVAEModel(),
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=0,
temporal_overlap=0,
encode=True,
)
assert out.dtype == torch.float16
class _SlicingDecodeVAE(nn.Module):
@ -164,7 +179,10 @@ class _SlicingDecodeVAE(nn.Module):
def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size():
vae = _SlicingDecodeVAE(slicing_latent_min_size=2)
z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8)
z = torch.arange(
_LATENT_CHANNELS * 5 * 8 * 8,
dtype=torch.float32,
).reshape(1, _LATENT_CHANNELS, 5, 8, 8)
tiled_vae(
z,
@ -199,16 +217,11 @@ def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size():
return torch.zeros(1, 3, 1, 16, 16)
with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae):
wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling)
wrapper.decode(torch.zeros(1, _LATENT_CHANNELS, 2, 2), seedvr2_tiling=seedvr2_tiling)
assert captured["temporal_overlap"] == 7
# ---------------------------------------------------------------------------
# From test_vae_decode_tiled_dispatcher_seedvr2_4d.py
# ---------------------------------------------------------------------------
def _force_oom(*a, **k):
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
@ -256,10 +269,10 @@ def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode):
def test_4d_seedvr2_latent_routes_to_owned_decode_tiled():
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
seedvr_vae_mod.VideoAutoencoderKLWrapper)
vae = _make_vae(wrapper, latent_channels=16, latent_dim=3)
vae = _make_vae(wrapper, latent_channels=_LATENT_CHANNELS, latent_dim=3)
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
_dispatch(vae, torch.zeros(1, 16 * 3, 8, 8), seedvr2_call, generic_call, True)
_dispatch(vae, torch.zeros(1, _LATENT_CHANNELS * 3, 8, 8), seedvr2_call, generic_call, True)
assert seedvr2_call.call_count == 1
assert generic_call.call_count == 0
@ -275,11 +288,6 @@ def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled():
assert seedvr2_call.call_count == 0
# ---------------------------------------------------------------------------
# From test_vae_encode_tiled_fallback_dispatcher_seedvr2.py
# ---------------------------------------------------------------------------
def _populate_common_vae_attrs_fallback(vae):
vae.patcher = MagicMock()
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
@ -291,7 +299,7 @@ def _populate_common_vae_attrs_fallback(vae):
vae.upscale_ratio = 8
vae.upscale_index_formula = None
vae.output_channels = 3
vae.latent_channels = 16
vae.latent_channels = _LATENT_CHANNELS
vae.latent_dim = 3
vae.downscale_ratio = 8
vae.downscale_index_formula = None
@ -334,8 +342,8 @@ def test_seedvr2_3d_routes_to_owned_encode_tiled_on_oom():
vae = _make_seedvr2_vae_fallback()
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
seedvr2_call = MagicMock(return_value=torch.zeros(1, _LATENT_CHANNELS, 2, 8, 8))
generic_call = MagicMock(return_value=torch.zeros(1, _LATENT_CHANNELS, 2, 8, 8))
with patch.object(sd_mod.model_management, "raise_non_oom",
lambda e: None), \
@ -363,7 +371,7 @@ def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete():
vae = _make_non_seedvr2_vae_fallback()
vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8)
vae.upscale_ratio = (lambda a: a * 4, 8, 8)
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
generic_call = MagicMock(return_value=torch.zeros(1, _LATENT_CHANNELS, 2, 8, 8))
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
with patch.object(sd_mod.model_management, "load_models_gpu",