diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index fc5b13c21..8a16cfe55 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -781,6 +781,7 @@ class ACEAudio(LatentFormat): class SeedVR2(LatentFormat): latent_channels = 16 + latent_dimensions = 3 class ACEAudio15(LatentFormat): latent_channels = 64 diff --git a/comfy/ldm/seedvr/attention.py b/comfy/ldm/seedvr/attention.py index 5d4054ab9..11b4c1e4a 100644 --- a/comfy/ldm/seedvr/attention.py +++ b/comfy/ldm/seedvr/attention.py @@ -22,33 +22,14 @@ def _var_attention_output(out, heads, head_dim, skip_output_reshape): return out.reshape(-1, heads * head_dim) -def _validate_split_cu_seqlens(name, cu_seqlens, token_count): - if cu_seqlens.dtype not in (torch.int32, torch.int64): - raise ValueError(f"{name} must use an integer dtype") - if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2: - raise ValueError(f"{name} must be a 1D tensor with at least two offsets") - if cu_seqlens[0].item() != 0: - raise ValueError(f"{name} must start at 0") - if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item(): - raise ValueError(f"{name} must be strictly increasing") - if cu_seqlens[-1].item() != token_count: - raise ValueError(f"{name} does not match token count") - - -def _split_indices(cu_seqlens): - return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long) - - def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) - _validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0]) - _validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0]) - if cu_seqlens_k[-1].item() != v.shape[0]: + q_split_indices = cu_seqlens_q[1:-1] + k_split_indices = cu_seqlens_k[1:-1] + if k.shape[0] != v.shape[0]: raise ValueError("cu_seqlens_k does not match v token count") - q_split_indices = _split_indices(cu_seqlens_q) - k_split_indices = _split_indices(cu_seqlens_k) q_splits = torch.tensor_split(q, q_split_indices, dim=0) k_splits = torch.tensor_split(k, k_split_indices, dim=0) v_splits = torch.tensor_split(v, k_split_indices, dim=0) diff --git a/comfy/ldm/seedvr/color_fix.py b/comfy/ldm/seedvr/color_fix.py index 440b3d26c..a43cb5270 100644 --- a/comfy/ldm/seedvr/color_fix.py +++ b/comfy/ldm/seedvr/color_fix.py @@ -45,7 +45,6 @@ def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS): def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: if content_feat.shape != style_feat.shape: - # Resize style to match content spatial dimensions if len(content_feat.shape) >= 3: style_feat = F.interpolate( style_feat, @@ -54,12 +53,11 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: align_corners=False ) - # Decompose both features into frequency components content_high_freq, content_low_freq = wavelet_decomposition(content_feat) - del content_low_freq # Free memory immediately + del content_low_freq style_high_freq, style_low_freq = wavelet_decomposition(style_feat) - del style_high_freq # Free memory immediately + del style_high_freq if content_high_freq.shape != style_low_freq.shape: style_low_freq = F.interpolate( @@ -73,27 +71,23 @@ def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: return content_high_freq.clamp_(-1.0, 1.0) -def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor: +def _histogram_matching_channel(source: Tensor, reference: Tensor) -> Tensor: original_shape = source.shape - # Flatten source_flat = source.flatten() reference_flat = reference.flatten() - # Sort both arrays source_sorted, source_indices = torch.sort(source_flat) reference_sorted, _ = torch.sort(reference_flat) del reference_flat - # Quantile mapping n_source = len(source_sorted) n_reference = len(reference_sorted) if n_source == n_reference: matched_sorted = reference_sorted else: - # Interpolate reference to match source quantiles - source_quantiles = torch.linspace(0, 1, n_source, device=device) + source_quantiles = torch.linspace(0, 1, n_source, device=source.device) ref_indices = (source_quantiles * (n_reference - 1)).long() ref_indices.clamp_(0, n_reference - 1) matched_sorted = reference_sorted[ref_indices] @@ -101,7 +95,6 @@ def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch del source_sorted, source_flat - # Reconstruct using argsort (portable across CUDA/ROCm/MPS) inverse_indices = torch.argsort(source_indices) del source_indices matched_flat = matched_sorted[inverse_indices] @@ -109,17 +102,14 @@ def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch return matched_flat.reshape(original_shape) -def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor: - """Convert batch of CIELAB images to RGB color space.""" +def _lab_to_rgb_batch(lab: Tensor, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor: L, a, b = lab[:, 0], lab[:, 1], lab[:, 2] - # LAB to XYZ fy = (L + 16.0) / 116.0 fx = a.div(500.0).add_(fy) fz = fy - b / 200.0 del L, a, b - # XYZ transformation x = torch.where( fx > epsilon, torch.pow(fx, 3.0), @@ -137,20 +127,16 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps ) del fx, fy, fz - # Apply D65 white point (in-place) x.mul_(D65_WHITE_X) - # y *= 1.00000 # (no-op, skip) z.mul_(D65_WHITE_Z) xyz = torch.stack([x, y, z], dim=1) del x, y, z - # Matrix multiplication: XYZ -> RGB - B, C, H, W = xyz.shape + B, _, H, W = xyz.shape xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3) del xyz - # Ensure dtype consistency for matrix multiplication xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype) rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T) del xyz_flat @@ -158,7 +144,6 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) del rgb_linear_flat - # Apply inverse gamma correction (delinearize) mask = rgb_linear > 0.0031308 rgb = torch.where( mask, @@ -169,9 +154,7 @@ def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, eps return torch.clamp(rgb, 0.0, 1.0) -def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor: - """Convert batch of RGB images to CIELAB color space using D65 illuminant.""" - # Apply sRGB gamma correction (linearize) +def _rgb_to_lab_batch(rgb: Tensor, matrix: Tensor, epsilon: float, kappa: float) -> Tensor: mask = rgb > 0.04045 rgb_linear = torch.where( mask, @@ -180,12 +163,10 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon ) del mask - # Matrix multiplication: RGB -> XYZ - B, C, H, W = rgb_linear.shape + B, _, H, W = rgb_linear.shape rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3) del rgb_linear - # Ensure dtype consistency for matrix multiplication rgb_flat = rgb_flat.to(dtype=matrix.dtype) xyz_flat = torch.matmul(rgb_flat, matrix.T) del rgb_flat @@ -193,12 +174,9 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) del xyz_flat - # Normalize by D65 white point (in-place) - xyz[:, 0].div_(D65_WHITE_X) # X - # xyz[:, 1] /= 1.00000 # Y (no-op, skip) - xyz[:, 2].div_(D65_WHITE_Z) # Z + xyz[:, 0].div_(D65_WHITE_X) + xyz[:, 2].div_(D65_WHITE_Z) - # XYZ to LAB transformation epsilon_cubed = epsilon ** 3 mask = xyz > epsilon_cubed f_xyz = torch.where( @@ -208,10 +186,9 @@ def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon ) del xyz, mask - # Extract channels and compute LAB - L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100] - a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127] - b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127] + L = f_xyz[:, 1].mul(116.0).sub_(16.0) + a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) + b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) del f_xyz return torch.stack([L, a, b], dim=1) @@ -232,13 +209,9 @@ def lab_color_transfer( ) device = content_feat.device - - def ensure_float32_precision(c): - orig_dtype = c.dtype - c = c.float() - return c, orig_dtype - content_feat, original_dtype = ensure_float32_precision(content_feat) - style_feat, _ = ensure_float32_precision(style_feat) + original_dtype = content_feat.dtype + content_feat = content_feat.float() + style_feat = style_feat.float() rgb_to_xyz_matrix = torch.tensor([ [0.4124564, 0.3575761, 0.1804375], @@ -258,39 +231,30 @@ def lab_color_transfer( content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) - # Convert to LAB color space - content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa) + content_lab = _rgb_to_lab_batch(content_feat, rgb_to_xyz_matrix, epsilon, kappa) del content_feat - style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa) + style_lab = _rgb_to_lab_batch(style_feat, rgb_to_xyz_matrix, epsilon, kappa) del style_feat, rgb_to_xyz_matrix - # Match chrominance channels (a*, b*) for accurate color transfer - matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device) - matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device) + matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1]) + matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2]) - # Handle luminance with weighted blending if luminance_weight < 1.0: - # Partially match luminance for better overall color accuracy - matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device) - # Blend: preserve some content L* for detail, adopt some style L* for color + matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0]) result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight)) del matched_L else: - # Fully preserve content luminance result_L = content_lab[:, 0] del content_lab, style_lab - # Reconstruct LAB with corrected channels result_lab = torch.stack([result_L, matched_a, matched_b], dim=1) del result_L, matched_a, matched_b - # Convert back to RGB - result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa) + result_rgb = _lab_to_rgb_batch(result_lab, xyz_to_rgb_matrix, epsilon, kappa) del result_lab, xyz_to_rgb_matrix - # Convert back to [-1, 1] range (in-place) result = result_rgb.mul_(2.0).sub_(1.0) del result_rgb diff --git a/comfy/ldm/seedvr/constants.py b/comfy/ldm/seedvr/constants.py index df91c7772..b8b300388 100644 --- a/comfy/ldm/seedvr/constants.py +++ b/comfy/ldm/seedvr/constants.py @@ -1,34 +1,21 @@ -"""Named constants for the SeedVR2 integration, grouped by provenance. +"""SeedVR2 constants.""" -Provenance prefixes: -- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline. -- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites - the upstream config/source path it was lifted from. -- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature / - ISO / CIE values; cite the standard. -""" - -SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim. - # (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.) -SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # OOM retry backoff: halve the chunk and retry. -SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case). -SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM. +SEEDVR2_7B_VID_DIM = 3072 +SEEDVR2_OOM_BACKOFF_DIVISOR = 2 +SEEDVR2_DTYPE_BYTES_FLOOR = 4 +SEEDVR2_7B_MLP_CHUNK = 8192 SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk. -SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels). +SEEDVR2_LATENT_CHANNELS = 16 -# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing) -SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk. -SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path. +SEEDVR2_COLOR_MEM_HEADROOM = 0.75 +SEEDVR2_LAB_SCALE_MULTIPLIER = 13 SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path. -SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path. +SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 -# -------------------------------------------------------------------------------------- -# ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR) -# -------------------------------------------------------------------------------------- -BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm. -BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift. -BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem). -BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem). +BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57. +BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 +BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 +BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28. BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28. BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16). @@ -42,18 +29,10 @@ BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal wind BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency). BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim). -# -------------------------------------------------------------------------------------- -# Published standards (cite the literature) -# -------------------------------------------------------------------------------------- ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864. -# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65). CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta). CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa). D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1). D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn. WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR). - -# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and -# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the -# exact existing coefficients move verbatim rather than being retyped here. diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index ee50449a4..872140558 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -3,7 +3,7 @@ from typing import Optional, Tuple, Union, List, Dict, Any, Callable import torch.nn.functional as F from math import ceil, pi import torch -from itertools import chain +from itertools import accumulate, chain from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding from comfy.ldm.seedvr.attention import optimized_var_attention from torch.nn.modules.utils import _triple @@ -18,6 +18,7 @@ from comfy.ldm.seedvr.constants import ( ROPE_THETA, SEEDVR2_7B_MLP_CHUNK, SEEDVR2_7B_VID_DIM, + SEEDVR2_LATENT_CHANNELS, SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, ) import comfy.model_management @@ -70,7 +71,7 @@ def repeat_concat_idx( vid_idx = torch.arange(vid_len.sum(), device=device) txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) txt_repeat_list = txt_repeat.tolist() - tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) + tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat_list) src_idx = torch.argsort(tgt_idx) txt_idx_len = len(tgt_idx) - len(vid_idx) repeat_txt_len = (txt_len * txt_repeat).tolist() @@ -88,6 +89,9 @@ def repeat_concat_idx( lambda all: unconcat_coalesce(all), ) +def cumulative_lengths(lengths): + return [0, *accumulate(lengths)] + @dataclass class MMArg: @@ -110,16 +114,14 @@ def get_window_op(name: str): raise ValueError(f"Unknown windowing method: {name}") -# -------------------------------- Windowing -------------------------------- # def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): t, h, w = size resized_nt, resized_nh, resized_nw = num_windows - #cal windows under 720p scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) resized_h, resized_w = round(h * scale), round(w * scale) - wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. - wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. - nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size. + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) + wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) + nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) return [ ( slice(it * wt, min((it + 1) * wt, t)), @@ -137,19 +139,18 @@ def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): t, h, w = size resized_nt, resized_nh, resized_nw = num_windows - #cal windows under 720p scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) resized_h, resized_w = round(h * scale), round(w * scale) - wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. - wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) + wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) - st, sh, sw = ( # shift size. + st, sh, sw = ( 0.5 if wt < t else 0, 0.5 if wh < h else 0, 0.5 if ww < w else 0, ) - nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size. - nt, nh, nw = ( # number of window. + nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) + nt, nh, nw = ( nt + 1 if st > 0 else 1, nh + 1 if sh > 0 else 1, nw + 1 if sw > 0 else 1, @@ -175,7 +176,6 @@ class RotaryEmbedding(nn.Module): freqs_for = 'lang', theta = 10000, max_freq = 10, - learned_freq = False, ): super().__init__() @@ -185,18 +185,14 @@ class RotaryEmbedding(nn.Module): freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) elif freqs_for == 'pixel': freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + else: + raise ValueError(f"Unknown rotary frequency type: {freqs_for}") - self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) - - self.learned_freq = learned_freq - - # dummy for device - - self.register_buffer('dummy', torch.tensor(0), persistent = False) + self.register_buffer("freqs", freqs) @property def device(self): - return self.dummy.device + return self.freqs.device def get_axial_freqs( self, @@ -206,10 +202,9 @@ class RotaryEmbedding(nn.Module): Colon = slice(None) all_freqs = [] - # handle offset - if exists(offsets): - assert len(offsets) == len(dims) + if len(offsets) != len(dims): + raise ValueError(f"SeedVR2 rotary offsets length must match dims length, got {len(offsets)} and {len(dims)}.") for ind, dim in enumerate(dims): @@ -224,7 +219,7 @@ class RotaryEmbedding(nn.Module): pos = pos + offset - freqs = self.forward(pos, seq_len = dim) + freqs = self.forward(pos) all_axis = [None] * len(dims) all_axis[ind] = Colon @@ -232,16 +227,12 @@ class RotaryEmbedding(nn.Module): new_axis_slice = (Ellipsis, *all_axis, Colon) all_freqs.append(freqs[new_axis_slice]) - # concat all freqs - all_freqs = torch.broadcast_tensors(*all_freqs) return torch.cat(all_freqs, dim = -1) def forward( self, t, - seq_len: int | None = None, - offset = 0 ): freqs = self.freqs @@ -258,9 +249,6 @@ class RotaryEmbeddingBase(nn.Module): freqs_for="pixel", max_freq=BYTEDANCE_ROPE_MAX_FREQ, ) - freqs = self.rope.freqs - del self.rope.freqs - self.rope.register_buffer("freqs", freqs.detach()) def get_axial_freqs(self, *dims): return self.rope.get_axial_freqs(*dims) @@ -306,7 +294,7 @@ class NaRotaryEmbedding3d(RotaryEmbedding3d): freqs_for="pixel", max_freq=BYTEDANCE_ROPE_MAX_FREQ, ) - plain_rope = plain_rope.to(self.rope.dummy.device) + plain_rope = plain_rope.to(self.rope.device) freq_list = [] for f, h, w in shape.tolist(): freqs = plain_rope.get_axial_freqs(f, h, w) @@ -322,9 +310,6 @@ class MMRotaryEmbeddingBase(RotaryEmbeddingBase): freqs_for="lang", theta=ROPE_THETA, ) - freqs = self.rope.freqs - del self.rope.freqs - self.rope.register_buffer("freqs", freqs.detach()) self.mm = True def slice_at_dim(t, dim_slice: slice, *, dim): @@ -333,8 +318,6 @@ def slice_at_dim(t, dim_slice: slice, *, dim): colons[dim] = dim_slice return t[tuple(colons)] -# rotary embedding helper functions - def rotate_half(x): x = x.reshape(*x.shape[:-1], x.shape[-1] // 2, 2) x1, x2 = x.unbind(dim = -1) @@ -373,7 +356,6 @@ def _apply_seedvr2_rotary_emb( return torch.cat((t_left, t_middle, t_right), dim=-1).to(dtype) def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor: - """Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`.""" angles = freqs_interleaved[..., ::2].float() cos = torch.cos(angles) sin = torch.sin(angles) @@ -382,12 +364,6 @@ def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor: def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - """Rotate the leading ``rot_d = 2 * freqs_cis.shape[-3]`` dims of ``t`` and pass the rest - through; in-place for inference, cloned for training (autograd). Mirrors the legacy - ``apply_rotary_emb`` ``t_left``/``t_middle``/``t_right`` split: 3B ``rope_dim=128`` gives - ``42*3 = 126`` rotated of head_dim 128 (trailing 2 unrotated). Fast path skips the cat when - ``rot_d == t.shape[-1]``. - """ out = t.clone() if t.requires_grad or comfy.model_management.in_training else t rot_d = 2 * freqs_cis.shape[-3] seq_len = out.shape[-2] @@ -454,14 +430,13 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): torch.Tensor, ]: - # Calculate actual max dimensions needed for this batch max_temporal = 0 max_height = 0 max_width = 0 max_txt_len = 0 for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): - max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal + max_temporal = max(max_temporal, l + f) max_height = max(max_height, h) max_width = max(max_width, w) max_txt_len = max(max_txt_len, l) @@ -475,7 +450,6 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): ).float() txt_freqs = self.get_axial_freqs(max_txt_len + 16) - # Now slice as before vid_freq_list, txt_freq_list = [], [] for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) @@ -485,13 +459,6 @@ class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): vid_freqs_interleaved = torch.cat(vid_freq_list, dim=0) txt_freqs_interleaved = torch.cat(txt_freq_list, dim=0) - # Convert from lucidrains-interleaved layout `[θ0, θ0, θ1, θ1, ...]` - # (produced by `repeat(freqs, '... n -> ... (n r)', r=2)` in the - # upstream `RotaryEmbedding.forward`) to flux-canonical `freqs_cis` - # in shape `[..., d/2, 2, 2]` with `cos/-sin/sin/cos` baked in. - # Mirrors `comfy/ldm/flux/math.py:rope` (line 27) so the trailing - # 2x2 is the per-frequency rotation matrix that - # `comfy.ldm.flux.math.apply_rope1` expects. return _to_flux_freqs_cis(vid_freqs_interleaved), _to_flux_freqs_cis(txt_freqs_interleaved) class MMModule(nn.Module): @@ -507,8 +474,10 @@ class MMModule(nn.Module): self.shared_weights = shared_weights self.vid_only = vid_only if self.shared_weights: - assert get_args("vid", args) == get_args("txt", args) - assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) + if get_args("vid", args) != get_args("txt", args): + raise ValueError("SeedVR2 shared MMModule requires matching vid/txt args.") + if get_kwargs("vid", kwargs) != get_kwargs("txt", kwargs): + raise ValueError("SeedVR2 shared MMModule requires matching vid/txt kwargs.") self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) else: self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) @@ -543,6 +512,7 @@ def get_na_rope(rope_type: Optional[str], dim: int): return NaRotaryEmbedding3d(dim=dim) if rope_type == "mmrope3d": return NaMMRotaryEmbedding3d(dim=dim) + raise ValueError(f"Unknown SeedVR2 rope type: {rope_type}") class NaMMAttention(nn.Module): def __init__( @@ -558,7 +528,6 @@ class NaMMAttention(nn.Module): rope_dim: int, shared_weights: bool, device, dtype, operations, - **kwargs, ): super().__init__() dim = MMArg(vid_dim, txt_dim) @@ -597,16 +566,19 @@ def window( ): hid = unflatten(hid, hid_shape) hid = list(map(window_fn, hid)) - hid_windows = torch.as_tensor([len(x) for x in hid], device=hid_shape.device) - hid, hid_shape = flatten(list(chain(*hid))) - return hid, hid_shape, hid_windows + hid_windows_list = [len(x) for x in hid] + hid_windows = torch.as_tensor(hid_windows_list, device=hid_shape.device) + hid = list(chain(*hid)) + hid_len_list = [math.prod(x.shape[:-1]) for x in hid] + hid, hid_shape = flatten(hid) + return hid, hid_shape, hid_windows, hid_len_list, hid_windows_list def window_idx( hid_shape: torch.LongTensor, # (b n) window_fn: Callable[[torch.Tensor], List[torch.Tensor]], ): hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) - tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) + tgt_idx, tgt_shape, tgt_windows, tgt_len_list, tgt_windows_list = window(hid_idx, hid_shape, window_fn) tgt_idx = tgt_idx.squeeze(-1) src_idx = torch.argsort(tgt_idx) return ( @@ -614,6 +586,8 @@ def window_idx( lambda hid: torch.index_select(hid, 0, src_idx), tgt_shape, tgt_windows, + tgt_len_list, + tgt_windows_list, ) class NaSwinAttention(NaMMAttention): @@ -622,13 +596,15 @@ class NaSwinAttention(NaMMAttention): *args, window: Union[int, Tuple[int, int, int]], window_method: str, + version: bool = False, **kwargs, ): super().__init__(*args, **kwargs) - self.version_7b = kwargs.get("version", False) + self.version_7b = version self.window = _triple(window) self.window_method = window_method - assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) + if not all(isinstance(v, int) and v >= 0 for v in self.window): + raise ValueError(f"SeedVR2 window must contain non-negative integers, got {self.window}.") self.window_op = get_window_op(window_method) @@ -646,7 +622,6 @@ class NaSwinAttention(NaMMAttention): vid_qkv, txt_qkv = self.proj_qkv(vid, txt) - # re-org the input seq for window attn cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3") def make_window(x: torch.Tensor): @@ -654,7 +629,7 @@ class NaSwinAttention(NaMMAttention): window_slices = self.window_op((t, h, w), self.window) return [x[st, sh, sw] for (st, sh, sw) in window_slices] - window_partition, window_reverse, window_shape, window_count = cache_win( + window_partition, window_reverse, window_shape, window_count, vid_len_win_list, window_count_list = cache_win( "win_transform", lambda: window_idx(vid_shape, make_window), ) @@ -674,23 +649,21 @@ class NaSwinAttention(NaMMAttention): vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) txt_len = txt_len.to(window_count.device) - # window rope if self.rope: if self.version_7b: vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) elif self.rope.mm: - # repeat text q and k for window mmrope _, num_h, _ = txt_q.shape txt_q_repeat = txt_q.flatten(1, 2) txt_q_repeat = unflatten(txt_q_repeat, txt_shape) - txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)] + txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count_list)] txt_q_repeat = list(chain(*txt_q_repeat)) txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat) txt_q_repeat = txt_q_repeat.reshape(txt_q_repeat.shape[0], num_h, self.head_dim) txt_k_repeat = txt_k.flatten(1, 2) txt_k_repeat = unflatten(txt_k_repeat, txt_shape) - txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)] + txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count_list)] txt_k_repeat = list(chain(*txt_k_repeat)) txt_k_repeat, _ = flatten(txt_k_repeat) txt_k_repeat = txt_k_repeat.reshape(txt_k_repeat.shape[0], num_h, self.head_dim) @@ -702,7 +675,11 @@ class NaSwinAttention(NaMMAttention): vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) - all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) + txt_len_win_list = cache_win( + "txt_len_list", + lambda: [txt_len for txt_len, window_count in zip(txt_len.tolist(), window_count_list) for _ in range(window_count)], + ) + all_len_win = cache_win("all_len", lambda: [vid_len + txt_len for vid_len, txt_len in zip(vid_len_win_list, txt_len_win_list)]) concat_win, unconcat_win = cache_win( "mm_pnp", lambda: repeat_concat_idx(vid_len_win, txt_len, window_count) ) @@ -711,12 +688,8 @@ class NaSwinAttention(NaMMAttention): k=concat_win(vid_k, txt_k), v=concat_win(vid_v, txt_v), heads=self.heads, skip_reshape=True, skip_output_reshape=True, - cu_seqlens_q=cache_win( - "vid_seqlens_q", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() - ), - cu_seqlens_k=cache_win( - "vid_seqlens_k", lambda: F.pad(all_len_win.cumsum(0), (1, 0)).int() - ), + cu_seqlens_q=cache_win("vid_seqlens_q", lambda: cumulative_lengths(all_len_win)), + cu_seqlens_k=cache_win("vid_seqlens_k", lambda: cumulative_lengths(all_len_win)), ) vid_out, txt_out = unconcat_win(out) @@ -766,11 +739,11 @@ class SwiGLUMLP(nn.Module): return self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) def get_mlp(mlp_type: Optional[str] = "normal"): - # 3b and 7b uses different mlp types if mlp_type == "normal": return MLP - elif mlp_type == "swiglu": + if mlp_type == "swiglu": return SwiGLUMLP + raise ValueError(f"Unknown SeedVR2 MLP type: {mlp_type}") class NaMMSRTransformerBlock(nn.Module): def __init__( @@ -792,11 +765,12 @@ class NaMMSRTransformerBlock(nn.Module): rope_type: str, rope_dim: int, is_last_layer: bool, + window: Union[int, Tuple[int, int, int]], + window_method: str, + version: bool, device, dtype, operations, - **kwargs, ): super().__init__() - version = kwargs.get("version", False) dim = MMArg(vid_dim, txt_dim) self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype) @@ -811,8 +785,8 @@ class NaMMSRTransformerBlock(nn.Module): rope_type=rope_type, rope_dim=rope_dim, shared_weights=shared_weights, - window=kwargs.pop("window", None), - window_method=kwargs.pop("window_method", None), + window=window, + window_method=window_method, version=version, device=device, dtype=dtype, operations=operations ) @@ -930,12 +904,14 @@ class NaPatchOut(PatchOut): self, vid: torch.FloatTensor, # l c vid_shape: torch.LongTensor, - cache: Cache = Cache(disable=True), + cache: Optional[Cache] = None, vid_shape_before_patchify = None ) -> Tuple[ torch.FloatTensor, torch.LongTensor, ]: + if cache is None: + cache = Cache(disable=True) t, h, w = self.patch_size vid = self.proj(vid) @@ -971,7 +947,10 @@ class PatchIn(nn.Module): ) -> torch.Tensor: t, h, w = self.patch_size if t > 1: - assert vid.size(2) % t == 1 + if vid.size(2) % t != 1: + raise ValueError( + f"SeedVR2 patch input temporal size must satisfy T % {t} == 1, got {vid.size(2)}." + ) vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) b, c, Tt, Hh, Ww = vid.shape vid = vid.view(b, c, Tt // t, t, Hh // h, h, Ww // w, w).permute(0, 2, 4, 6, 3, 5, 7, 1).reshape(b, Tt // t, Hh // h, Ww // w, t * h * w * c) @@ -983,8 +962,10 @@ class NaPatchIn(PatchIn): self, vid: torch.Tensor, # l c vid_shape: torch.LongTensor, - cache: Cache = Cache(disable=True), + cache: Optional[Cache] = None, ) -> torch.Tensor: + if cache is None: + cache = Cache(disable=True) cache = cache.namespace("patch") vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) t, h, w = self.patch_size @@ -1012,10 +993,11 @@ class AdaSingle(nn.Module): dim: int, emb_dim: int, layers: List[str], - modes: List[str] = ["in", "out"], + modes: Tuple[str, ...] = ("in", "out"), device = None, dtype = None, ): - assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" + if emb_dim != 6 * dim: + raise ValueError(f"SeedVR2 AdaSingle requires emb_dim == 6 * dim, got emb_dim={emb_dim}, dim={dim}.") super().__init__() self.dim = dim self.emb_dim = emb_dim @@ -1036,22 +1018,20 @@ class AdaSingle(nn.Module): emb: torch.FloatTensor, # b d layer: str, mode: str, - cache: Cache = Cache(disable=True), + cache: Optional[Cache] = None, branch_tag: str = "", hid_len: Optional[torch.LongTensor] = None, # b ) -> torch.FloatTensor: + if cache is None: + cache = Cache(disable=True) idx = self.layers.index(layer) emb = emb.reshape(emb.shape[0], -1, len(self.layers), 3)[:, :, idx, :] emb = expand_dims(emb, 1, hid.ndim + 1) if hid_len is not None: - slice_inputs = lambda x, dim: x emb = cache( f"emb_repeat_{idx}_{branch_tag}", - lambda: slice_inputs( - torch.repeat_interleave(emb, hid_len, dim=0), - dim=0, - ), + lambda: torch.repeat_interleave(emb, hid_len, dim=0), ) shiftA, scaleA, gateA = emb.unbind(-1) @@ -1069,7 +1049,7 @@ class AdaSingle(nn.Module): else: return hid.mul_(gateA) - raise NotImplementedError + raise ValueError(f"Unknown AdaSingle mode: {mode}") class TimeEmbedding(nn.Module): @@ -1117,7 +1097,8 @@ def flatten( torch.FloatTensor, # (L c) torch.LongTensor, # (b n) ]: - assert len(hid) > 0 + if len(hid) == 0: + raise ValueError("SeedVR2 flatten requires at least one tensor.") shape = torch.as_tensor([x.shape[:-1] for x in hid], device=hid[0].device) hid = torch.cat([x.flatten(0, -2) for x in hid]) return hid, shape @@ -1140,7 +1121,7 @@ class NaDiT(nn.Module): num_layers, mlp_type, vid_in_channels = 33, - vid_out_channels = 16, + vid_out_channels = SEEDVR2_LATENT_CHANNELS, vid_dim = 2560, txt_in_dim = 5120, heads = 20, @@ -1148,15 +1129,17 @@ class NaDiT(nn.Module): mm_layers = 10, expand_ratio = 4, qk_bias = False, - patch_size = [ 1,2,2 ], + patch_size = (1, 2, 2), rope_dim = 128, rope_type = "mmrope3d", vid_out_norm: Optional[str] = None, + image_model = None, device = None, dtype = None, operations = None, - **kwargs, ): + if image_model not in (None, "seedvr2"): + raise ValueError(f"SeedVR2 NaDiT expected image_model='seedvr2', got {image_model!r}.") self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM if self._7b_version: rope_type = "rope3d" @@ -1212,14 +1195,13 @@ class NaDiT(nn.Module): rope_dim = rope_dim, window=window[i], window_method=window_method[i], + version = self._7b_version, is_last_layer=(i == num_layers - 1) and not self._7b_version, rope_type = rope_type, shared_weights=not ( (i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] ), - version = self._7b_version, operations = operations, - **kwargs, **factory_kwargs ) for i in range(num_layers) @@ -1272,13 +1254,17 @@ class NaDiT(nn.Module): first = cond_or_uncond[0] return all(entry == first for entry in cond_or_uncond) + @staticmethod + def _check_seedvr2_video_latent(x, channels, name): + if x.ndim != 5: + raise ValueError(f"SeedVR2 expected {name} to be 5-D native latent, got shape {tuple(x.shape)}.") + if x.shape[1] != channels: + raise ValueError(f"SeedVR2 expected {name} channels to be {channels}, got shape {tuple(x.shape)}.") + return x + def _swap_pos_neg_halves(self, out, cond_or_uncond=None): if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond): return out - # ``dim=0`` is explicit on both calls. The contract is "split - # the batch axis into two halves and swap them"; making the - # axis load-bearing in source guards against silent drift if a - # future refactor reorders tensor axes. pos, neg = out.chunk(2, dim=0) return torch.cat([neg, pos], dim=0) @@ -1294,9 +1280,15 @@ class NaDiT(nn.Module): patches_replace = transformer_options.get("patches_replace", {}) blocks_replace = patches_replace.get("dit", {}) conditions = kwargs.get("condition") - b, tc, h, w = x.shape - x = x.view(b, 16, -1, h, w) - conditions = conditions.view(b, 17, -1, h, w) + if conditions is None: + raise ValueError("SeedVR2 requires conditioning latents from the SeedVR2Conditioning node.") + x = self._check_seedvr2_video_latent(x, SEEDVR2_LATENT_CHANNELS, "latent") + conditions = self._check_seedvr2_video_latent(conditions, SEEDVR2_LATENT_CHANNELS + 1, "conditioning") + b, _, t, h, w = x.shape + if conditions.shape[0] != b or conditions.shape[2:] != (t, h, w): + raise ValueError( + f"SeedVR2 conditioning shape must match latent batch/temporal/spatial dimensions; got latent {tuple(x.shape)} and conditioning {tuple(conditions.shape)}." + ) x = x.movedim(1, -1) conditions = conditions.movedim(1, -1) cache = Cache(disable=disable_cache) @@ -1361,7 +1353,6 @@ class NaDiT(nn.Module): vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify) vid = unflatten(vid, vid_shape) - out = torch.stack(vid) + out = torch.stack(vid) out = out.movedim(-1, 1) - out = out.reshape(out.shape[0], out.shape[1] * out.shape[2], out.shape[3], out.shape[4]) return self._swap_pos_neg_halves(out, transformer_options.get("cond_or_uncond")) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 5daab022a..3f23a4691 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -62,7 +62,6 @@ def tiled_vae( temporal_size=16, temporal_overlap=0, encode=True, - **kwargs, ): if x.ndim != 5: x = x.unsqueeze(2) @@ -166,8 +165,8 @@ def tiled_vae( if single_spatial_tile: result = tile_out[:, :, :target_d, :target_h, :target_w] - if result.device != x.device: - result = result.to(x.device).to(x.dtype) + if result.device != x.device or result.dtype != x.dtype: + result = result.to(device=x.device, dtype=x.dtype) if x.shape[2] == 1 and sf_t == 1: result = result.squeeze(2) bar.update(1) @@ -221,8 +220,8 @@ def tiled_vae( result.div_(count.clamp(min=1e-6)) - if result.device != x.device: - result = result.to(x.device).to(x.dtype) + if result.device != x.device or result.dtype != x.dtype: + result = result.to(device=x.device, dtype=x.dtype) if x.shape[2] == 1 and sf_t == 1: result = result.squeeze(2) @@ -256,15 +255,18 @@ class MemoryState(Enum): UNSET = 3 def get_cache_size(conv_module, input_len, pad_len, dim=0): - dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 - output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 + dilated_kernel_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 + output_len = (input_len + pad_len - dilated_kernel_size) // conv_module.stride[dim] + 1 remain_len = ( - input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size) + input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernel_size) ) - overlap_len = dilated_kernerl_size - conv_module.stride[dim] - cache_len = overlap_len + remain_len # >= 0 + overlap_len = dilated_kernel_size - conv_module.stride[dim] + cache_len = overlap_len + remain_len - assert output_len > 0 + if output_len <= 0: + raise ValueError( + f"SeedVR2 VAE cache input is too short for convolution: input_len={input_len}, pad_len={pad_len}." + ) return cache_len class DiagonalGaussianDistribution(object): @@ -294,52 +296,27 @@ class SpatialNorm(nn.Module): new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) return new_f -# partial implementation of diffusers's Attention for comfyui class Attention(nn.Module): def __init__( self, query_dim: int, - cross_attention_dim: Optional[int] = None, heads: int = 8, - kv_heads: Optional[int] = None, dim_head: int = 64, - dropout: float = 0.0, bias: bool = False, - upcast_softmax: bool = False, norm_num_groups: Optional[int] = None, spatial_norm_dim: Optional[int] = None, out_bias: bool = True, - scale_qk: bool = True, - only_cross_attention: bool = False, eps: float = 1e-5, rescale_output_factor: float = 1.0, residual_connection: bool = False, - out_dim: int = None, - pre_only=False, ): super().__init__() - self.inner_dim = out_dim if out_dim is not None else dim_head * heads - self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads - self.query_dim = query_dim - self.use_bias = bias - self.is_cross_attention = cross_attention_dim is not None - self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.upcast_softmax = upcast_softmax + self.inner_dim = dim_head * heads self.rescale_output_factor = rescale_output_factor self.residual_connection = residual_connection - self.dropout = dropout - self.fused_projections = False - self.out_dim = out_dim if out_dim is not None else query_dim - self.pre_only = pre_only - - self.scale_qk = scale_qk - self.scale = dim_head**-0.5 if self.scale_qk else 1.0 - - self.heads = out_dim // dim_head if out_dim is not None else heads - self.sliceable_head_dim = heads - - self.only_cross_attention = only_cross_attention + self.out_dim = query_dim + self.heads = heads if norm_num_groups is not None: self.group_norm = ops.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) @@ -351,37 +328,19 @@ class Attention(nn.Module): else: self.spatial_norm = None - self.norm_q = None - self.norm_k = None - - self.norm_cross = None self.to_q = ops.Linear(query_dim, self.inner_dim, bias=bias) - - if not self.only_cross_attention: - # only relevant for the `AddedKVProcessor` classes - self.to_k = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) - self.to_v = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) - else: - self.to_k = None - self.to_v = None - - if not self.pre_only: - self.to_out = nn.ModuleList([]) - self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias)) - self.to_out.append(nn.Dropout(dropout)) - else: - self.to_out = None + self.to_k = ops.Linear(query_dim, self.inner_dim, bias=bias) + self.to_v = ops.Linear(query_dim, self.inner_dim, bias=bias) + self.to_out = nn.ModuleList([]) + self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Identity()) self.optimized_vae_attention = vae_attention() - def __call__( + def forward( self, hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None, - *args, - **kwargs, ) -> torch.Tensor: residual = hidden_states @@ -394,20 +353,14 @@ class Attention(nn.Module): batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) + batch_size = hidden_states.shape[0] if self.group_norm is not None: hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = self.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - - key = self.to_k(encoder_hidden_states) - value = self.to_v(encoder_hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // self.heads @@ -417,25 +370,18 @@ class Attention(nn.Module): key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - if self.norm_q is not None: - query = self.norm_q(query) - if self.norm_k is not None: - key = self.norm_k(key) - - if input_ndim == 4 and encoder_hidden_states is hidden_states and attention_mask is None and self.heads == 1: + if input_ndim == 4 and self.heads == 1: query = query.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) key = key.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) value = value.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) hidden_states = self.optimized_vae_attention(query, key, value).reshape(batch_size, self.heads, head_dim, height * width).transpose(2, 3) else: - hidden_states = optimized_attention(query, key, value, heads = self.heads, mask = attention_mask, skip_reshape=True, skip_output_reshape=True) + hidden_states = optimized_attention(query, key, value, heads = self.heads, skip_reshape=True, skip_output_reshape=True) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) hidden_states = hidden_states.to(query.dtype) - # linear proj hidden_states = self.to_out[0](hidden_states) - # dropout hidden_states = self.to_out[1](hidden_states) if input_ndim == 4: @@ -471,7 +417,10 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: memory_occupy = x.numel() * x.element_size() / 1024**3 if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit(): num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups) - assert norm_layer.num_groups % num_chunks == 0 + if norm_layer.num_groups % num_chunks != 0: + raise ValueError( + f"SeedVR2 VAE GroupNorm groups must divide chunks: groups={norm_layer.num_groups}, chunks={num_chunks}." + ) num_groups_per_chunk = norm_layer.num_groups // num_chunks x = list(x.chunk(num_chunks, dim=1)) @@ -485,14 +434,15 @@ def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: x = norm_layer(x) x = x.reshape((b, t, x.size(1), x.size(2), x.size(3))).transpose(1, 2) return x.to(input_dtype) - raise NotImplementedError + raise TypeError(f"SeedVR2 VAE unsupported norm layer type: {type(norm_layer).__name__}") _receptive_field_t = Literal["half", "full"] def extend_head(tensor, times: int = 2, memory = None): if memory is not None: return torch.cat((memory.to(tensor), tensor), dim=2) - assert times >= 0, "Invalid input for function 'extend_head'!" + if times < 0: + raise ValueError(f"SeedVR2 VAE extend_head expected times >= 0, got {times}.") if times == 0: return tensor else: @@ -547,13 +497,11 @@ class InflatedCausalConv3d(ops.Conv3d): padding=(0, 0, 0, 0, 0, 0), prev_cache=None, ): - # Compatible with no limit. if math.isinf(self.memory_limit): if prev_cache is not None: x = torch.cat([prev_cache, x], dim=split_dim - 1) return super().forward(x) - # Compute tensor shape after concat & padding. shape = list(x.size()) if prev_cache is not None: shape[split_dim - 1] += prev_cache.size(split_dim - 1) @@ -597,16 +545,19 @@ class InflatedCausalConv3d(ops.Conv3d): next_cache = None cache_len = cache.size(split_dim) if cache is not None else 0 - next_catch_size = get_cache_size( + next_cache_size = get_cache_size( conv_module=self, input_len=x[idx].size(split_dim) + cache_len, pad_len=pad_len, dim=split_dim - 2, ) - if next_catch_size != 0: - assert next_catch_size <= x[idx].size(split_dim) + if next_cache_size != 0: + if next_cache_size > x[idx].size(split_dim): + raise ValueError( + f"SeedVR2 VAE cache size {next_cache_size} exceeds split size {x[idx].size(split_dim)}." + ) next_cache = ( - x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim) + x[idx].transpose(0, split_dim)[-next_cache_size:].transpose(0, split_dim) ) x[idx] = self.memory_limit_conv( @@ -627,7 +578,8 @@ class InflatedCausalConv3d(ops.Conv3d): memory_state: MemoryState = MemoryState.UNSET, memory_cache = None, ) -> Tensor: - assert memory_state != MemoryState.UNSET + if memory_state == MemoryState.UNSET: + raise ValueError("SeedVR2 VAE convolution requires an explicit MemoryState.") if memory_cache is None: memory_cache = {} if memory_state != MemoryState.ACTIVE: @@ -677,9 +629,8 @@ class InflatedCausalConv3d(ops.Conv3d): input, cache_size=cache_size, memory=memory, times=self.temporal_padding * 2 ) - # Single GPU inference - simplified memory management if ( - memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing + memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] and cache_size != 0 ): if cache_size > input[-1].size(2) and cache is not None and len(input) == 1: @@ -690,7 +641,6 @@ class InflatedCausalConv3d(ops.Conv3d): padding = tuple(x for x in reversed(self.padding) for _ in range(2)) for i in range(len(input)): - # Prepare cache for next input slice. next_cache = None cache_size = 0 if i < len(input) - 1: @@ -700,17 +650,16 @@ class InflatedCausalConv3d(ops.Conv3d): if cache_size > input[i].size(2) and cache is not None: input[i] = torch.cat([cache, input[i]], dim=2) cache = None - assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}" + if cache_size > input[i].size(2): + raise ValueError(f"SeedVR2 VAE cache size {cache_size} exceeds input length {input[i].size(2)}.") next_cache = input[i][:, :, -cache_size:] - # Conv forward for this input slice. input[i] = self.memory_limit_conv( input[i], padding=padding, prev_cache=cache ) - # Update cache. cache = next_cache return input[0] if squeeze_out else input @@ -729,7 +678,6 @@ class Upsample3D(nn.Module): inflation_mode = "tail", temporal_up: bool = False, spatial_up: bool = True, - **kwargs, ): super().__init__() self.channels = channels @@ -760,9 +708,9 @@ class Upsample3D(nn.Module): hidden_states: torch.FloatTensor, memory_state=None, memory_cache=None, - **kwargs, ) -> torch.FloatTensor: - assert hidden_states.shape[1] == self.channels + if hidden_states.shape[1] != self.channels: + raise ValueError(f"SeedVR2 upsample expected {self.channels} channels, got {hidden_states.shape[1]}.") hidden_states = self.upscale_conv(hidden_states) b, channels, f, h, w = hidden_states.shape @@ -785,8 +733,6 @@ class Upsample3D(nn.Module): class Downsample3D(nn.Module): - """A 3D downsampling layer with an optional convolution.""" - def __init__( self, channels, @@ -794,7 +740,6 @@ class Downsample3D(nn.Module): inflation_mode = "tail", spatial_down: bool = False, temporal_down: bool = False, - **kwargs, ): super().__init__() self.channels = channels @@ -823,20 +768,17 @@ class Downsample3D(nn.Module): hidden_states: torch.FloatTensor, memory_state = None, memory_cache = None, - **kwargs, ) -> torch.FloatTensor: - assert hidden_states.shape[1] == self.channels - - if hasattr(self, "norm") and self.norm is not None: - # [Overridden] change to causal norm. - hidden_states = causal_norm_wrapper(self.norm, hidden_states) + if hidden_states.shape[1] != self.channels: + raise ValueError(f"SeedVR2 downsample expected {self.channels} channels, got {hidden_states.shape[1]}.") if self.spatial_down: pad = (0, 1, 0, 1) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) - assert hidden_states.shape[1] == self.channels + if hidden_states.shape[1] != self.channels: + raise ValueError(f"SeedVR2 downsample expected {self.channels} channels after padding, got {hidden_states.shape[1]}.") hidden_states = self.conv(hidden_states, memory_state=memory_state, memory_cache=memory_cache) @@ -848,7 +790,6 @@ class ResnetBlock3D(nn.Module): self, in_channels: int, out_channels: Optional[int] = None, - dropout: float = 0.0, temb_channels: int = 512, groups: int = 32, groups_out: Optional[int] = None, @@ -857,7 +798,6 @@ class ResnetBlock3D(nn.Module): skip_time_act: bool = False, inflation_mode = "tail", time_receptive_field: _receptive_field_t = "half", - **kwargs, ): super().__init__() self.in_channels = in_channels @@ -866,15 +806,14 @@ class ResnetBlock3D(nn.Module): self.skip_time_act = skip_time_act self.nonlinearity = nn.SiLU() if temb_channels is not None: - self.time_emb_proj = ops.Linear(temb_channels, out_channels) + self.time_emb_proj = ops.Linear(temb_channels, self.out_channels) else: self.time_emb_proj = None self.norm1 = ops.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) if groups_out is None: groups_out = groups - self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) - self.use_in_shortcut = self.in_channels != out_channels - self.dropout = torch.nn.Dropout(dropout) + self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=self.out_channels, eps=eps, affine=True) + self.use_in_shortcut = self.in_channels != self.out_channels self.conv1 = InflatedCausalConv3d( self.in_channels, self.out_channels, @@ -886,7 +825,7 @@ class ResnetBlock3D(nn.Module): self.conv2 = InflatedCausalConv3d( self.out_channels, - out_channels, + self.out_channels, kernel_size=3, stride=1, padding=1, @@ -897,7 +836,7 @@ class ResnetBlock3D(nn.Module): if self.use_in_shortcut: self.conv_shortcut = InflatedCausalConv3d( self.in_channels, - out_channels, + self.out_channels, kernel_size=1, stride=1, padding=0, @@ -905,9 +844,7 @@ class ResnetBlock3D(nn.Module): inflation_mode=inflation_mode, ) - def forward( - self, input_tensor, temb, memory_state = None, memory_cache = None, **kwargs - ): + def forward(self, input_tensor, temb, memory_state = None, memory_cache = None): hidden_states = input_tensor hidden_states = causal_norm_wrapper(self.norm1, hidden_states) @@ -928,7 +865,6 @@ class ResnetBlock3D(nn.Module): hidden_states = self.nonlinearity(hidden_states) - hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states, memory_state=memory_state, memory_cache=memory_cache) if self.conv_shortcut is not None: @@ -944,7 +880,6 @@ class DownEncoderBlock3D(nn.Module): self, in_channels: int, out_channels: int, - dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_groups: int = 32, @@ -957,28 +892,23 @@ class DownEncoderBlock3D(nn.Module): ): super().__init__() resnets = [] - temporal_modules = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - # [Override] Replace module. ResnetBlock3D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, - dropout=dropout, output_scale_factor=output_scale_factor, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) ) - temporal_modules.append(nn.Identity()) self.resnets = nn.ModuleList(resnets) - self.temporal_modules = nn.ModuleList(temporal_modules) if add_downsample: self.downsamplers = nn.ModuleList( @@ -1000,11 +930,9 @@ class DownEncoderBlock3D(nn.Module): hidden_states: torch.FloatTensor, memory_state = None, memory_cache = None, - **kwargs, ) -> torch.FloatTensor: - for resnet, temporal in zip(self.resnets, self.temporal_modules): + for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache) - hidden_states = temporal(hidden_states) if self.downsamplers is not None: for downsampler in self.downsamplers: @@ -1018,7 +946,6 @@ class UpDecoderBlock3D(nn.Module): self, in_channels: int, out_channels: int, - dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_groups: int = 32, @@ -1032,33 +959,26 @@ class UpDecoderBlock3D(nn.Module): ): super().__init__() resnets = [] - temporal_modules = [] for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels resnets.append( - # [Override] Replace module. ResnetBlock3D( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, - dropout=dropout, output_scale_factor=output_scale_factor, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) ) - temporal_modules.append(nn.Identity()) - self.resnets = nn.ModuleList(resnets) - self.temporal_modules = nn.ModuleList(temporal_modules) if add_upsample: - # [Override] Replace module & use learnable upsample self.upsamplers = nn.ModuleList( [ Upsample3D( @@ -1080,9 +1000,8 @@ class UpDecoderBlock3D(nn.Module): memory_state=None, memory_cache=None, ) -> torch.FloatTensor: - for resnet, temporal in zip(self.resnets, self.temporal_modules): + for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state, memory_cache=memory_cache) - hidden_states = temporal(hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -1096,7 +1015,6 @@ class UNetMidBlock3D(nn.Module): self, in_channels: int, temb_channels: int, - dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial @@ -1111,16 +1029,13 @@ class UNetMidBlock3D(nn.Module): resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) self.add_attention = add_attention - # there is always at least one resnet resnets = [ - # [Override] Replace module. ResnetBlock3D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, - dropout=dropout, output_scale_factor=output_scale_factor, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, @@ -1148,7 +1063,6 @@ class UNetMidBlock3D(nn.Module): ), residual_connection=True, bias=True, - upcast_softmax=True, ) ) else: @@ -1161,7 +1075,6 @@ class UNetMidBlock3D(nn.Module): temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, - dropout=dropout, output_scale_factor=output_scale_factor, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, @@ -1172,7 +1085,7 @@ class UNetMidBlock3D(nn.Module): self.resnets = nn.ModuleList(resnets) def forward(self, hidden_states, temb=None, memory_state=None, memory_cache=None): - video_length, frame_height, frame_width = hidden_states.size()[-3:] + video_length = hidden_states.size(2) hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state, memory_cache=memory_cache) for attn, resnet in zip(self.attentions, self.resnets[1:]): if attn is not None: @@ -1195,7 +1108,6 @@ class Encoder3D(nn.Module): layers_per_block: int = 2, norm_num_groups: int = 32, mid_block_add_attention=True, - # [Override] add temporal down num temporal_down_num: int = 2, inflation_mode = "tail", time_receptive_field: _receptive_field_t = "half", @@ -1216,17 +1128,15 @@ class Encoder3D(nn.Module): self.mid_block = None self.down_blocks = nn.ModuleList([]) - # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 - # [Override] to support temporal down block design is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 - # Note: take the last ones - assert down_block_type == "DownEncoderBlock3D" + if down_block_type != "DownEncoderBlock3D": + raise ValueError(f"SeedVR2 encoder only supports DownEncoderBlock3D, got {down_block_type}.") down_block = DownEncoderBlock3D( num_layers=self.layers_per_block, @@ -1242,7 +1152,6 @@ class Encoder3D(nn.Module): ) self.down_blocks.append(down_block) - # mid self.mid_block = UNetMidBlock3D( in_channels=block_out_channels[-1], resnet_eps=1e-6, @@ -1256,7 +1165,6 @@ class Encoder3D(nn.Module): time_receptive_field=time_receptive_field, ) - # out self.conv_norm_out = ops.GroupNorm( num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 ) @@ -1274,17 +1182,13 @@ class Encoder3D(nn.Module): memory_state = None, memory_cache = None, ) -> torch.FloatTensor: - r"""The forward method of the `Encoder` class.""" sample = sample.to(next(self.parameters()).device) sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache) - # down for down_block in self.down_blocks: sample = down_block(sample, memory_state=memory_state, memory_cache=memory_cache) - # middle sample = self.mid_block(sample, memory_state=memory_state, memory_cache=memory_cache) - # post-process sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = self.conv_act(sample) sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache) @@ -1303,7 +1207,6 @@ class Decoder3D(nn.Module): layers_per_block: int = 2, norm_num_groups: int = 32, mid_block_add_attention=True, - # [Override] add temporal up block inflation_mode = "tail", time_receptive_field: _receptive_field_t = "half", temporal_up_num: int = 2, @@ -1326,7 +1229,6 @@ class Decoder3D(nn.Module): temb_channels = None - # mid self.mid_block = UNetMidBlock3D( in_channels=block_out_channels[-1], resnet_eps=1e-6, @@ -1340,7 +1242,6 @@ class Decoder3D(nn.Module): time_receptive_field=time_receptive_field, ) - # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): @@ -1349,7 +1250,8 @@ class Decoder3D(nn.Module): is_final_block = i == len(block_out_channels) - 1 is_temporal_up_block = i < self.temporal_up_num - assert up_block_type == "UpDecoderBlock3D" + if up_block_type != "UpDecoderBlock3D": + raise ValueError(f"SeedVR2 decoder only supports UpDecoderBlock3D, got {up_block_type}.") up_block = UpDecoderBlock3D( num_layers=self.layers_per_block + 1, in_channels=prev_output_channel, @@ -1365,7 +1267,6 @@ class Decoder3D(nn.Module): self.up_blocks.append(up_block) prev_output_channel = output_channel - # out self.conv_norm_out = ops.GroupNorm( num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 ) @@ -1375,7 +1276,6 @@ class Decoder3D(nn.Module): ) - # Note: Just copy from Decoder. def forward( self, sample: torch.FloatTensor, @@ -1388,15 +1288,12 @@ class Decoder3D(nn.Module): sample = self.conv_in(sample, memory_state=memory_state, memory_cache=memory_cache) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - # middle sample = self.mid_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache) sample = sample.to(upscale_dtype) - # up for up_block in self.up_blocks: sample = up_block(sample, latent_embeds, memory_state=memory_state, memory_cache=memory_cache) - # post-process sample = causal_norm_wrapper(self.conv_norm_out, sample) sample = self.conv_act(sample) sample = self.conv_out(sample, memory_state=memory_state, memory_cache=memory_cache) @@ -1415,8 +1312,6 @@ class VideoAutoencoderKL(nn.Module): inflation_mode = "pad", time_receptive_field: _receptive_field_t = "full", slicing_sample_min_size = BYTEDANCE_SLICING_SAMPLE_MIN, - *args, - **kwargs, ): self.slicing_sample_min_size = slicing_sample_min_size self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) @@ -1425,7 +1320,6 @@ class VideoAutoencoderKL(nn.Module): up_block_types = ("UpDecoderBlock3D",) * 4 super().__init__() - # pass init params to Encoder self.encoder = Encoder3D( in_channels=in_channels, out_channels=latent_channels, @@ -1433,13 +1327,11 @@ class VideoAutoencoderKL(nn.Module): block_out_channels=block_out_channels, layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, - # [Override] add temporal_down_num parameter temporal_down_num=temporal_scale_num, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, ) - # pass init params to Decoder self.decoder = Decoder3D( in_channels=latent_channels, out_channels=out_channels, @@ -1447,7 +1339,6 @@ class VideoAutoencoderKL(nn.Module): block_out_channels=block_out_channels, layers_per_block=layers_per_block, norm_num_groups=norm_num_groups, - # [Override] add temporal_up_num parameter temporal_up_num=temporal_scale_num, inflation_mode=inflation_mode, time_receptive_field=time_receptive_field, @@ -1489,11 +1380,10 @@ class VideoAutoencoderKL(nn.Module): return output.to(z.device) def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: - sp_size =1 - if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: + if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size: memory_cache = {} split_size = max( - self.slicing_sample_min_size * sp_size, + self.slicing_sample_min_size, getattr(self, "temporal_downsample_factor", 1), ) x_slices = list(x[:, :, 1:].split(split_size=split_size, dim=2)) @@ -1518,10 +1408,9 @@ class VideoAutoencoderKL(nn.Module): return self._encode(x) def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: - sp_size = 1 - if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: + if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size: memory_cache = {} - z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) + z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size, dim=2) decoded_slices = [ self._decode( torch.cat((z[:, :, :1], z_slices[0]), dim=2), @@ -1538,33 +1427,28 @@ class VideoAutoencoderKL(nn.Module): else: return self._decode(z) - def forward( - self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs - ): - # x: [b c t h w] + def forward(self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all"): def _unwrap(value): return value[0] if isinstance(value, tuple) else value if mode == "encode": return _unwrap(self.encode(x)) - elif mode == "decode": + if mode == "decode": return _unwrap(self.decode_(x)) - else: + if mode == "all": latent = _unwrap(self.encode(x)) return _unwrap(self.decode_(latent)) + raise ValueError(f"Unknown SeedVR2 VAE forward mode: {mode}") class VideoAutoencoderKLWrapper(VideoAutoencoderKL): def __init__( self, - *args, spatial_downsample_factor = 8, temporal_downsample_factor = 4, - **kwargs, ): self.spatial_downsample_factor = spatial_downsample_factor self.temporal_downsample_factor = temporal_downsample_factor - self.enable_tiling = False - super().__init__(*args, **kwargs) + super().__init__() self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB) def forward(self, x: torch.FloatTensor): @@ -1581,7 +1465,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): z = p.squeeze(2) return z, p - def encode(self, x, orig_dims=None): + def encode(self, x): z, _ = self._encode_with_raw_latent(x) return z @@ -1594,26 +1478,27 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): ) if z.ndim == 5: - b, c, t_latent, h, w = z.shape - if c != 16: + _, c, _, _, _ = z.shape + if c != SEEDVR2_LATENT_CHANNELS: raise RuntimeError( "SeedVR2 VideoAutoencoderKLWrapper.decode: 5-D latent input must " - f"have 16 channels; got shape {tuple(z.shape)}." + f"have {SEEDVR2_LATENT_CHANNELS} channels; got shape {tuple(z.shape)}." ) latent = z elif z.ndim == 4: b, tc, h, w = z.shape - if tc % 16 != 0: + if tc % SEEDVR2_LATENT_CHANNELS != 0: raise RuntimeError( "SeedVR2 VideoAutoencoderKLWrapper.decode: 4-D latent input must " - "use collapsed channel layout (B, 16*T, H, W); " + f"use collapsed channel layout (B, {SEEDVR2_LATENT_CHANNELS}*T, H, W); " f"got shape {tuple(z.shape)}." ) - latent = z.reshape(b, 16, -1, h, w) + latent = z.reshape(b, SEEDVR2_LATENT_CHANNELS, -1, h, w) else: raise RuntimeError( "SeedVR2 VideoAutoencoderKLWrapper.decode: latent input must be " - "4-D collapsed (B, 16*T, H, W) or 5-D (B, 16, T, H, W); " + f"4-D collapsed (B, {SEEDVR2_LATENT_CHANNELS}*T, H, W) or " + f"5-D (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); " f"got shape {tuple(z.shape)}." ) scale = BYTEDANCE_VAE_SCALING_FACTOR @@ -1621,10 +1506,11 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): latent = latent / scale + shift self.device = latent.device - self.enable_tiling = seedvr2_tiling.get("enable_tiling", False) + enable_tiling = seedvr2_tiling.get("enable_tiling", False) - if self.enable_tiling: + if enable_tiling: decode_seedvr2_args = dict(seedvr2_tiling) + decode_seedvr2_args.pop("enable_tiling", None) tile_h, tile_w = decode_seedvr2_args.get("tile_size", (512, 512)) ov_h, ov_w = decode_seedvr2_args.get("tile_overlap", (64, 64)) decode_seedvr2_args["tile_overlap"] = ( @@ -1641,7 +1527,6 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): else: x = super().decode_(latent) - # ensure even dims for save video h, w = x.shape[-2:] w2 = w - (w % 2) h2 = h - (h % 2) @@ -1693,7 +1578,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): if samples.ndim == 4: samples = samples.unsqueeze(2) samples = samples.contiguous() - samples = samples * 0.9152 + samples = samples * BYTEDANCE_VAE_SCALING_FACTOR return samples def comfy_memory_used_decode(self, shape): @@ -1707,15 +1592,15 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): # plus int64 sort indices dominate peak memory, not the VAE weight dtype. if len(shape) == 5: candidates = [] - if shape[1] == 16: + if shape[1] == SEEDVR2_LATENT_CHANNELS: candidates.append((shape[2], shape[3], shape[4])) - if shape[-1] == 16: + if shape[-1] == SEEDVR2_LATENT_CHANNELS: candidates.append((shape[1], shape[2], shape[3])) if len(candidates) == 0: candidates.append((shape[2], shape[3], shape[4])) pixels = max(output_pixels(*candidate) for candidate in candidates) elif len(shape) == 4: - latent_t = max(1, (shape[1] + 15) // 16) + latent_t = max(1, (shape[1] + SEEDVR2_LATENT_CHANNELS - 1) // SEEDVR2_LATENT_CHANNELS) pixels = output_pixels(latent_t, shape[2], shape[3]) else: pixels = output_pixels(1, shape[-2], shape[-1]) diff --git a/comfy/model_base.py b/comfy/model_base.py index 0004b339a..4b02e5bb4 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -933,7 +933,8 @@ class HunyuanDiT(BaseModel): class SeedVR2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): - super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT) + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.seedvr.model.NaDiT) + def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) condition = kwargs.get("condition", None) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index bf44b832c..97673988b 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -598,43 +598,34 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return dit_config - if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b + seedvr2_7b_separate_key = "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) + if seedvr2_7b_separate_key in state_dict_keys and state_dict[seedvr2_7b_separate_key].shape[1] == 3072: # seedvr2 7b dit_config = {} dit_config["image_model"] = "seedvr2" dit_config["vid_dim"] = 3072 dit_config["heads"] = 24 dit_config["num_layers"] = 36 - # 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.`` - # submodules) at EVERY block — verified by inspecting the 7B - # state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means - # ``MMModule.shared_weights=False``). Native NaDiT computes - # per-block ``shared_weights = not (i < mm_layers)``, so to keep - # every block non-shared we set ``mm_layers = num_layers``. - # Without this, blocks at index >= mm_layers (default 10) try to - # load ``blocks.N.*.all.*`` keys that don't exist in the file, - # silently miss-load → all-black output. + # This checkpoint uses separate vid/txt MMModule keys in every block. dit_config["mm_layers"] = 36 dit_config["norm_eps"] = 1e-5 dit_config["rope_type"] = "rope3d" dit_config["rope_dim"] = 64 dit_config["mlp_type"] = "normal" return dit_config - elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b + if "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b dit_config = {} dit_config["image_model"] = "seedvr2" dit_config["vid_dim"] = 3072 dit_config["heads"] = 24 dit_config["num_layers"] = 36 - # This checkpoint layout carries shared ``all.`` MMModule keys. - # Preserve the historical split: the initial blocks use separate - # vid/txt modules, later blocks use shared modules. + # This checkpoint uses shared all.* MMModule keys after the initial blocks. dit_config["mm_layers"] = 10 dit_config["norm_eps"] = 1e-5 dit_config["rope_type"] = "rope3d" dit_config["rope_dim"] = 64 dit_config["mlp_type"] = "swiglu" return dit_config - elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b + if "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b dit_config = {} dit_config["image_model"] = "seedvr2" dit_config["vid_dim"] = 2560 @@ -1150,8 +1141,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config["heatmap_head"] = True return unet_config +def normalize_seedvr2_unet_config(unet_config): + if unet_config.get("image_model") != "seedvr2" or "num_heads" not in unet_config: + return unet_config + + unet_config = dict(unet_config) + num_heads = unet_config.pop("num_heads") + if "heads" in unet_config and unet_config["heads"] != num_heads: + raise ValueError( + f"SeedVR2 config has conflicting heads={unet_config['heads']} and num_heads={num_heads}." + ) + unet_config["heads"] = num_heads + return unet_config + def model_config_from_unet_config(unet_config, state_dict=None, unet_key_prefix=""): + unet_config = normalize_seedvr2_unet_config(unet_config) for model_config in comfy.supported_models.models: if model_config.matches(unet_config, state_dict, unet_key_prefix=unet_key_prefix): return model_config(unet_config) diff --git a/comfy/sd.py b/comfy/sd.py index 06c6196d3..20726d782 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -472,8 +472,7 @@ class VAE: def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format - if metadata is None or metadata.get("keep_diffusers_format") != "true": - sd = diffusers_convert.convert_vae_state_dict(sd) + sd = diffusers_convert.convert_vae_state_dict(sd) if model_management.is_amd(): VAE_KL_MEM_RATIO = 2.73 @@ -549,7 +548,7 @@ class VAE: self.latent_channels = 16 elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() - self.latent_channels = 16 + self.latent_channels = comfy.ldm.seedvr.vae.SEEDVR2_LATENT_CHANNELS self.latent_dim = 3 self.disable_offload = True self.memory_used_decode = lambda shape, dtype: self.first_stage_model.comfy_memory_used_decode(shape) @@ -1074,6 +1073,20 @@ class VAE: out = self.first_stage_model.encode_tiled(x, **kwargs) return out.to(device=self.output_device, dtype=self.vae_output_dtype()) + def _owned_tiled_args(self, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): + args = {} + if tile_x is not None: + args["tile_x"] = tile_x + if tile_y is not None: + args["tile_y"] = tile_y + if overlap is not None: + args["overlap"] = overlap + if tile_t is not None: + args["tile_t"] = tile_t + if overlap_t is not None: + args["overlap_t"] = overlap_t + return args + def decode(self, samples_in, vae_options={}): self.throw_exception_if_invalid() pixel_samples = None @@ -1153,18 +1166,7 @@ class VAE: with model_management.cuda_device_context(self.device): if self.handles_tiling and dims in (2, 3): - tiled_args = {} - if tile_x is not None: - tiled_args["tile_x"] = tile_x - if tile_y is not None: - tiled_args["tile_y"] = tile_y - if overlap is not None: - tiled_args["overlap"] = overlap - if tile_t is not None: - tiled_args["tile_t"] = tile_t - if overlap_t is not None: - tiled_args["overlap_t"] = overlap_t - output = self._decode_tiled_owned(samples, **tiled_args) + output = self._decode_tiled_owned(samples, **self._owned_tiled_args(tile_x, tile_y, overlap, tile_t, overlap_t)) elif dims == 1 or self.extra_1d_channel is not None: args.pop("tile_y") output = self.decode_tiled_1d(samples, **args) @@ -1269,18 +1271,7 @@ class VAE: samples = self.encode_tiled_(pixel_samples, **args) elif dims == 3: if self.handles_tiling: - tiled_args = {} - if tile_x is not None: - tiled_args["tile_x"] = tile_x - if tile_y is not None: - tiled_args["tile_y"] = tile_y - if overlap is not None: - tiled_args["overlap"] = overlap - if tile_t is not None: - tiled_args["tile_t"] = tile_t - if overlap_t is not None: - tiled_args["overlap_t"] = overlap_t - samples = self._encode_tiled_owned(pixel_samples, **tiled_args) + samples = self._encode_tiled_owned(pixel_samples, **self._owned_tiled_args(tile_x, tile_y, overlap, tile_t, overlap_t)) else: if tile_t is not None: tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) @@ -1850,7 +1841,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (model, clip, vae) - def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False): sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 5c849358e..65cc79bb2 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1688,6 +1688,7 @@ class SeedVR2(supported_models_base.BASE): unet_config = { "image_model": "seedvr2" } + unet_extra_config = {} required_keys = { "{}positive_conditioning", "{}negative_conditioning", diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index bf5b3c15c..933b76237 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -19,21 +19,14 @@ from comfy.ldm.seedvr.constants import ( ) from torchvision.transforms import functional as TVF -from torchvision.transforms import Lambda from torchvision.transforms.functional import InterpolationMode -_SEEDVR2_INVALID_MODEL_MSG_PREFIX = ( - "SeedVR2Conditioning: model object does not match expected SeedVR2 structure" -) - -# Private sentinel for getattr default: distinguishes "attribute missing" -# from "attribute present but None" so the failure message is accurate. +_SEEDVR2_INVALID_MODEL_MSG_PREFIX = "SeedVR2Conditioning: model object does not match expected SeedVR2 structure" _ATTR_MISSING = object() def _resolve_seedvr2_diffusion_model(model): - """Resolve ``model.model.diffusion_model``, failing loud via the ``_ATTR_MISSING`` sentinel so each of the four modes (model/diffusion_model missing vs None) gives an accurate message.""" inner = getattr(model, "model", _ATTR_MISSING) if inner is _ATTR_MISSING: raise RuntimeError( @@ -59,15 +52,7 @@ def _resolve_seedvr2_diffusion_model(model): return diffusion_model -def get_conditions(latent, latent_blur): - t, h, w, c = latent.shape - cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) - cond[:, ..., :-1] = latent_blur[:] - cond[:, ..., -1:] = 1.0 - return cond - def div_pad(image, factor): - height_factor, width_factor = factor height, width = image.shape[-2:] @@ -77,31 +62,25 @@ def div_pad(image, factor): if pad_height == 0 and pad_width == 0: return image - if isinstance(image, torch.Tensor): - padding = (0, pad_width, 0, pad_height) - image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0) - - return image + padding = (0, pad_width, 0, pad_height) + return torch.nn.functional.pad(image, padding, mode='constant', value=0.0) def cut_videos(videos): t = videos.size(1) + if t < 1: + raise ValueError("SeedVR2Preprocess expected at least one frame.") if t == 1: return videos - if t <= 4 : - padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1) - padding = torch.cat(padding, dim=1) - videos = torch.cat([videos, padding], dim=1) - return videos - if (t - 1) % (4) == 0: - return videos - else: - padding = [videos[:, -1].unsqueeze(1)] * ( - 4 - ((t - 1) % (4)) - ) - padding = torch.cat(padding, dim=1) - videos = torch.cat([videos, padding], dim=1) - assert (videos.size(1) - 1) % (4) == 0 + if t <= 4: + padding = videos[:, -1:].repeat(1, 4 - t + 1, 1, 1, 1) + return torch.cat([videos, padding], dim=1) + if (t - 1) % 4 == 0: return videos + padding = videos[:, -1:].repeat(1, 4 - ((t - 1) % 4), 1, 1, 1) + videos = torch.cat([videos, padding], dim=1) + if (videos.size(1) - 1) % 4 != 0: + raise ValueError(f"SeedVR2Preprocess failed to pad video length to 4n+1; got {videos.size(1)} frames.") + return videos def _seedvr2_input_shorter_edge(images, node_name): if images.dim() == 4: @@ -136,8 +115,7 @@ def _seedvr2_pad(images, upscaled_shorter_edge, node_name): b, t, c, h, w = images.shape images = images.reshape(b * t, c, h, w) - clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) - images = clip(images) + images = torch.clamp(images, 0.0, 1.0) images = div_pad(images, (16, 16)) _, _, new_h, new_w = images.shape @@ -295,7 +273,6 @@ class SeedVR2PostProcessing(io.ComfyNode): def _color_transfer_chunked(cls, decoded_flat, reference_flat, output_device, color_correction_method): chunk_size = cls._estimate_color_correction_chunk_size(decoded_flat, color_correction_method) while True: - next_chunk_size = None try: return cls._run_color_transfer_chunks( decoded_flat, reference_flat, output_device, color_correction_method, chunk_size, @@ -307,9 +284,7 @@ class SeedVR2PostProcessing(io.ComfyNode): "SeedVR2PostProcessing: color correction OOM at one frame; " f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}." ) from e - next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR) - - chunk_size = next_chunk_size + chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR) @classmethod def _run_color_transfer_chunks(cls, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size): @@ -392,10 +367,8 @@ class SeedVR2Conditioning(io.ComfyNode): io.Latent.Input("vae_conditioning", display_name="latent"), ], outputs=[ - io.Model.Output(display_name="model", tooltip="The SeedVR2 model, passed through."), io.Conditioning.Output(display_name="positive", tooltip="The positive conditioning for sampling."), io.Conditioning.Output(display_name="negative", tooltip="The negative conditioning for sampling."), - io.Latent.Output(display_name="latent", tooltip="The latent to denoise."), ], ) @@ -408,29 +381,30 @@ class SeedVR2Conditioning(io.ComfyNode): "SeedVR2Conditioning expects a 5-D VAE latent in Comfy " f"channel-first layout; got shape {tuple(vae_conditioning.shape)}." ) - if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS: + if vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS: + if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS: + raise ValueError( + "SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy " + f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); " + f"got channel-last shape {tuple(vae_conditioning.shape)}." + ) raise ValueError( - "SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy " - f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); " - f"got channel-last shape {tuple(vae_conditioning.shape)}." + "SeedVR2Conditioning expects SeedVR2 VAE latents with " + f"{SEEDVR2_LATENT_CHANNELS} channels; got shape {tuple(vae_conditioning.shape)}." ) vae_conditioning = vae_conditioning.movedim(1, -1).contiguous() - model_patcher = model - model = _resolve_seedvr2_diffusion_model(model_patcher) + model = _resolve_seedvr2_diffusion_model(model) pos_cond = model.positive_conditioning neg_cond = model.negative_conditioning - condition = torch.stack([get_conditions(c, c) for c in vae_conditioning]) + mask = vae_conditioning.new_ones(vae_conditioning.shape[:-1] + (1,)) + condition = torch.cat((vae_conditioning, mask), dim=-1) condition = condition.movedim(-1, 1) - latent = vae_conditioning.movedim(-1, 1) - - latent = latent.reshape(latent.shape[0], latent.shape[1] * latent.shape[2], latent.shape[3], latent.shape[4]) - condition = condition.reshape(condition.shape[0], condition.shape[1] * condition.shape[2], condition.shape[3], condition.shape[4]) negative = [[neg_cond.unsqueeze(0), {"condition": condition}]] positive = [[pos_cond.unsqueeze(0), {"condition": condition}]] - return io.NodeOutput(model_patcher, positive, negative, {"samples": latent}) + return io.NodeOutput(positive, negative) class SeedVRExtension(ComfyExtension): @override diff --git a/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py index d36e50428..045502b5b 100644 --- a/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py +++ b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py @@ -1,20 +1,15 @@ -"""Consolidated SeedVR2 conditioning and refactor regression tests. - -Merges the prior test_seedvr2_refactor_nodes.py and -test_seedvr_conditioning_hardening.py modules. Refactor tests use the -top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests -use _import_nodes_seedvr_isolated() for sys.modules isolation when -mocking comfy.model_management. -""" +"""SeedVR2 conditioning node regression tests.""" import importlib import sys from unittest.mock import MagicMock +import pytest import torch import torch.nn as nn from comfy.cli_args import args as cli_args +from comfy.ldm.seedvr.constants import SEEDVR2_LATENT_CHANNELS if not torch.cuda.is_available(): cli_args.cpu = True @@ -79,21 +74,18 @@ def _import_nodes_seedvr_isolated(): class _Rope(nn.Module): - """Minimal RoPE stub exposing a `freqs` parameter.""" def __init__(self): super().__init__() self.freqs = nn.Parameter(torch.zeros(4)) class _Block(nn.Module): - """Minimal transformer block stub holding a `_Rope`.""" def __init__(self): super().__init__() self.rope = _Rope() class _DiffusionModel(nn.Module): - """Stub diffusion model with N blocks and pos/neg conditioning buffers.""" def __init__(self, n_blocks=3, conditioning_dtype=torch.float32): super().__init__() self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)]) @@ -102,18 +94,16 @@ class _DiffusionModel(nn.Module): class _ModelInner: - """Inner model wrapper exposing `.diffusion_model`.""" def __init__(self, diffusion_model): self.diffusion_model = diffusion_model class _ModelPatcher: - """ModelPatcher stub exposing `.model._ModelInner`.""" def __init__(self, diffusion_model): self.model = _ModelInner(diffusion_model) -def test_seedvr2_conditioning_schema_exposes_model_passthrough_output(): +def test_seedvr2_conditioning_schema_exposes_conditioning_outputs(): nodes_seedvr, restore = _import_nodes_seedvr_isolated() try: schema = nodes_seedvr.SeedVR2Conditioning.define_schema() @@ -123,37 +113,50 @@ def test_seedvr2_conditioning_schema_exposes_model_passthrough_output(): ] assert schema.inputs[1].display_name == "latent" assert [output.display_name for output in schema.outputs] == [ - "model", "positive", "negative", - "latent", ] finally: restore() -def test_seedvr2_conditioning_returns_packed_input_latent_deterministically(): +def test_seedvr2_conditioning_rejects_wrong_latent_channels(): + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + patcher = _ModelPatcher(_DiffusionModel()) + vae_conditioning = {"samples": torch.zeros(1, 8, 2, 2, 2)} + + with pytest.raises(ValueError, match=f"{SEEDVR2_LATENT_CHANNELS} channels"): + nodes_seedvr.SeedVR2Conditioning.execute(patcher, vae_conditioning) + finally: + restore() + + +def test_seedvr2_conditioning_returns_conditioning_deterministically(): nodes_seedvr, restore = _import_nodes_seedvr_isolated() try: diffusion_model = _DiffusionModel() patcher = _ModelPatcher(diffusion_model) - samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2) + samples = torch.arange( + 1, + 1 + SEEDVR2_LATENT_CHANNELS * 3 * 2 * 2, + dtype=torch.float32, + ).reshape(1, SEEDVR2_LATENT_CHANNELS, 3, 2, 2) vae_conditioning = {"samples": samples} - _, first_positive, first_negative, first_latent = ( + first_positive, first_negative = ( nodes_seedvr.SeedVR2Conditioning.execute( patcher, vae_conditioning, ) ) - _, second_positive, second_negative, second_latent = ( + second_positive, second_negative = ( nodes_seedvr.SeedVR2Conditioning.execute( patcher, vae_conditioning, ) ) - expected_latent = samples.reshape(1, 6, 2, 2) channel_last = samples.movedim(1, -1).contiguous() expected_condition = torch.cat( [ @@ -161,10 +164,8 @@ def test_seedvr2_conditioning_returns_packed_input_latent_deterministically(): torch.ones((*channel_last.shape[:-1], 1)), ], dim=-1, - ).movedim(-1, 1).reshape(1, 9, 2, 2) + ).movedim(-1, 1) - assert torch.equal(first_latent["samples"], expected_latent) - assert torch.equal(second_latent["samples"], expected_latent) assert torch.equal( first_positive[0][1]["condition"], expected_condition, diff --git a/tests-unit/comfy_test/model_detection_test.py b/tests-unit/comfy_test/model_detection_test.py index 587c393c9..192fdbfe5 100644 --- a/tests-unit/comfy_test/model_detection_test.py +++ b/tests-unit/comfy_test/model_detection_test.py @@ -201,6 +201,17 @@ class TestModelDetection: del sd["positive_conditioning"] assert model_config_from_unet_config(unet_config, sd) is None + def test_seedvr2_model_match_normalizes_num_heads(self): + sd = _make_seedvr2_7b_shared_mm_sd() + unet_config = detect_unet_config(sd, "") + unet_config["num_heads"] = unet_config.pop("heads") + + model_config = model_config_from_unet_config(unet_config, sd) + + assert type(model_config).__name__ == "SeedVR2" + assert model_config.unet_config["heads"] == 24 + assert "num_heads" not in model_config.unet_config + def test_seedvr2_model_match_accepts_full_checkpoint_prefix(self): sd = _add_model_diffusion_prefix(_make_seedvr2_7b_shared_mm_sd()) diff --git a/tests-unit/comfy_test/seedvr_vae_forward_test.py b/tests-unit/comfy_test/seedvr_vae_forward_test.py index d4af4c2b1..7ea7a143e 100644 --- a/tests-unit/comfy_test/seedvr_vae_forward_test.py +++ b/tests-unit/comfy_test/seedvr_vae_forward_test.py @@ -1,22 +1,6 @@ -"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must -honor the actual tensor/tuple return contract of ``encode()`` and -``decode_()`` and must NOT dereference diffusers-style ``.latent_dist`` -or ``.sample`` attributes on those returns. - -The pre-fix body raised ``AttributeError: 'Tensor' object has no -attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and -``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'`` -for ``mode == "decode"`` (the class only defines ``decode_`` with a -trailing underscore). The post-fix body unwraps the optional one-element -tuple shape that ``return_dict=False`` produces and returns the tensor -directly. - -Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses -the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and -overrides ``encode``/``decode_`` with known tensors so the contract can -be probed without loading any real VAE weights. -""" +"""Regression tests for the SeedVR2 VAE forward return contract.""" +import pytest import torch import torch.nn as nn @@ -25,13 +9,13 @@ from comfy.cli_args import args as cli_args if not torch.cuda.is_available(): cli_args.cpu = True -from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402 +from comfy.ldm.seedvr.vae import SEEDVR2_LATENT_CHANNELS, VideoAutoencoderKL # noqa: E402 -_LATENT_SHAPE = (1, 16, 2, 2, 2) +_LATENT_SHAPE = (1, SEEDVR2_LATENT_CHANNELS, 2, 2, 2) _DECODED_SHAPE = (1, 3, 5, 16, 16) _INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16) -_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2) +_INPUT_DECODE_SHAPE = _LATENT_SHAPE class _StubVAE(VideoAutoencoderKL): @@ -64,8 +48,6 @@ def test_forward_decode_returns_tensor(): class _TupleReturningStubVAE(VideoAutoencoderKL): - """Stub whose ``encode``/``decode_`` return the ``(tensor,)`` tuple of ``return_dict=False``, exercising the unwrap branch of ``VideoAutoencoderKL.forward``.""" - def __init__(self): nn.Module.__init__(self) self._encode_tensor = torch.zeros(*_LATENT_SHAPE) @@ -84,3 +66,9 @@ def test_forward_all_unwraps_one_tuple_at_each_step(): result = vae.forward(x, mode="all") assert type(result) is torch.Tensor assert result.shape == torch.Size(_DECODED_SHAPE) + + +def test_forward_rejects_unknown_mode(): + vae = _StubVAE() + with pytest.raises(ValueError, match="Unknown SeedVR2 VAE forward mode"): + vae.forward(torch.zeros(*_INPUT_ENCODE_SHAPE), mode="bogus") diff --git a/tests-unit/comfy_test/test_seedvr2_dtype.py b/tests-unit/comfy_test/test_seedvr2_dtype.py index f03c0406c..8e08b6dde 100644 --- a/tests-unit/comfy_test/test_seedvr2_dtype.py +++ b/tests-unit/comfy_test/test_seedvr2_dtype.py @@ -41,8 +41,9 @@ def test_seedvr2_text_conditioning_accepts_cfg1_single_branch(): def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer(): wrapper = seedvr_vae.VideoAutoencoderKLWrapper.__new__(seedvr_vae.VideoAutoencoderKLWrapper) - estimate = wrapper.comfy_memory_used_decode((1, 16, 26, 120, 160)) - old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2 + latent_channels = seedvr_vae.SEEDVR2_LATENT_CHANNELS + estimate = wrapper.comfy_memory_used_decode((1, latent_channels, 26, 120, 160)) + old_estimate = latent_channels * 120 * 160 * (4 * 8 * 8) * 2 assert estimate == 101 * 960 * 1280 * 160 assert estimate > 15 * 1024 ** 3 diff --git a/tests-unit/comfy_test/test_seedvr2_internals.py b/tests-unit/comfy_test/test_seedvr2_internals.py index 966e9465d..fe4bde1c4 100644 --- a/tests-unit/comfy_test/test_seedvr2_internals.py +++ b/tests-unit/comfy_test/test_seedvr2_internals.py @@ -1,16 +1,4 @@ -"""Consolidated SeedVR2 internals regression tests. - -Sources (all merged verbatim, helper names disambiguated where colliding): - - * GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare - memory_occupy against get_norm_limit(), not float('inf'). - * SeedVR2 variable-length attention split-loop contract. - -Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and -comfy.ldm.modules.attention transitively pull in comfy.model_management, -which probes torch.cuda.current_device() at import time unless args.cpu is -set first. -""" +"""SeedVR2 internals regression tests.""" from __future__ import annotations @@ -35,10 +23,6 @@ from comfy.ldm.seedvr.vae import ( # noqa: E402 from comfy.ldm.seedvr.attention import var_attention_optimized_split # noqa: E402 -# --------------------------------------------------------------------------- -# GroupNorm limit tests (test_seedvr_groupnorm_limit.py) -# --------------------------------------------------------------------------- - _NUM_CHANNELS = 8 _NUM_GROUPS = 4 _TENSOR_SHAPE = (1, 8, 2, 4, 4) @@ -89,10 +73,6 @@ def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls): set_norm_limit(None) -# --------------------------------------------------------------------------- -# SeedVR2 var_attention split-loop tests -# --------------------------------------------------------------------------- - def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch): dim = 8 heads = 2 @@ -140,18 +120,8 @@ def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypa assert call["heads"] == heads assert call["skip_reshape"] is True assert call["skip_output_reshape"] is True - torch.testing.assert_close( - call["cu_seqlens_q"], - torch.tensor([0, 7, 14], dtype=torch.int32), - rtol=0, - atol=0, - ) - torch.testing.assert_close( - call["cu_seqlens_k"], - torch.tensor([0, 7, 14], dtype=torch.int32), - rtol=0, - atol=0, - ) + assert call["cu_seqlens_q"] == [0, 7, 14] + assert call["cu_seqlens_k"] == [0, 7, 14] def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch): @@ -160,7 +130,7 @@ def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatc q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim) k = q + 100 v = q + 200 - cu = torch.tensor([0, 2, 5], dtype=torch.int32) + cu = [0, 2, 5] calls = [] def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs): @@ -197,20 +167,3 @@ def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatc assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls) torch.testing.assert_close(out, q + v, rtol=0, atol=0) - -def test_var_attention_optimized_split_rejects_bad_offsets(): - q = torch.randn(5, 2, 3) - cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32) - cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32) - - with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"): - var_attention_optimized_split( - q, - q, - q, - 2, - cu_bad, - cu_ok, - skip_reshape=True, - skip_output_reshape=True, - ) diff --git a/tests-unit/comfy_test/test_seedvr2_model.py b/tests-unit/comfy_test/test_seedvr2_model.py index 06b2f1564..421e98e43 100644 --- a/tests-unit/comfy_test/test_seedvr2_model.py +++ b/tests-unit/comfy_test/test_seedvr2_model.py @@ -1,17 +1,10 @@ -"""Consolidated SeedVR2 model/graph/forward regression tests. - -Merged from: -- seedvr_model_test.py -- test_seedvr_7b_final_block_text_path.py -- test_seedvr_forward_no_device_cast.py -- test_seedvr_latent_format.py -- test_seedvr2_vae_graph_boundaries.py -""" +"""SeedVR2 model, latent-format, and VAE graph regression tests.""" from __future__ import annotations from unittest.mock import MagicMock +import pytest import torch from torch import nn @@ -22,7 +15,6 @@ if not torch.cuda.is_available(): import comfy # noqa: E402 import comfy.latent_formats # noqa: E402 -import comfy.ldm.seedvr.model # noqa: E402 import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 import comfy.model_management # noqa: E402 @@ -33,9 +25,7 @@ import nodes as nodes_mod # noqa: E402 from comfy.ldm.seedvr.model import NaDiT # noqa: E402 -# --------------------------------------------------------------------------- -# Helpers from seedvr_model_test.py -# --------------------------------------------------------------------------- +_LATENT_CHANNELS = seedvr_vae_mod.SEEDVR2_LATENT_CHANNELS def _make_standin(positive_conditioning): @@ -51,11 +41,6 @@ def _make_standin(positive_conditioning): return _StandIn() -# --------------------------------------------------------------------------- -# Helpers from test_seedvr_7b_final_block_text_path.py -# --------------------------------------------------------------------------- - - class _StubModule(nn.Module): def __init__(self, *args, **kwargs): super().__init__() @@ -88,11 +73,6 @@ def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> lis return flags -# --------------------------------------------------------------------------- -# Helpers from test_seedvr_latent_format.py -# --------------------------------------------------------------------------- - - class _Model: def __init__(self, latent_format): self._latent_format = latent_format @@ -102,11 +82,6 @@ class _Model: return self._latent_format -# --------------------------------------------------------------------------- -# Helpers from test_seedvr2_vae_graph_boundaries.py -# --------------------------------------------------------------------------- - - class _Patcher: def get_free_memory(self, device): return 1024 * 1024 * 1024 @@ -136,14 +111,14 @@ class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper): self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling}) if z.ndim == 4: b, tc, h, w = z.shape - t = tc // 16 + t = tc // _LATENT_CHANNELS else: b, _, t, h, w = z.shape return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device) def test_seedvr2_wrapper_public_encode_returns_tensor(monkeypatch): - raw_latent = torch.full((1, 16, 1, 4, 5), 2.0) + raw_latent = torch.full((1, _LATENT_CHANNELS, 1, 4, 5), 2.0) seen_shapes = [] def base_encode(self, x): @@ -159,12 +134,12 @@ def test_seedvr2_wrapper_public_encode_returns_tensor(monkeypatch): latent = vae.encode(torch.zeros(1, 3, 32, 40)) assert type(latent) is torch.Tensor - assert tuple(latent.shape) == (1, 16, 4, 5) + assert tuple(latent.shape) == (1, _LATENT_CHANNELS, 4, 5) assert seen_shapes == [(1, 3, 1, 32, 40)] def test_seedvr2_wrapper_private_encode_helper_keeps_raw_latent(monkeypatch): - raw_latent = torch.full((1, 16, 1, 4, 5), 3.0) + raw_latent = torch.full((1, _LATENT_CHANNELS, 1, 4, 5), 3.0) def base_encode(self, x): return raw_latent.to(device=x.device, dtype=x.dtype) @@ -177,8 +152,8 @@ def test_seedvr2_wrapper_private_encode_helper_keeps_raw_latent(monkeypatch): latent, raw = vae._encode_with_raw_latent(torch.zeros(1, 3, 32, 40)) - assert tuple(latent.shape) == (1, 16, 4, 5) - assert tuple(raw.shape) == (1, 16, 1, 4, 5) + assert tuple(latent.shape) == (1, _LATENT_CHANNELS, 4, 5) + assert tuple(raw.shape) == (1, _LATENT_CHANNELS, 1, 4, 5) assert torch.equal(raw, raw_latent) @@ -188,7 +163,7 @@ def _make_vae(wrapper): vae.device = torch.device("cpu") vae.output_device = torch.device("cpu") vae.vae_dtype = torch.float32 - vae.latent_channels = 16 + vae.latent_channels = _LATENT_CHANNELS vae.latent_dim = 3 vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8) vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) @@ -212,13 +187,7 @@ def _make_vae(wrapper): return vae -# --------------------------------------------------------------------------- -# Tests from seedvr_model_test.py -# --------------------------------------------------------------------------- - - def test_missing_context_falls_back_to_positive_buffer(): - """``context is None`` falls back to the registered ``positive_conditioning`` buffer and runs to completion.""" pos_buffer = torch.full((58, 5120), 7.0) standin = _make_standin(pos_buffer) txt, txt_shape = standin._resolve_text_conditioning(None) @@ -231,11 +200,6 @@ def test_missing_context_falls_back_to_positive_buffer(): assert txt_shape[0, 0].item() == 58 -# --------------------------------------------------------------------------- -# Tests from test_seedvr_7b_final_block_text_path.py -# --------------------------------------------------------------------------- - - def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch): assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [ False, @@ -268,43 +232,49 @@ def test_seedvr2_7b_rope3d_matches_wrapper_oracle(): torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0) -# --------------------------------------------------------------------------- -# Tests from test_seedvr_latent_format.py -# --------------------------------------------------------------------------- +def test_seedvr2_forward_requires_conditioning_latents(): + model = NaDiT.__new__(NaDiT) + x = torch.zeros(1, _LATENT_CHANNELS, 1, 4, 5) + + with pytest.raises(ValueError, match="requires conditioning latents"): + NaDiT.forward(model, x, timestep=torch.tensor([1.0]), context=None) -def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion(): +def test_seedvr2_latent_format_uses_native_video_latent_shape(): latent_format = comfy.latent_formats.SeedVR2() latent_image = torch.zeros(1, 1, 4, 5) fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image) - assert latent_format.latent_channels == 16 - assert latent_format.latent_dimensions == 2 - assert fixed.shape == (1, 16, 4, 5) + assert latent_format.latent_channels == _LATENT_CHANNELS + assert latent_format.latent_dimensions == 3 + assert fixed.shape == (1, _LATENT_CHANNELS, 1, 4, 5) -# --------------------------------------------------------------------------- -# Tests from test_seedvr2_vae_graph_boundaries.py -# --------------------------------------------------------------------------- +def test_seedvr2_model_requires_native_5d_latent(): + latent = torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5) + assert NaDiT._check_seedvr2_video_latent(latent, _LATENT_CHANNELS, "latent") is latent + + with pytest.raises(ValueError, match="5-D native latent"): + NaDiT._check_seedvr2_video_latent(torch.zeros(1, _LATENT_CHANNELS * 2, 4, 5), _LATENT_CHANNELS, "latent") def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch): monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) - encoded = torch.full((1, 16, 2, 4, 5), 2.0) + encoded = torch.full((1, _LATENT_CHANNELS, 2, 4, 5), 2.0) vae = _make_vae(_EncodeWrapper(encoded)) pixels = torch.zeros(1, 5, 32, 40, 3) node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0] node_latent = node_output["samples"] assert set(node_output) == {"samples"} - assert tuple(node_latent.shape) == (1, 16, 2, 4, 5) + assert tuple(node_latent.shape) == (1, _LATENT_CHANNELS, 2, 4, 5) assert node_latent.dtype == torch.float32 assert node_latent.stride()[-1] == 1 - assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152)) + assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * seedvr_vae_mod.BYTEDANCE_VAE_SCALING_FACTOR)) - tiled = torch.full((1, 16, 2, 4, 5), 3.0) + tiled = torch.full((1, _LATENT_CHANNELS, 2, 4, 5), 3.0) monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled)) tiled_output = nodes_mod.VAEEncodeTiled().encode( vae, @@ -316,9 +286,9 @@ def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeyp )[0] tiled_latent = tiled_output["samples"] assert set(tiled_output) == {"samples"} - assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5) + assert tuple(tiled_latent.shape) == (1, _LATENT_CHANNELS, 2, 4, 5) assert tiled_latent.dtype == torch.float32 - assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152)) + assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * seedvr_vae_mod.BYTEDANCE_VAE_SCALING_FACTOR)) def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch): @@ -327,7 +297,7 @@ def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch): nodes_mod.VAEDecodeTiled().decode( vae, - {"samples": torch.zeros(1, 16, 2, 4, 5)}, + {"samples": torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5)}, tile_size=512, overlap=64, temporal_size=16, @@ -339,7 +309,7 @@ def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch): # knobs are no-ops at the wrapper. assert vae.first_stage_model.calls == [ { - "shape": (1, 16, 2, 4, 5), + "shape": (1, _LATENT_CHANNELS, 2, 4, 5), "seedvr2_tiling": { "enable_tiling": True, "tile_size": (512, 512), diff --git a/tests-unit/comfy_test/test_seedvr2_vae_decode.py b/tests-unit/comfy_test/test_seedvr2_vae_decode.py index ea9f978f3..c486b9195 100644 --- a/tests-unit/comfy_test/test_seedvr2_vae_decode.py +++ b/tests-unit/comfy_test/test_seedvr2_vae_decode.py @@ -13,6 +13,9 @@ import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 from comfy_extras import nodes_seedvr # noqa: E402 +_LATENT_CHANNELS = vae_mod.SEEDVR2_LATENT_CHANNELS + + def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper: wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( vae_mod.VideoAutoencoderKLWrapper @@ -40,7 +43,7 @@ def _decode_with_patches(wrapper, z): def test_decode_b2_t3_multi_frame_batch_unchanged(): wrapper = _make_wrapper() - out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2)) + out = _decode_with_patches(wrapper, torch.zeros(2, _LATENT_CHANNELS * 3, 2, 2)) assert tuple(out.shape) == (2, 3, 3, 16, 16) @@ -62,17 +65,17 @@ def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preproc wrapper = _Wrapper() with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub): - out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5)) + out = wrapper.decode(torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5)) assert tuple(out.shape) == (1, 3, 2, 32, 40) - assert wrapper.calls == [(1, 16, 2, 4, 5)] + assert wrapper.calls == [(1, _LATENT_CHANNELS, 2, 4, 5)] def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents(): wrapper = _Wrapper() with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"): - wrapper.decode(torch.zeros(1, 16, 4)) + wrapper.decode(torch.zeros(1, _LATENT_CHANNELS, 4)) def _t_padded(t_in: int) -> int: diff --git a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py index 0d3c97e4a..33c6d8915 100644 --- a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py +++ b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py @@ -16,9 +16,7 @@ import comfy.sd as sd_mod # noqa: E402 from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402 -# --------------------------------------------------------------------------- -# From test_seedvr_vae_tiled_decode_latent_min_size_override.py -# --------------------------------------------------------------------------- +_LATENT_CHANNELS = seedvr_vae_mod.SEEDVR2_LATENT_CHANNELS def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): @@ -44,7 +42,7 @@ def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype) vae = StubVAEModel() - z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32) + z = torch.zeros((1, _LATENT_CHANNELS, 5, 8, 8), dtype=torch.float32) tiled_vae( z, @@ -61,11 +59,6 @@ def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): assert vae.slicing_latent_min_size == 2 -# --------------------------------------------------------------------------- -# From test_seedvr_vae_tiled_encode_runt_slice_override.py -# --------------------------------------------------------------------------- - - def test_zero_temporal_size_preserves_min_size_when_encode_raises(): class RaisingVAEModel(torch.nn.Module): def __init__(self): @@ -110,7 +103,7 @@ def test_tiled_vae_encode_uses_tensor_return_without_indexing(): def encode(self, t_chunk): self.calls.append(tuple(t_chunk.shape)) b, _, _, h, w = t_chunk.shape - return torch.ones((b, 16, 1, h // 8, w // 8), dtype=t_chunk.dtype) + return torch.ones((b, _LATENT_CHANNELS, 1, h // 8, w // 8), dtype=t_chunk.dtype) vae = TensorEncodeVAEModel() x = torch.zeros((2, 3, 1, 64, 64), dtype=torch.float32) @@ -126,12 +119,34 @@ def test_tiled_vae_encode_uses_tensor_return_without_indexing(): ) assert vae.calls == [(2, 3, 1, 64, 64)] - assert tuple(out.shape) == (2, 16, 1, 8, 8) + assert tuple(out.shape) == (2, _LATENT_CHANNELS, 1, 8, 8) -# --------------------------------------------------------------------------- -# From test_seedvr_vae_tiled_temporal_slicing.py -# --------------------------------------------------------------------------- +def test_tiled_vae_preserves_input_dtype_on_single_tile(): + class FloatOutputVAEModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.slicing_sample_min_size = 4 + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) + + def encode(self, t_chunk): + b, _, _, h, w = t_chunk.shape + return torch.ones((b, _LATENT_CHANNELS, 1, h // 8, w // 8), dtype=torch.float32) + + out = tiled_vae( + torch.zeros((1, 3, 1, 64, 64), dtype=torch.float16), + FloatOutputVAEModel(), + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=0, + temporal_overlap=0, + encode=True, + ) + + assert out.dtype == torch.float16 class _SlicingDecodeVAE(nn.Module): @@ -164,7 +179,10 @@ class _SlicingDecodeVAE(nn.Module): def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size(): vae = _SlicingDecodeVAE(slicing_latent_min_size=2) - z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8) + z = torch.arange( + _LATENT_CHANNELS * 5 * 8 * 8, + dtype=torch.float32, + ).reshape(1, _LATENT_CHANNELS, 5, 8, 8) tiled_vae( z, @@ -199,16 +217,11 @@ def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size(): return torch.zeros(1, 3, 1, 16, 16) with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae): - wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling) + wrapper.decode(torch.zeros(1, _LATENT_CHANNELS, 2, 2), seedvr2_tiling=seedvr2_tiling) assert captured["temporal_overlap"] == 7 -# --------------------------------------------------------------------------- -# From test_vae_decode_tiled_dispatcher_seedvr2_4d.py -# --------------------------------------------------------------------------- - - def _force_oom(*a, **k): raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") @@ -256,10 +269,10 @@ def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode): def test_4d_seedvr2_latent_routes_to_owned_decode_tiled(): wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( seedvr_vae_mod.VideoAutoencoderKLWrapper) - vae = _make_vae(wrapper, latent_channels=16, latent_dim=3) + vae = _make_vae(wrapper, latent_channels=_LATENT_CHANNELS, latent_dim=3) seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) - _dispatch(vae, torch.zeros(1, 16 * 3, 8, 8), seedvr2_call, generic_call, True) + _dispatch(vae, torch.zeros(1, _LATENT_CHANNELS * 3, 8, 8), seedvr2_call, generic_call, True) assert seedvr2_call.call_count == 1 assert generic_call.call_count == 0 @@ -275,11 +288,6 @@ def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled(): assert seedvr2_call.call_count == 0 -# --------------------------------------------------------------------------- -# From test_vae_encode_tiled_fallback_dispatcher_seedvr2.py -# --------------------------------------------------------------------------- - - def _populate_common_vae_attrs_fallback(vae): vae.patcher = MagicMock() vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) @@ -291,7 +299,7 @@ def _populate_common_vae_attrs_fallback(vae): vae.upscale_ratio = 8 vae.upscale_index_formula = None vae.output_channels = 3 - vae.latent_channels = 16 + vae.latent_channels = _LATENT_CHANNELS vae.latent_dim = 3 vae.downscale_ratio = 8 vae.downscale_index_formula = None @@ -334,8 +342,8 @@ def test_seedvr2_3d_routes_to_owned_encode_tiled_on_oom(): vae = _make_seedvr2_vae_fallback() pixel_samples = torch.zeros((1, 8, 64, 64, 3)) - seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) - generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + seedvr2_call = MagicMock(return_value=torch.zeros(1, _LATENT_CHANNELS, 2, 8, 8)) + generic_call = MagicMock(return_value=torch.zeros(1, _LATENT_CHANNELS, 2, 8, 8)) with patch.object(sd_mod.model_management, "raise_non_oom", lambda e: None), \ @@ -363,7 +371,7 @@ def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete(): vae = _make_non_seedvr2_vae_fallback() vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8) vae.upscale_ratio = (lambda a: a * 4, 8, 8) - generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + generic_call = MagicMock(return_value=torch.zeros(1, _LATENT_CHANNELS, 2, 8, 8)) pixel_samples = torch.zeros((1, 8, 64, 64, 3)) with patch.object(sd_mod.model_management, "load_models_gpu",