From 00b633f368e68ffc229084ed819354c29006f92c Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Mon, 8 Jun 2026 15:00:20 -0700 Subject: [PATCH] Revert "Add SeedVR2 support (CORE-6) (#14110)" (#14359) This reverts commit 7863cf0e53ca599a84b3ec5bcda122e4ecc3765c. --- comfy/latent_formats.py | 5 - comfy/ldm/modules/attention.py | 84 +- comfy/ldm/modules/diffusionmodules/model.py | 8 +- comfy/ldm/seedvr/color_fix.py | 340 --- comfy/ldm/seedvr/constants.py | 79 - comfy/ldm/seedvr/model.py | 1665 ------------- comfy/ldm/seedvr/vae.py | 2110 ----------------- comfy/model_base.py | 12 - comfy/model_detection.py | 50 - comfy/sample.py | 8 +- comfy/sd.py | 237 +- comfy/supported_models.py | 31 +- comfy/supported_models_base.py | 2 +- comfy_extras/nodes_seedvr.py | 1015 -------- nodes.py | 42 +- .../test_seedvr2_conditioning.py | 213 -- .../comfy_extras_test/test_seedvr2_nodes.py | 55 - .../test_seedvr2_post_processing.py | 57 - tests-unit/comfy_test/model_detection_test.py | 60 - .../comfy_test/seedvr_vae_forward_test.py | 90 - tests-unit/comfy_test/test_seedvr2_dtype.py | 47 - .../comfy_test/test_seedvr2_internals.py | 341 --- tests-unit/comfy_test/test_seedvr2_model.py | 308 --- .../comfy_test/test_seedvr2_vae_decode.py | 91 - .../comfy_test/test_seedvr2_vae_tiled.py | 347 --- .../test_seedvr_progressive_sampler.py | 126 - 26 files changed, 40 insertions(+), 7383 deletions(-) delete mode 100644 comfy/ldm/seedvr/color_fix.py delete mode 100644 comfy/ldm/seedvr/constants.py delete mode 100644 comfy/ldm/seedvr/model.py delete mode 100644 comfy/ldm/seedvr/vae.py delete mode 100644 comfy_extras/nodes_seedvr.py delete mode 100644 tests-unit/comfy_extras_test/test_seedvr2_conditioning.py delete mode 100644 tests-unit/comfy_extras_test/test_seedvr2_nodes.py delete mode 100644 tests-unit/comfy_extras_test/test_seedvr2_post_processing.py delete mode 100644 tests-unit/comfy_test/seedvr_vae_forward_test.py delete mode 100644 tests-unit/comfy_test/test_seedvr2_dtype.py delete mode 100644 tests-unit/comfy_test/test_seedvr2_internals.py delete mode 100644 tests-unit/comfy_test/test_seedvr2_model.py delete mode 100644 tests-unit/comfy_test/test_seedvr2_vae_decode.py delete mode 100644 tests-unit/comfy_test/test_seedvr2_vae_tiled.py delete mode 100644 tests-unit/comfy_test/test_seedvr_progressive_sampler.py diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index fcbd97c59..bbdfd4bc2 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -4,7 +4,6 @@ class LatentFormat: scale_factor = 1.0 latent_channels = 4 latent_dimensions = 2 - preserve_empty_channel_multiples = False latent_rgb_factors = None latent_rgb_factors_bias = None latent_rgb_factors_reshape = None @@ -780,10 +779,6 @@ class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 -class SeedVR2(LatentFormat): - latent_channels = 16 - preserve_empty_channel_multiples = True - class ACEAudio15(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index b78e764c7..55360535a 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -735,86 +735,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) return out -def _var_attention_qkv(q, k, v, heads, skip_reshape): - if skip_reshape: - return q, k, v, q.shape[-1] - total_tokens, embed_dim = q.shape - head_dim = embed_dim // heads - return ( - q.view(total_tokens, heads, head_dim), - k.view(k.shape[0], heads, head_dim), - v.view(v.shape[0], heads, head_dim), - head_dim, - ) - -def _var_attention_output(out, heads, head_dim, skip_output_reshape): - if skip_output_reshape: - return out - return out.reshape(-1, heads * head_dim) - - -def _use_blackwell_attention(): - device = model_management.get_torch_device() - if device.type != "cuda": - return False - major, minor = torch.cuda.get_device_capability(device) - return (major, minor) >= (12, 0) - - -def _validate_split_cu_seqlens(name, cu_seqlens, token_count): - if cu_seqlens.dtype not in (torch.int32, torch.int64): - raise ValueError(f"{name} must use an integer dtype") - if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2: - raise ValueError(f"{name} must be a 1D tensor with at least two offsets") - if cu_seqlens[0].item() != 0: - raise ValueError(f"{name} must start at 0") - if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item(): - raise ValueError(f"{name} must be strictly increasing") - if cu_seqlens[-1].item() != token_count: - raise ValueError(f"{name} does not match token count") - - -def _split_indices(cu_seqlens): - return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long) - - -def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): - q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) - - _validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0]) - _validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0]) - if cu_seqlens_k[-1].item() != v.shape[0]: - raise ValueError("cu_seqlens_k does not match v token count") - - q_split_indices = _split_indices(cu_seqlens_q) - k_split_indices = _split_indices(cu_seqlens_k) - q_splits = torch.tensor_split(q, q_split_indices, dim=0) - k_splits = torch.tensor_split(k, k_split_indices, dim=0) - v_splits = torch.tensor_split(v, k_split_indices, dim=0) - if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits): - raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count") - - out = [] - for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits): - q_i = q_i.permute(1, 0, 2).unsqueeze(0) - k_i = k_i.permute(1, 0, 2).unsqueeze(0) - v_i = v_i.permute(1, 0, 2).unsqueeze(0) - out_dtype = q_i.dtype - if optimized_attention is attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16): - q_i = q_i.to(torch.bfloat16) - k_i = k_i.to(torch.bfloat16) - v_i = v_i.to(torch.bfloat16) - out_i = optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True) - if out_i.dtype != out_dtype: - out_i = out_i.to(out_dtype) - out.append(out_i.squeeze(0).permute(1, 0, 2)) - - out = torch.cat(out, dim=0) - return _var_attention_output(out, heads, head_dim, skip_output_reshape) - - -optimized_var_attention = var_attention_optimized_split optimized_attention = attention_basic if model_management.sage_attention_enabled(): @@ -837,8 +758,6 @@ else: logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad -logging.info("Using optimized_attention split-loop for variable-length attention") - optimized_attention_masked = optimized_attention @@ -854,7 +773,6 @@ if model_management.xformers_enabled(): register_attention_function("pytorch", attention_pytorch) register_attention_function("sub_quad", attention_sub_quad) register_attention_function("split", attention_split) -register_attention_function("var_attention_optimized_split", var_attention_optimized_split) def optimized_attention_for_device(device, mask=False, small_input=False): @@ -1291,3 +1209,5 @@ class SpatialVideoTransformer(SpatialTransformer): x = self.proj_out(x) out = x + x_in return out + + diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index 235df0b83..fcbaa074f 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -13,7 +13,6 @@ if model_management.xformers_enabled_vae(): import xformers import xformers.ops - def torch_cat_if_needed(xl, dim): xl = [x for x in xl if x is not None and x.shape[dim] > 0] if len(xl) > 1: @@ -23,8 +22,7 @@ def torch_cat_if_needed(xl, dim): else: return None - -def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1): +def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. @@ -35,13 +33,11 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - downscale_freq_shift) + emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if flip_sin_to_cos: - emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0,1,0,0)) return emb diff --git a/comfy/ldm/seedvr/color_fix.py b/comfy/ldm/seedvr/color_fix.py deleted file mode 100644 index 7ddfc03af..000000000 --- a/comfy/ldm/seedvr/color_fix.py +++ /dev/null @@ -1,340 +0,0 @@ -import torch -import torch.nn.functional as F -from torch import Tensor - -from comfy.ldm.seedvr.model import safe_pad_operation -from comfy.ldm.seedvr.vae import safe_interpolate_operation -from comfy.ldm.seedvr.constants import ( - CIELAB_DELTA, - CIELAB_KAPPA, - D65_WHITE_X, - D65_WHITE_Z, - WAVELET_DECOMP_LEVELS, -) - - -def wavelet_blur(image: Tensor, radius): - max_safe_radius = max(1, min(image.shape[-2:]) // 8) - if radius > max_safe_radius: - radius = max_safe_radius - - num_channels = image.shape[1] - - kernel_vals = [ - [0.0625, 0.125, 0.0625], - [0.125, 0.25, 0.125], - [0.0625, 0.125, 0.0625], - ] - kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) - kernel = kernel[None, None].repeat(num_channels, 1, 1, 1) - - image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate') - output = F.conv2d(image, kernel, groups=num_channels, dilation=radius) - - return output - -def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS): - high_freq = torch.zeros_like(image) - - for i in range(levels): - radius = 2 ** i - low_freq = wavelet_blur(image, radius) - high_freq.add_(image).sub_(low_freq) - image = low_freq - - return high_freq, low_freq - -def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: - - if content_feat.shape != style_feat.shape: - # Resize style to match content spatial dimensions - if len(content_feat.shape) >= 3: - # safe_interpolate_operation handles FP16 conversion automatically - style_feat = safe_interpolate_operation( - style_feat, - size=content_feat.shape[-2:], - mode='bilinear', - align_corners=False - ) - - # Decompose both features into frequency components - content_high_freq, content_low_freq = wavelet_decomposition(content_feat) - del content_low_freq # Free memory immediately - - style_high_freq, style_low_freq = wavelet_decomposition(style_feat) - del style_high_freq # Free memory immediately - - if content_high_freq.shape != style_low_freq.shape: - style_low_freq = safe_interpolate_operation( - style_low_freq, - size=content_high_freq.shape[-2:], - mode='bilinear', - align_corners=False - ) - - content_high_freq.add_(style_low_freq) - - return content_high_freq.clamp_(-1.0, 1.0) - -def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor: - original_shape = source.shape - - # Flatten - source_flat = source.flatten() - reference_flat = reference.flatten() - - # Sort both arrays - source_sorted, source_indices = torch.sort(source_flat) - reference_sorted, _ = torch.sort(reference_flat) - del reference_flat - - # Quantile mapping - n_source = len(source_sorted) - n_reference = len(reference_sorted) - - if n_source == n_reference: - matched_sorted = reference_sorted - else: - # Interpolate reference to match source quantiles - source_quantiles = torch.linspace(0, 1, n_source, device=device) - ref_indices = (source_quantiles * (n_reference - 1)).long() - ref_indices.clamp_(0, n_reference - 1) - matched_sorted = reference_sorted[ref_indices] - del source_quantiles, ref_indices, reference_sorted - - del source_sorted, source_flat - - # Reconstruct using argsort (portable across CUDA/ROCm/MPS) - inverse_indices = torch.argsort(source_indices) - del source_indices - matched_flat = matched_sorted[inverse_indices] - del matched_sorted, inverse_indices - - return matched_flat.reshape(original_shape) - -def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor: - """Convert batch of CIELAB images to RGB color space.""" - L, a, b = lab[:, 0], lab[:, 1], lab[:, 2] - - # LAB to XYZ - fy = (L + 16.0) / 116.0 - fx = a.div(500.0).add_(fy) - fz = fy - b / 200.0 - del L, a, b - - # XYZ transformation - x = torch.where( - fx > epsilon, - torch.pow(fx, 3.0), - fx.mul(116.0).sub_(16.0).div_(kappa) - ) - y = torch.where( - fy > epsilon, - torch.pow(fy, 3.0), - fy.mul(116.0).sub_(16.0).div_(kappa) - ) - z = torch.where( - fz > epsilon, - torch.pow(fz, 3.0), - fz.mul(116.0).sub_(16.0).div_(kappa) - ) - del fx, fy, fz - - # Apply D65 white point (in-place) - x.mul_(D65_WHITE_X) - # y *= 1.00000 # (no-op, skip) - z.mul_(D65_WHITE_Z) - - xyz = torch.stack([x, y, z], dim=1) - del x, y, z - - # Matrix multiplication: XYZ -> RGB - B, C, H, W = xyz.shape - xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3) - del xyz - - # Ensure dtype consistency for matrix multiplication - xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype) - rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T) - del xyz_flat - - rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) - del rgb_linear_flat - - # Apply inverse gamma correction (delinearize) - mask = rgb_linear > 0.0031308 - rgb = torch.where( - mask, - torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055), - rgb_linear * 12.92 - ) - del mask, rgb_linear - - return torch.clamp(rgb, 0.0, 1.0) - -def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor: - """Convert batch of RGB images to CIELAB color space using D65 illuminant.""" - # Apply sRGB gamma correction (linearize) - mask = rgb > 0.04045 - rgb_linear = torch.where( - mask, - torch.pow((rgb + 0.055) / 1.055, 2.4), - rgb / 12.92 - ) - del mask - - # Matrix multiplication: RGB -> XYZ - B, C, H, W = rgb_linear.shape - rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3) - del rgb_linear - - # Ensure dtype consistency for matrix multiplication - rgb_flat = rgb_flat.to(dtype=matrix.dtype) - xyz_flat = torch.matmul(rgb_flat, matrix.T) - del rgb_flat - - xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) - del xyz_flat - - # Normalize by D65 white point (in-place) - xyz[:, 0].div_(D65_WHITE_X) # X - # xyz[:, 1] /= 1.00000 # Y (no-op, skip) - xyz[:, 2].div_(D65_WHITE_Z) # Z - - # XYZ to LAB transformation - epsilon_cubed = epsilon ** 3 - mask = xyz > epsilon_cubed - f_xyz = torch.where( - mask, - torch.pow(xyz, 1.0 / 3.0), - xyz.mul(kappa).add_(16.0).div_(116.0) - ) - del xyz, mask - - # Extract channels and compute LAB - L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100] - a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127] - b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127] - del f_xyz - - return torch.stack([L, a, b], dim=1) - -def lab_color_transfer( - content_feat: Tensor, - style_feat: Tensor, - luminance_weight: float = 0.8 -) -> Tensor: - content_feat = wavelet_reconstruction(content_feat, style_feat) - - if content_feat.shape != style_feat.shape: - style_feat = safe_interpolate_operation( - style_feat, - size=content_feat.shape[-2:], - mode='bilinear', - align_corners=False - ) - - device = content_feat.device - - def ensure_float32_precision(c): - orig_dtype = c.dtype - c = c.float() - return c, orig_dtype - content_feat, original_dtype = ensure_float32_precision(content_feat) - style_feat, _ = ensure_float32_precision(style_feat) - - rgb_to_xyz_matrix = torch.tensor([ - [0.4124564, 0.3575761, 0.1804375], - [0.2126729, 0.7151522, 0.0721750], - [0.0193339, 0.1191920, 0.9503041] - ], dtype=torch.float32, device=device) - - xyz_to_rgb_matrix = torch.tensor([ - [ 3.2404542, -1.5371385, -0.4985314], - [-0.9692660, 1.8760108, 0.0415560], - [ 0.0556434, -0.2040259, 1.0572252] - ], dtype=torch.float32, device=device) - - epsilon = CIELAB_DELTA - kappa = CIELAB_KAPPA - - content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) - style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) - - # Convert to LAB color space - content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa) - del content_feat - - style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa) - del style_feat, rgb_to_xyz_matrix - - # Match chrominance channels (a*, b*) for accurate color transfer - matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device) - matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device) - - # Handle luminance with weighted blending - if luminance_weight < 1.0: - # Partially match luminance for better overall color accuracy - matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device) - # Blend: preserve some content L* for detail, adopt some style L* for color - result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight)) - del matched_L - else: - # Fully preserve content luminance - result_L = content_lab[:, 0] - - del content_lab, style_lab - - # Reconstruct LAB with corrected channels - result_lab = torch.stack([result_L, matched_a, matched_b], dim=1) - del result_L, matched_a, matched_b - - # Convert back to RGB - result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa) - del result_lab, xyz_to_rgb_matrix - - # Convert back to [-1, 1] range (in-place) - result = result_rgb.mul_(2.0).sub_(1.0) - del result_rgb - - result = result.to(original_dtype) - - return result - - -def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor: - return wavelet_reconstruction(content_feat, style_feat) - - -def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor: - if content_feat.shape != style_feat.shape: - style_feat = safe_interpolate_operation( - style_feat, - size=content_feat.shape[-2:], - mode='bilinear', - align_corners=False, - ) - - original_dtype = content_feat.dtype - content_feat = content_feat.float() - style_feat = style_feat.float() - - b, c = content_feat.shape[:2] - content_flat = content_feat.reshape(b, c, -1) - style_flat = style_feat.reshape(b, c, -1) - - content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1) - content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1) - style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1) - style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1) - del content_flat, style_flat - - normalized = (content_feat - content_mean) / content_std - del content_mean, content_std - result = normalized * style_std + style_mean - del normalized, style_mean, style_std - - result = result.clamp_(-1.0, 1.0) - if result.dtype != original_dtype: - result = result.to(original_dtype) - return result diff --git a/comfy/ldm/seedvr/constants.py b/comfy/ldm/seedvr/constants.py deleted file mode 100644 index 95838d1dd..000000000 --- a/comfy/ldm/seedvr/constants.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Named constants for the SeedVR2 integration, grouped by provenance. - -Provenance prefixes: -- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline. -- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites - the upstream config/source path it was lifted from. -- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature / - ISO / CIE values; cite the standard. -""" - -# -------------------------------------------------------------------------------------- -# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment) -# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN) -# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070 -# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT). -# -------------------------------------------------------------------------------------- -SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB) -SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB - -# -------------------------------------------------------------------------------------- -# B. Fork heuristics (SEEDVR2 - this integration) -# -------------------------------------------------------------------------------------- -SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim. - # (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.) -SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry. -SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case). -SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM. -SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk. -SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels). -SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16). -SEEDVR2_DEFAULT_TEMPORAL_SIZE = 16 # default VAE temporal tile when unset. - -# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing) -SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk. -SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path. -SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path. -SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path. - -# -------------------------------------------------------------------------------------- -# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR) -# -------------------------------------------------------------------------------------- -BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm. -BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift. -BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem). -BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem). -BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28. -BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28. -BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16). -BYTEDANCE_GN_CHUNKS_FP32 = 2 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp32). -BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD = 64 # attn_video_vae.py:308 (force .contiguous() above this b*t). -BYTEDANCE_BLOCK_OUT_CHANNELS = (128, 256, 512, 512) # s8_c16_t4_inflation_sd3.yaml:7-11. -BYTEDANCE_SLICING_SAMPLE_MIN = 4 # s8_c16_t4_inflation_sd3.yaml:22 (slicing_sample_min_size). -BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE = 4 # infer.py:230 (temporal_downsample_factor); the 4n+1 factor. -BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE = 8 # infer.py:231 (spatial_downsample_factor). -BYTEDANCE_SCHEDULE_T = 1000.0 # configs_3b/main.yaml:65 (schedule.T); timestep range. -BYTEDANCE_SPATIAL_DIVISOR = 16 # inference_seedvr2_3b.py:241 (DivisibleCrop((16,16))). -BYTEDANCE_720P_REF_AREA = 45 * 80 # dit_v2/window.py:32 (720p reference area for window scaling). -BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal window frames). -BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency). -BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim). -# Resolution-dependent timestep-shift linear fits: (x1, y1, x2, y2) for get_lin_function. -BYTEDANCE_IMG_SHIFT_FIT = (256 * 256, 1.0, 1024 * 1024, 3.2) # infer.py:242. -BYTEDANCE_VID_SHIFT_FIT = (256 * 256 * 37, 1.0, 1280 * 720 * 145, 5.0) # infer.py:243. - -# -------------------------------------------------------------------------------------- -# D. Published standards (cite the literature) -# -------------------------------------------------------------------------------------- -ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864. - -# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65). -CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta). -CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa). -D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1). -D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn. -WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR). - -# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and -# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the -# exact existing coefficients move verbatim rather than being retyped here. diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py deleted file mode 100644 index 3fa9fe07e..000000000 --- a/comfy/ldm/seedvr/model.py +++ /dev/null @@ -1,1665 +0,0 @@ -from dataclasses import dataclass -from typing import Optional, Tuple, Union, List, Dict, Any, Callable -import einops -from einops import rearrange -import torch.nn.functional as F -from math import ceil, pi -import torch -from itertools import chain -from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding -from comfy.ldm.modules.attention import optimized_var_attention -from torch.nn.modules.utils import _triple -from torch import nn -import math -from comfy.ldm.flux.math import apply_rope1 -from comfy.ldm.seedvr.constants import ( - BYTEDANCE_720P_REF_AREA, - BYTEDANCE_MAX_TEMPORAL_WINDOW, - BYTEDANCE_ROPE_MAX_FREQ, - BYTEDANCE_SINUSOIDAL_DIM, - ROPE_THETA, - SEEDVR2_7B_MLP_CHUNK, - SEEDVR2_7B_VID_DIM, - SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, -) -import comfy.model_management -import numbers - -def _torch_float8_types(): - return tuple( - getattr(torch, name) - for name in ( - "float8_e4m3fn", - "float8_e4m3fnuz", - "float8_e5m2", - "float8_e5m2fnuz", - "float8_e8m0fnu", - ) - if hasattr(torch, name) - ) - -class CustomRMSNorm(nn.Module): - - def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, device=None, dtype=None): - super(CustomRMSNorm, self).__init__() - - if isinstance(normalized_shape, numbers.Integral): - normalized_shape = (normalized_shape,) - self.normalized_shape = torch.Size(normalized_shape) - self.eps = eps - self.elementwise_affine = elementwise_affine - - if self.elementwise_affine: - self.weight = nn.Parameter(torch.ones(*normalized_shape, device=device, dtype=dtype)) - else: - self.register_parameter('weight', None) - - def forward(self, input): - - dims = tuple(range(-len(self.normalized_shape), 0)) - - normalized = input.float() - variance = normalized.pow(2).mean(dim=dims, keepdim=True) - rms = torch.sqrt(variance + self.eps) - - normalized = normalized / rms - - if self.elementwise_affine: - return normalized * self.weight.to(input.dtype) - return normalized - -class Cache: - def __init__(self, disable=False, prefix="", cache=None): - self.cache = cache if cache is not None else {} - self.disable = disable - self.prefix = prefix - - def __call__(self, key: str, fn: Callable): - if self.disable: - return fn() - - key = self.prefix + key - try: - result = self.cache[key] - except KeyError: - result = fn() - self.cache[key] = result - return result - - def namespace(self, namespace: str): - return Cache( - disable=self.disable, - prefix=self.prefix + namespace + ".", - cache=self.cache, - ) - - def get(self, key: str): - key = self.prefix + key - return self.cache[key] - -def repeat_concat( - vid: torch.FloatTensor, # (VL ... c) - txt: torch.FloatTensor, # (TL ... c) - vid_len: torch.LongTensor, # (n*b) - txt_len: torch.LongTensor, # (b) - txt_repeat: List, # (n) -) -> torch.FloatTensor: # (L ... c) - vid = torch.split(vid, vid_len.tolist()) - txt = torch.split(txt, txt_len.tolist()) - txt = [[x] * n for x, n in zip(txt, txt_repeat)] - txt = list(chain(*txt)) - return torch.cat(list(chain(*zip(vid, txt)))) - -def concat( - vid: torch.FloatTensor, # (VL ... c) - txt: torch.FloatTensor, # (TL ... c) - vid_len: torch.LongTensor, # (b) - txt_len: torch.LongTensor, # (b) -) -> torch.FloatTensor: # (L ... c) - vid = torch.split(vid, vid_len.tolist()) - txt = torch.split(txt, txt_len.tolist()) - return torch.cat(list(chain(*zip(vid, txt)))) - -def concat_idx( - vid_len: torch.LongTensor, # (b) - txt_len: torch.LongTensor, # (b) -) -> Tuple[ - Callable, - Callable, -]: - device = vid_len.device - vid_idx = torch.arange(vid_len.sum(), device=device) - txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) - tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len) - src_idx = torch.argsort(tgt_idx) - return ( - lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx), - lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]), - ) - - -def repeat_concat_idx( - vid_len: torch.LongTensor, # (n*b) - txt_len: torch.LongTensor, # (b) - txt_repeat: torch.LongTensor, # (n) -) -> Tuple[ - Callable, - Callable, -]: - device = vid_len.device - vid_idx = torch.arange(vid_len.sum(), device=device) - txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) - txt_repeat_list = txt_repeat.tolist() - tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) - src_idx = torch.argsort(tgt_idx) - txt_idx_len = len(tgt_idx) - len(vid_idx) - repeat_txt_len = (txt_len * txt_repeat).tolist() - - def unconcat_coalesce(all): - vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len]) - txt_out_coalesced = [] - for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list): - txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1) - txt_out_coalesced.append(txt) - return vid_out, torch.cat(txt_out_coalesced) - - return ( - lambda vid, txt: torch.cat([vid, txt])[tgt_idx], - lambda all: unconcat_coalesce(all), - ) - - -@dataclass -class MMArg: - vid: Any - txt: Any - -def safe_pad_operation(x, padding, mode='constant', value=0.0): - """Safe padding operation that handles Half precision only for problematic modes""" - # Modes qui nécessitent le fix Half precision - problematic_modes = ['replicate', 'reflect', 'circular'] - - if mode in problematic_modes: - try: - return F.pad(x, padding, mode=mode, value=value) - except RuntimeError as e: - if "not implemented for 'Half'" in str(e): - original_dtype = x.dtype - return F.pad(x.float(), padding, mode=mode, value=value).to(original_dtype) - else: - raise e - else: - # Pour 'constant' et autres modes compatibles, pas de fix nécessaire - return F.pad(x, padding, mode=mode, value=value) - - -def get_args(key: str, args: List[Any]) -> List[Any]: - return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] - - -def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: - return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} - - -def get_window_op(name: str): - if name == "720pwin_by_size_bysize": - return make_720Pwindows_bysize - if name == "720pswin_by_size_bysize": - return make_shifted_720Pwindows_bysize - raise ValueError(f"Unknown windowing method: {name}") - - -# -------------------------------- Windowing -------------------------------- # -def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): - t, h, w = size - resized_nt, resized_nh, resized_nw = num_windows - #cal windows under 720p - scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) - resized_h, resized_w = round(h * scale), round(w * scale) - wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. - wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. - nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size. - return [ - ( - slice(it * wt, min((it + 1) * wt, t)), - slice(ih * wh, min((ih + 1) * wh, h)), - slice(iw * ww, min((iw + 1) * ww, w)), - ) - for iw in range(nw) - if min((iw + 1) * ww, w) > iw * ww - for ih in range(nh) - if min((ih + 1) * wh, h) > ih * wh - for it in range(nt) - if min((it + 1) * wt, t) > it * wt - ] - -def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): - t, h, w = size - resized_nt, resized_nh, resized_nw = num_windows - #cal windows under 720p - scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) - resized_h, resized_w = round(h * scale), round(w * scale) - wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. - wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. - - st, sh, sw = ( # shift size. - 0.5 if wt < t else 0, - 0.5 if wh < h else 0, - 0.5 if ww < w else 0, - ) - nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size. - nt, nh, nw = ( # number of window. - nt + 1 if st > 0 else 1, - nh + 1 if sh > 0 else 1, - nw + 1 if sw > 0 else 1, - ) - return [ - ( - slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)), - slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)), - slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)), - ) - for iw in range(nw) - if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0) - for ih in range(nh) - if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0) - for it in range(nt) - if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0) - ] - -class RotaryEmbedding(nn.Module): - def __init__( - self, - dim, - custom_freqs = None, - freqs_for = 'lang', - theta = 10000, - max_freq = 10, - num_freqs = 1, - learned_freq = False, - use_xpos = False, - xpos_scale_base = 512, - interpolate_factor = 1., - theta_rescale_factor = 1., - seq_before_head_dim = False, - cache_if_possible = True, - cache_max_seq_len = 8192 - ): - super().__init__() - - theta *= theta_rescale_factor ** (dim / (dim - 2)) - - self.freqs_for = freqs_for - - if exists(custom_freqs): - freqs = custom_freqs - elif freqs_for == 'lang': - freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) - elif freqs_for == 'pixel': - freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi - elif freqs_for == 'constant': - freqs = torch.ones(num_freqs).float() - - self.cache_if_possible = cache_if_possible - self.cache_max_seq_len = cache_max_seq_len - - self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False) - self.cached_freqs_seq_len = 0 - - self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) - - self.learned_freq = learned_freq - - # dummy for device - - self.register_buffer('dummy', torch.tensor(0), persistent = False) - - # default sequence dimension - - self.seq_before_head_dim = seq_before_head_dim - self.default_seq_dim = -3 if seq_before_head_dim else -2 - - # interpolation factors - - assert interpolate_factor >= 1. - self.interpolate_factor = interpolate_factor - - # xpos - - self.use_xpos = use_xpos - - if not use_xpos: - return - - scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) - self.scale_base = xpos_scale_base - - self.register_buffer('scale', scale, persistent = False) - self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent = False) - self.cached_scales_seq_len = 0 - - # add apply_rotary_emb as static method - - self.apply_rotary_emb = staticmethod(apply_rotary_emb) - - @property - def device(self): - return self.dummy.device - - def get_axial_freqs( - self, - *dims, - offsets = None - ): - Colon = slice(None) - all_freqs = [] - - # handle offset - - if exists(offsets): - assert len(offsets) == len(dims) - - for ind, dim in enumerate(dims): - - offset = 0 - if exists(offsets): - offset = offsets[ind] - - if self.freqs_for == 'pixel': - pos = torch.linspace(-1, 1, steps = dim, device = self.device) - else: - pos = torch.arange(dim, device = self.device) - - pos = pos + offset - - freqs = self.forward(pos, seq_len = dim) - - all_axis = [None] * len(dims) - all_axis[ind] = Colon - - new_axis_slice = (Ellipsis, *all_axis, Colon) - all_freqs.append(freqs[new_axis_slice]) - - # concat all freqs - - all_freqs = torch.broadcast_tensors(*all_freqs) - return torch.cat(all_freqs, dim = -1) - - def forward( - self, - t, - seq_len: int | None = None, - offset = 0 - ): - should_cache = ( - self.cache_if_possible and - not self.learned_freq and - exists(seq_len) and - self.freqs_for != 'pixel' and - (offset + seq_len) <= self.cache_max_seq_len - ) - - if ( - should_cache and \ - exists(self.cached_freqs) and \ - (offset + seq_len) <= self.cached_freqs_seq_len - ): - return self.cached_freqs[offset:(offset + seq_len)].detach() - - freqs = self.freqs - - freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs) - freqs = einops.repeat(freqs, '... n -> ... (n r)', r = 2) - - if should_cache and offset == 0: - self.cached_freqs[:seq_len] = freqs.detach() - self.cached_freqs_seq_len = seq_len - - return freqs - -class RotaryEmbeddingBase(nn.Module): - def __init__(self, dim: int, rope_dim: int): - super().__init__() - self.rope = RotaryEmbedding( - dim=dim // rope_dim, - freqs_for="pixel", - max_freq=BYTEDANCE_ROPE_MAX_FREQ, - ) - freqs = self.rope.freqs - del self.rope.freqs - self.rope.register_buffer("freqs", freqs.data) - - def get_axial_freqs(self, *dims): - return self.rope.get_axial_freqs(*dims) - - -class RotaryEmbedding3d(RotaryEmbeddingBase): - def __init__(self, dim: int): - super().__init__(dim, rope_dim=3) - self.mm = False - - def forward( - self, - q: torch.FloatTensor, # b h l d - k: torch.FloatTensor, # b h l d - size: Tuple[int, int, int], - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - T, H, W = size - freqs = self.get_axial_freqs(T, H, W) - q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) - k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) - q = apply_rotary_emb(freqs, q.float()).to(q.dtype) - k = apply_rotary_emb(freqs, k.float()).to(k.dtype) - q = rearrange(q, "b h T H W d -> b h (T H W) d") - k = rearrange(k, "b h T H W d -> b h (T H W) d") - return q, k - - -class NaRotaryEmbedding3d(RotaryEmbedding3d): - def forward( - self, - q: torch.FloatTensor, - k: torch.FloatTensor, - shape: torch.LongTensor, - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - freqs = cache("rope_freqs_3d", lambda: self.get_freqs(shape)) - freqs = freqs.to(device=q.device) - q = rearrange(q, "L h d -> h L d") - k = rearrange(k, "L h d -> h L d") - q = _apply_seedvr2_rotary_emb(freqs, q.float()).to(q.dtype) - k = _apply_seedvr2_rotary_emb(freqs, k.float()).to(k.dtype) - q = rearrange(q, "h L d -> L h d") - k = rearrange(k, "h L d -> L h d") - return q, k - - @torch._dynamo.disable - def get_freqs( - self, - shape: torch.LongTensor, - ) -> torch.Tensor: - # Primary provenance: ByteDance-Seed/SeedVR models/dit/rope.py builds - # 7B pixel RoPE with the interleaved-angle convention, not Comfy's - # Flux freqs_cis matrix. - plain_rope = RotaryEmbedding( - dim=self.rope.freqs.numel() * 2, - freqs_for="pixel", - max_freq=BYTEDANCE_ROPE_MAX_FREQ, - ) - plain_rope = plain_rope.to(self.rope.dummy.device) - freq_list = [] - for f, h, w in shape.tolist(): - freqs = plain_rope.get_axial_freqs(f, h, w) - freq_list.append(freqs.view(-1, freqs.size(-1))) - return torch.cat(freq_list, dim=0) - - -class MMRotaryEmbeddingBase(RotaryEmbeddingBase): - def __init__(self, dim: int, rope_dim: int): - super().__init__(dim, rope_dim) - self.rope = RotaryEmbedding( - dim=dim // rope_dim, - freqs_for="lang", - theta=ROPE_THETA, - cache_if_possible=False, - ) - freqs = self.rope.freqs - del self.rope.freqs - self.rope.register_buffer("freqs", freqs.data) - self.mm = True - -def slice_at_dim(t, dim_slice: slice, *, dim): - dim += (t.ndim if dim < 0 else 0) - colons = [slice(None)] * t.ndim - colons[dim] = dim_slice - return t[tuple(colons)] - -# rotary embedding helper functions - -def rotate_half(x): - x = rearrange(x, '... (d r) -> ... d r', r = 2) - x1, x2 = x.unbind(dim = -1) - x = torch.stack((-x2, x1), dim = -1) - return rearrange(x, '... d r -> ... (d r)') -def exists(val): - return val is not None - -def apply_rotary_emb( - freqs, - t, - start_index = 0, - scale = 1., - seq_dim = -2, - freqs_seq_dim = None -): - dtype = t.dtype - if not exists(freqs_seq_dim): - if freqs.ndim == 2 or t.ndim == 3: - freqs_seq_dim = 0 - - if t.ndim == 3 or exists(freqs_seq_dim): - seq_len = t.shape[seq_dim] - freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim) - - rot_feats = freqs.shape[-1] - end_index = start_index + rot_feats - - t_left = t[..., :start_index] - t_middle = t[..., start_index:end_index] - t_right = t[..., end_index:] - - angles = freqs.to(t_middle.device)[..., ::2] - cos = torch.cos(angles) * scale - sin = torch.sin(angles) * scale - - col0 = torch.stack([cos, sin], dim=-1) - col1 = torch.stack([-sin, cos], dim=-1) - freqs_mat = torch.stack([col0, col1], dim=-1) - - t_middle_out = apply_rope1(t_middle, freqs_mat) - out = torch.cat((t_left, t_middle_out, t_right), dim=-1) - return out.type(dtype) - - -def _apply_seedvr2_rotary_emb( - freqs: torch.Tensor, - t: torch.Tensor, - start_index: int = 0, - scale: float = 1.0, - seq_dim: int = -2, - freqs_seq_dim: int | None = None, -) -> torch.Tensor: - dtype = t.dtype - if freqs_seq_dim is None and (freqs.ndim == 2 or t.ndim == 3): - freqs_seq_dim = 0 - - if t.ndim == 3 or freqs_seq_dim is not None: - seq_len = t.shape[seq_dim] - freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim) - - rot_feats = freqs.shape[-1] - end_index = start_index + rot_feats - - t_left = t[..., :start_index] - t_middle = t[..., start_index:end_index] - t_right = t[..., end_index:] - - freqs = freqs.to(device=t_middle.device, dtype=t_middle.dtype) - cos = freqs.cos() * scale - sin = freqs.sin() * scale - t_middle = (t_middle * cos) + (rotate_half(t_middle) * sin) - return torch.cat((t_left, t_middle, t_right), dim=-1).to(dtype) - -def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor: - """Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`.""" - angles = freqs_interleaved[..., ::2].float() - cos = torch.cos(angles) - sin = torch.sin(angles) - out = torch.stack([cos, -sin, sin, cos], dim=-1) - return rearrange(out, "... d (i j) -> ... d i j", i=2, j=2) - - -def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - """Rotate the leading ``rot_d = 2 * freqs_cis.shape[-3]`` dims of ``t`` and pass the rest - through; in-place for inference, cloned for training (autograd). Mirrors the legacy - ``apply_rotary_emb`` ``t_left``/``t_middle``/``t_right`` split: 3B ``rope_dim=128`` gives - ``42*3 = 126`` rotated of head_dim 128 (trailing 2 unrotated). Fast path skips the cat when - ``rot_d == t.shape[-1]``. - """ - out = t.clone() if t.requires_grad or comfy.model_management.in_training else t - rot_d = 2 * freqs_cis.shape[-3] - seq_len = out.shape[-2] - for start in range(0, seq_len, SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS): - end = min(start + SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, seq_len) - freqs_chunk = freqs_cis[start:end] - if rot_d == out.shape[-1]: - out[..., start:end, :] = apply_rope1(out[..., start:end, :], freqs_chunk).to(out.dtype) - else: - out[..., start:end, :rot_d] = apply_rope1(out[..., start:end, :rot_d], freqs_chunk).to(out.dtype) - return out - - -class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): - def __init__(self, dim: int): - super().__init__(dim, rope_dim=3) - - def forward( - self, - vid_q: torch.FloatTensor, # L h d - vid_k: torch.FloatTensor, # L h d - vid_shape: torch.LongTensor, # B 3 - txt_q: torch.FloatTensor, # L h d - txt_k: torch.FloatTensor, # L h d - txt_shape: torch.LongTensor, # B 1 - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - torch.FloatTensor, - torch.FloatTensor, - ]: - vid_freqs, txt_freqs = cache( - "mmrope_freqs_3d", - lambda: self.get_freqs(vid_shape, txt_shape), - ) - target_device = vid_q.device - if vid_freqs.device != target_device: - vid_freqs = vid_freqs.to(target_device) - if txt_freqs.device != target_device: - txt_freqs = txt_freqs.to(target_device) - vid_q = rearrange(vid_q, "L h d -> h L d") - vid_k = rearrange(vid_k, "L h d -> h L d") - vid_q = _apply_rope1_partial(vid_q, vid_freqs) - vid_k = _apply_rope1_partial(vid_k, vid_freqs) - vid_q = rearrange(vid_q, "h L d -> L h d") - vid_k = rearrange(vid_k, "h L d -> L h d") - - txt_q = rearrange(txt_q, "L h d -> h L d") - txt_k = rearrange(txt_k, "L h d -> h L d") - txt_q = _apply_rope1_partial(txt_q, txt_freqs) - txt_k = _apply_rope1_partial(txt_k, txt_freqs) - txt_q = rearrange(txt_q, "h L d -> L h d") - txt_k = rearrange(txt_k, "h L d -> L h d") - return vid_q, vid_k, txt_q, txt_k - - @torch._dynamo.disable # Disable compilation: .tolist() is data-dependent and causes graph breaks - def get_freqs( - self, - vid_shape: torch.LongTensor, - txt_shape: torch.LongTensor, - ) -> Tuple[ - torch.Tensor, - torch.Tensor, - ]: - - # Calculate actual max dimensions needed for this batch - max_temporal = 0 - max_height = 0 - max_width = 0 - max_txt_len = 0 - - for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): - max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal - max_height = max(max_height, h) - max_width = max(max_width, w) - max_txt_len = max(max_txt_len, l) - - autocast_device = "cuda" if torch.cuda.is_available() else "cpu" - with torch.amp.autocast(autocast_device, enabled=False): - vid_freqs = self.get_axial_freqs( - max_temporal + 16, - max_height + 4, - max_width + 4, - ).float() - txt_freqs = self.get_axial_freqs(max_txt_len + 16) - - # Now slice as before - vid_freq_list, txt_freq_list = [], [] - for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): - vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) - txt_freq = txt_freqs[:l].repeat(1, 3).reshape(-1, vid_freqs.size(-1)) - vid_freq_list.append(vid_freq) - txt_freq_list.append(txt_freq) - vid_freqs_interleaved = torch.cat(vid_freq_list, dim=0) - txt_freqs_interleaved = torch.cat(txt_freq_list, dim=0) - - # Convert from lucidrains-interleaved layout `[θ0, θ0, θ1, θ1, ...]` - # (produced by `repeat(freqs, '... n -> ... (n r)', r=2)` in the - # upstream `RotaryEmbedding.forward`) to flux-canonical `freqs_cis` - # in shape `[..., d/2, 2, 2]` with `cos/-sin/sin/cos` baked in. - # Mirrors `comfy/ldm/flux/math.py:rope` (line 27) so the trailing - # 2x2 is the per-frequency rotation matrix that - # `comfy.ldm.flux.math.apply_rope1` expects. - return _to_flux_freqs_cis(vid_freqs_interleaved), _to_flux_freqs_cis(txt_freqs_interleaved) - -class MMModule(nn.Module): - def __init__( - self, - module: Callable[..., nn.Module], - *args, - shared_weights: bool = False, - vid_only: bool = False, - **kwargs, - ): - super().__init__() - self.shared_weights = shared_weights - self.vid_only = vid_only - if self.shared_weights: - assert get_args("vid", args) == get_args("txt", args) - assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) - self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) - else: - self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) - self.txt = ( - module(*get_args("txt", args), **get_kwargs("txt", kwargs)) - if not vid_only - else None - ) - - def forward( - self, - vid: torch.FloatTensor, - txt: torch.FloatTensor, - *args, - **kwargs, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - vid_module = self.vid if not self.shared_weights else self.all - vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) - if not self.vid_only: - txt_module = self.txt if not self.shared_weights else self.all - txt = txt.to(device=vid.device, dtype=vid.dtype) - txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) - return vid, txt - -def get_na_rope(rope_type: Optional[str], dim: int): - if rope_type is None: - return None - if rope_type == "rope3d": - return NaRotaryEmbedding3d(dim=dim) - if rope_type == "mmrope3d": - return NaMMRotaryEmbedding3d(dim=dim) - -class NaMMAttention(nn.Module): - def __init__( - self, - vid_dim: int, - txt_dim: int, - heads: int, - head_dim: int, - qk_bias: bool, - qk_norm, - qk_norm_eps: float, - rope_type: Optional[str], - rope_dim: int, - shared_weights: bool, - device, dtype, operations, - **kwargs, - ): - super().__init__() - dim = MMArg(vid_dim, txt_dim) - self.heads = heads - inner_dim = heads * head_dim - qkv_dim = inner_dim * 3 - self.head_dim = head_dim - self.proj_qkv = MMModule( - operations.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights, device=device, dtype=dtype - ) - self.proj_out = MMModule(operations.Linear, inner_dim, dim, shared_weights=shared_weights, device=device, dtype=dtype) - self.norm_q = MMModule( - qk_norm, - normalized_shape=head_dim, - eps=qk_norm_eps, - elementwise_affine=True, - shared_weights=shared_weights, - device=device, dtype=dtype - ) - self.norm_k = MMModule( - qk_norm, - normalized_shape=head_dim, - eps=qk_norm_eps, - elementwise_affine=True, - shared_weights=shared_weights, - device=device, dtype=dtype - ) - - - self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim) - - def forward(self): - pass - -def window( - hid: torch.FloatTensor, # (L c) - hid_shape: torch.LongTensor, # (b n) - window_fn: Callable[[torch.Tensor], List[torch.Tensor]], -): - hid = unflatten(hid, hid_shape) - hid = list(map(window_fn, hid)) - hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) - hid, hid_shape = flatten(list(chain(*hid))) - return hid, hid_shape, hid_windows - -def window_idx( - hid_shape: torch.LongTensor, # (b n) - window_fn: Callable[[torch.Tensor], List[torch.Tensor]], -): - hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) - tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) - tgt_idx = tgt_idx.squeeze(-1) - src_idx = torch.argsort(tgt_idx) - return ( - lambda hid: torch.index_select(hid, 0, tgt_idx), - lambda hid: torch.index_select(hid, 0, src_idx), - tgt_shape, - tgt_windows, - ) - -class NaSwinAttention(NaMMAttention): - def __init__( - self, - *args, - window: Union[int, Tuple[int, int, int]], - window_method: bool, # shifted or not - **kwargs, - ): - super().__init__(*args, **kwargs) - self.version_7b = kwargs.get("version", False) - self.window = _triple(window) - self.window_method = window_method - assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) - - self.window_op = get_window_op(window_method) - - def forward( - self, - vid: torch.FloatTensor, # l c - txt: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, # b 3 - txt_shape: torch.LongTensor, # b 1 - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - - vid_qkv, txt_qkv = self.proj_qkv(vid, txt) - - # re-org the input seq for window attn - cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3") - - def make_window(x: torch.Tensor): - t, h, w, _ = x.shape - window_slices = self.window_op((t, h, w), self.window) - return [x[st, sh, sw] for (st, sh, sw) in window_slices] - - window_partition, window_reverse, window_shape, window_count = cache_win( - "win_transform", - lambda: window_idx(vid_shape, make_window), - ) - vid_qkv_win = window_partition(vid_qkv) - - vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim) - txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) - - vid_q, vid_k, vid_v = vid_qkv_win.unbind(1) - txt_q, txt_k, txt_v = txt_qkv.unbind(1) - - vid_q, txt_q = self.norm_q(vid_q, txt_q) - vid_k, txt_k = self.norm_k(vid_k, txt_k) - - txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) - - vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) - txt_len = txt_len.to(window_count.device) - - # window rope - if self.rope: - if self.version_7b: - vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) - elif self.rope.mm: - # repeat text q and k for window mmrope - _, num_h, _ = txt_q.shape - txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)") - txt_q_repeat = unflatten(txt_q_repeat, txt_shape) - txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)] - txt_q_repeat = list(chain(*txt_q_repeat)) - txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat) - txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h) - - txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)") - txt_k_repeat = unflatten(txt_k_repeat, txt_shape) - txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)] - txt_k_repeat = list(chain(*txt_k_repeat)) - txt_k_repeat, _ = flatten(txt_k_repeat) - txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h) - - vid_q, vid_k, txt_q, txt_k = self.rope( - vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win - ) - else: - vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) - - txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) - all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) - concat_win, unconcat_win = cache_win( - "mm_pnp", lambda: repeat_concat_idx(vid_len_win, txt_len, window_count) - ) - out = optimized_var_attention( - q=concat_win(vid_q, txt_q), - k=concat_win(vid_k, txt_k), - v=concat_win(vid_v, txt_v), - heads=self.heads, skip_reshape=True, skip_output_reshape=True, - cu_seqlens_q=cache_win( - "vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() - ), - cu_seqlens_k=cache_win( - "vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() - ), - ) - vid_out, txt_out = unconcat_win(out) - - vid_out = rearrange(vid_out, "l h d -> l (h d)") - txt_out = rearrange(txt_out, "l h d -> l (h d)") - vid_out = window_reverse(vid_out) - - vid_out, txt_out = self.proj_out(vid_out, txt_out) - - return vid_out, txt_out - -class MLP(nn.Module): - def __init__( - self, - dim: int, - expand_ratio: int, - device, dtype, operations - ): - super().__init__() - self.proj_in = operations.Linear(dim, dim * expand_ratio, device=device, dtype=dtype) - self.act = nn.GELU("tanh") - self.proj_out = operations.Linear(dim * expand_ratio, dim, device=device, dtype=dtype) - - def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: - x = self.proj_in(x) - x = self.act(x) - x = self.proj_out(x) - return x - - -class SwiGLUMLP(nn.Module): - def __init__( - self, - dim: int, - expand_ratio: int, - multiple_of: int = 256, - device=None, dtype=None, operations=None - ): - super().__init__() - hidden_dim = int(2 * dim * expand_ratio / 3) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - self.proj_in_gate = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) - self.proj_out = operations.Linear(hidden_dim, dim, bias=False, device=device, dtype=dtype) - self.proj_in = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) - - def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: - return self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) - -def get_mlp(mlp_type: Optional[str] = "normal"): - # 3b and 7b uses different mlp types - if mlp_type == "normal": - return MLP - elif mlp_type == "swiglu": - return SwiGLUMLP - -class NaMMSRTransformerBlock(nn.Module): - def __init__( - self, - *, - vid_dim: int, - txt_dim: int, - emb_dim: int, - heads: int, - head_dim: int, - expand_ratio: int, - norm, - norm_eps: float, - ada, - qk_bias: bool, - qk_norm, - mlp_type: str, - shared_weights: bool, - rope_type: str, - rope_dim: int, - is_last_layer: bool, - device, dtype, operations, - **kwargs, - ): - super().__init__() - version = kwargs.get("version", False) - dim = MMArg(vid_dim, txt_dim) - self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype) - - self.attn = NaSwinAttention( - vid_dim=vid_dim, - txt_dim=txt_dim, - heads=heads, - head_dim=head_dim, - qk_bias=qk_bias, - qk_norm=qk_norm, - qk_norm_eps=norm_eps, - rope_type=rope_type, - rope_dim=rope_dim, - shared_weights=shared_weights, - window=kwargs.pop("window", None), - window_method=kwargs.pop("window_method", None), - version=version, - device=device, dtype=dtype, operations=operations - ) - - self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype) - self.mlp = MMModule( - get_mlp(mlp_type), - dim=dim, - expand_ratio=expand_ratio, - shared_weights=shared_weights, - vid_only=is_last_layer, - device=device, dtype=dtype, operations=operations - ) - self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype) - self.is_last_layer = is_last_layer - self.version = version - - def _seedvr2_7b_mlp( - self, - vid: torch.FloatTensor, - txt: torch.FloatTensor, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - ]: - vid_module = self.mlp.vid if not self.mlp.shared_weights else self.mlp.all - if comfy.model_management.in_training or vid.requires_grad: - vid = torch.cat([vid_module(chunk) for chunk in vid.split(SEEDVR2_7B_MLP_CHUNK, dim=0)], dim=0) - else: - vid_out = None - offset = 0 - for chunk in vid.split(SEEDVR2_7B_MLP_CHUNK, dim=0): - chunk_out = vid_module(chunk) - if vid_out is None: - vid_out = chunk_out.new_empty((vid.shape[0], *chunk_out.shape[1:])) - vid_out[offset:offset + chunk_out.shape[0]] = chunk_out - offset += chunk_out.shape[0] - vid = vid_out - if not self.mlp.vid_only: - txt_module = self.mlp.txt if not self.mlp.shared_weights else self.mlp.all - txt = txt.to(device=vid.device, dtype=vid.dtype) - txt = txt_module(txt) - return vid, txt - - def forward( - self, - vid: torch.FloatTensor, # l c - txt: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, # b 3 - txt_shape: torch.LongTensor, # b 1 - emb: torch.FloatTensor, - cache: Cache, - ) -> Tuple[ - torch.FloatTensor, - torch.FloatTensor, - torch.LongTensor, - torch.LongTensor, - ]: - hid_len = MMArg( - cache("vid_len", lambda: vid_shape.prod(-1)), - cache("txt_len", lambda: txt_shape.prod(-1)), - ) - ada_kwargs = { - "emb": emb, - "hid_len": hid_len, - "cache": cache, - "branch_tag": MMArg("vid", "txt"), - } - - vid_attn, txt_attn = self.attn_norm(vid, txt) - vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) - vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) - vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) - vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) - - vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) - vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs) - if self.version: - vid_mlp, txt_mlp = self._seedvr2_7b_mlp(vid_mlp, txt_mlp) - else: - vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) - vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs) - vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) - - return vid_mlp, txt_mlp, vid_shape, txt_shape - -class PatchOut(nn.Module): - def __init__( - self, - out_channels: int, - patch_size: Union[int, Tuple[int, int, int]], - dim: int, - device, dtype, operations - ): - super().__init__() - t, h, w = _triple(patch_size) - self.patch_size = t, h, w - self.proj = operations.Linear(dim, out_channels * t * h * w, device=device, dtype=dtype) - - def forward( - self, - vid: torch.Tensor, - ) -> torch.Tensor: - t, h, w = self.patch_size - vid = self.proj(vid) - vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) - if t > 1: - vid = vid[:, :, (t - 1) :] - return vid - -class NaPatchOut(PatchOut): - def forward( - self, - vid: torch.FloatTensor, # l c - vid_shape: torch.LongTensor, - cache: Cache = Cache(disable=True), # for test - vid_shape_before_patchify = None - ) -> Tuple[ - torch.FloatTensor, - torch.LongTensor, - ]: - - t, h, w = self.patch_size - vid = self.proj(vid) - - if not (t == h == w == 1): - vid = unflatten(vid, vid_shape) - for i in range(len(vid)): - vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w) - if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: - vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :] - vid, vid_shape = flatten(vid) - - return vid, vid_shape - -class PatchIn(nn.Module): - def __init__( - self, - in_channels: int, - patch_size: Union[int, Tuple[int, int, int]], - dim: int, - device, dtype, operations - ): - super().__init__() - t, h, w = _triple(patch_size) - self.patch_size = t, h, w - self.proj = operations.Linear(in_channels * t * h * w, dim, device=device, dtype=dtype) - - def forward( - self, - vid: torch.Tensor, - ) -> torch.Tensor: - t, h, w = self.patch_size - if t > 1: - assert vid.size(2) % t == 1 - vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) - vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) - vid = self.proj(vid) - return vid - -class NaPatchIn(PatchIn): - def forward( - self, - vid: torch.Tensor, # l c - vid_shape: torch.LongTensor, - cache: Cache = Cache(disable=True), # for test - ) -> torch.Tensor: - cache = cache.namespace("patch") - vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) - t, h, w = self.patch_size - if not (t == h == w == 1): - vid = unflatten(vid, vid_shape) - for i in range(len(vid)): - if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: - vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0) - vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w) - vid, vid_shape = flatten(vid) - - vid = self.proj(vid) - return vid, vid_shape - -def expand_dims(x: torch.Tensor, dim: int, ndim: int): - shape = x.shape - shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:] - return x.reshape(shape) - - -class AdaSingle(nn.Module): - def __init__( - self, - dim: int, - emb_dim: int, - layers: List[str], - modes: List[str] = ["in", "out"], - device = None, dtype = None, - ): - assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" - super().__init__() - self.dim = dim - self.emb_dim = emb_dim - self.layers = layers - - randn_kwargs = {"device": device} - fp8_types = _torch_float8_types() - if dtype is not None and dtype not in fp8_types: - randn_kwargs["dtype"] = dtype - - for l in layers: - if "in" in modes: - # Passing fp8 ``dtype=`` here would break CPU weight - # loads: CPU has no ``normal_kernel_cpu`` for fp8. - self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim, **randn_kwargs) / dim**0.5)) - self.register_parameter( - f"{l}_scale", nn.Parameter(torch.randn(dim, **randn_kwargs) / dim**0.5 + 1) - ) - if "out" in modes: - self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim, **randn_kwargs) / dim**0.5)) - - def forward( - self, - hid: torch.FloatTensor, # b ... c - emb: torch.FloatTensor, # b d - layer: str, - mode: str, - cache: Cache = Cache(disable=True), - branch_tag: str = "", - hid_len: Optional[torch.LongTensor] = None, # b - ) -> torch.FloatTensor: - idx = self.layers.index(layer) - emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] - emb = expand_dims(emb, 1, hid.ndim + 1) - - if hid_len is not None: - slice_inputs = lambda x, dim: x - emb = cache( - f"emb_repeat_{idx}_{branch_tag}", - lambda: slice_inputs( - torch.repeat_interleave(emb, hid_len, dim=0), - dim=0, - ), - ) - - shiftA, scaleA, gateA = emb.unbind(-1) - shiftB, scaleB, gateB = ( - getattr(self, f"{layer}_shift", None), - getattr(self, f"{layer}_scale", None), - getattr(self, f"{layer}_gate", None), - ) - - fp8_types = _torch_float8_types() - if fp8_types: - target_dtype = hid.dtype - - if shiftB is not None and shiftB.dtype in fp8_types: - shiftB = shiftB.to(target_dtype) - if scaleB is not None and scaleB.dtype in fp8_types: - scaleB = scaleB.to(target_dtype) - if gateB is not None and gateB.dtype in fp8_types: - gateB = gateB.to(target_dtype) - - if mode == "in": - return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) - if mode == "out": - if gateB is not None: - return hid.mul_(gateA + gateB) - else: - return hid.mul_(gateA) - - raise NotImplementedError - - -def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]): - return emb1 if emb2 is None else emb1 + emb2 - - -class TimeEmbedding(nn.Module): - def __init__( - self, - sinusoidal_dim: int, - hidden_dim: int, - output_dim: int, - device, dtype, operations - ): - super().__init__() - self.sinusoidal_dim = sinusoidal_dim - self.proj_in = operations.Linear(sinusoidal_dim, hidden_dim, device=device, dtype=dtype) - self.proj_hid = operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype) - self.proj_out = operations.Linear(hidden_dim, output_dim, device=device, dtype=dtype) - self.act = nn.SiLU() - - def forward( - self, - timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], - device: torch.device, - dtype: torch.dtype, - ) -> torch.FloatTensor: - if not torch.is_tensor(timestep): - timestep = torch.tensor([timestep], device=device, dtype=dtype) - if timestep.ndim == 0: - timestep = timestep[None] - - emb = get_timestep_embedding( - timesteps=timestep, - embedding_dim=self.sinusoidal_dim, - flip_sin_to_cos=False, - downscale_freq_shift=0, - ).to(dtype) - emb = self.proj_in(emb) - emb = self.act(emb) - emb = self.proj_hid(emb) - emb = self.act(emb) - emb = self.proj_out(emb) - return emb - -def flatten( - hid: List[torch.FloatTensor], # List of (*** c) -) -> Tuple[ - torch.FloatTensor, # (L c) - torch.LongTensor, # (b n) -]: - assert len(hid) > 0 - shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) - hid = torch.cat([x.flatten(0, -2) for x in hid]) - return hid, shape - - -def unflatten( - hid: torch.FloatTensor, # (L c) or (L ... c) - hid_shape: torch.LongTensor, # (b n) -) -> List[torch.Tensor]: # List of (*** c) or (*** ... c) - hid_len = hid_shape.prod(-1) - hid = hid.split(hid_len.tolist()) - hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] - return hid - -def repeat( - hid: torch.FloatTensor, # (L c) - hid_shape: torch.LongTensor, # (b n) - pattern: str, - **kwargs: Dict[str, torch.LongTensor], # (b) -) -> Tuple[ - torch.FloatTensor, - torch.LongTensor, -]: - hid = unflatten(hid, hid_shape) - kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] - return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) - -class NaDiT(nn.Module): - - def __init__( - self, - norm_eps, - qk_rope, - num_layers, - mlp_type, - vid_in_channels = 33, - vid_out_channels = 16, - vid_dim = 2560, - txt_in_dim = 5120, - heads = 20, - head_dim = 128, - mm_layers = 10, - expand_ratio = 4, - qk_bias = False, - patch_size = [ 1,2,2 ], - shared_qkv: bool = False, - shared_mlp: bool = False, - window_method: Optional[Tuple[str]] = None, - temporal_window_size: int = None, - temporal_shifted: bool = False, - rope_dim = 128, - rope_type = "mmrope3d", - vid_out_norm: Optional[str] = None, - device = None, - dtype = None, - operations = None, - **kwargs, - ): - self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM - if self._7b_version: - rope_type = "rope3d" - self.dtype = dtype - factory_kwargs = {"device": device, "dtype": dtype} - window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"] - txt_dim = vid_dim - emb_dim = vid_dim * 6 - block_type = ["mmdit_sr"] * num_layers - window = num_layers * [(4,3,3)] - ada = AdaSingle - norm = CustomRMSNorm - qk_norm = CustomRMSNorm - if isinstance(block_type, str): - block_type = [block_type] * num_layers - elif len(block_type) != num_layers: - raise ValueError("The ``block_type`` list should equal to ``num_layers``.") - super().__init__() - # ``torch.empty`` returns uninitialized memory, not zeros. The - # SeedVR2Conditioning fail-loud guard at - # ``comfy_extras/nodes_seedvr.py`` distinguishes "buffer was loaded" - # from "buffer was never populated by the file" by checking - # ``positive_conditioning.abs().sum() == 0``. That sentinel is only - # reliable if the post-construction buffer state is deterministically - # zero, so explicitly zero-fill here rather than relying on the - # allocator's zero-on-alloc behavior (allocator-dependent and not - # contractual). When ``load_state_dict`` populates these buffers - # from a properly-baked SeedVR2 .safetensors, the in-place copy - # overwrites the zeros with the universal SeedVR2 conditioning - # tensors (shape (58, 5120) and (64, 5120) bf16). - self.register_buffer("positive_conditioning", torch.zeros((58, 5120), device=device, dtype=dtype)) - self.register_buffer("negative_conditioning", torch.zeros((64, 5120), device=device, dtype=dtype)) - self.vid_in = NaPatchIn( - in_channels=vid_in_channels, - patch_size=patch_size, - dim=vid_dim, - device=device, dtype=dtype, operations=operations - ) - self.txt_in = ( - operations.Linear(txt_in_dim, txt_dim, **factory_kwargs) - if txt_in_dim and txt_in_dim != txt_dim - else nn.Identity() - ) - self.emb_in = TimeEmbedding( - sinusoidal_dim=BYTEDANCE_SINUSOIDAL_DIM, - hidden_dim=max(vid_dim, txt_dim), - output_dim=emb_dim, - device=device, dtype=dtype, operations=operations - ) - - if window is None or isinstance(window[0], int): - window = [window] * num_layers - if window_method is None or isinstance(window_method, str): - window_method = [window_method] * num_layers - if temporal_window_size is None or isinstance(temporal_window_size, int): - temporal_window_size = [temporal_window_size] * num_layers - if temporal_shifted is None or isinstance(temporal_shifted, bool): - temporal_shifted = [temporal_shifted] * num_layers - - rope_dim = rope_dim if rope_dim is not None else head_dim // 2 - self.blocks = nn.ModuleList( - [ - NaMMSRTransformerBlock( - vid_dim=vid_dim, - txt_dim=txt_dim, - emb_dim=emb_dim, - heads=heads, - head_dim=head_dim, - expand_ratio=expand_ratio, - norm=norm, - norm_eps=norm_eps, - ada=ada, - qk_bias=qk_bias, - qk_rope=qk_rope, - qk_norm=qk_norm, - shared_qkv=shared_qkv, - shared_mlp=shared_mlp, - mlp_type=mlp_type, - rope_dim = rope_dim, - window=window[i], - window_method=window_method[i], - temporal_window_size=temporal_window_size[i], - temporal_shifted=temporal_shifted[i], - is_last_layer=(i == num_layers - 1) and not self._7b_version, - rope_type = rope_type, - shared_weights=not ( - (i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] - ), - version = self._7b_version, - operations = operations, - **kwargs, - **factory_kwargs - ) - for i in range(num_layers) - ] - ) - self.vid_out = NaPatchOut( - out_channels=vid_out_channels, - patch_size=patch_size, - dim=vid_dim, - device=device, dtype=dtype, operations=operations - ) - - self.need_txt_repeat = block_type[0] in [ - "mmdit_stwin", - "mmdit_stwin_spatial", - "mmdit_stwin_3d_spatial", - ] - - self.vid_out_norm = None - if vid_out_norm is not None: - self.vid_out_norm = CustomRMSNorm( - normalized_shape=vid_dim, - eps=norm_eps, - elementwise_affine=True, - device=device, dtype=dtype - ) - self.vid_out_ada = ada( - dim=vid_dim, - emb_dim=emb_dim, - layers=["out"], - modes=["in"], - device=device, dtype=dtype - ) - - def _resolve_text_conditioning(self, context, cond_or_uncond=None): - if context is None or getattr(context, "numel", lambda: None)() == 0: - context = self.positive_conditioning - return flatten([context]) - if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond): - if context.shape[0] == 1: - context = context.squeeze(0) - return flatten([context]) - return flatten(context.unbind(0)) - if context.shape[0] % 2 != 0: - raise ValueError(f"SeedVR2 expected an even text-conditioning batch, got shape {tuple(context.shape)}") - neg_cond, pos_cond = context.chunk(2, dim=0) - if pos_cond.shape[0] == 1: - pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0) - return flatten([pos_cond, neg_cond]) - return flatten((*pos_cond.unbind(0), *neg_cond.unbind(0))) - - @staticmethod - def _seedvr2_is_single_conditioning_branch(cond_or_uncond): - if cond_or_uncond is None or len(cond_or_uncond) == 0: - return False - first = cond_or_uncond[0] - return all(entry == first for entry in cond_or_uncond) - - def _swap_pos_neg_halves(self, out, cond_or_uncond=None): - if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond): - return out - # ``dim=0`` is explicit on both calls. The contract is "split - # the batch axis into two halves and swap them"; making the - # axis load-bearing in source guards against silent drift if a - # future refactor reorders tensor axes. - pos, neg = out.chunk(2, dim=0) - return torch.cat([neg, pos], dim=0) - - def forward( - self, - x, - timestep, - context, # l c - disable_cache: bool = False, # for test # TODO ? // gives an error when set to True - **kwargs - ): - transformer_options = kwargs.get("transformer_options", {}) - patches_replace = transformer_options.get("patches_replace", {}) - blocks_replace = patches_replace.get("dit", {}) - conditions = kwargs.get("condition") - b, tc, h, w = x.shape - x = x.view(b, 16, -1, h, w) - conditions = conditions.view(b, 17, -1, h, w) - x = x.movedim(1, -1) - conditions = conditions.movedim(1, -1) - cache = Cache(disable=disable_cache) - - txt, txt_shape = self._resolve_text_conditioning(context, transformer_options.get("cond_or_uncond")) - - vid, vid_shape = flatten(x) - cond_latent, _ = flatten(conditions) - - vid = torch.cat([vid, cond_latent], dim=-1) - if txt_shape.size(-1) == 1 and self.need_txt_repeat: - txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) - - txt = self.txt_in(txt) - - vid_shape_before_patchify = vid_shape - vid, vid_shape = self.vid_in(vid, vid_shape, cache=cache) - - emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) - - for i, block in enumerate(self.blocks): - if ("block", i) in blocks_replace: - def block_wrap(args): - out = {} - out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] = block( - vid=args["vid"], - txt=args["txt"], - vid_shape=args["vid_shape"], - txt_shape=args["txt_shape"], - emb=args["emb"], - cache=args["cache"], - ) - return out - out = blocks_replace[("block", i)]({ - "vid":vid, - "txt":txt, - "vid_shape":vid_shape, - "txt_shape":txt_shape, - "emb":emb, - "cache":cache, - }, {"original_block": block_wrap}) - vid, txt, vid_shape, txt_shape = out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] - else: - vid, txt, vid_shape, txt_shape = block( - vid=vid, - txt=txt, - vid_shape=vid_shape, - txt_shape=txt_shape, - emb=emb, - cache=cache, - ) - - if self.vid_out_norm: - vid = self.vid_out_norm(vid) - vid = self.vid_out_ada( - vid, - emb=emb, - layer="out", - mode="in", - hid_len=cache("vid_len", lambda: vid_shape.prod(-1)), - cache=cache, - branch_tag="vid", - ) - - vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify) - vid = unflatten(vid, vid_shape) - out = torch.stack(vid) - out = out.movedim(-1, 1) - out = rearrange(out, "b c t h w -> b (c t) h w") - return self._swap_pos_neg_halves(out, transformer_options.get("cond_or_uncond")) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py deleted file mode 100644 index 68b11c0ff..000000000 --- a/comfy/ldm/seedvr/vae.py +++ /dev/null @@ -1,2110 +0,0 @@ -from contextlib import nullcontext -from typing import Literal, Optional, Tuple -import gc -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from torch import Tensor -from contextlib import contextmanager -from comfy.utils import ProgressBar - -from comfy.ldm.seedvr.model import safe_pad_operation -from comfy.ldm.seedvr.constants import ( - BYTEDANCE_BLOCK_OUT_CHANNELS, - BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD, - BYTEDANCE_GN_CHUNKS_FP16, - BYTEDANCE_GN_CHUNKS_FP32, - BYTEDANCE_LOGVAR_CLAMP_MAX, - BYTEDANCE_LOGVAR_CLAMP_MIN, - BYTEDANCE_SLICING_SAMPLE_MIN, - BYTEDANCE_VAE_CONV_MEM_GIB, - BYTEDANCE_VAE_NORM_MEM_GIB, - BYTEDANCE_VAE_SCALING_FACTOR, - BYTEDANCE_VAE_SHIFTING_FACTOR, - BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE, - BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE, - SEEDVR2_LATENT_CHANNELS, -) -from comfy.ldm.modules.attention import optimized_attention -from comfy.ldm.modules.diffusionmodules.model import vae_attention - -import math -from enum import Enum -from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND - -import logging -import comfy.model_management -import comfy.ops -ops = comfy.ops.disable_weight_init - - -def _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap, temporal_scale=1): - if temporal_size is None: - return None - - temporal_size = int(temporal_size) - if temporal_size <= 0: - return 0 - - temporal_overlap = max(0, int(temporal_overlap or 0)) - temporal_overlap = min(temporal_overlap, temporal_size - 1) - temporal_step = temporal_size - temporal_overlap - temporal_scale = max(1, int(temporal_scale)) - return max(1, math.ceil(temporal_step / temporal_scale)) - - -def _seedvr2_clamped_spatial_overlap(overlap, tile_size): - overlap = max(0, int(overlap)) - tile_size = max(1, int(tile_size)) - return min(overlap, tile_size - 1) - - -def _seedvr2_clear_temporal_memory(model): - for module in model.modules(): - if hasattr(module, "memory"): - module.memory = None - - -@torch.inference_mode() -def tiled_vae( - x, - vae_model, - tile_size=(512, 512), - tile_overlap=(64, 64), - temporal_size=16, - temporal_overlap=0, - encode=True, - **kwargs, -): - gc.collect() - comfy.model_management.soft_empty_cache() - - x = x.to(next(vae_model.parameters()).dtype) - if x.ndim != 5: - x = x.unsqueeze(2) - - _, _, d, h, w = x.shape - - sf_s = getattr(vae_model, "spatial_downsample_factor", BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE) - sf_t = getattr(vae_model, "temporal_downsample_factor", BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE) - if encode: - slicing_attr = "slicing_sample_min_size" - slicing_min_size = _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap) - else: - slicing_attr = "slicing_latent_min_size" - slicing_min_size = _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap, sf_t) - if encode: - ti_h, ti_w = tile_size - ov_h = _seedvr2_clamped_spatial_overlap(tile_overlap[0], ti_h) - ov_w = _seedvr2_clamped_spatial_overlap(tile_overlap[1], ti_w) - blend_ov_h = max(0, ov_h // sf_s) - blend_ov_w = max(0, ov_w // sf_s) - target_d = (d + sf_t - 1) // sf_t - target_h = (h + sf_s - 1) // sf_s - target_w = (w + sf_s - 1) // sf_s - else: - ti_h = max(1, tile_size[0] // sf_s) - ti_w = max(1, tile_size[1] // sf_s) - ov_h = _seedvr2_clamped_spatial_overlap(tile_overlap[0] // sf_s, ti_h) - ov_w = _seedvr2_clamped_spatial_overlap(tile_overlap[1] // sf_s, ti_w) - blend_ov_h = ov_h * sf_s - blend_ov_w = ov_w * sf_s - - target_d = max(1, d * sf_t - (sf_t - 1)) - target_h = h * sf_s - target_w = w * sf_s - - stride_h = max(1, ti_h - ov_h) - stride_w = max(1, ti_w - ov_w) - - storage_device = vae_model.device - result = None - count = None - def run_temporal_chunks(spatial_tile, model=vae_model, device=storage_device): - device = torch.device(device) - _seedvr2_clear_temporal_memory(model) - t_chunk = spatial_tile.to(device=device, dtype=next(model.parameters()).dtype, non_blocking=True).contiguous() - old_device = getattr(model, "device", None) - model.device = device - old_slicing_min_size = getattr(model, slicing_attr, None) - if old_slicing_min_size is not None and slicing_min_size is not None: - if slicing_min_size <= 0: - setattr(model, slicing_attr, t_chunk.shape[2]) - else: - setattr(model, slicing_attr, slicing_min_size) - try: - if encode: - out = model.encode(t_chunk)[0] - else: - out = model.decode_(t_chunk) - finally: - if old_slicing_min_size is not None and slicing_min_size is not None: - setattr(model, slicing_attr, old_slicing_min_size) - if old_device is not None: - model.device = old_device - if isinstance(out, (tuple, list)): - out = out[0] - if out.ndim == 4: - out = out.unsqueeze(2) - return out.to(storage_device) - - ramp_cache = {} - def get_ramp(steps): - if steps not in ramp_cache: - t = torch.linspace(0, 1, steps=steps, device=storage_device, dtype=torch.float32) - ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi) - return ramp_cache[steps] - - tile_ranges = [] - for y_idx in range(0, h, stride_h): - y_end = min(y_idx + ti_h, h) - if y_idx > 0 and (y_end - y_idx) <= ov_h: - continue - for x_idx in range(0, w, stride_w): - x_end = min(x_idx + ti_w, w) - if x_idx > 0 and (x_end - x_idx) <= ov_w: - continue - tile_ranges.append((y_idx, y_end, x_idx, x_end)) - - total_tiles = len(tile_ranges) - bar = ProgressBar(total_tiles) - single_spatial_tile = h <= ti_h and w <= ti_w - - _seedvr2_clear_temporal_memory(vae_model) - - def run_tile(tile_index, tile_range): - y_idx, y_end, x_idx, x_end = tile_range - tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end] - tile_out = run_temporal_chunks(tile_x) - return tile_index, y_idx, y_end, x_idx, x_end, tile_out - - ordered_tile_outputs = ( - run_tile(tile_index, tile_range) - for tile_index, tile_range in enumerate(tile_ranges) - ) - - for _, y_idx, y_end, x_idx, x_end, tile_out in ordered_tile_outputs: - - if single_spatial_tile: - result = tile_out[:, :, :target_d, :target_h, :target_w] - if result.device != x.device: - result = result.to(x.device).to(x.dtype) - if x.shape[2] == 1 and sf_t == 1: - result = result.squeeze(2) - bar.update(1) - return result - - if result is None: - b_out, c_out = tile_out.shape[0], tile_out.shape[1] - result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32) - count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32) - - if encode: - ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3] - xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4] - cur_ov_h = max(0, min(blend_ov_h, tile_out.shape[3] // 2)) - cur_ov_w = max(0, min(blend_ov_w, tile_out.shape[4] // 2)) - else: - ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3] - xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4] - cur_ov_h = max(0, min(blend_ov_h, tile_out.shape[3] // 2)) - cur_ov_w = max(0, min(blend_ov_w, tile_out.shape[4] // 2)) - - w_h = torch.ones((tile_out.shape[3],), device=storage_device) - w_w = torch.ones((tile_out.shape[4],), device=storage_device) - - if cur_ov_h > 0: - r = get_ramp(cur_ov_h) - if y_idx > 0: - w_h[:cur_ov_h] = r - if y_end < h: - w_h[-cur_ov_h:] = 1.0 - r - - if cur_ov_w > 0: - r = get_ramp(cur_ov_w) - if x_idx > 0: - w_w[:cur_ov_w] = r - if x_end < w: - w_w[-cur_ov_w:] = 1.0 - r - - final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1) - - valid_d = min(tile_out.shape[2], result.shape[2]) - tile_out = tile_out[:, :, :valid_d, :, :] - - tile_out.mul_(final_weight) - - result[:, :, :valid_d, ys:ye, xs:xe] += tile_out - count[:, :, :, ys:ye, xs:xe] += final_weight - - del tile_out, final_weight, w_h, w_w - bar.update(1) - - result.div_(count.clamp(min=1e-6)) - _seedvr2_clear_temporal_memory(vae_model) - - if result.device != x.device: - result = result.to(x.device).to(x.dtype) - - if x.shape[2] == 1 and sf_t == 1: - result = result.squeeze(2) - - return result - -_NORM_LIMIT = float("inf") -def get_norm_limit(): - return _NORM_LIMIT - - -def set_norm_limit(value: Optional[float] = None): - global _NORM_LIMIT - if value is None: - value = float("inf") - _NORM_LIMIT = value - -@contextmanager -def ignore_padding(model): - orig_padding = model.padding - model.padding = (0, 0, 0) - try: - yield - finally: - model.padding = orig_padding - -class MemoryState(Enum): - DISABLED = 0 - INITIALIZING = 1 - ACTIVE = 2 - UNSET = 3 - -def get_cache_size(conv_module, input_len, pad_len, dim=0): - dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 - output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 - remain_len = ( - input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size) - ) - overlap_len = dilated_kernerl_size - conv_module.stride[dim] - cache_len = overlap_len + remain_len # >= 0 - - assert output_len > 0 - return cache_len - -class DiagonalGaussianDistribution(object): - def __init__(self, parameters: torch.Tensor, deterministic: bool = False): - self.parameters = parameters - self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) - self.logvar = torch.clamp(self.logvar, BYTEDANCE_LOGVAR_CLAMP_MIN, BYTEDANCE_LOGVAR_CLAMP_MAX) - self.deterministic = deterministic - self.std = torch.exp(0.5 * self.logvar) - self.var = torch.exp(self.logvar) - if self.deterministic: - self.var = self.std = torch.zeros_like( - self.mean, device=self.parameters.device, dtype=self.parameters.dtype - ) - - def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: - sample = torch.randn( - self.mean.shape, - generator=generator, - device=self.parameters.device, - dtype=self.parameters.dtype, - ) - x = self.mean + self.std * sample - return x - - def mode(self): - return self.mean - -class SpatialNorm(nn.Module): - def __init__( - self, - f_channels: int, - zq_channels: int, - ): - super().__init__() - self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) - self.conv_y = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - self.conv_b = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: - f_size = f.shape[-2:] - zq = F.interpolate(zq, size=f_size, mode="nearest") - norm_f = self.norm_layer(f) - new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) - return new_f - -# partial implementation of diffusers's Attention for comfyui -class Attention(nn.Module): - def __init__( - self, - query_dim: int, - cross_attention_dim: Optional[int] = None, - heads: int = 8, - kv_heads: Optional[int] = None, - dim_head: int = 64, - dropout: float = 0.0, - bias: bool = False, - upcast_attention: bool = False, - upcast_softmax: bool = False, - added_kv_proj_dim: Optional[int] = None, - added_proj_bias: Optional[bool] = True, - norm_num_groups: Optional[int] = None, - spatial_norm_dim: Optional[int] = None, - out_bias: bool = True, - scale_qk: bool = True, - only_cross_attention: bool = False, - eps: float = 1e-5, - rescale_output_factor: float = 1.0, - residual_connection: bool = False, - _from_deprecated_attn_block: bool = False, - out_dim: int = None, - out_context_dim: int = None, - context_pre_only=None, - pre_only=False, - is_causal: bool = False, - ): - super().__init__() - - self.inner_dim = out_dim if out_dim is not None else dim_head * heads - self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads - self.query_dim = query_dim - self.use_bias = bias - self.is_cross_attention = cross_attention_dim is not None - self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim - self.upcast_attention = upcast_attention - self.upcast_softmax = upcast_softmax - self.rescale_output_factor = rescale_output_factor - self.residual_connection = residual_connection - self.dropout = dropout - self.fused_projections = False - self.out_dim = out_dim if out_dim is not None else query_dim - self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim - self.context_pre_only = context_pre_only - self.pre_only = pre_only - self.is_causal = is_causal - - # we make use of this private variable to know whether this class is loaded - # with an deprecated state dict so that we can convert it on the fly - self._from_deprecated_attn_block = _from_deprecated_attn_block - - self.scale_qk = scale_qk - self.scale = dim_head**-0.5 if self.scale_qk else 1.0 - - self.heads = out_dim // dim_head if out_dim is not None else heads - self.sliceable_head_dim = heads - - self.added_kv_proj_dim = added_kv_proj_dim - self.only_cross_attention = only_cross_attention - - if norm_num_groups is not None: - self.group_norm = ops.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) - else: - self.group_norm = None - - if spatial_norm_dim is not None: - self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) - else: - self.spatial_norm = None - - self.norm_q = None - self.norm_k = None - - self.norm_cross = None - self.to_q = ops.Linear(query_dim, self.inner_dim, bias=bias) - - if not self.only_cross_attention: - # only relevant for the `AddedKVProcessor` classes - self.to_k = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) - self.to_v = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) - else: - self.to_k = None - self.to_v = None - - self.added_proj_bias = added_proj_bias - if self.added_kv_proj_dim is not None: - self.add_k_proj = ops.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) - self.add_v_proj = ops.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) - if self.context_pre_only is not None: - self.add_q_proj = ops.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) - else: - self.add_q_proj = None - self.add_k_proj = None - self.add_v_proj = None - - if not self.pre_only: - self.to_out = nn.ModuleList([]) - self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias)) - self.to_out.append(nn.Dropout(dropout)) - else: - self.to_out = None - - if self.context_pre_only is not None and not self.context_pre_only: - self.to_add_out = ops.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) - else: - self.to_add_out = None - - self.norm_added_q = None - self.norm_added_k = None - self.optimized_vae_attention = vae_attention() - - def __call__( - self, - hidden_states: torch.Tensor, - encoder_hidden_states: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - temb: Optional[torch.Tensor] = None, - *args, - **kwargs, - ) -> torch.Tensor: - - residual = hidden_states - if self.spatial_norm is not None: - hidden_states = self.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) - attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1]) - - if self.group_norm is not None: - hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = self.to_q(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif self.norm_cross: - encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) - - key = self.to_k(encoder_hidden_states) - value = self.to_v(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // self.heads - - query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) - - if self.norm_q is not None: - query = self.norm_q(query) - if self.norm_k is not None: - key = self.norm_k(key) - - if input_ndim == 4 and encoder_hidden_states is hidden_states and attention_mask is None and self.heads == 1: - query = query.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) - key = key.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) - value = value.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) - hidden_states = self.optimized_vae_attention(query, key, value).reshape(batch_size, self.heads, head_dim, height * width).transpose(2, 3) - else: - hidden_states = optimized_attention(query, key, value, heads = self.heads, mask = attention_mask, skip_reshape=True, skip_output_reshape=True) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # linear proj - hidden_states = self.to_out[0](hidden_states) - # dropout - hidden_states = self.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if self.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / self.rescale_output_factor - - return hidden_states - - -def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor): - with torch.no_grad(): - depth = weight_3d.size(2) - weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) - return weight_3d - -def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor): - with torch.no_grad(): - bias_3d.copy_(bias_2d) - return bias_3d - - -def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn): - weight_name = prefix + "weight" - bias_name = prefix + "bias" - if weight_name in state_dict: - weight_2d = state_dict[weight_name] - if weight_2d.dim() == 4: - weight_3d = inflate_weight_fn( - weight_2d=weight_2d, - weight_3d=layer.weight, - ) - state_dict[weight_name] = weight_3d - else: - return state_dict - if bias_name in state_dict: - bias_2d = state_dict[bias_name] - if bias_2d.dim() == 1: - bias_3d = inflate_bias_fn( - bias_2d=bias_2d, - bias_3d=layer.bias, - ) - state_dict[bias_name] = bias_3d - return state_dict - -def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: - input_dtype = x.dtype - if isinstance(norm_layer, (ops.LayerNorm, ops.RMSNorm)): - if x.ndim == 4: - x = rearrange(x, "b c h w -> b h w c") - x = norm_layer(x) - x = rearrange(x, "b h w c -> b c h w") - return x.to(input_dtype) - if x.ndim == 5: - x = rearrange(x, "b c t h w -> b t h w c") - x = norm_layer(x) - x = rearrange(x, "b t h w c -> b c t h w") - return x.to(input_dtype) - if isinstance(norm_layer, (ops.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): - if x.ndim <= 4: - return norm_layer(x).to(input_dtype) - if x.ndim == 5: - t = x.size(2) - x = rearrange(x, "b c t h w -> (b t) c h w") - memory_occupy = x.numel() * x.element_size() / 1024**3 - if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit(): - num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups) - assert norm_layer.num_groups % num_chunks == 0 - num_groups_per_chunk = norm_layer.num_groups // num_chunks - - x = list(x.chunk(num_chunks, dim=1)) - weights = norm_layer.weight.chunk(num_chunks, dim=0) - biases = norm_layer.bias.chunk(num_chunks, dim=0) - for i, (w, b) in enumerate(zip(weights, biases)): - x[i] = F.group_norm(x[i], num_groups_per_chunk, w, b, norm_layer.eps) - x[i] = x[i].to(input_dtype) - x = torch.cat(x, dim=1) - else: - x = norm_layer(x) - x = rearrange(x, "(b t) c h w -> b c t h w", t=t) - return x.to(input_dtype) - raise NotImplementedError - -def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): - problematic_modes = ['bilinear', 'bicubic', 'trilinear'] - - if mode in problematic_modes: - try: - return F.interpolate( - x, - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor - ) - except RuntimeError as e: - if ("not implemented for 'Half'" in str(e) or - "compute_indices_weights" in str(e)): - original_dtype = x.dtype - return F.interpolate( - x.float(), - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor - ).to(original_dtype) - else: - raise e - else: - # Pour 'nearest' et autres modes compatibles, pas de fix nécessaire - return F.interpolate( - x, - size=size, - scale_factor=scale_factor, - mode=mode, - align_corners=align_corners, - recompute_scale_factor=recompute_scale_factor - ) - -_receptive_field_t = Literal["half", "full"] - -def extend_head(tensor, times: int = 2, memory = None): - if memory is not None: - return torch.cat((memory.to(tensor), tensor), dim=2) - assert times >= 0, "Invalid input for function 'extend_head'!" - if times == 0: - return tensor - else: - tile_repeat = [1] * tensor.ndim - tile_repeat[2] = times - return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2) - -def cache_send_recv(tensor, cache_size, times, memory=None): - recv_buffer = None - - if memory is not None: - recv_buffer = memory.to(tensor[0]) - elif times > 0: - tile_repeat = [1] * tensor[0].ndim - tile_repeat[2] = times - recv_buffer = torch.tile(tensor[0][:, :, :1], tile_repeat) - - return recv_buffer - -class InflatedCausalConv3d(ops.Conv3d): - def __init__( - self, - *args, - inflation_mode, - memory_device = "same", - **kwargs, - ): - self.inflation_mode = inflation_mode - self.memory = None - super().__init__(*args, **kwargs) - self.temporal_padding = self.padding[0] - self.memory_device = memory_device - self.padding = (0, *self.padding[1:]) - self.memory_limit = float("inf") - self.logged_once = False - - def set_memory_limit(self, value: float): - self.memory_limit = value - - def set_memory_device(self, memory_device): - self.memory_device = memory_device - - def _conv_forward(self, input, weight, bias, *args, **kwargs): - if (NVIDIA_MEMORY_CONV_BUG_WORKAROUND and - weight.dtype in (torch.float16, torch.bfloat16) and - hasattr(torch.backends.cudnn, 'is_available') and - torch.backends.cudnn.is_available() and - getattr(torch.backends.cudnn, 'enabled', True)): - try: - out = torch.cudnn_convolution( - input, weight, self.padding, self.stride, self.dilation, self.groups, - benchmark=False, deterministic=False, allow_tf32=True - ) - if bias is not None: - out += bias.reshape((1, -1) + (1,) * (out.ndim - 2)) - return out - except RuntimeError: - pass - except NotImplementedError: - pass - try: - return super()._conv_forward(input, weight, bias, *args, **kwargs) - except NotImplementedError: - # for: Could not run 'aten::cudnn_convolution' with arguments from the 'CPU' backend - if not self.logged_once: - logging.warning("VAE is on CPU for decoding. This is most likely due to not enough memory") - self.logged_once = True - return F.conv3d(input, weight, bias, *args, **kwargs) - - def memory_limit_conv( - self, - x, - *, - split_dim=3, - padding=(0, 0, 0, 0, 0, 0), - prev_cache=None, - ): - # Compatible with no limit. - if math.isinf(self.memory_limit): - if prev_cache is not None: - x = torch.cat([prev_cache, x], dim=split_dim - 1) - return super().forward(x) - - # Compute tensor shape after concat & padding. - shape = torch.tensor(x.size()) - if prev_cache is not None: - shape[split_dim - 1] += prev_cache.size(split_dim - 1) - shape[-3:] += torch.tensor(padding).view(3, 2).sum(-1).flip(0) - memory_occupy = shape.prod() * x.element_size() / 1024**3 # GiB - if memory_occupy < self.memory_limit or split_dim == x.ndim: - x_concat = x - if prev_cache is not None: - x_concat = torch.cat([prev_cache, x], dim=split_dim - 1) - - def pad_and_forward(): - padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0) - if not padded.is_contiguous(): - padded = padded.contiguous() - with ignore_padding(self): - return torch.nn.Conv3d.forward(self, padded) - - return pad_and_forward() - - num_splits = math.ceil(memory_occupy / self.memory_limit) - size_per_split = x.size(split_dim) // num_splits - split_sizes = [size_per_split] * (num_splits - 1) - split_sizes += [x.size(split_dim) - sum(split_sizes)] - - x = list(x.split(split_sizes, dim=split_dim)) - if prev_cache is not None: - prev_cache = list(prev_cache.split(split_sizes, dim=split_dim)) - cache = None - for idx in range(len(x)): - if prev_cache is not None: - x[idx] = torch.cat([prev_cache[idx], x[idx]], dim=split_dim - 1) - - lpad_dim = (x[idx].ndim - split_dim - 1) * 2 - rpad_dim = lpad_dim + 1 - padding = list(padding) - padding[lpad_dim] = self.padding[split_dim - 2] if idx == 0 else 0 - padding[rpad_dim] = self.padding[split_dim - 2] if idx == len(x) - 1 else 0 - pad_len = padding[lpad_dim] + padding[rpad_dim] - padding = tuple(padding) - - next_cache = None - cache_len = cache.size(split_dim) if cache is not None else 0 - next_catch_size = get_cache_size( - conv_module=self, - input_len=x[idx].size(split_dim) + cache_len, - pad_len=pad_len, - dim=split_dim - 2, - ) - if next_catch_size != 0: - assert next_catch_size <= x[idx].size(split_dim) - next_cache = ( - x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim) - ) - - x[idx] = self.memory_limit_conv( - x[idx], - split_dim=split_dim + 1, - padding=padding, - prev_cache=cache - ) - - cache = next_cache - - output = torch.cat(x, dim=split_dim) - return output - - def forward( - self, - input, - memory_state: MemoryState = MemoryState.UNSET - ) -> Tensor: - assert memory_state != MemoryState.UNSET - if memory_state != MemoryState.ACTIVE: - self.memory = None - if ( - math.isinf(self.memory_limit) - and torch.is_tensor(input) - ): - return self.basic_forward(input, memory_state) - return self.slicing_forward(input, memory_state) - - def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET): - mem_size = self.stride[0] - self.kernel_size[0] - if (self.memory is not None) and (memory_state == MemoryState.ACTIVE): - input = extend_head(input, memory=self.memory, times=-1) - else: - input = extend_head(input, times=self.temporal_padding * 2) - memory = ( - input[:, :, mem_size:].detach() - if (mem_size != 0 and memory_state != MemoryState.DISABLED) - else None - ) - if ( - memory_state != MemoryState.DISABLED - and not self.training - and (self.memory_device is not None) - ): - self.memory = memory - if self.memory_device == "cpu" and self.memory is not None: - self.memory = self.memory.to("cpu") - return super().forward(input) - - def slicing_forward( - self, - input, - memory_state: MemoryState = MemoryState.UNSET, - ) -> Tensor: - squeeze_out = False - if torch.is_tensor(input): - input = [input] - squeeze_out = True - - cache_size = self.kernel_size[0] - self.stride[0] - cache = cache_send_recv( - input, cache_size=cache_size, memory=self.memory, times=self.temporal_padding * 2 - ) - - # Single GPU inference - simplified memory management - if ( - memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing - and not self.training - and (self.memory_device is not None) - and cache_size != 0 - ): - if cache_size > input[-1].size(2) and cache is not None and len(input) == 1: - input[0] = torch.cat([cache, input[0]], dim=2) - cache = None - if cache_size <= input[-1].size(2): - self.memory = input[-1][:, :, -cache_size:].detach().contiguous() - if self.memory_device == "cpu" and self.memory is not None: - self.memory = self.memory.to("cpu") - - padding = tuple(x for x in reversed(self.padding) for _ in range(2)) - for i in range(len(input)): - # Prepare cache for next input slice. - next_cache = None - cache_size = 0 - if i < len(input) - 1: - cache_len = cache.size(2) if cache is not None else 0 - cache_size = get_cache_size(self, input[i].size(2) + cache_len, pad_len=0) - if cache_size != 0: - if cache_size > input[i].size(2) and cache is not None: - input[i] = torch.cat([cache, input[i]], dim=2) - cache = None - assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}" - next_cache = input[i][:, :, -cache_size:] - - # Conv forward for this input slice. - input[i] = self.memory_limit_conv( - input[i], - padding=padding, - prev_cache=cache - ) - - # Update cache. - cache = next_cache - - return input[0] if squeeze_out else input - -def remove_head(tensor: Tensor, times: int = 1) -> Tensor: - if times == 0: - return tensor - return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2) - -class Upsample3D(nn.Module): - - def __init__( - self, - channels, - out_channels = None, - inflation_mode = "tail", - temporal_up: bool = False, - spatial_up: bool = True, - slicing: bool = False, - interpolate = True, - name: str = "conv", - use_conv_transpose = False, - use_conv: bool = False, - padding = 1, - bias = True, - kernel_size = None, - **kwargs, - ): - super().__init__() - self.interpolate = interpolate - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv_transpose = use_conv_transpose - self.use_conv = use_conv - self.name = name - - self.conv = None - if use_conv_transpose: - if kernel_size is None: - kernel_size = 4 - self.conv = ops.ConvTranspose2d( - channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias - ) - elif use_conv: - if kernel_size is None: - kernel_size = 3 - self.conv = ops.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) - - conv = self.conv if self.name == "conv" else self.Conv2d_0 - - # Note: lora_layer is not passed into constructor in the original implementation. - # So we make a simplification. - conv = InflatedCausalConv3d( - self.channels, - self.out_channels, - 3, - padding=1, - inflation_mode=inflation_mode, - ) - - self.temporal_up = temporal_up - self.spatial_up = spatial_up - self.temporal_ratio = 2 if temporal_up else 1 - self.spatial_ratio = 2 if spatial_up else 1 - self.slicing = slicing - - assert not self.interpolate - # [Override] MAGViT v2 implementation - if not self.interpolate: - upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio - self.upscale_conv = ops.Conv3d( - self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 - ) - identity = ( - torch.eye(self.channels) - .repeat(upscale_ratio, 1) - .reshape_as(self.upscale_conv.weight) - ) - self.upscale_conv.weight.data.copy_(identity) - - if self.name == "conv": - self.conv = conv - else: - self.Conv2d_0 = conv - - self.norm = None - - def forward( - self, - hidden_states: torch.FloatTensor, - memory_state=None, - **kwargs, - ) -> torch.FloatTensor: - assert hidden_states.shape[1] == self.channels - - if hasattr(self, "norm") and self.norm is not None: - # [Overridden] change to causal norm. - hidden_states = causal_norm_wrapper(self.norm, hidden_states) - - if self.use_conv_transpose: - return self.conv(hidden_states) - - if self.slicing: - split_size = hidden_states.size(2) // 2 - hidden_states = list( - hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2) - ) - else: - hidden_states = [hidden_states] - - for i in range(len(hidden_states)): - hidden_states[i] = self.upscale_conv(hidden_states[i]) - hidden_states[i] = rearrange( - hidden_states[i], - "b (x y z c) f h w -> b c (f z) (h x) (w y)", - x=self.spatial_ratio, - y=self.spatial_ratio, - z=self.temporal_ratio, - ) - - if self.temporal_up and memory_state != MemoryState.ACTIVE: - hidden_states[0] = remove_head(hidden_states[0]) - - if not self.slicing: - hidden_states = hidden_states[0] - - if self.use_conv: - if self.name == "conv": - hidden_states = self.conv(hidden_states, memory_state=memory_state) - else: - hidden_states = self.Conv2d_0(hidden_states, memory_state=memory_state) - - if not self.slicing: - return hidden_states - else: - return torch.cat(hidden_states, dim=2) - - -class Downsample3D(nn.Module): - """A 3D downsampling layer with an optional convolution.""" - - def __init__( - self, - channels, - out_channels = None, - inflation_mode = "tail", - spatial_down: bool = False, - temporal_down: bool = False, - name: str = "conv", - kernel_size=3, - use_conv: bool = False, - padding = 1, - bias=True, - **kwargs, - ): - super().__init__() - self.padding = padding - self.name = name - self.channels = channels - self.out_channels = out_channels or channels - self.temporal_down = temporal_down - self.spatial_down = spatial_down - self.use_conv = use_conv - self.padding = padding - - self.temporal_ratio = 2 if temporal_down else 1 - self.spatial_ratio = 2 if spatial_down else 1 - - self.temporal_kernel = 3 if temporal_down else 1 - self.spatial_kernel = 3 if spatial_down else 1 - - if use_conv: - conv = InflatedCausalConv3d( - self.channels, - self.out_channels, - kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), - stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), - padding=( - 1 if self.temporal_down else 0, - self.padding if self.spatial_down else 0, - self.padding if self.spatial_down else 0, - ), - inflation_mode=inflation_mode, - ) - else: - assert self.channels == self.out_channels - conv = nn.AvgPool3d( - kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), - stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), - ) - - self.conv = conv - - - def forward( - self, - hidden_states: torch.FloatTensor, - memory_state = None, - **kwargs, - ) -> torch.FloatTensor: - - assert hidden_states.shape[1] == self.channels - - if hasattr(self, "norm") and self.norm is not None: - # [Overridden] change to causal norm. - hidden_states = causal_norm_wrapper(self.norm, hidden_states) - - if self.use_conv and self.padding == 0 and self.spatial_down: - pad = (0, 1, 0, 1) - hidden_states = safe_pad_operation(hidden_states, pad, mode="constant", value=0) - - assert hidden_states.shape[1] == self.channels - - hidden_states = self.conv(hidden_states, memory_state=memory_state) - - return hidden_states - - -class ResnetBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: Optional[int] = None, - conv_shortcut: bool = False, - dropout: float = 0.0, - temb_channels: int = 512, - groups: int = 32, - groups_out: Optional[int] = None, - eps: float = 1e-6, - non_linearity: str = "swish", - time_embedding_norm: str = "default", - output_scale_factor: float = 1.0, - skip_time_act: bool = False, - use_in_shortcut: Optional[bool] = None, - up: bool = False, - down: bool = False, - conv_shortcut_bias: bool = True, - conv_2d_out_channels: Optional[int] = None, - inflation_mode = "tail", - time_receptive_field: _receptive_field_t = "half", - slicing: bool = False, - **kwargs, - ): - super().__init__() - self.up = up - self.down = down - self.in_channels = in_channels - self.out_channels = in_channels if out_channels is None else out_channels - conv_2d_out_channels = conv_2d_out_channels or out_channels - self.use_in_shortcut = use_in_shortcut - self.output_scale_factor = output_scale_factor - self.skip_time_act = skip_time_act - self.nonlinearity = nn.SiLU() - if temb_channels is not None: - self.time_emb_proj = ops.Linear(temb_channels, out_channels) - else: - self.time_emb_proj = None - self.norm1 = ops.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - if groups_out is None: - groups_out = groups - self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) - self.use_in_shortcut = self.in_channels != out_channels - self.dropout = torch.nn.Dropout(dropout) - self.conv1 = InflatedCausalConv3d( - self.in_channels, - self.out_channels, - kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), - stride=1, - padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), - inflation_mode=inflation_mode, - ) - - self.conv2 = InflatedCausalConv3d( - self.out_channels, - conv_2d_out_channels, - kernel_size=3, - stride=1, - padding=1, - inflation_mode=inflation_mode, - ) - - self.upsample = self.downsample = None - if self.up: - self.upsample = Upsample3D( - self.in_channels, - use_conv=False, - inflation_mode=inflation_mode, - slicing=slicing, - ) - elif self.down: - self.downsample = Downsample3D( - self.in_channels, - use_conv=False, - padding=1, - name="op", - inflation_mode=inflation_mode, - ) - - self.conv_shortcut = None - if self.use_in_shortcut: - self.conv_shortcut = InflatedCausalConv3d( - self.in_channels, - conv_2d_out_channels, - kernel_size=1, - stride=1, - padding=0, - bias=True, - inflation_mode=inflation_mode, - ) - - def forward( - self, input_tensor, temb, memory_state = None, **kwargs - ): - hidden_states = input_tensor - - hidden_states = causal_norm_wrapper(self.norm1, hidden_states) - - hidden_states = self.nonlinearity(hidden_states) - - if self.upsample is not None: - if hidden_states.shape[0] >= BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD: - input_tensor = input_tensor.contiguous() - hidden_states = hidden_states.contiguous() - input_tensor = self.upsample(input_tensor, memory_state=memory_state) - hidden_states = self.upsample(hidden_states, memory_state=memory_state) - elif self.downsample is not None: - input_tensor = self.downsample(input_tensor, memory_state=memory_state) - hidden_states = self.downsample(hidden_states, memory_state=memory_state) - - hidden_states = self.conv1(hidden_states, memory_state=memory_state) - - if self.time_emb_proj is not None: - if not self.skip_time_act: - temb = self.nonlinearity(temb) - temb = self.time_emb_proj(temb)[:, :, None, None] - - if temb is not None: - hidden_states = hidden_states + temb - - hidden_states = causal_norm_wrapper(self.norm2, hidden_states) - - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states, memory_state=memory_state) - - if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) - - output_tensor = (input_tensor + hidden_states) / self.output_scale_factor - - return output_tensor - - -class DownEncoderBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_downsample: bool = True, - downsample_padding: int = 1, - inflation_mode = "tail", - time_receptive_field: _receptive_field_t = "half", - temporal_down: bool = True, - spatial_down: bool = True, - ): - super().__init__() - resnets = [] - temporal_modules = [] - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - # [Override] Replace module. - ResnetBlock3D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=None, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - ) - temporal_modules.append(nn.Identity()) - - self.resnets = nn.ModuleList(resnets) - self.temporal_modules = nn.ModuleList(temporal_modules) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample3D( - out_channels, - use_conv=True, - out_channels=out_channels, - padding=downsample_padding, - name="op", - temporal_down=temporal_down, - spatial_down=spatial_down, - inflation_mode=inflation_mode, - ) - ] - ) - else: - self.downsamplers = None - - def forward( - self, - hidden_states: torch.FloatTensor, - memory_state = None, - **kwargs, - ) -> torch.FloatTensor: - for resnet, temporal in zip(self.resnets, self.temporal_modules): - hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) - hidden_states = temporal(hidden_states) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states, memory_state=memory_state) - - return hidden_states - - -class UpDecoderBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", # default, spatial - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - output_scale_factor: float = 1.0, - add_upsample: bool = True, - temb_channels: Optional[int] = None, - inflation_mode = "tail", - time_receptive_field: _receptive_field_t = "half", - temporal_up: bool = True, - spatial_up: bool = True, - slicing: bool = False, - ): - super().__init__() - resnets = [] - temporal_modules = [] - - for i in range(num_layers): - input_channels = in_channels if i == 0 else out_channels - - resnets.append( - # [Override] Replace module. - ResnetBlock3D( - in_channels=input_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - slicing=slicing, - ) - ) - - temporal_modules.append(nn.Identity()) - - self.resnets = nn.ModuleList(resnets) - self.temporal_modules = nn.ModuleList(temporal_modules) - - if add_upsample: - # [Override] Replace module & use learnable upsample - self.upsamplers = nn.ModuleList( - [ - Upsample3D( - out_channels, - use_conv=True, - out_channels=out_channels, - temporal_up=temporal_up, - spatial_up=spatial_up, - interpolate=False, - inflation_mode=inflation_mode, - slicing=slicing, - ) - ] - ) - else: - self.upsamplers = None - - def forward( - self, - hidden_states: torch.FloatTensor, - temb: Optional[torch.FloatTensor] = None, - memory_state=None - ) -> torch.FloatTensor: - for resnet, temporal in zip(self.resnets, self.temporal_modules): - hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) - hidden_states = temporal(hidden_states) - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, memory_state=memory_state) - - return hidden_states - - -class UNetMidBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", # default, spatial - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - add_attention: bool = True, - attention_head_dim: int = 1, - output_scale_factor: float = 1.0, - inflation_mode = "tail", - time_receptive_field: _receptive_field_t = "half", - ): - super().__init__() - resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) - self.add_attention = add_attention - - # there is always at least one resnet - resnets = [ - # [Override] Replace module. - ResnetBlock3D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - ] - attentions = [] - - if attention_head_dim is None: - attention_head_dim = in_channels - - for _ in range(num_layers): - if self.add_attention: - attentions.append( - Attention( - in_channels, - heads=in_channels // attention_head_dim, - dim_head=attention_head_dim, - rescale_output_factor=output_scale_factor, - eps=resnet_eps, - norm_num_groups=( - resnet_groups if resnet_time_scale_shift == "default" else None - ), - spatial_norm_dim=( - temb_channels if resnet_time_scale_shift == "spatial" else None - ), - residual_connection=True, - bias=True, - upcast_softmax=True, - _from_deprecated_attn_block=True, - ) - ) - else: - attentions.append(None) - - resnets.append( - ResnetBlock3D( - in_channels=in_channels, - out_channels=in_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - - def forward(self, hidden_states, temb=None, memory_state=None): - video_length, frame_height, frame_width = hidden_states.size()[-3:] - hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state) - for attn, resnet in zip(self.attentions, self.resnets[1:]): - if attn is not None: - hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") - hidden_states = attn(hidden_states, temb=temb) - hidden_states = rearrange( - hidden_states, "(b f) c h w -> b c f h w", f=video_length - ) - hidden_states = resnet(hidden_states, temb, memory_state=memory_state) - - return hidden_states - - -class Encoder3D(nn.Module): - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",), - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - act_fn: str = "silu", - double_z: bool = True, - mid_block_add_attention=True, - # [Override] add extra_cond_dim, temporal down num - temporal_down_num: int = 2, - extra_cond_dim: int = None, - gradient_checkpoint: bool = False, - inflation_mode = "tail", - time_receptive_field: _receptive_field_t = "half", - ): - super().__init__() - self.layers_per_block = layers_per_block - self.temporal_down_num = temporal_down_num - - self.conv_in = InflatedCausalConv3d( - in_channels, - block_out_channels[0], - kernel_size=3, - stride=1, - padding=1, - inflation_mode=inflation_mode, - ) - - self.mid_block = None - self.down_blocks = nn.ModuleList([]) - self.extra_cond_dim = extra_cond_dim - - self.conv_extra_cond = nn.ModuleList([]) - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - # [Override] to support temporal down block design - is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 - # Note: take the last ones - - assert down_block_type == "DownEncoderBlock3D" - - down_block = DownEncoderBlock3D( - num_layers=self.layers_per_block, - in_channels=input_channel, - out_channels=output_channel, - add_downsample=not is_final_block, - resnet_eps=1e-6, - downsample_padding=0, - # Note: Don't know why set it as 0 - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - temporal_down=is_temporal_down_block, - spatial_down=True, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - self.down_blocks.append(down_block) - - def zero_module(module): - # Zero out the parameters of a module and return it. - for p in module.parameters(): - p.detach().zero_() - return module - - self.conv_extra_cond.append( - zero_module( - ops.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0) - ) - if self.extra_cond_dim is not None and self.extra_cond_dim > 0 - else None - ) - - # mid - self.mid_block = UNetMidBlock3D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default", - attention_head_dim=block_out_channels[-1], - resnet_groups=norm_num_groups, - temb_channels=None, - add_attention=mid_block_add_attention, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - # out - self.conv_norm_out = ops.GroupNorm( - num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 - ) - self.conv_act = nn.SiLU() - - conv_out_channels = 2 * out_channels if double_z else out_channels - self.conv_out = InflatedCausalConv3d( - block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode - ) - - self.gradient_checkpointing = gradient_checkpoint - - def forward( - self, - sample: torch.FloatTensor, - extra_cond=None, - memory_state = None - ) -> torch.FloatTensor: - r"""The forward method of the `Encoder` class.""" - sample = sample.to(next(self.parameters()).device) - sample = self.conv_in(sample, memory_state = memory_state) - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs) - - return custom_forward - - # down - # [Override] add extra block and extra cond - for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): - sample = torch.utils.checkpoint.checkpoint( - create_custom_forward(down_block), sample, use_reentrant=False - ) - if extra_block is not None: - sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:]) - - # middle - sample = self.mid_block(sample) - - else: - # down - # [Override] add extra block and extra cond - for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): - sample = down_block(sample, memory_state=memory_state) - if extra_block is not None: - sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:]) - - # middle - sample = self.mid_block(sample, memory_state=memory_state) - - # post-process - sample = causal_norm_wrapper(self.conv_norm_out, sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample, memory_state = memory_state) - - return sample - - -class Decoder3D(nn.Module): - - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - up_block_types: Tuple[str, ...] = ("UpDecoderBlock3D",), - block_out_channels: Tuple[int, ...] = (64,), - layers_per_block: int = 2, - norm_num_groups: int = 32, - act_fn: str = "silu", - norm_type: str = "group", # group, spatial - mid_block_add_attention=True, - # [Override] add temporal up block - inflation_mode = "tail", - time_receptive_field: _receptive_field_t = "half", - temporal_up_num: int = 2, - slicing_up_num: int = 0, - gradient_checkpoint: bool = False, - ): - super().__init__() - self.layers_per_block = layers_per_block - self.temporal_up_num = temporal_up_num - - self.conv_in = InflatedCausalConv3d( - in_channels, - block_out_channels[-1], - kernel_size=3, - stride=1, - padding=1, - inflation_mode=inflation_mode, - ) - - self.mid_block = None - self.up_blocks = nn.ModuleList([]) - - temb_channels = in_channels if norm_type == "spatial" else None - - # mid - self.mid_block = UNetMidBlock3D( - in_channels=block_out_channels[-1], - resnet_eps=1e-6, - resnet_act_fn=act_fn, - output_scale_factor=1, - resnet_time_scale_shift="default" if norm_type == "group" else norm_type, - attention_head_dim=block_out_channels[-1], - resnet_groups=norm_num_groups, - temb_channels=temb_channels, - add_attention=mid_block_add_attention, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - - is_final_block = i == len(block_out_channels) - 1 - is_temporal_up_block = i < self.temporal_up_num - is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num - # Note: Keep symmetric - - assert up_block_type == "UpDecoderBlock3D" - up_block = UpDecoderBlock3D( - num_layers=self.layers_per_block + 1, - in_channels=prev_output_channel, - out_channels=output_channel, - add_upsample=not is_final_block, - resnet_eps=1e-6, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - resnet_time_scale_shift=norm_type, - temb_channels=temb_channels, - temporal_up=is_temporal_up_block, - slicing=is_slicing_up_block, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - if norm_type == "spatial": - self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) - else: - self.conv_norm_out = ops.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 - ) - self.conv_act = nn.SiLU() - self.conv_out = InflatedCausalConv3d( - block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode - ) - - self.gradient_checkpointing = gradient_checkpoint - - # Note: Just copy from Decoder. - def forward( - self, - sample: torch.FloatTensor, - latent_embeds: Optional[torch.FloatTensor] = None, - memory_state = None, - ) -> torch.FloatTensor: - - sample = sample.to(next(self.parameters()).device) - sample = self.conv_in(sample, memory_state=memory_state) - - upscale_dtype = next(iter(self.up_blocks.parameters())).dtype - # middle - sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) - sample = sample.to(upscale_dtype) - - # up - for up_block in self.up_blocks: - sample = up_block(sample, latent_embeds, memory_state=memory_state) - - # post-process - sample = causal_norm_wrapper(self.conv_norm_out, sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample, memory_state=memory_state) - - return sample - -class VideoAutoencoderKL(nn.Module): - def __init__( - self, - in_channels: int = 3, - out_channels: int = 3, - layers_per_block: int = 2, - act_fn: str = "silu", - latent_channels: int = SEEDVR2_LATENT_CHANNELS, - norm_num_groups: int = 32, - attention: bool = True, - temporal_scale_num: int = 2, - slicing_up_num: int = 0, - gradient_checkpoint: bool = False, - inflation_mode = "pad", - time_receptive_field: _receptive_field_t = "full", - use_quant_conv: bool = False, - use_post_quant_conv: bool = False, - slicing_sample_min_size = BYTEDANCE_SLICING_SAMPLE_MIN, - *args, - **kwargs, - ): - self.slicing_sample_min_size = slicing_sample_min_size - self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) - extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None - block_out_channels = BYTEDANCE_BLOCK_OUT_CHANNELS - down_block_types = ("DownEncoderBlock3D",) * 4 - up_block_types = ("UpDecoderBlock3D",) * 4 - super().__init__() - - # pass init params to Encoder - self.encoder = Encoder3D( - in_channels=in_channels, - out_channels=latent_channels, - down_block_types=down_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - act_fn=act_fn, - norm_num_groups=norm_num_groups, - double_z=True, - extra_cond_dim=extra_cond_dim, - # [Override] add temporal_down_num parameter - temporal_down_num=temporal_scale_num, - gradient_checkpoint=gradient_checkpoint, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - # pass init params to Decoder - self.decoder = Decoder3D( - in_channels=latent_channels, - out_channels=out_channels, - up_block_types=up_block_types, - block_out_channels=block_out_channels, - layers_per_block=layers_per_block, - norm_num_groups=norm_num_groups, - act_fn=act_fn, - # [Override] add temporal_up_num parameter - temporal_up_num=temporal_scale_num, - slicing_up_num=slicing_up_num, - gradient_checkpoint=gradient_checkpoint, - inflation_mode=inflation_mode, - time_receptive_field=time_receptive_field, - ) - - self.quant_conv = ( - InflatedCausalConv3d( - in_channels=2 * latent_channels, - out_channels=2 * latent_channels, - kernel_size=1, - inflation_mode=inflation_mode, - ) - if use_quant_conv - else None - ) - self.post_quant_conv = ( - InflatedCausalConv3d( - in_channels=latent_channels, - out_channels=latent_channels, - kernel_size=1, - inflation_mode=inflation_mode, - ) - if use_post_quant_conv - else None - ) - - # A hacky way to remove attention. - if not attention: - self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) - self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) - - self.use_slicing = True - - def encode(self, x: torch.FloatTensor, return_dict: bool = True): - h = self.slicing_encode(x) - posterior = DiagonalGaussianDistribution(h).mode() - - if not return_dict: - return (posterior,) - - return posterior - - def decode_( - self, z: torch.Tensor, return_dict: bool = True - ): - decoded = self.slicing_decode(z) - - if not return_dict: - return (decoded,) - - return decoded - - def _encode( - self, x, memory_state = MemoryState.DISABLED - ) -> torch.Tensor: - _x = x.to(self.device) - h = self.encoder(_x, memory_state=memory_state) - if self.quant_conv is not None: - output = self.quant_conv(h, memory_state=memory_state) - else: - output = h - return output.to(x.device) - - def _decode( - self, z, memory_state = MemoryState.DISABLED - ) -> torch.Tensor: - _z = z.to(self.device) - - if self.post_quant_conv is not None: - _z = self.post_quant_conv(_z, memory_state=memory_state) - - output = self.decoder(_z, memory_state=memory_state) - return output.to(z.device) - - def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: - sp_size =1 - if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: - split_size = max( - self.slicing_sample_min_size * sp_size, - getattr(self, "temporal_downsample_factor", 1), - ) - x_slices = list(x[:, :, 1:].split(split_size=split_size, dim=2)) - min_active_len = getattr(self, "temporal_downsample_factor", 1) - if len(x_slices) > 1 and x_slices[-1].shape[2] < min_active_len: - x_slices[-2] = torch.cat((x_slices[-2], x_slices[-1]), dim=2) - x_slices.pop() - encoded_slices = [ - self._encode( - torch.cat((x[:, :, :1], x_slices[0]), dim=2), - memory_state=MemoryState.INITIALIZING, - ) - ] - for x_idx in range(1, len(x_slices)): - encoded_slices.append( - self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) - ) - out = torch.cat(encoded_slices, dim=2) - modules_with_memory = [m for m in self.modules() - if isinstance(m, InflatedCausalConv3d) and m.memory is not None] - for m in modules_with_memory: - m.memory = None - return out - else: - return self._encode(x) - - def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: - sp_size = 1 - if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: - z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) - decoded_slices = [ - self._decode( - torch.cat((z[:, :, :1], z_slices[0]), dim=2), - memory_state=MemoryState.INITIALIZING - ) - ] - for z_idx in range(1, len(z_slices)): - decoded_slices.append( - self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) - ) - out = torch.cat(decoded_slices, dim=2) - modules_with_memory = [m for m in self.modules() - if isinstance(m, InflatedCausalConv3d) and m.memory is not None] - for m in modules_with_memory: - m.memory = None - return out - else: - return self._decode(z) - - def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: - raise NotImplementedError - - def tiled_decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: - raise NotImplementedError - - def forward( - self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs - ): - # x: [b c t h w] - def _unwrap(value): - return value[0] if isinstance(value, tuple) else value - - if mode == "encode": - return _unwrap(self.encode(x)) - elif mode == "decode": - return _unwrap(self.decode_(x)) - else: - latent = _unwrap(self.encode(x)) - return _unwrap(self.decode_(latent)) - -class VideoAutoencoderKLWrapper(VideoAutoencoderKL): - def __init__( - self, - *args, - spatial_downsample_factor = 8, - temporal_downsample_factor = 4, - freeze_encoder = True, - **kwargs, - ): - self.spatial_downsample_factor = spatial_downsample_factor - self.temporal_downsample_factor = temporal_downsample_factor - self.freeze_encoder = freeze_encoder - self.enable_tiling = False - super().__init__(*args, **kwargs) - self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB) - - def forward(self, x: torch.FloatTensor): - with torch.no_grad() if self.freeze_encoder else nullcontext(): - z, p = self.encode(x) - x = self.decode(z) - return x, z, p - - def encode(self, x, orig_dims=None): - if x.ndim == 4: - x = x.unsqueeze(2) - x = x.to(dtype=next(self.parameters()).dtype) - self.device = x.device - p = super().encode(x) - z = p.squeeze(2) - return z, p - - def decode(self, z, seedvr2_tiling=None): - seedvr2_tiling = {} if seedvr2_tiling is None else seedvr2_tiling - if not isinstance(seedvr2_tiling, dict): - raise RuntimeError( - "SeedVR2 VideoAutoencoderKLWrapper.decode: `seedvr2_tiling` must be a dict; " - f"got {type(seedvr2_tiling).__name__} with value {seedvr2_tiling!r}." - ) - - if z.ndim == 5: - b, c, t_latent, h, w = z.shape - if c != 16: - raise RuntimeError( - "SeedVR2 VideoAutoencoderKLWrapper.decode: 5-D latent input must " - f"have 16 channels; got shape {tuple(z.shape)}." - ) - latent = z - elif z.ndim == 4: - b, tc, h, w = z.shape - if tc % 16 != 0: - raise RuntimeError( - "SeedVR2 VideoAutoencoderKLWrapper.decode: 4-D latent input must " - "use collapsed channel layout (B, 16*T, H, W); " - f"got shape {tuple(z.shape)}." - ) - latent = z.reshape(b, 16, -1, h, w) - else: - raise RuntimeError( - "SeedVR2 VideoAutoencoderKLWrapper.decode: latent input must be " - "4-D collapsed (B, 16*T, H, W) or 5-D (B, 16, T, H, W); " - f"got shape {tuple(z.shape)}." - ) - scale = BYTEDANCE_VAE_SCALING_FACTOR - shift = BYTEDANCE_VAE_SHIFTING_FACTOR - latent = latent / scale + shift - - self.device = latent.device - self.enable_tiling = seedvr2_tiling.get("enable_tiling", False) - - if self.enable_tiling: - decode_seedvr2_args = dict(seedvr2_tiling) - tile_h, tile_w = decode_seedvr2_args.get("tile_size", (512, 512)) - ov_h, ov_w = decode_seedvr2_args.get("tile_overlap", (64, 64)) - decode_seedvr2_args["tile_overlap"] = ( - min(ov_h, max(0, tile_h - 8)), - min(ov_w, max(0, tile_w - 8)), - ) - x = tiled_vae(latent, self, **decode_seedvr2_args, encode=False) - if x.ndim == 4: - # tiled_vae squeezes the temporal axis when - # temporal_downsample_factor == 1 AND latent T == 1 - # (see tiled_vae line 179-180); re-add it so the post-decode - # pipeline can keep batch and time distinct on the tiled path. - x = x.unsqueeze(2) - else: - x = super().decode_(latent) - - # ensure even dims for save video - h, w = x.shape[-2:] - w2 = w - (w % 2) - h2 = h - (h % 2) - x = x[..., :h2, :w2] - - return x - - def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float], memory_device = "same"): - set_norm_limit(norm_max_mem) - for m in self.modules(): - if isinstance(m, InflatedCausalConv3d): - m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) - - for module in self.modules(): - if isinstance(module, InflatedCausalConv3d): - module.set_memory_device(memory_device) diff --git a/comfy/model_base.py b/comfy/model_base.py index c084e23bb..042804771 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -54,8 +54,6 @@ import comfy.ldm.pixeldit.model import comfy.ldm.pixeldit.pid import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 -import comfy.ldm.seedvr.model - import comfy.ldm.qwen_image.model import comfy.ldm.ideogram4.model import comfy.ldm.kandinsky5.model @@ -930,16 +928,6 @@ class HunyuanDiT(BaseModel): out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) return out -class SeedVR2(BaseModel): - def __init__(self, model_config, model_type=ModelType.FLOW, device=None): - super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT) - def extra_conds(self, **kwargs): - out = super().extra_conds(**kwargs) - condition = kwargs.get("condition", None) - if condition is not None: - out["condition"] = comfy.conds.CONDRegular(condition) - return out - class PixArt(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 955581006..74c838d13 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -598,56 +598,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return dit_config - if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b - dit_config = {} - dit_config["image_model"] = "seedvr2" - dit_config["vid_dim"] = 3072 - dit_config["heads"] = 24 - dit_config["num_layers"] = 36 - # 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.`` - # submodules) at EVERY block — verified by inspecting the 7B - # state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means - # ``MMModule.shared_weights=False``). Native NaDiT computes - # per-block ``shared_weights = not (i < mm_layers)``, so to keep - # every block non-shared we set ``mm_layers = num_layers``. - # Without this, blocks at index >= mm_layers (default 10) try to - # load ``blocks.N.*.all.*`` keys that don't exist in the file, - # silently miss-load → all-black output. - dit_config["mm_layers"] = 36 - dit_config["norm_eps"] = 1e-5 - dit_config["qk_rope"] = True - dit_config["rope_type"] = "rope3d" - dit_config["rope_dim"] = 64 - dit_config["mlp_type"] = "normal" - return dit_config - elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b - dit_config = {} - dit_config["image_model"] = "seedvr2" - dit_config["vid_dim"] = 3072 - dit_config["heads"] = 24 - dit_config["num_layers"] = 36 - # This checkpoint layout carries shared ``all.`` MMModule keys. - # Preserve the historical split: the initial blocks use separate - # vid/txt modules, later blocks use shared modules. - dit_config["mm_layers"] = 10 - dit_config["norm_eps"] = 1e-5 - dit_config["qk_rope"] = True - dit_config["rope_type"] = "rope3d" - dit_config["rope_dim"] = 64 - dit_config["mlp_type"] = "swiglu" - return dit_config - elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b - dit_config = {} - dit_config["image_model"] = "seedvr2" - dit_config["vid_dim"] = 2560 - dit_config["heads"] = 20 - dit_config["num_layers"] = 32 - dit_config["norm_eps"] = 1.0e-05 - dit_config["qk_rope"] = None - dit_config["mlp_type"] = "swiglu" - dit_config["vid_out_norm"] = True - return dit_config - if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 dit_config = {} dit_config["image_model"] = "wan2.1" diff --git a/comfy/sample.py b/comfy/sample.py index de71596b3..2be0cae5f 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -44,13 +44,7 @@ def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None, is_empty = torch.count_nonzero(latent_image) == 0 if is_empty: if latent_format.latent_channels != latent_image.shape[1]: - preserves_collapsed_channels = ( - getattr(latent_format, "preserve_empty_channel_multiples", False) - and latent_image.ndim == 4 - and latent_image.shape[1] % latent_format.latent_channels == 0 - ) - if not preserves_collapsed_channels: - latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) + latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) if downscale_ratio_spacial is not None: if downscale_ratio_spacial != latent_format.spacial_downscale_ratio: ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio diff --git a/comfy/sd.py b/comfy/sd.py index 8ac08ac42..a66ba1bfb 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,4 +1,3 @@ -import inspect import json import torch from enum import Enum @@ -17,7 +16,6 @@ import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae -import comfy.ldm.seedvr.vae import comfy.ldm.triposplat.vae import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.cogvideo.vae @@ -86,36 +84,6 @@ import comfy.latent_formats import comfy.ldm.flux.redux -SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL = 160 - - -def _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w): - output_t = max(1, (latent_t - 1) * 4 + 1) - return output_t * latent_h * 8 * latent_w * 8 - - -def _seedvr2_vae_decode_memory_used(shape): - if len(shape) == 5: - candidates = [] - if shape[1] == 16: - candidates.append((shape[2], shape[3], shape[4])) - if shape[-1] == 16: - candidates.append((shape[1], shape[2], shape[3])) - if len(candidates) == 0: - candidates.append((shape[2], shape[3], shape[4])) - output_pixels = max(_seedvr2_vae_decode_output_pixels(*candidate) for candidate in candidates) - elif len(shape) == 4: - latent_t = max(1, (shape[1] + 15) // 16) - latent_h, latent_w = shape[2], shape[3] - output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w) - else: - latent_t, latent_h, latent_w = 1, shape[-2], shape[-1] - output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w) - # SeedVR2 decode performs full-frame LAB histogram matching: fp32 channels - # plus int64 sort indices dominate peak memory, not the VAE weight dtype. - return output_pixels * SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL - - def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None): key_map = {} if model is not None: @@ -499,10 +467,8 @@ class CLIP: class VAE: def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): - is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd - if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format - if metadata is None or metadata.get("keep_diffusers_format") != "true": - sd = diffusers_convert.convert_vae_state_dict(sd) + if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format + sd = diffusers_convert.convert_vae_state_dict(sd) if model_management.is_amd(): VAE_KL_MEM_RATIO = 2.73 @@ -574,20 +540,6 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 - elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 - self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() - self.latent_channels = 16 - self.latent_dim = 3 - self.disable_offload = True - self.memory_used_decode = lambda shape, dtype: _seedvr2_vae_decode_memory_used(shape) - self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype) - self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] - self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) - self.downscale_index_formula = (4, 8, 8) - self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) - self.upscale_index_formula = (4, 8, 8) - self.process_input = lambda image: image * 2.0 - 1.0 - self.crop_input = False elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} @@ -715,7 +667,6 @@ class VAE: self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) self.downscale_index_formula = (8, 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] - elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] @@ -1055,40 +1006,6 @@ class VAE: decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) - def decode_tiled_seedvr2(self, samples, tile_x=32, tile_y=32, overlap=8, tile_t=16, overlap_t=4): - sf_s = getattr(self.first_stage_model, "spatial_downsample_factor", 8) - sf_t = getattr(self.first_stage_model, "temporal_downsample_factor", 4) - if tile_t is None: - tile_t = 16 - if overlap_t is None: - overlap_t = 4 - if tile_t > 0: - temporal_size = tile_t * sf_t - temporal_overlap = max(0, overlap_t) * sf_t - else: - temporal_size = 0 - temporal_overlap = 0 - args = { - "enable_tiling": True, - "tile_size": (tile_y * sf_s, tile_x * sf_s), - "tile_overlap": (overlap * sf_s, overlap * sf_s), - "temporal_size": temporal_size, - "temporal_overlap": temporal_overlap, - } - output = self.first_stage_model.decode( - samples.to(self.vae_dtype).to(self.device), - seedvr2_tiling=args, - ) - return self.process_output(output.to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)) - - def _format_seedvr2_encoded_samples(self, samples): - if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): - if samples.ndim == 4: - samples = samples.unsqueeze(2) - samples = samples.contiguous() - samples = samples * 0.9152 - return samples - def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) @@ -1125,36 +1042,6 @@ class VAE: encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) - def encode_tiled_seedvr2(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): - if tile_y is None: - tile_y = 512 - if tile_x is None: - tile_x = 512 - if overlap is None: - overlap_y = 64 - overlap_x = 64 - else: - overlap_y = overlap - overlap_x = overlap - if tile_t is None: - tile_t = 9999 - if overlap_t is None: - overlap_t = 0 - overlap_y = min(overlap_y, max(0, tile_y - 8)) - overlap_x = min(overlap_x, max(0, tile_x - 8)) - self.first_stage_model.device = self.device - x = self.process_input(pixel_samples).to(self.vae_dtype).to(self.device) - output = comfy.ldm.seedvr.vae.tiled_vae( - x, - self.first_stage_model, - tile_size=(tile_y, tile_x), - tile_overlap=(overlap_y, overlap_x), - temporal_size=tile_t, - temporal_overlap=overlap_t, - encode=True, - ) - return output.to(device=self.output_device, dtype=self.vae_output_dtype()) - def decode(self, samples_in, vae_options={}): self.throw_exception_if_invalid() pixel_samples = None @@ -1202,40 +1089,16 @@ class VAE: if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) elif dims == 2: - # SeedVR2 latents arrive in 4D collapsed form ``(B, 16*T, H, W)`` - # downstream of ``SeedVR2Conditioning`` (which performs the - # ``rearrange(b c t h w -> b (c t) h w)`` collapse). The - # generic ``decode_tiled_`` would treat the channel dim as - # spatial-only and crash on the collapsed (16, T) layout - # under ``tiled_scale``'s mask broadcast; route SeedVR2 4D - # latents to ``decode_tiled_seedvr2`` instead, whose wrapper - # dispatch handles both 4D and 5D inputs. - if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): - tile = 256 // self.spacial_compression_decode() - overlap = tile // 4 - pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap) - else: - pixel_samples = self.decode_tiled_(samples_in) + pixel_samples = self.decode_tiled_(samples_in) elif dims == 3: tile = 256 // self.spacial_compression_decode() overlap = tile // 4 - if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): - pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap) - else: - pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples - def decode_tiled( - self, - samples, - tile_x=None, - tile_y=None, - overlap=None, - tile_t=None, - overlap_t=None, - ): + def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): self.throw_exception_if_invalid() memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -1249,20 +1112,7 @@ class VAE: args["overlap"] = overlap with model_management.cuda_device_context(self.device): - if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper) and dims in (2, 3): - seedvr2_args = {} - if tile_x is not None: - seedvr2_args["tile_x"] = tile_x - if tile_y is not None: - seedvr2_args["tile_y"] = tile_y - if overlap is not None: - seedvr2_args["overlap"] = overlap - if tile_t is not None: - seedvr2_args["tile_t"] = tile_t - if overlap_t is not None: - seedvr2_args["overlap_t"] = overlap_t - output = self.decode_tiled_seedvr2(samples, **seedvr2_args) - elif dims == 1 or self.extra_1d_channel is not None: + if dims == 1 or self.extra_1d_channel is not None: args.pop("tile_y") output = self.decode_tiled_1d(samples, **args) elif dims == 2: @@ -1304,8 +1154,6 @@ class VAE: else: pixels_in = pixels_in.to(self.device) out = self.first_stage_model.encode(pixels_in) - if isinstance(out, tuple): - out = out[0] out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) if samples is None: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) @@ -1325,23 +1173,20 @@ class VAE: if self.latent_dim == 3: tile = 256 overlap = tile // 4 - if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): - samples = self.encode_tiled_seedvr2(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap) - else: - samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) elif self.latent_dim == 1 or self.extra_1d_channel is not None: samples = self.encode_tiled_1d(pixel_samples) else: samples = self.encode_tiled_(pixel_samples) - return self._format_seedvr2_encoded_samples(samples) + return samples def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) dims = self.latent_dim pixel_samples = pixel_samples.movedim(-1, 1) - if dims == 3 and pixel_samples.ndim < 5: + if dims == 3: if not self.not_video: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) else: @@ -1365,47 +1210,22 @@ class VAE: elif dims == 2: samples = self.encode_tiled_(pixel_samples, **args) elif dims == 3: - if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): - seedvr2_args = {} - if tile_x is not None: - seedvr2_args["tile_x"] = tile_x - else: - seedvr2_args["tile_x"] = 512 - if tile_y is not None: - seedvr2_args["tile_y"] = tile_y - else: - seedvr2_args["tile_y"] = 512 - if overlap is not None: - seedvr2_args["overlap"] = overlap - else: - seedvr2_args["overlap"] = 64 - if tile_t is not None: - seedvr2_args["tile_t"] = tile_t - else: - seedvr2_args["tile_t"] = 9999 - if overlap_t is not None: - seedvr2_args["overlap_t"] = overlap_t - else: - seedvr2_args["overlap_t"] = 0 - samples = self.encode_tiled_seedvr2(pixel_samples, **seedvr2_args) + if tile_t is not None: + tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) else: - if tile_t is not None: - tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) - else: - tile_t_latent = 9999 - args["tile_t"] = self.upscale_ratio[0](tile_t_latent) + tile_t_latent = 9999 + args["tile_t"] = self.upscale_ratio[0](tile_t_latent) - spatial_overlap = overlap if overlap is not None else 64 - if overlap_t is None: - args["overlap"] = (1, spatial_overlap, spatial_overlap) - else: - args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), spatial_overlap, spatial_overlap) - maximum = pixel_samples.shape[2] - maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) + if overlap_t is None: + args["overlap"] = (1, overlap, overlap) + else: + args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) + maximum = pixel_samples.shape[2] + maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) - samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) + samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) - return self._format_seedvr2_encoded_samples(samples) + return samples def get_sd(self): return self.first_stage_model.state_dict() @@ -1932,17 +1752,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (model, clip, vae) - -def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device): - set_dtype = model_config.set_inference_dtype - parameters = inspect.signature(set_dtype).parameters - supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values()) - if supports_device: - set_dtype(dtype, manual_cast_dtype, device=device) - else: - set_dtype(dtype, manual_cast_dtype) - - def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False): sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) @@ -2050,7 +1859,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) - _set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if model_config.clip_vision_prefix is not None: if output_clipvision: @@ -2191,7 +2000,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) - _set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device) + model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if custom_operations is not None: model_config.custom_operations = custom_operations diff --git a/comfy/supported_models.py b/comfy/supported_models.py index fa95003cc..7cf9c133b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1672,35 +1672,6 @@ class Chroma(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) -class SeedVR2(supported_models_base.BASE): - unet_config = { - "image_model": "seedvr2" - } - latent_format = comfy.latent_formats.SeedVR2 - - vae_key_prefix = ["vae."] - text_encoder_key_prefix = ["text_encoders."] - supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] - sampling_settings = { - "shift": 1.0, - } - - def set_inference_dtype(self, dtype, manual_cast_dtype, device=None): - if ( - dtype == torch.float16 - and manual_cast_dtype is None - and comfy.model_management.should_use_bf16(device) - ): - manual_cast_dtype = torch.bfloat16 - super().set_inference_dtype(dtype, manual_cast_dtype, device=device) - - def get_model(self, state_dict, prefix="", device=None): - out = model_base.SeedVR2(self, device=device) - return out - - def clip_target(self, state_dict={}): - return None - class ChromaRadiance(Chroma): unet_config = { "image_model": "chroma_radiance", @@ -2058,6 +2029,7 @@ class LongCatImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) + class RT_DETR_v4(supported_models_base.BASE): unet_config = { "image_model": "RT_DETR_v4", @@ -2295,7 +2267,6 @@ models = [ HiDream, HiDreamO1, Chroma, - SeedVR2, ChromaRadiance, ACEStep, ACEStep15, diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 572f9984e..0e7a829ba 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -115,7 +115,7 @@ class BASE: replace_prefix = {"": self.vae_key_prefix[0]} return utils.state_dict_prefix_replace(state_dict, replace_prefix) - def set_inference_dtype(self, dtype, manual_cast_dtype, device=None): + def set_inference_dtype(self, dtype, manual_cast_dtype): self.unet_config['dtype'] = dtype self.manual_cast_dtype = manual_cast_dtype diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py deleted file mode 100644 index d5cd029ba..000000000 --- a/comfy_extras/nodes_seedvr.py +++ /dev/null @@ -1,1015 +0,0 @@ -from typing_extensions import override -from comfy_api.latest import ComfyExtension, io -import torch -import math -import logging -from einops import rearrange - -import gc -import comfy.model_management -import comfy.sample -import comfy.samplers -from comfy.ldm.seedvr.color_fix import ( - adain_color_transfer, - lab_color_transfer, - wavelet_color_transfer, -) -from comfy.ldm.seedvr.constants import ( - BYTEDANCE_IMG_SHIFT_FIT, - BYTEDANCE_SCHEDULE_T, - BYTEDANCE_VID_SHIFT_FIT, - SEEDVR2_ADAIN_SCALE_MULTIPLIER, - SEEDVR2_COLOR_MEM_HEADROOM, - SEEDVR2_COND_CHANNELS, - SEEDVR2_DTYPE_BYTES_FLOOR, - SEEDVR2_LAB_SCALE_MULTIPLIER, - SEEDVR2_LATENT_CHANNELS, - SEEDVR2_OOM_BACKOFF_DIVISOR, - SEEDVR2_WAVELET_SCALE_MULTIPLIER, -) - -from torchvision.transforms import functional as TVF -from torchvision.transforms import Lambda -from torchvision.transforms.functional import InterpolationMode - - -_SEEDVR2_INVALID_MODEL_MSG_PREFIX = ( - "SeedVR2Conditioning: model object does not match expected SeedVR2 structure" -) - -# Private sentinel for getattr default: distinguishes "attribute missing" -# from "attribute present but None" so the failure message is accurate. -_ATTR_MISSING = object() - - -def _seedvr2_auto_chunk_attempts(t_latent, t_pixel, frames_per_chunk): - """Return stricter 4n+1 frame chunk sizes for auto OOM retries.""" - attempts = [frames_per_chunk] - current_chunk_latent = ( - t_latent if t_pixel <= frames_per_chunk - else (frames_per_chunk - 1) // 4 + 1 - ) - current_chunk_count = max(1, math.ceil(t_latent / current_chunk_latent)) - seen = {frames_per_chunk} - - for target_chunks in range(max(2, current_chunk_count + 1), t_latent + 1): - chunk_latent = max(1, math.ceil(t_latent / target_chunks)) - candidate = 4 * (chunk_latent - 1) + 1 - if candidate in seen: - continue - if candidate >= attempts[-1]: - continue - attempts.append(candidate) - seen.add(candidate) - - return attempts - - -def _resolve_seedvr2_diffusion_model(model): - """Resolve ``model.model.diffusion_model``, failing loud via the ``_ATTR_MISSING`` sentinel so each of the four modes (model/diffusion_model missing vs None) gives an accurate message.""" - inner = getattr(model, "model", _ATTR_MISSING) - if inner is _ATTR_MISSING: - raise RuntimeError( - f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input has no 'model' attribute " - f"(got type {type(model).__name__})." - ) - if inner is None: - raise RuntimeError( - f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input.model is None " - f"(input type {type(model).__name__})." - ) - diffusion_model = getattr(inner, "diffusion_model", _ATTR_MISSING) - if diffusion_model is _ATTR_MISSING: - raise RuntimeError( - f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model' has no " - f"'diffusion_model' attribute (got type {type(inner).__name__})." - ) - if diffusion_model is None: - raise RuntimeError( - f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model.diffusion_model' " - f"is None (model.model type {type(inner).__name__})." - ) - return diffusion_model - - -def _apply_rope_freqs_float32_cast(diffusion_model): - """Cast every module's ``rope.freqs`` to float32; the per-tensor dtype check (not a sentinel attr) self-corrects across Comfy's unload/reload, which would otherwise restore the archived fp16/bf16 dtype.""" - for module in diffusion_model.modules(): - if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'): - if module.rope.freqs.data.dtype != torch.float32: - module.rope.freqs.data = module.rope.freqs.data.to(torch.float32) - - -def clear_vae_memory(vae_model): - for module in vae_model.modules(): - if hasattr(module, "memory"): - module.memory = None - gc.collect() - comfy.model_management.soft_empty_cache() - -def expand_dims(tensor, ndim): - shape = tensor.shape + (1,) * (ndim - tensor.ndim) - return tensor.reshape(shape) - -def get_conditions(latent, latent_blur): - t, h, w, c = latent.shape - cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) - cond[:, ..., :-1] = latent_blur[:] - cond[:, ..., -1:] = 1.0 - return cond - -def timestep_transform(timesteps, latents_shapes): - vt = 4 - vs = 8 - frames = (latents_shapes[:, 0] - 1) * vt + 1 - heights = latents_shapes[:, 1] * vs - widths = latents_shapes[:, 2] * vs - - # Compute shift factor. - def get_lin_function(x1, y1, x2, y2): - m = (y2 - y1) / (x2 - x1) - b = y1 - m * x1 - return lambda x: m * x + b - - img_shift_fn = get_lin_function(*BYTEDANCE_IMG_SHIFT_FIT) - vid_shift_fn = get_lin_function(*BYTEDANCE_VID_SHIFT_FIT) - shift = torch.where( - frames > 1, - vid_shift_fn(heights * widths * frames), - img_shift_fn(heights * widths), - ).to(timesteps.device) - - # Shift timesteps. - T = BYTEDANCE_SCHEDULE_T - timesteps = timesteps / T - timesteps = shift * timesteps / (1 + (shift - 1) * timesteps) - timesteps = timesteps * T - return timesteps - -def inter(x_0, x_T, t): - t = expand_dims(t, x_0.ndim) - T = BYTEDANCE_SCHEDULE_T - B = lambda t: t / T - A = lambda t: 1 - (t / T) - return A(t) * x_0 + B(t) * x_T - -def div_pad(image, factor): - - height_factor, width_factor = factor - height, width = image.shape[-2:] - - pad_height = (height_factor - (height % height_factor)) % height_factor - pad_width = (width_factor - (width % width_factor)) % width_factor - - if pad_height == 0 and pad_width == 0: - return image - - if isinstance(image, torch.Tensor): - padding = (0, pad_width, 0, pad_height) - image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0) - - return image - -def cut_videos(videos): - t = videos.size(1) - if t == 1: - return videos - if t <= 4 : - padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1) - padding = torch.cat(padding, dim=1) - videos = torch.cat([videos, padding], dim=1) - return videos - if (t - 1) % (4) == 0: - return videos - else: - padding = [videos[:, -1].unsqueeze(1)] * ( - 4 - ((t - 1) % (4)) - ) - padding = torch.cat(padding, dim=1) - videos = torch.cat([videos, padding], dim=1) - assert (videos.size(1) - 1) % (4) == 0 - return videos - -def _seedvr2_input_shorter_edge(images, node_name): - if images.dim() == 4: - return min(images.shape[1], images.shape[2]) - if images.dim() == 5: - return min(images.shape[2], images.shape[3]) - raise ValueError( - f"{node_name}: expected 4-D or 5-D IMAGE tensor, " - f"got shape {tuple(images.shape)}" - ) - - -def _seedvr2_pad(images, upscaled_shorter_edge, node_name): - if upscaled_shorter_edge < 2: - raise ValueError( - f"{node_name}: input shorter edge must be at least 2 pixels; " - f"got {upscaled_shorter_edge}." - ) - if images.shape[-1] > 3: - images = images[..., :3] - if images.dim() == 4: - # Comfy video components arrive as a 4-D IMAGE frame sequence: - # (frames, H, W, C). SeedVR2 consumes that as one video. - images = images.unsqueeze(0) - elif images.dim() != 5: - raise ValueError( - f"{node_name}: expected 4-D or 5-D IMAGE tensor, " - f"got shape {tuple(images.shape)}" - ) - images = images.permute(0, 1, 4, 2, 3) - - b, t, c, h, w = images.shape - images = images.reshape(b * t, c, h, w) - - clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) - images = clip(images) - images = div_pad(images, (16, 16)) - _, _, new_h, new_w = images.shape - - images = images.reshape(b, t, c, new_h, new_w) - images = cut_videos(images) - images_bthwc = rearrange(images, "b t c h w -> b t h w c") - - return io.NodeOutput(images_bthwc) - - -class SeedVR2Preprocess(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SeedVR2Preprocess", - display_name="Pre-Process SeedVR2 Input", - category="image/upscaling", - description="Pad a resized image for SeedVR2 model. Alpha channel is dropped. The node Post-Process SeedVR2 Output re-applies it from the original resized image.", - inputs=[ - io.Image.Input("resized_images", tooltip="The resized image to process."), - ], - outputs=[ - io.Image.Output("images"), - ] - ) - - @classmethod - def execute(cls, resized_images): - upscaled_shorter_edge = _seedvr2_input_shorter_edge(resized_images, "SeedVR2Preprocess") - return _seedvr2_pad( - resized_images, upscaled_shorter_edge, "SeedVR2Preprocess", - ) - - -class SeedVR2PostProcessing(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SeedVR2PostProcessing", - display_name="Post-Process SeedVR2 Output", - category="image/upscaling", - description="Align the generated image with the original resized image and apply color correction.", - inputs=[ - io.Image.Input("images", tooltip="The generated image to process."), - io.Image.Input("original_resized_images", tooltip="The original resized image before pre-processing, used as reference."), - io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab", tooltip="Method to match the generated image colors to the original image. lab: transfer color in CIELAB space, preserving detail (most faithful). wavelet: transfer low-frequency color, keeping upscaled high-frequency detail. adain: match per-channel mean/std (fastest, global tint). none: skip color transfer (geometry alignment only)."), - ], - outputs=[io.Image.Output(display_name="images")], - ) - - @classmethod - def execute(cls, images, original_resized_images, color_correction_method): - alpha_input = None - if original_resized_images.shape[-1] == 4: - alpha_input = original_resized_images[..., 3:4] - original_resized_images = original_resized_images[..., :3] - decoded_5d, decoded_was_4d = cls._as_bthwc(images) - reference_full, _ = cls._as_bthwc(original_resized_images) - decoded_5d = cls._restore_reference_batch_time(decoded_5d, reference_full) - - b = min(decoded_5d.shape[0], reference_full.shape[0]) - t = min(decoded_5d.shape[1], reference_full.shape[1]) - reference_h = reference_full.shape[2] - reference_w = reference_full.shape[3] - - decoded_5d = decoded_5d[:b, :t, :, :, :] - target_h = min(decoded_5d.shape[2], reference_h) - target_w = min(decoded_5d.shape[3], reference_w) - decoded_5d = decoded_5d[:, :, :target_h, :target_w, :] - if color_correction_method in ("lab", "wavelet", "adain"): - reference_5d = reference_full[:b, :t, :, :, :] - reference_5d = cls._resize_reference(reference_5d, target_h, target_w) - output_device = decoded_5d.device - decoded_raw = cls._to_seedvr2_raw(decoded_5d) - reference_raw = cls._to_seedvr2_raw(reference_5d) - decoded_flat = rearrange(decoded_raw, "b t h w c -> (b t) c h w") - reference_flat = rearrange(reference_raw, "b t h w c -> (b t) c h w") - output = cls._color_transfer_chunked( - decoded_flat, reference_flat, output_device, color_correction_method, - ) - output = rearrange(output, "(b t) c h w -> b t h w c", b=b, t=t) - output = output.add(1.0).div(2.0).clamp(0.0, 1.0) - elif color_correction_method == "none": - output = decoded_5d - else: - raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}") - - if alpha_input is not None: - alpha_5d, _ = cls._as_bthwc(alpha_input) - alpha_5d = alpha_5d[:output.shape[0], :output.shape[1], :output.shape[2], :output.shape[3], :] - output = torch.cat([output, alpha_5d.to(dtype=output.dtype, device=output.device)], dim=-1) - h2 = output.shape[-3] - (output.shape[-3] % 2) - w2 = output.shape[-2] - (output.shape[-2] % 2) - output = output[:, :, :h2, :w2, :] - if decoded_was_4d: - output = output.reshape(-1, output.shape[-3], output.shape[-2], output.shape[-1]) - return io.NodeOutput(output) - - @staticmethod - def _as_bthwc(images): - if images.ndim == 4: - return images.unsqueeze(0), True - if images.ndim == 5: - return images, False - raise ValueError( - f"SeedVR2PostProcessing: expected 4-D or 5-D IMAGE tensor, got shape {tuple(images.shape)}" - ) - - @staticmethod - def _restore_reference_batch_time(decoded, reference): - if decoded.shape[0] != 1: - return decoded - ref_b, ref_t = reference.shape[:2] - if ref_b < 1 or decoded.shape[1] % ref_b != 0: - return decoded - decoded_t = decoded.shape[1] // ref_b - if decoded_t < ref_t: - return decoded - return decoded.reshape(ref_b, decoded_t, decoded.shape[2], decoded.shape[3], decoded.shape[4]) - - @staticmethod - def _to_seedvr2_raw(images): - return images.mul(2.0).sub(1.0) - - @staticmethod - def _color_transfer_on_vae_device(decoded_flat, reference_flat, output_device, transfer_fn): - color_device = comfy.model_management.vae_device() - decoded_flat = decoded_flat.to(device=color_device) - reference_flat = reference_flat.to(device=color_device) - output = transfer_fn(decoded_flat, reference_flat) - return output.to(device=output_device) - - @staticmethod - def _lab_color_transfer_on_vae_device(decoded_flat, reference_flat, output_device): - color_device = comfy.model_management.vae_device() - result = None - for start in range(decoded_flat.shape[0]): - decoded_frame = decoded_flat[start:start + 1].to(device=color_device).clone() - reference_frame = reference_flat[start:start + 1].to(device=color_device).clone() - output = lab_color_transfer(decoded_frame, reference_frame).to(device=output_device) - if result is None: - result = torch.empty( - (decoded_flat.shape[0],) + tuple(output.shape[1:]), - device=output_device, - dtype=output.dtype, - ) - result[start:start + 1].copy_(output) - if result is None: - raise ValueError("SeedVR2PostProcessing: LAB color correction requires at least one frame.") - return result - - @classmethod - def _color_transfer_chunked(cls, decoded_flat, reference_flat, output_device, color_correction_method): - chunk_size = cls._estimate_color_correction_chunk_size(decoded_flat, color_correction_method) - while True: - next_chunk_size = None - try: - return cls._run_color_transfer_chunks( - decoded_flat, reference_flat, output_device, color_correction_method, chunk_size, - ) - except Exception as e: - comfy.model_management.raise_non_oom(e) - if chunk_size <= 1: - raise RuntimeError( - "SeedVR2PostProcessing: color correction OOM at one frame; " - f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}." - ) from e - next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR) - - comfy.model_management.soft_empty_cache() - chunk_size = next_chunk_size - - @classmethod - def _run_color_transfer_chunks(cls, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size): - result = None - for start in range(0, decoded_flat.shape[0], chunk_size): - end = min(start + chunk_size, decoded_flat.shape[0]) - decoded_chunk = decoded_flat[start:end] - reference_chunk = reference_flat[start:end] - if color_correction_method == "lab": - output = cls._lab_color_transfer_on_vae_device(decoded_chunk, reference_chunk, output_device) - elif color_correction_method == "wavelet": - output = cls._color_transfer_on_vae_device( - decoded_chunk, reference_chunk, output_device, wavelet_color_transfer, - ) - else: - output = cls._color_transfer_on_vae_device( - decoded_chunk, reference_chunk, output_device, adain_color_transfer, - ) - if result is None: - result = torch.empty( - (decoded_flat.shape[0],) + tuple(output.shape[1:]), - device=output_device, - dtype=output.dtype, - ) - result[start:end].copy_(output) - if result is None: - raise ValueError("SeedVR2PostProcessing: color correction requires at least one frame.") - return result - - @classmethod - def _estimate_color_correction_chunk_size(cls, decoded_flat, color_correction_method): - multiplier = cls._color_correction_memory_multiplier(color_correction_method) - frames = decoded_flat.shape[0] - _, channels, height, width = decoded_flat.shape - dtype_bytes = max(decoded_flat.element_size(), SEEDVR2_DTYPE_BYTES_FLOOR) - bytes_per_frame = height * width * channels * dtype_bytes * multiplier - if bytes_per_frame <= 0: - return frames - color_device = comfy.model_management.vae_device() - free_memory = comfy.model_management.get_free_memory(color_device) - chunk_size = int((free_memory * SEEDVR2_COLOR_MEM_HEADROOM) // bytes_per_frame) - return max(1, min(frames, chunk_size)) - - @staticmethod - def _color_correction_memory_multiplier(color_correction_method): - if color_correction_method == "lab": - return SEEDVR2_LAB_SCALE_MULTIPLIER - if color_correction_method == "wavelet": - return SEEDVR2_WAVELET_SCALE_MULTIPLIER - if color_correction_method == "adain": - return SEEDVR2_ADAIN_SCALE_MULTIPLIER - raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}") - - @staticmethod - def _resize_reference(reference, height, width): - if reference.shape[2] == height and reference.shape[3] == width: - return reference - b, t = reference.shape[:2] - reference_flat = rearrange(reference, "b t h w c -> (b t) c h w") - resized = TVF.resize( - reference_flat, - size=(height, width), - interpolation=InterpolationMode.BICUBIC, - antialias=not (isinstance(reference_flat, torch.Tensor) and reference_flat.device.type == "mps"), - ) - return rearrange(resized, "(b t) c h w -> b t h w c", b=b, t=t) - - -class SeedVR2Conditioning(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SeedVR2Conditioning", - display_name="Apply SeedVR2 Conditioning", - category="conditioning", - description="Build SeedVR2 positive/negative conditioning from a VAE latent.", - inputs=[ - io.Model.Input("model", tooltip="The SeedVR2 model."), - io.Latent.Input("vae_conditioning", display_name="latent"), - ], - outputs=[ - io.Model.Output(display_name = "model"), - io.Conditioning.Output(display_name = "positive"), - io.Conditioning.Output(display_name = "negative"), - io.Latent.Output(display_name = "latent"), - ], - ) - - @classmethod - def execute(cls, model, vae_conditioning) -> io.NodeOutput: - - vae_conditioning = vae_conditioning["samples"] - if vae_conditioning.ndim != 5: - raise ValueError( - "SeedVR2Conditioning expects a 5-D VAE latent in Comfy " - f"channel-first layout; got shape {tuple(vae_conditioning.shape)}." - ) - if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS: - raise ValueError( - "SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy " - f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); " - f"got channel-last shape {tuple(vae_conditioning.shape)}." - ) - vae_conditioning = vae_conditioning.movedim(1, -1).contiguous() - model_patcher = model - model = _resolve_seedvr2_diffusion_model(model_patcher) - pos_cond = model.positive_conditioning - neg_cond = model.negative_conditioning - - # Fail-loud guard against silently-wrong output when a - # DiT-only ``.safetensors`` (no ``positive_conditioning`` / - # ``negative_conditioning`` keys) is loaded via ``UNETLoader``. - # ``NaDiT.__init__`` zero-fills the buffers via ``torch.zeros`` (see - # ``comfy/ldm/seedvr/model.py``); ``load_state_dict(strict=False)`` - # leaves them at zero when the keys are absent. Detect that state - # here rather than at ``BaseModel.extra_conds`` (per sampling step, - # wasteful) or at the resolver helper (mixes structural shape with - # semantic content). Both buffers must be checked together — partial - # bake regressions could populate one but not the other. - if ( - pos_cond.float().abs().sum().item() == 0 - and neg_cond.float().abs().sum().item() == 0 - ): - raise RuntimeError( - f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: positive_conditioning " - f"and negative_conditioning buffers are zero-valued — model " - f"file appears to be a DiT-only export missing " - f"the SeedVR2 conditioning tensors. " - f"Re-bake the file with ``positive_conditioning`` (58, 5120) " - f"and ``negative_conditioning`` (64, 5120) keys at top level, " - f"or load via CheckpointLoaderSimple from a bundled " - f"checkpoint." - ) - - _apply_rope_freqs_float32_cast(model) - - condition = torch.stack([get_conditions(c, c) for c in vae_conditioning]) - condition = condition.movedim(-1, 1) - latent = vae_conditioning.movedim(-1, 1) - - latent = rearrange(latent, "b c t h w -> b (c t) h w") - condition = rearrange(condition, "b c t h w -> b (c t) h w") - - negative = [[neg_cond.unsqueeze(0), {"condition": condition}]] - positive = [[pos_cond.unsqueeze(0), {"condition": condition}]] - - return io.NodeOutput(model_patcher, positive, negative, {"samples": latent}) - -def _slice_collapsed_4d_along_t(tensor_4d: torch.Tensor, t_start: int, - t_end: int, channels: int) -> torch.Tensor: - """Slice collapsed ``(B, channels*T, H, W)`` along latent T: reshape (accepts non-contiguous inputs), slice, ``.contiguous()`` (T-slice of 5D is a non-contiguous view; re-collapse needs contiguous), re-collapse.""" - B, CT, H, W = tensor_4d.shape - if CT % channels != 0: - raise ValueError( - f"_slice_collapsed_4d_along_t: collapsed channel dim {CT} is not " - f"divisible by channels={channels}; tensor shape {tuple(tensor_4d.shape)}." - ) - T = CT // channels - if not (0 <= t_start < t_end <= T): - raise ValueError( - f"_slice_collapsed_4d_along_t: slice [{t_start}:{t_end}] out of " - f"range for T={T}." - ) - new_T = t_end - t_start - sliced = tensor_4d.reshape(B, channels, T, H, W)[:, :, t_start:t_end, :, :].contiguous() - return sliced.reshape(B, channels * new_T, H, W) - - -def _slice_seedvr2_cond_along_t(cond_list, t_start: int, t_end: int): - """Return a new conditioning list with each entry's ``options["condition"]`` (collapsed ``(B, 17*T, H, W)``) sliced along latent T; text tensors, other option keys, and condition-less entries pass through unchanged and inputs are not mutated.""" - new_list = [] - for entry in cond_list: - text_cond, options = entry[0], entry[1] - if "condition" not in options: - new_list.append(entry) - continue - new_options = options.copy() - new_options["condition"] = _slice_collapsed_4d_along_t( - new_options["condition"], t_start, t_end, - SEEDVR2_COND_CHANNELS, - ) - new_list.append([text_cond, new_options]) - return new_list - - -def _slice_seedvr2_noise_mask_along_t(noise_mask: torch.Tensor, - samples_4d: torch.Tensor, - t_start: int, - t_end: int): - """Slice only masks already expanded to collapsed ``(B, 16*T, H, W)``; pass standard ``(B, 1, H, W)`` ``SetLatentNoiseMask`` outputs through for KSampler to expand.""" - if noise_mask.ndim == samples_4d.ndim and noise_mask.shape[1] == samples_4d.shape[1]: - return _slice_collapsed_4d_along_t( - noise_mask, t_start, t_end, SEEDVR2_LATENT_CHANNELS, - ) - return noise_mask - - -def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor: - """Concatenate collapsed ``(B, channels*T_i, H, W)`` chunks along latent T: un-collapse to 5D, cat on ``dim=2``, re-collapse to 4D.""" - if len(chunks_4d) == 0: - raise ValueError("_concat_chunks_along_t: empty chunk list.") - fives = [] - for ch in chunks_4d: - B, CT, H, W = ch.shape - if CT % channels != 0: - raise ValueError( - f"_concat_chunks_along_t: chunk shape {tuple(ch.shape)} " - f"channel dim {CT} not divisible by channels={channels}." - ) - T = CT // channels - fives.append(ch.reshape(B, channels, T, H, W)) - cat = torch.cat(fives, dim=2).contiguous() - B, C, T_total, H, W = cat.shape - return cat.reshape(B, C * T_total, H, W) - - -def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor: - """1D length-``overlap`` crossfade weights for the previous chunk (current = ``1 - w_prev``): - Hann window with a ``[1/3, 2/3]`` dead-band for ``overlap >= 3``, linear ramp for ``overlap < 3`` - (dead-band would collapse a tiny transition). Window shape matched to the reference - overlapping-frame blend for parity; caller broadcasts across ``(B, C, T_overlap, H, W)``. - """ - if overlap < 1: - raise ValueError( - f"_hann_blend_weights_1d: overlap must be >= 1; got {overlap}." - ) - if overlap >= 3: - t = torch.linspace(0.0, 1.0, steps=overlap, device=device, dtype=dtype) - blend_start = 1.0 / 3.0 - blend_end = 2.0 / 3.0 - u = ((t - blend_start) / (blend_end - blend_start)).clamp(0.0, 1.0) - return 0.5 + 0.5 * torch.cos(torch.pi * u) - return torch.linspace(1.0, 0.0, steps=overlap, device=device, dtype=dtype) - - -def _blend_overlap_region(prev_tail_5d: torch.Tensor, - cur_head_5d: torch.Tensor) -> torch.Tensor: - """Blend two equal-shape 5D ``(B, C, T_overlap, H, W)`` tensors with a 1D Hann/linear T-ramp: ``prev_tail_5d`` takes the descending weight, ``cur_head_5d`` takes ``1 - w_prev`` (caller ensures matching shape/dtype/device).""" - if prev_tail_5d.shape != cur_head_5d.shape: - raise ValueError( - f"_blend_overlap_region: shape mismatch " - f"prev {tuple(prev_tail_5d.shape)} vs " - f"cur {tuple(cur_head_5d.shape)}." - ) - overlap = int(prev_tail_5d.shape[2]) - w_prev_1d = _hann_blend_weights_1d( - overlap, prev_tail_5d.device, prev_tail_5d.dtype, - ) - # Reshape to (1, 1, overlap, 1, 1) for broadcast across B, C, H, W. - w_prev = w_prev_1d.view(1, 1, overlap, 1, 1) - w_cur = 1.0 - w_prev - return prev_tail_5d * w_prev + cur_head_5d * w_cur - - -def _concat_chunks_with_overlap_blend(chunk_specs, channels: int, - overlap_latent: int) -> torch.Tensor: - """Concatenate overlapping ``(t_start, t_end, chunk_4d)`` specs (source-latent T coords) into one collapsed 4D tensor, Hann/linear-blending overlaps; ``overlap_latent == 0`` fast-paths to plain concat (bit-identical to ``_concat_chunks_along_t``). Each blend uses the actual width ``min(prev_end - cur_start, chunk length)``, smaller than ``overlap_latent`` for a runt final chunk.""" - if len(chunk_specs) == 0: - raise ValueError("_concat_chunks_with_overlap_blend: empty chunk list.") - if overlap_latent < 0: - raise ValueError( - f"_concat_chunks_with_overlap_blend: overlap_latent must be " - f">= 0; got {overlap_latent}." - ) - - # Validate channel divisibility once and capture per-chunk T. - chunk_5d = [] - for t_start, t_end, ch in chunk_specs: - B, CT, H, W = ch.shape - if CT % channels != 0: - raise ValueError( - f"_concat_chunks_with_overlap_blend: chunk shape " - f"{tuple(ch.shape)} channel dim {CT} not divisible " - f"by channels={channels}." - ) - T = CT // channels - if t_end - t_start != T: - raise ValueError( - f"_concat_chunks_with_overlap_blend: chunk T={T} mismatches " - f"declared range [{t_start}:{t_end}]." - ) - chunk_5d.append((t_start, t_end, ch.reshape(B, channels, T, H, W))) - - if overlap_latent == 0: - # Fast path: pure concat in the caller-provided chunk order. - return _concat_chunks_along_t( - [c.reshape(c.shape[0], channels * c.shape[2], c.shape[3], c.shape[4]) - for _, _, c in chunk_5d], - channels, - ) - - T_total = max(t_end for _, t_end, _ in chunk_5d) - first_5d = chunk_5d[0][2] - B = first_5d.shape[0] - H = first_5d.shape[3] - W = first_5d.shape[4] - result = torch.empty( - (B, channels, T_total, H, W), - device=first_5d.device, dtype=first_5d.dtype, - ) - filled_until = 0 - for i, (cs, ce, ct_5d) in enumerate(chunk_5d): - chunk_T = int(ct_5d.shape[2]) - if i == 0: - result[:, :, cs:ce, :, :] = ct_5d - filled_until = ce - continue - # Overlap region width is bounded by both the previous fill - # frontier and the current chunk's actual length (for runt - # final chunks shorter than the configured overlap). - overlap_len = min(filled_until - cs, chunk_T) - if overlap_len > 0: - prev_tail = result[:, :, cs:cs + overlap_len, :, :].contiguous() - cur_head = ct_5d[:, :, :overlap_len, :, :].contiguous() - blended = _blend_overlap_region(prev_tail, cur_head) - result[:, :, cs:cs + overlap_len, :, :] = blended - tail_start = cs + overlap_len - tail_end = ce - if tail_end > tail_start: - result[:, :, tail_start:tail_end, :, :] = ( - ct_5d[:, :, overlap_len:, :, :] - ) - else: - # Disjoint chunks (overlap_latent set but this pair did not - # actually overlap, e.g. step_latent equal to chunk_latent - # in a degenerate config). Treat as concat. - result[:, :, cs:ce, :, :] = ct_5d - filled_until = ce - - return result.contiguous().reshape(B, channels * T_total, H, W) - - -def _run_standard_sample(model, seed: int, steps: int, cfg: float, - sampler_name: str, scheduler: str, - positive, negative, latent: dict, - denoise: float) -> dict: - """Single-shot mirror of ``nodes.py:common_ksampler`` (seed -> noise, ``comfy.sample.sample``, latent dict); used by the ProgressiveSampler short-circuit when the whole sequence fits one chunk.""" - samples_in = latent["samples"] - samples_in = comfy.sample.fix_empty_latent_channels( - model, samples_in, latent.get("downscale_ratio_spacial", None), - ) - batch_inds = latent.get("batch_index", None) - noise = comfy.sample.prepare_noise(samples_in, seed, batch_inds) - noise_mask = latent.get("noise_mask", None) - samples = comfy.sample.sample( - model, noise, steps, cfg, sampler_name, scheduler, - positive, negative, samples_in, - denoise=denoise, noise_mask=noise_mask, seed=seed, - ) - out = latent.copy() - out.pop("downscale_ratio_spacial", None) - out["samples"] = samples - return out - - -class SeedVR2ProgressiveSampler(io.ComfyNode): - """Sequential temporal chunking sampler for SeedVR2 native. - - Drop-in replacement for ``KSampler`` in SeedVR2 native workflows that - OOM on long sequences. The latent enters the sampler in SeedVR2's - collapsed form ``(B, 16*T, H, W)`` (collapsed by ``SeedVR2Conditioning`` - at ``rearrange(b c t h w -> b (c t) h w)``); this node slices that - tensor along the temporal axis, runs the configured inner sampler - sequentially per chunk against the standard ``comfy.sample.sample`` - entry point, and concatenates per-chunk outputs back into a single - ``(B, 16*T_total, H, W)`` latent. - - ``frames_per_chunk`` is expressed in pixel-frame units to match the - SeedVR2 4n+1 constraint enforced upstream by ``cut_videos`` and the - VAE's ``temporal_downsample_factor=4``. A pixel chunk size ``F`` - maps to ``(F - 1) // 4 + 1`` latent-frame chunks. - - Determinism contract: a single noise tensor is generated once from - the user seed and sliced per chunk (rather than re-seeding each - chunk), so a workflow that fits in a single chunk produces output - identical to a workflow that fits in N chunks at the same seed, - modulo the inherent T-axis chunk-boundary independence of the model. - """ - - @classmethod - def define_schema(cls): - return io.Schema( - node_id="SeedVR2ProgressiveSampler", - display_name="Sample SeedVR2 (Progressive)", - category="sampling", - description="Sample a SeedVR2 latent in sequential temporal chunks to allow longer videos to fit into VRAM via frame blending the resulting upscaled latents.", - inputs=[ - io.Model.Input("model", tooltip="The model used for denoising the input latent."), - io.Int.Input("seed", default=0, min=0, - max=0xffffffffffffffff, - control_after_generate=True, - tooltip="The random seed used for creating the noise."), - io.Int.Input("steps", default=20, min=1, max=10000, - tooltip="The number of steps used in the denoising process."), - io.Float.Input("cfg", default=1.0, min=0.0, max=100.0, - step=0.1, round=0.01, - tooltip="The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."), - io.Combo.Input("sampler_name", - options=comfy.samplers.SAMPLER_NAMES, - tooltip="The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."), - io.Combo.Input("scheduler", - options=comfy.samplers.SCHEDULER_NAMES, - tooltip="The scheduler controls how noise is gradually removed to form the image."), - io.Conditioning.Input("positive", - tooltip="The conditioning describing the attributes you want to include in the image."), - io.Conditioning.Input("negative", - tooltip="The conditioning describing the attributes you want to exclude from the image."), - io.Latent.Input("latent", - tooltip="The latent image to denoise."), - io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, - step=0.01, - tooltip="The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."), - io.Int.Input("frames_per_chunk", default=21, min=1, - max=16384, step=4, - tooltip="Pixel frames per temporal chunk (4n+1: 1, 5, 9, 13, ...)."), - io.Int.Input("temporal_overlap", default=0, min=0, - max=16384, - tooltip="Latent frames blended between adjacent chunks to hide the seam; 0 = no blend."), - io.Combo.Input("chunking_mode", - options=["manual", "auto"], - default="manual", - tooltip="manual = use frames_per_chunk exactly; auto = shrink the chunk until it fits in VRAM."), - ], - outputs=[io.Latent.Output(display_name="latent")], - ) - - @classmethod - def execute(cls, model, seed, steps, cfg, sampler_name, scheduler, - positive, negative, latent, denoise, - frames_per_chunk, temporal_overlap, - chunking_mode="manual") -> io.NodeOutput: - # 4n+1 validation in pixel-frame domain. The SeedVR2 native pipeline - # requires pixel-frame counts of the form 4n+1 (1, 5, 9, 13, ...), - # imposed at ``cut_videos`` upstream and propagated through the VAE's - # temporal_downsample_factor=4. Reject violations explicitly before - # any model invocation; a silent rounding would mis-align chunk - # boundaries with the 4n+1 lattice. - if frames_per_chunk < 1 or (frames_per_chunk - 1) % 4 != 0: - raise ValueError( - f"SeedVR2ProgressiveSampler: frames_per_chunk must be a " - f"4n+1 pixel-frame count (1, 5, 9, 13, 17, 21, ...); " - f"got {frames_per_chunk}." - ) - - samples_4d = latent["samples"] - samples_4d = comfy.sample.fix_empty_latent_channels( - model, samples_4d, - latent.get("downscale_ratio_spacial", None), - ) - if samples_4d.ndim != 4: - raise ValueError( - f"SeedVR2ProgressiveSampler: expected 4D collapsed latent " - f"(B, 16*T, H, W); got shape {tuple(samples_4d.shape)}." - ) - B, CT, H, W = samples_4d.shape - if CT % SEEDVR2_LATENT_CHANNELS != 0: - raise ValueError( - f"SeedVR2ProgressiveSampler: collapsed channel dim {CT} is " - f"not divisible by SeedVR2 latent channels " - f"{SEEDVR2_LATENT_CHANNELS}; latent does not appear to be " - f"SeedVR2-shaped." - ) - T_latent = CT // SEEDVR2_LATENT_CHANNELS - T_pixel = 4 * (T_latent - 1) + 1 - - if chunking_mode not in ("manual", "auto"): - raise ValueError( - f"SeedVR2ProgressiveSampler: chunking_mode must be " - f"'manual' or 'auto'; got {chunking_mode!r}." - ) - - if chunking_mode == "auto": - attempts = _seedvr2_auto_chunk_attempts( - T_latent, T_pixel, frames_per_chunk, - ) - for i, attempt_frames_per_chunk in enumerate(attempts): - retry = False - try: - return cls.execute( - model=model, seed=seed, steps=steps, cfg=cfg, - sampler_name=sampler_name, scheduler=scheduler, - positive=positive, negative=negative, - latent=latent, denoise=denoise, - frames_per_chunk=attempt_frames_per_chunk, - temporal_overlap=temporal_overlap, - chunking_mode="manual", - ) - except Exception as e: - comfy.model_management.raise_non_oom(e) - if i == len(attempts) - 1: - raise RuntimeError( - "SeedVR2ProgressiveSampler: exhausted auto " - "chunking attempts after OOM. Tried " - f"frames_per_chunk values {attempts}." - ) from e - retry = True - - if retry: - logging.warning( - "SeedVR2ProgressiveSampler auto chunking OOM at " - "frames_per_chunk=%s; retrying with " - "frames_per_chunk=%s.", - attempt_frames_per_chunk, attempts[i + 1], - ) - comfy.model_management.soft_empty_cache() - - # Short-circuit: total fits in one chunk -> standard path with no - # chunking overhead. Output of this branch is byte-identical to the - # built-in KSampler given the same (model, seed, steps, cfg, - # sampler_name, scheduler, positive, negative, latent, - # denoise) tuple. - if T_pixel <= frames_per_chunk: - return io.NodeOutput(_run_standard_sample( - model, seed, steps, cfg, sampler_name, scheduler, - positive, negative, latent, denoise, - )) - - # Map pixel chunk -> latent chunk. Each chunk's latent length is - # at most ``chunk_latent``; the final chunk may be a runt that - # is automatically 4n+1-aligned in the pixel domain by the - # T_pixel = 4*(T_latent-1) + 1 mapping (every positive integer - # T_latent corresponds to a valid 4n+1 pixel count). - chunk_latent = (frames_per_chunk - 1) // 4 + 1 - - # ``temporal_overlap`` is exposed in latent-frame units, but users - # do not know the derived latent chunk length. Treat oversized - # values as "maximum valid overlap" while preserving a strictly - # positive chunk-loop stride. - if temporal_overlap < 0: - raise ValueError( - f"SeedVR2ProgressiveSampler: temporal_overlap must be >= 0; " - f"got {temporal_overlap}." - ) - temporal_overlap = min(temporal_overlap, chunk_latent - 1) - step_latent = chunk_latent - temporal_overlap - - # Generate full noise once from the user seed, then slice along T - # per chunk. Using one global noise tensor (rather than re-seeding - # per chunk) preserves seed-determinism across chunk-count - # variations: the same (seed, total T_latent) always produces the - # same noise samples regardless of how the work is partitioned. - batch_inds = latent.get("batch_index", None) - noise_full = comfy.sample.prepare_noise(samples_4d, seed, batch_inds) - - noise_mask = latent.get("noise_mask", None) - - # Build the flat list of chunk ranges first so the chunking - # geometry is fully known before any sample call. - chunk_ranges = [] - for chunk_start in range(0, T_latent, step_latent): - chunk_end = min(chunk_start + chunk_latent, T_latent) - if chunk_start >= chunk_end: - # The final iteration of a stride that lands exactly on - # T_latent produces a zero-length chunk; skip it. - break - chunk_ranges.append((chunk_start, chunk_end)) - if chunk_end >= T_latent: - break - - def _sample_one_chunk(chunk_start, chunk_end): - samples_chunk = _slice_collapsed_4d_along_t( - samples_4d, chunk_start, chunk_end, - SEEDVR2_LATENT_CHANNELS, - ) - noise_chunk = _slice_collapsed_4d_along_t( - noise_full, chunk_start, chunk_end, - SEEDVR2_LATENT_CHANNELS, - ) - positive_chunk = _slice_seedvr2_cond_along_t( - positive, chunk_start, chunk_end, - ) - negative_chunk = _slice_seedvr2_cond_along_t( - negative, chunk_start, chunk_end, - ) - - # Per-chunk noise_mask handling: standard masks are passed - # through for KSampler expansion; pre-expanded collapsed - # masks are sliced. - chunk_noise_mask = None - if noise_mask is not None: - chunk_noise_mask = _slice_seedvr2_noise_mask_along_t( - noise_mask, samples_4d, chunk_start, chunk_end, - ) - - return comfy.sample.sample( - model, noise_chunk, steps, cfg, sampler_name, scheduler, - positive_chunk, negative_chunk, samples_chunk, - denoise=denoise, noise_mask=chunk_noise_mask, seed=seed, - ) - - chunk_specs = [] - for chunk_start, chunk_end in chunk_ranges: - chunk_samples = _sample_one_chunk(chunk_start, chunk_end) - chunk_specs.append((chunk_start, chunk_end, chunk_samples)) - - final = _concat_chunks_with_overlap_blend( - chunk_specs, SEEDVR2_LATENT_CHANNELS, temporal_overlap, - ) - - out = latent.copy() - out.pop("downscale_ratio_spacial", None) - out["samples"] = final - return io.NodeOutput(out) - - -class SeedVRExtension(ComfyExtension): - @override - async def get_node_list(self) -> list[type[io.ComfyNode]]: - return [ - SeedVR2Conditioning, - SeedVR2Preprocess, - SeedVR2PostProcessing, - SeedVR2ProgressiveSampler, - ] - -async def comfy_entrypoint() -> SeedVRExtension: - return SeedVRExtension() diff --git a/nodes.py b/nodes.py index d9ac53ede..2f5a478b5 100644 --- a/nodes.py +++ b/nodes.py @@ -47,18 +47,14 @@ import node_helpers if args.enable_manager: import comfyui_manager - def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() - def interrupt_processing(value=True): comfy.model_management.interrupt_current_processing(value) - MAX_RESOLUTION=16384 - class CLIPTextEncode(ComfyNodeABC): @classmethod def INPUT_TYPES(s) -> InputTypeDict: @@ -327,8 +323,8 @@ class VAEDecodeTiled: return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32, "advanced": True}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}), - "temporal_size": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time. SeedVR2 allows 0 to disable temporal slicing.", "advanced": True}), - "temporal_overlap": ("INT", {"default": 8, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), + "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time.", "advanced": True}), + "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), }} RETURN_TYPES = ("IMAGE",) FUNCTION = "decode" @@ -338,32 +334,18 @@ class VAEDecodeTiled: def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8): if tile_size < overlap * 4: overlap = tile_size // 4 + if temporal_size < temporal_overlap * 2: + temporal_overlap = temporal_overlap // 2 temporal_compression = vae.temporal_compression_decode() if temporal_compression is not None: - if temporal_size <= 0: - temporal_size = 0 - temporal_overlap = 0 - else: - requested_temporal_overlap = temporal_overlap - if temporal_size < temporal_overlap * 2: - temporal_overlap = temporal_overlap // 2 - temporal_size = max(2, temporal_size // temporal_compression) - temporal_overlap = min(temporal_size // 2, temporal_overlap // temporal_compression) - if requested_temporal_overlap > 0: - temporal_overlap = max(1, temporal_overlap) + temporal_size = max(2, temporal_size // temporal_compression) + temporal_overlap = max(1, min(temporal_size // 2, temporal_overlap // temporal_compression)) else: temporal_size = None temporal_overlap = None compression = vae.spacial_compression_decode() - images = vae.decode_tiled( - samples["samples"], - tile_x=tile_size // compression, - tile_y=tile_size // compression, - overlap=overlap // compression, - tile_t=temporal_size, - overlap_t=temporal_overlap, - ) + images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) return (images, ) @@ -380,7 +362,7 @@ class VAEEncode: def encode(self, vae, pixels): t = vae.encode(pixels) - return ({"samples": t}, ) + return ({"samples":t}, ) class VAEEncodeTiled: @classmethod @@ -388,8 +370,8 @@ class VAEEncodeTiled: return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64, "advanced": True}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}), - "temporal_size": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time. SeedVR2 allows 0 to disable temporal slicing.", "advanced": True}), - "temporal_overlap": ("INT", {"default": 8, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), + "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time.", "advanced": True}), + "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), }} RETURN_TYPES = ("LATENT",) FUNCTION = "encode" @@ -397,9 +379,6 @@ class VAEEncodeTiled: CATEGORY = "experimental" def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): - if temporal_size <= 0: - temporal_size = 0 - temporal_overlap = 0 t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) return ({"samples": t}, ) @@ -2439,7 +2418,6 @@ async def init_builtin_extra_nodes(): "nodes_camera_trajectory.py", "nodes_edit_model.py", "nodes_tcfg.py", - "nodes_seedvr.py", "nodes_context_windows.py", "nodes_qwen.py", "nodes_chroma_radiance.py", diff --git a/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py deleted file mode 100644 index 2a6e3d430..000000000 --- a/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py +++ /dev/null @@ -1,213 +0,0 @@ -"""Consolidated SeedVR2 conditioning and refactor regression tests. - -Merges the prior test_seedvr2_refactor_nodes.py and -test_seedvr_conditioning_hardening.py modules. Refactor tests use the -top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests -use _import_nodes_seedvr_isolated() for sys.modules isolation when -mocking comfy.model_management. -""" - -import importlib -import sys -from unittest.mock import MagicMock - -import pytest -import torch -import torch.nn as nn - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - - -_SENTINEL = object() -_TARGETS = ( - ("comfy.model_management", "comfy"), - ("comfy_extras.nodes_seedvr", "comfy_extras"), -) - - -def _import_nodes_seedvr_isolated(): - """Import comfy_extras.nodes_seedvr with comfy.model_management mocked.""" - priors = [] - for mod_name, parent_name in _TARGETS: - prior_mod = sys.modules.get(mod_name, _SENTINEL) - parent = sys.modules.get(parent_name) - attr = mod_name.split(".")[-1] - prior_attr = ( - getattr(parent, attr, _SENTINEL) if parent is not None else _SENTINEL - ) - priors.append((mod_name, parent_name, attr, prior_mod, prior_attr)) - - mock_mm = MagicMock() - for fn in ( - "xformers_enabled", "xformers_enabled_vae", - "pytorch_attention_enabled", "pytorch_attention_enabled_vae", - "sage_attention_enabled", "flash_attention_enabled", - "is_intel_xpu", - ): - getattr(mock_mm, fn).return_value = False - tv = torch.version.__version__.split(".") - mock_mm.torch_version_numeric = (int(tv[0]), int(tv[1])) - mock_mm.WINDOWS = False - sys.modules["comfy.model_management"] = mock_mm - if sys.modules.get("comfy") is None: - import comfy as _comfy_pkg # noqa: F401 - comfy_pkg = sys.modules.get("comfy") - if comfy_pkg is not None: - setattr(comfy_pkg, "model_management", mock_mm) - nodes_seedvr = sys.modules.get("comfy_extras.nodes_seedvr") or ( - importlib.import_module("comfy_extras.nodes_seedvr") - ) - - def _restore(): - for mod_name, parent_name, attr, prior_mod, prior_attr in priors: - if prior_mod is _SENTINEL: - sys.modules.pop(mod_name, None) - else: - sys.modules[mod_name] = prior_mod - parent = sys.modules.get(parent_name) - if parent is None: - continue - if prior_attr is _SENTINEL: - if hasattr(parent, attr): - delattr(parent, attr) - else: - setattr(parent, attr, prior_attr) - - return nodes_seedvr, _restore - - -class _Rope(nn.Module): - """Minimal RoPE stub exposing a `freqs` parameter.""" - def __init__(self): - super().__init__() - self.freqs = nn.Parameter(torch.zeros(4)) - - -class _Block(nn.Module): - """Minimal transformer block stub holding a `_Rope`.""" - def __init__(self): - super().__init__() - self.rope = _Rope() - - -class _DiffusionModel(nn.Module): - """Stub diffusion model with N blocks and pos/neg conditioning buffers.""" - def __init__(self, n_blocks=3, zero_conditioning=False, conditioning_dtype=torch.float32): - super().__init__() - self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)]) - pos = torch.zeros if zero_conditioning else torch.ones - self.register_buffer("positive_conditioning", pos((2, 4), dtype=conditioning_dtype)) - self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype)) - - -class _ModelInner: - """Inner model wrapper exposing `.diffusion_model`.""" - def __init__(self, diffusion_model): - self.diffusion_model = diffusion_model - - -class _ModelPatcher: - """ModelPatcher stub exposing `.model._ModelInner`.""" - def __init__(self, diffusion_model): - self.model = _ModelInner(diffusion_model) - - -def test_seedvr2_conditioning_schema_exposes_model_passthrough_output(): - nodes_seedvr, restore = _import_nodes_seedvr_isolated() - try: - schema = nodes_seedvr.SeedVR2Conditioning.define_schema() - assert [input_item.id for input_item in schema.inputs] == [ - "model", - "vae_conditioning", - ] - assert schema.inputs[1].display_name == "latent" - assert [output.display_name for output in schema.outputs] == [ - "model", - "positive", - "negative", - "latent", - ] - finally: - restore() - - -def test_seedvr2_conditioning_returns_packed_input_latent_deterministically(): - nodes_seedvr, restore = _import_nodes_seedvr_isolated() - try: - diffusion_model = _DiffusionModel() - patcher = _ModelPatcher(diffusion_model) - samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2) - vae_conditioning = {"samples": samples} - - _, first_positive, first_negative, first_latent = ( - nodes_seedvr.SeedVR2Conditioning.execute( - patcher, - vae_conditioning, - ) - ) - _, second_positive, second_negative, second_latent = ( - nodes_seedvr.SeedVR2Conditioning.execute( - patcher, - vae_conditioning, - ) - ) - - expected_latent = samples.reshape(1, 6, 2, 2) - channel_last = samples.movedim(1, -1).contiguous() - expected_condition = torch.cat( - [ - channel_last, - torch.ones((*channel_last.shape[:-1], 1)), - ], - dim=-1, - ).movedim(-1, 1).reshape(1, 9, 2, 2) - - assert torch.equal(first_latent["samples"], expected_latent) - assert torch.equal(second_latent["samples"], expected_latent) - assert torch.equal( - first_positive[0][1]["condition"], - expected_condition, - ) - assert torch.equal( - second_positive[0][1]["condition"], - expected_condition, - ) - assert torch.equal( - first_negative[0][1]["condition"], - expected_condition, - ) - assert torch.equal( - second_negative[0][1]["condition"], - expected_condition, - ) - finally: - restore() - - -def test_seedvr2_conditioning_fails_loud_on_zero_buffers(): - nodes_seedvr, restore = _import_nodes_seedvr_isolated() - try: - diffusion_model = _DiffusionModel(zero_conditioning=True) - patcher = _ModelPatcher(diffusion_model) - vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))} - - with pytest.raises(RuntimeError) as excinfo: - nodes_seedvr.SeedVR2Conditioning.execute( - patcher, vae_conditioning, - ) - - message = str(excinfo.value) - assert message.startswith( - nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX - ), ( - "Fail-loud message must use the standard " - "_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers " - f"can match it. Got: {message!r}" - ) - assert "positive_conditioning" in message - assert "negative_conditioning" in message - finally: - restore() diff --git a/tests-unit/comfy_extras_test/test_seedvr2_nodes.py b/tests-unit/comfy_extras_test/test_seedvr2_nodes.py deleted file mode 100644 index f7d9a4f65..000000000 --- a/tests-unit/comfy_extras_test/test_seedvr2_nodes.py +++ /dev/null @@ -1,55 +0,0 @@ -import importlib -import inspect -import sys -from unittest.mock import MagicMock, patch - -import torch - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - - -def test_seedvr_node_signature_matches_schema(): - mock_mm = MagicMock() - mock_mm.xformers_enabled.return_value = False - mock_mm.xformers_enabled_vae.return_value = False - mock_mm.sage_attention_enabled.return_value = False - mock_mm.flash_attention_enabled.return_value = False - - sentinel = object() - prior_cpu = cli_args.cpu - cli_args.cpu = True - prior_module = sys.modules.get("comfy_extras.nodes_seedvr", sentinel) - comfy_pkg = sys.modules.get("comfy") - prior_mm_attr = getattr(comfy_pkg, "model_management", sentinel) if comfy_pkg else sentinel - - with patch.dict(sys.modules, {"comfy.model_management": mock_mm}): - if comfy_pkg is not None: - setattr(comfy_pkg, "model_management", mock_mm) - sys.modules.pop("comfy_extras.nodes_seedvr", None) - try: - nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr") - for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning, nodes_seedvr.SeedVR2ProgressiveSampler): - schema_ids = [i.id for i in node_cls.define_schema().inputs] - exec_params = [ - p for p in inspect.signature(node_cls.execute).parameters.keys() - if p != "cls" - ] - assert schema_ids == exec_params, ( - f"{node_cls.__name__} schema/execute drift: " - f"schema_ids={schema_ids}, exec_params={exec_params}" - ) - finally: - cli_args.cpu = prior_cpu - if prior_module is sentinel: - sys.modules.pop("comfy_extras.nodes_seedvr", None) - else: - sys.modules["comfy_extras.nodes_seedvr"] = prior_module - if comfy_pkg is not None: - if prior_mm_attr is sentinel: - if hasattr(comfy_pkg, "model_management"): - delattr(comfy_pkg, "model_management") - else: - setattr(comfy_pkg, "model_management", prior_mm_attr) diff --git a/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py b/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py deleted file mode 100644 index a27a8f8df..000000000 --- a/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py +++ /dev/null @@ -1,57 +0,0 @@ -from unittest.mock import patch - -import torch - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -from comfy_extras import nodes_seedvr # noqa: E402 - - -def _schema_ids(items): - return [item.id for item in items] - - -def test_seedvr2_post_processing_schema(): - schema = nodes_seedvr.SeedVR2PostProcessing.define_schema() - - assert _schema_ids(schema.inputs) == ["images", "original_resized_images", "color_correction_method"] - assert schema.inputs[2].options == ["lab", "wavelet", "adain", "none"] - assert schema.inputs[2].default == "lab" - assert schema.outputs[0].get_io_type() == "IMAGE" - - -def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch): - decoded = torch.full((1, 3, 4, 4), 0.25) - reference = torch.full((1, 3, 4, 4), 0.75) - - def _lab(content, style): - raise torch.cuda.OutOfMemoryError("CUDA out of memory") - - monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu")) - monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000) - monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None) - - with patch.object(nodes_seedvr, "lab_color_transfer", _lab): - try: - nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked( - decoded, reference, torch.device("cpu"), "lab", - ) - except RuntimeError as exc: - assert "color_correction_method=lab" in str(exc) - assert " method=lab" not in str(exc) - else: - raise AssertionError("expected RuntimeError for one-frame LAB OOM") - - -def test_seedvr2_post_processing_unknown_color_correction_method_raises(): - decoded = torch.zeros(1, 2, 4, 4, 3) - original = torch.zeros(1, 2, 4, 4, 3) - try: - nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus") - except ValueError as exc: - assert "color_correction_method" in str(exc) - else: - raise AssertionError("expected ValueError for unknown color_correction_method") diff --git a/tests-unit/comfy_test/model_detection_test.py b/tests-unit/comfy_test/model_detection_test.py index c63f69a0d..4e9350602 100644 --- a/tests-unit/comfy_test/model_detection_test.py +++ b/tests-unit/comfy_test/model_detection_test.py @@ -73,24 +73,6 @@ def _make_flux_schnell_comfyui_sd(): return sd -def _make_seedvr2_7b_separate_mm_sd(): - return { - "blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072), - } - - -def _make_seedvr2_7b_shared_mm_sd(): - return { - "blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1), - } - - -def _make_seedvr2_3b_shared_mm_sd(): - return { - "blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1), - } - - class TestModelDetection: """Verify that first-match model detection selects the correct model based on list ordering and unet_config specificity.""" @@ -143,48 +125,6 @@ class TestModelDetection: assert model_config is not None assert type(model_config).__name__ == "FluxSchnell" - def test_seedvr2_7b_separate_mm_detection_config(self): - sd = _make_seedvr2_7b_separate_mm_sd() - unet_config = detect_unet_config(sd, "") - - assert unet_config is not None - assert unet_config["image_model"] == "seedvr2" - assert unet_config["vid_dim"] == 3072 - assert unet_config["heads"] == 24 - assert unet_config["num_layers"] == 36 - assert unet_config["mm_layers"] == 36 - assert unet_config["mlp_type"] == "normal" - assert unet_config["qk_rope"] is True - assert unet_config["rope_type"] == "rope3d" - assert unet_config["rope_dim"] == 64 - - def test_seedvr2_7b_shared_mm_detection_config(self): - sd = _make_seedvr2_7b_shared_mm_sd() - unet_config = detect_unet_config(sd, "") - - assert unet_config is not None - assert unet_config["image_model"] == "seedvr2" - assert unet_config["vid_dim"] == 3072 - assert unet_config["heads"] == 24 - assert unet_config["num_layers"] == 36 - assert unet_config["mm_layers"] == 10 - assert unet_config["mlp_type"] == "swiglu" - assert unet_config["qk_rope"] is True - assert unet_config["rope_type"] == "rope3d" - assert unet_config["rope_dim"] == 64 - - def test_seedvr2_3b_shared_mm_detection_config(self): - sd = _make_seedvr2_3b_shared_mm_sd() - unet_config = detect_unet_config(sd, "") - - assert unet_config is not None - assert unet_config["image_model"] == "seedvr2" - assert unet_config["vid_dim"] == 2560 - assert unet_config["heads"] == 20 - assert unet_config["num_layers"] == 32 - assert unet_config["mlp_type"] == "swiglu" - assert unet_config["qk_rope"] is None - def test_unet_config_and_required_keys_combination_is_unique(self): """Each model in the registry must have a unique combination of ``unet_config`` and ``required_keys``. If two models share the same diff --git a/tests-unit/comfy_test/seedvr_vae_forward_test.py b/tests-unit/comfy_test/seedvr_vae_forward_test.py deleted file mode 100644 index f9dbd6890..000000000 --- a/tests-unit/comfy_test/seedvr_vae_forward_test.py +++ /dev/null @@ -1,90 +0,0 @@ -"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must -honor the actual tensor/tuple return contract of ``encode()`` and -``decode_()`` and must NOT dereference diffusers-style ``.latent_dist`` -or ``.sample`` attributes on those returns. - -The pre-fix body raised ``AttributeError: 'Tensor' object has no -attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and -``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'`` -for ``mode == "decode"`` (the class only defines ``decode_`` with a -trailing underscore). The post-fix body unwraps the optional one-element -tuple shape that ``return_dict=False`` produces and returns the tensor -directly. - -Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses -the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and -overrides ``encode``/``decode_`` with known tensors so the contract can -be probed without loading any real VAE weights. -""" - -import torch -import torch.nn as nn - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402 - - -_LATENT_SHAPE = (1, 16, 2, 2, 2) -_DECODED_SHAPE = (1, 3, 5, 16, 16) -_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16) -_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2) - - -class _StubVAE(VideoAutoencoderKL): - def __init__(self): - nn.Module.__init__(self) - self._encode_out = torch.zeros(*_LATENT_SHAPE) - self._decode_out = torch.zeros(*_DECODED_SHAPE) - - def encode(self, x, return_dict=True): - return self._encode_out - - def decode_(self, z, return_dict=True): - return self._decode_out - - -def test_forward_encode_returns_tensor(): - vae = _StubVAE() - x = torch.zeros(*_INPUT_ENCODE_SHAPE) - result = vae.forward(x, mode="encode") - assert type(result) is torch.Tensor - assert result.shape == torch.Size(_LATENT_SHAPE) - - -def test_forward_decode_returns_tensor(): - vae = _StubVAE() - z = torch.zeros(*_INPUT_DECODE_SHAPE) - result = vae.forward(z, mode="decode") - assert type(result) is torch.Tensor - assert result.shape == torch.Size(_DECODED_SHAPE) - - -class _TupleReturningStubVAE(VideoAutoencoderKL): - """Stub variant whose ``encode``/``decode_`` return the - ``(tensor,)`` one-element tuple shape ``return_dict=False`` produces - in the parent class. Exercises the unwrap branch of - ``VideoAutoencoderKL.forward``. - """ - - def __init__(self): - nn.Module.__init__(self) - self._encode_tensor = torch.zeros(*_LATENT_SHAPE) - self._decode_tensor = torch.zeros(*_DECODED_SHAPE) - - def encode(self, x, return_dict=True): - return (self._encode_tensor,) - - def decode_(self, z, return_dict=True): - return (self._decode_tensor,) - - -def test_forward_all_unwraps_one_tuple_at_each_step(): - vae = _TupleReturningStubVAE() - x = torch.zeros(*_INPUT_ENCODE_SHAPE) - result = vae.forward(x, mode="all") - assert type(result) is torch.Tensor - assert result.shape == torch.Size(_DECODED_SHAPE) diff --git a/tests-unit/comfy_test/test_seedvr2_dtype.py b/tests-unit/comfy_test/test_seedvr2_dtype.py deleted file mode 100644 index e5d79a306..000000000 --- a/tests-unit/comfy_test/test_seedvr2_dtype.py +++ /dev/null @@ -1,47 +0,0 @@ -import torch - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -import comfy.sd -import comfy.supported_models -import comfy.ldm.seedvr.model as seedvr_model - - -def test_seedvr2_fp16_manual_cast_only_for_bf16_device(monkeypatch): - bf16_device = object() - fp16_device = object() - - monkeypatch.setattr( - comfy.supported_models.comfy.model_management, - "should_use_bf16", - lambda device=None: device is bf16_device, - ) - - bf16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"}) - bf16_config.set_inference_dtype(torch.float16, None, device=bf16_device) - assert bf16_config.manual_cast_dtype is torch.bfloat16 - - fp16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"}) - fp16_config.set_inference_dtype(torch.float16, None, device=fp16_device) - assert fp16_config.manual_cast_dtype is None - - -def test_seedvr2_text_conditioning_accepts_cfg1_single_branch(): - context = torch.arange(6, dtype=torch.float32).reshape(1, 3, 2) - - txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0]) - - torch.testing.assert_close(txt, context.squeeze(0)) - torch.testing.assert_close(txt_shape, torch.tensor([[3]], device=context.device)) - - -def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer(): - estimate = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 26, 120, 160)) - old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2 - - assert estimate == 101 * 960 * 1280 * 160 - assert estimate > 15 * 1024 ** 3 - assert estimate > old_estimate * 100 diff --git a/tests-unit/comfy_test/test_seedvr2_internals.py b/tests-unit/comfy_test/test_seedvr2_internals.py deleted file mode 100644 index 5b008ea6e..000000000 --- a/tests-unit/comfy_test/test_seedvr2_internals.py +++ /dev/null @@ -1,341 +0,0 @@ -"""Consolidated SeedVR2 internals regression tests. - -Sources (all merged verbatim, helper names disambiguated where colliding): - - * RoPE rewrite — NaMMRotaryEmbedding3d.forward must match the legacy - apply_rotary_emb wrapper oracle at fp32. - * GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare - memory_occupy against get_norm_limit(), not float('inf'). - * SeedVR2 variable-length attention split-loop contract. - -Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and -comfy.ldm.modules.attention transitively pull in comfy.model_management, -which probes torch.cuda.current_device() at import time unless args.cpu is -set first. -""" - -from __future__ import annotations - -from unittest.mock import patch - -import pytest -import torch - -from comfy.cli_args import args - -if not torch.cuda.is_available(): - args.cpu = True - -import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 -import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 -import comfy.ldm.modules.attention as attention # noqa: E402 -import comfy.ops as comfy_ops # noqa: E402 -from comfy.ldm.seedvr.model import ( # noqa: E402 - Cache, - NaMMRotaryEmbedding3d, -) -from comfy.ldm.seedvr.vae import ( # noqa: E402 - causal_norm_wrapper, - set_norm_limit, -) -from comfy.ldm.modules.attention import var_attention_optimized_split # noqa: E402 - - -# --------------------------------------------------------------------------- -# RoPE rewrite tests (test_seedvr_rope_rewrite.py) -# --------------------------------------------------------------------------- - -# Test rig dimensions. dim=192 → per-axis rope dim = 64 (even, lucidrains -# requirement). vid_shape=(2,4,4) → L_vid = 32. txt_shape=(8,) → L_txt = 8. -_DIM = 192 -_HEADS = 4 -_VID_T, _VID_H, _VID_W = 2, 4, 4 -_TXT_L = 8 -_L_VID = _VID_T * _VID_H * _VID_W -_SEED = 0 - - -def _make_inputs(dtype=torch.float32, device="cpu"): - """Construct the 6 forward inputs + cache. Deterministic via local - Generator so global RNG state is not mutated. - """ - g = torch.Generator(device=device).manual_seed(_SEED) - vid_q = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g) - vid_k = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g) - txt_q = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g) - txt_k = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g) - vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long, device=device) - txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long, device=device) - cache = Cache(disable=True) - return vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache - - -def _legacy_get_freqs(rope: NaMMRotaryEmbedding3d, vid_shape, txt_shape): - """Reproduce the pre-rewrite ``get_freqs`` body verbatim against - ``self.get_axial_freqs`` (parent ``RotaryEmbeddingBase`` method, - unchanged by the rewrite). - """ - max_temporal = 0 - max_height = 0 - max_width = 0 - max_txt_len = 0 - for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): - max_temporal = max(max_temporal, l + f) - max_height = max(max_height, h) - max_width = max(max_width, w) - max_txt_len = max(max_txt_len, l) - with torch.amp.autocast(device_type="cuda", enabled=False): - vid_freqs_full = rope.get_axial_freqs( - min(max_temporal + 16, 1024), - min(max_height + 4, 128), - min(max_width + 4, 128), - ).float() - txt_freqs_full = rope.get_axial_freqs(min(max_txt_len + 16, 1024)) - vid_freq_list, txt_freq_list = [], [] - for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): - vid_freq = vid_freqs_full[l : l + f, :h, :w].reshape(-1, vid_freqs_full.size(-1)) - txt_freq = txt_freqs_full[:l].repeat(1, 3).reshape(-1, vid_freqs_full.size(-1)) - vid_freq_list.append(vid_freq) - txt_freq_list.append(txt_freq) - return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0) - - -def _legacy_forward(rope: NaMMRotaryEmbedding3d, vid_q, vid_k, vid_shape, - txt_q, txt_k, txt_shape): - """Compute expected forward output via the unchanged - ``apply_rotary_emb`` wrapper fed with legacy-shape freqs. This is the - oracle. The wrapper itself is out of scope for the rewrite (Shape B). - """ - vid_freqs, txt_freqs = _legacy_get_freqs(rope, vid_shape, txt_shape) - vid_freqs = vid_freqs.to(vid_q.device) - txt_freqs = txt_freqs.to(txt_q.device) - - from einops import rearrange - - vid_q = rearrange(vid_q, "L h d -> h L d") - vid_k = rearrange(vid_k, "L h d -> h L d") - vid_q_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype) - vid_k_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype) - vid_q_out = rearrange(vid_q_out, "h L d -> L h d") - vid_k_out = rearrange(vid_k_out, "h L d -> L h d") - - txt_q = rearrange(txt_q, "L h d -> h L d") - txt_k = rearrange(txt_k, "L h d -> h L d") - txt_q_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype) - txt_k_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype) - txt_q_out = rearrange(txt_q_out, "h L d -> L h d") - txt_k_out = rearrange(txt_k_out, "h L d -> L h d") - return vid_q_out, vid_k_out, txt_q_out, txt_k_out - - -def test_namm_forward_output_tensor_equal_against_legacy_oracle(): - rope = NaMMRotaryEmbedding3d(dim=_DIM) - vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache = _make_inputs() - - expected_vid_q, expected_vid_k, expected_txt_q, expected_txt_k = _legacy_forward( - rope, - vid_q.clone(), vid_k.clone(), vid_shape, - txt_q.clone(), txt_k.clone(), txt_shape, - ) - - actual_vid_q, actual_vid_k, actual_txt_q, actual_txt_k = rope.forward( - vid_q.clone(), vid_k.clone(), vid_shape, - txt_q.clone(), txt_k.clone(), txt_shape, cache, - ) - - torch.testing.assert_close(actual_vid_q, expected_vid_q, rtol=0, atol=0, - msg="vid_q output diverges from wrapper oracle") - torch.testing.assert_close(actual_vid_k, expected_vid_k, rtol=0, atol=0, - msg="vid_k output diverges from wrapper oracle") - torch.testing.assert_close(actual_txt_q, expected_txt_q, rtol=0, atol=0, - msg="txt_q output diverges from wrapper oracle") - torch.testing.assert_close(actual_txt_k, expected_txt_k, rtol=0, atol=0, - msg="txt_k output diverges from wrapper oracle") - - -# --------------------------------------------------------------------------- -# GroupNorm limit tests (test_seedvr_groupnorm_limit.py) -# --------------------------------------------------------------------------- - -_NUM_CHANNELS = 8 -_NUM_GROUPS = 4 -_TENSOR_SHAPE = (1, 8, 2, 4, 4) - -_GROUPNORM_SUBCLASSES = [ - pytest.param(comfy_ops.disable_weight_init.GroupNorm, id="disable_weight_init"), - pytest.param(comfy_ops.manual_cast.GroupNorm, id="manual_cast"), -] - - -@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES) -def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls): - real_group_norm = vae_mod.F.group_norm - set_norm_limit(1e-9) - try: - gn = groupnorm_cls(num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS) - gn.eval() - - forward_hook_calls = [] - - def _hook(module, inputs, output): - forward_hook_calls.append(tuple(inputs[0].shape)) - - spy_calls = [] - - def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs): - spy_calls.append({"num_groups": int(num_groups_arg)}) - return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs) - - handle = gn.register_forward_hook(_hook) - try: - with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy): - out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE)) - finally: - handle.remove() - - full_calls = len(forward_hook_calls) - chunked_calls = sum(1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS) - - assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE - assert full_calls == 0, ( - f"low-limit GroupNorm gate must NOT take the full-forward path; got full_calls={full_calls}" - ) - assert chunked_calls > 0, ( - f"low-limit GroupNorm gate must take the chunked path; got chunked_calls={chunked_calls}" - ) - finally: - set_norm_limit(None) - - -# --------------------------------------------------------------------------- -# SeedVR2 var_attention split-loop tests -# --------------------------------------------------------------------------- - -def test_var_attention_registry_contains_always_available_entries(): - assert ( - attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_optimized_split"] - is attention.var_attention_optimized_split - ) - - -def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch): - dim = 8 - heads = 2 - head_dim = 4 - attn = seedvr_model.NaSwinAttention( - vid_dim=dim, - txt_dim=dim, - heads=heads, - head_dim=head_dim, - qk_bias=False, - qk_norm=seedvr_model.CustomRMSNorm, - qk_norm_eps=1e-6, - rope_type=None, - rope_dim=head_dim, - shared_weights=False, - window=(2, 1, 1), - window_method="720pwin_by_size_bysize", - version=True, - device="cpu", - dtype=torch.float32, - operations=comfy_ops.disable_weight_init, - ) - generator = torch.Generator(device="cpu").manual_seed(11) - vid = torch.randn(8, dim, generator=generator) - txt = torch.randn(3, dim, generator=generator) - vid_shape = torch.tensor([[2, 2, 2]], dtype=torch.long) - txt_shape = torch.tensor([[3]], dtype=torch.long) - calls = [] - - def fake_optimized_var_attention(**kwargs): - calls.append(kwargs) - return kwargs["q"] - - monkeypatch.setattr(seedvr_model, "optimized_var_attention", fake_optimized_var_attention) - - vid_out, txt_out = attn(vid, txt, vid_shape, txt_shape, seedvr_model.Cache(disable=True)) - - assert tuple(vid_out.shape) == (8, dim) - assert tuple(txt_out.shape) == (3, dim) - assert len(calls) == 1 - call = calls[0] - assert tuple(call["q"].shape) == (14, heads, head_dim) - assert tuple(call["k"].shape) == (14, heads, head_dim) - assert tuple(call["v"].shape) == (14, heads, head_dim) - assert call["heads"] == heads - assert call["skip_reshape"] is True - assert call["skip_output_reshape"] is True - torch.testing.assert_close( - call["cu_seqlens_q"], - torch.tensor([0, 7, 14], dtype=torch.int32), - rtol=0, - atol=0, - ) - torch.testing.assert_close( - call["cu_seqlens_k"], - torch.tensor([0, 7, 14], dtype=torch.int32), - rtol=0, - atol=0, - ) - - -def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch): - heads = 2 - head_dim = 3 - q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim) - k = q + 100 - v = q + 200 - cu = torch.tensor([0, 2, 5], dtype=torch.int32) - calls = [] - - def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs): - calls.append( - { - "q_shape": tuple(q_arg.shape), - "k_shape": tuple(k_arg.shape), - "v_shape": tuple(v_arg.shape), - "heads": heads_arg, - "kwargs": kwargs, - } - ) - return q_arg + v_arg - - monkeypatch.setattr(attention, "optimized_attention", fake_optimized_attention) - - out = var_attention_optimized_split( - q, - k, - v, - heads, - cu, - cu, - skip_reshape=True, - skip_output_reshape=True, - ) - - assert tuple(out.shape) == (5, heads, head_dim) - assert len(calls) == 2 - assert calls[0]["q_shape"] == (1, heads, 2, head_dim) - assert calls[1]["q_shape"] == (1, heads, 3, head_dim) - assert all(call["heads"] == heads for call in calls) - assert all(call["kwargs"]["skip_reshape"] is True for call in calls) - assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls) - torch.testing.assert_close(out, q + v, rtol=0, atol=0) - - -def test_var_attention_optimized_split_rejects_bad_offsets(): - q = torch.randn(5, 2, 3) - cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32) - cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32) - - with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"): - var_attention_optimized_split( - q, - q, - q, - 2, - cu_bad, - cu_ok, - skip_reshape=True, - skip_output_reshape=True, - ) diff --git a/tests-unit/comfy_test/test_seedvr2_model.py b/tests-unit/comfy_test/test_seedvr2_model.py deleted file mode 100644 index f2b9bcbbe..000000000 --- a/tests-unit/comfy_test/test_seedvr2_model.py +++ /dev/null @@ -1,308 +0,0 @@ -"""Consolidated SeedVR2 model/graph/forward regression tests. - -Merged from: -- seedvr_model_test.py -- test_seedvr_7b_final_block_text_path.py -- test_seedvr_forward_no_device_cast.py -- test_seedvr_latent_format.py -- test_seedvr2_vae_graph_boundaries.py -""" - -from __future__ import annotations - -from unittest.mock import MagicMock - -import torch -from torch import nn - -from comfy.cli_args import args - -if not torch.cuda.is_available(): - args.cpu = True - -import comfy # noqa: E402 -import comfy.latent_formats # noqa: E402 -import comfy.ldm.seedvr.model # noqa: E402 -import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 -import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 -import comfy.model_management # noqa: E402 -import comfy.sample # noqa: E402 -import comfy.sd as sd_mod # noqa: E402 -import nodes as nodes_mod # noqa: E402 -from comfy.ldm.seedvr.model import NaDiT # noqa: E402 - - -# --------------------------------------------------------------------------- -# Helpers from seedvr_model_test.py -# --------------------------------------------------------------------------- - - -def _make_standin(positive_conditioning): - class _StandIn(torch.nn.Module): - def __init__(self): - super().__init__() - self.register_buffer( - "positive_conditioning", positive_conditioning - ) - - _resolve_text_conditioning = NaDiT._resolve_text_conditioning - - return _StandIn() - - -# --------------------------------------------------------------------------- -# Helpers from test_seedvr_7b_final_block_text_path.py -# --------------------------------------------------------------------------- - - -class _StubModule(nn.Module): - def __init__(self, *args, **kwargs): - super().__init__() - - -def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]: - flags = [] - - class _Block(_StubModule): - def __init__(self, *args, **kwargs): - flags.append(kwargs["is_last_layer"]) - super().__init__() - - monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule) - monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule) - monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule) - monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block) - - seedvr_model.NaDiT( - norm_eps=1e-5, - qk_rope=None, - num_layers=4, - mlp_type="normal", - vid_dim=vid_dim, - txt_in_dim=txt_in_dim, - heads=24, - mm_layers=3, - ) - - return flags - - -# --------------------------------------------------------------------------- -# Helpers from test_seedvr_latent_format.py -# --------------------------------------------------------------------------- - - -class _Model: - def __init__(self, latent_format): - self._latent_format = latent_format - - def get_model_object(self, name): - assert name == "latent_format" - return self._latent_format - - -# --------------------------------------------------------------------------- -# Helpers from test_seedvr2_vae_graph_boundaries.py -# --------------------------------------------------------------------------- - - -class _Patcher: - def get_free_memory(self, device): - return 1024 * 1024 * 1024 - - -class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper): - def __init__(self, encoded): - nn.Module.__init__(self) - self.encoded = encoded - self.spatial_downsample_factor = 8 - self.temporal_downsample_factor = 4 - self.seen = [] - - def encode(self, x): - self.seen.append(tuple(x.shape)) - return self.encoded.to(device=x.device, dtype=x.dtype) - - -class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper): - def __init__(self): - nn.Module.__init__(self) - self.spatial_downsample_factor = 8 - self.temporal_downsample_factor = 4 - self.calls = [] - - def decode(self, z, seedvr2_tiling=None): - self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling}) - if z.ndim == 4: - b, tc, h, w = z.shape - t = tc // 16 - else: - b, _, t, h, w = z.shape - return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device) - - -def _make_vae(wrapper): - vae = sd_mod.VAE.__new__(sd_mod.VAE) - vae.first_stage_model = wrapper - vae.device = torch.device("cpu") - vae.output_device = torch.device("cpu") - vae.vae_dtype = torch.float32 - vae.latent_channels = 16 - vae.latent_dim = 3 - vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8) - vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) - vae.output_channels = 3 - vae.disable_offload = True - vae.extra_1d_channel = None - vae.crop_input = False - vae.not_video = False - vae.patcher = _Patcher() - vae.process_input = lambda image: image - vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0) - vae.vae_output_dtype = lambda: torch.float32 - vae.memory_used_encode = lambda shape, dtype: 1 - vae.memory_used_decode = lambda shape, dtype: 1 - vae.throw_exception_if_invalid = lambda: None - vae.vae_encode_crop_pixels = lambda pixels: pixels - vae.spacial_compression_decode = lambda: 8 - vae.temporal_compression_decode = lambda: 4 - return vae - - -# --------------------------------------------------------------------------- -# Tests from seedvr_model_test.py -# --------------------------------------------------------------------------- - - -def test_missing_context_falls_back_to_positive_buffer(): - """AC: ``context is None`` falls back to the registered - ``positive_conditioning`` buffer and runs to completion — no - silent zero substitution, no raised exception. - """ - pos_buffer = torch.full((58, 5120), 7.0) - standin = _make_standin(pos_buffer) - txt, txt_shape = standin._resolve_text_conditioning(None) - assert txt.shape == (58, 5120) - assert (txt == 7.0).all(), ( - "fallback path must use the positive_conditioning buffer " - "verbatim, not a zero tensor" - ) - assert txt_shape.shape == (1, 1) - assert txt_shape[0, 0].item() == 58 - - -# --------------------------------------------------------------------------- -# Tests from test_seedvr_7b_final_block_text_path.py -# --------------------------------------------------------------------------- - - -def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch): - assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [ - False, - False, - False, - False, - ] - - -def test_seedvr2_7b_rope3d_matches_wrapper_oracle(): - rope = seedvr_model.get_na_rope("rope3d", dim=64) - generator = torch.Generator(device="cpu").manual_seed(0) - q = torch.randn(4, 2, 128, generator=generator) - k = torch.randn(4, 2, 128, generator=generator) - shape = torch.tensor([[1, 2, 2]], dtype=torch.long) - freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1) - - expected_q = seedvr_model._apply_seedvr2_rotary_emb( - freqs, - q.permute(1, 0, 2).float(), - ).to(q.dtype).permute(1, 0, 2) - expected_k = seedvr_model._apply_seedvr2_rotary_emb( - freqs, - k.permute(1, 0, 2).float(), - ).to(k.dtype).permute(1, 0, 2) - - actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True)) - - torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0) - torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0) - - -# --------------------------------------------------------------------------- -# Tests from test_seedvr_latent_format.py -# --------------------------------------------------------------------------- - - -def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion(): - latent_format = comfy.latent_formats.SeedVR2() - latent_image = torch.zeros(1, 1, 4, 5) - - fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image) - - assert latent_format.latent_channels == 16 - assert latent_format.latent_dimensions == 2 - assert fixed.shape == (1, 16, 4, 5) - - -# --------------------------------------------------------------------------- -# Tests from test_seedvr2_vae_graph_boundaries.py -# --------------------------------------------------------------------------- - - -def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch): - monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) - - encoded = torch.full((1, 16, 2, 4, 5), 2.0) - vae = _make_vae(_EncodeWrapper(encoded)) - pixels = torch.zeros(1, 5, 32, 40, 3) - - node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0] - node_latent = node_output["samples"] - assert set(node_output) == {"samples"} - assert tuple(node_latent.shape) == (1, 16, 2, 4, 5) - assert node_latent.dtype == torch.float32 - assert node_latent.stride()[-1] == 1 - assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152)) - - tiled = torch.full((1, 16, 2, 4, 5), 3.0) - monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled)) - tiled_output = nodes_mod.VAEEncodeTiled().encode( - vae, - pixels, - tile_size=512, - overlap=64, - temporal_size=16, - temporal_overlap=4, - )[0] - tiled_latent = tiled_output["samples"] - assert set(tiled_output) == {"samples"} - assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5) - assert tiled_latent.dtype == torch.float32 - assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152)) - - -def test_vaedecode_tiled_visible_inputs_are_seedvr2_decode_tiling_authority(monkeypatch): - monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) - vae = _make_vae(_DecodeWrapper()) - - nodes_mod.VAEDecodeTiled().decode( - vae, - {"samples": torch.zeros(1, 16, 2, 4, 5)}, - tile_size=512, - overlap=64, - temporal_size=16, - temporal_overlap=4, - ) - - assert vae.first_stage_model.calls == [ - { - "shape": (1, 16, 2, 4, 5), - "seedvr2_tiling": { - "enable_tiling": True, - "tile_size": (512, 512), - "tile_overlap": (64, 64), - "temporal_size": 16, - "temporal_overlap": 4, - }, - } - ] diff --git a/tests-unit/comfy_test/test_seedvr2_vae_decode.py b/tests-unit/comfy_test/test_seedvr2_vae_decode.py deleted file mode 100644 index ea9f978f3..000000000 --- a/tests-unit/comfy_test/test_seedvr2_vae_decode.py +++ /dev/null @@ -1,91 +0,0 @@ -from unittest.mock import patch - -import pytest -import torch -import torch.nn as nn - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 -from comfy_extras import nodes_seedvr # noqa: E402 - - -def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper: - wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( - vae_mod.VideoAutoencoderKLWrapper - ) - nn.Module.__init__(wrapper) - return wrapper - - -def _fingerprint_decode_(self, z, return_dict=True): - b = int(z.shape[0]) - t = int(z.shape[2]) - h = int(z.shape[3]) - w = int(z.shape[4]) - out = torch.empty(b, 3, t, h * 8, w * 8) - for batch_idx in range(b): - out[batch_idx].fill_(float(batch_idx + 1)) - return out - - -def _decode_with_patches(wrapper, z): - with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_): - return wrapper.decode(z) - - -def test_decode_b2_t3_multi_frame_batch_unchanged(): - wrapper = _make_wrapper() - - out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2)) - - assert tuple(out.shape) == (2, 3, 3, 16, 16) - - -class _Wrapper(vae_mod.VideoAutoencoderKLWrapper): - def __init__(self): - nn.Module.__init__(self) - self.calls = [] - - def parameters(self): - return iter([torch.nn.Parameter(torch.zeros(()))]) - -def _decode_stub(self, latent): - self.calls.append(tuple(latent.shape)) - return torch.zeros(latent.shape[0], 3, latent.shape[2], latent.shape[3] * 8, latent.shape[4] * 8) - - -def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preprocessor_state(): - wrapper = _Wrapper() - - with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub): - out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5)) - - assert tuple(out.shape) == (1, 3, 2, 32, 40) - assert wrapper.calls == [(1, 16, 2, 4, 5)] - - -def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents(): - wrapper = _Wrapper() - - with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"): - wrapper.decode(torch.zeros(1, 16, 4)) - - -def _t_padded(t_in: int) -> int: - if t_in == 1: - return 1 - if t_in <= 4: - return 5 - if (t_in - 1) % 4 == 0: - return t_in - return t_in + (4 - ((t_in - 1) % 4)) - - -@pytest.mark.parametrize("t_in", [1, 5, 9]) -def test_t_padded_matches_cut_videos(t_in): - dummy = torch.zeros(1, t_in, 1, 1, 1) - assert nodes_seedvr.cut_videos(dummy).shape[1] == _t_padded(t_in) diff --git a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py deleted file mode 100644 index 40079bbe2..000000000 --- a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py +++ /dev/null @@ -1,347 +0,0 @@ -from contextlib import ExitStack -from unittest.mock import MagicMock, patch - -import torch -import torch.nn as nn - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 -import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 -import comfy.sd as sd_mod # noqa: E402 -from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402 - - -# --------------------------------------------------------------------------- -# From test_seedvr_vae_tiled_decode_latent_min_size_override.py -# --------------------------------------------------------------------------- - - -def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): - from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae - - class StubVAEModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.slicing_latent_min_size = 2 - self.spatial_downsample_factor = 8 - self.temporal_downsample_factor = 4 - self.device = torch.device("cpu") - self.use_slicing = True - self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) - self.decode_min_sizes = [] - self.memory_states = [] - - def decode_(self, t_chunk): - self.decode_min_sizes.append(self.slicing_latent_min_size) - return VideoAutoencoderKL.slicing_decode(self, t_chunk) - - def _decode(self, z, memory_state=MemoryState.DISABLED): - self.memory_states.append(memory_state) - b, c, d, h, w = z.shape - return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype) - - vae = StubVAEModel() - z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32) - - tiled_vae( - z, - vae, - tile_size=(64, 64), - tile_overlap=(0, 0), - temporal_size=0, - temporal_overlap=0, - encode=False, - ) - - assert vae.decode_min_sizes == [5] - assert vae.memory_states == [MemoryState.DISABLED] - assert vae.slicing_latent_min_size == 2 - - -# --------------------------------------------------------------------------- -# From test_seedvr_vae_tiled_encode_runt_slice_override.py -# --------------------------------------------------------------------------- - - -def test_zero_temporal_size_preserves_min_size_when_encode_raises(): - from comfy.ldm.seedvr.vae import tiled_vae - - class RaisingVAEModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.slicing_sample_min_size = 4 - self.spatial_downsample_factor = 8 - self.temporal_downsample_factor = 4 - self.device = torch.device("cpu") - self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) - - def encode(self, t_chunk): - raise RuntimeError("simulated encode failure") - - vae = RaisingVAEModel() - x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32) - - raised = False - try: - tiled_vae( - x, - vae, - tile_size=(64, 64), - tile_overlap=(0, 0), - temporal_size=0, - temporal_overlap=0, - encode=True, - ) - except RuntimeError as exc: - if "simulated encode failure" not in str(exc): - raise - raised = True - - assert raised - assert vae.slicing_sample_min_size == 4 - - -# --------------------------------------------------------------------------- -# From test_seedvr_vae_tiled_temporal_slicing.py -# --------------------------------------------------------------------------- - - -class _SlicingDecodeVAE(nn.Module): - def __init__(self, slicing_latent_min_size): - super().__init__() - self.slicing_latent_min_size = slicing_latent_min_size - self.spatial_downsample_factor = 8 - self.temporal_downsample_factor = 4 - self.device = torch.device("cpu") - self.use_slicing = True - self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32)) - self.decode_min_sizes = [] - self.memory_states = [] - - def decode_(self, z): - self.decode_min_sizes.append(self.slicing_latent_min_size) - return vae_mod.VideoAutoencoderKL.slicing_decode(self, z) - - def _decode(self, z, memory_state=MemoryState.DISABLED): - self.memory_states.append(memory_state) - x = z[:, :1].repeat( - 1, - 3, - 1, - self.spatial_downsample_factor, - self.spatial_downsample_factor, - ) - return x - - -def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size(): - vae = _SlicingDecodeVAE(slicing_latent_min_size=2) - z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8) - - tiled_vae( - z, - vae, - tile_size=(64, 64), - tile_overlap=(0, 0), - temporal_size=12, - temporal_overlap=4, - encode=False, - ) - - assert vae.decode_min_sizes == [2] - assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE] - assert vae.slicing_latent_min_size == 2 - - wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( - vae_mod.VideoAutoencoderKLWrapper - ) - nn.Module.__init__(wrapper) - seedvr2_tiling = { - "enable_tiling": True, - "tile_size": (64, 64), - "tile_overlap": (0, 0), - "temporal_size": 8, - "temporal_overlap": 7, - } - - captured = {} - - def _fake_tiled_vae(latent, model, **kwargs): - captured.update(kwargs) - return torch.zeros(1, 3, 1, 16, 16) - - with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae): - wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling) - - assert captured["temporal_overlap"] == 7 - - -# --------------------------------------------------------------------------- -# From test_vae_decode_tiled_dispatcher_seedvr2_4d.py -# --------------------------------------------------------------------------- - - -def _force_oom(*a, **k): - raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") - - -def _make_vae(first_stage_model, latent_channels, latent_dim): - vae = sd_mod.VAE.__new__(sd_mod.VAE) - vae.first_stage_model = first_stage_model - vae.patcher = MagicMock() - vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) - vae.device = vae.output_device = torch.device("cpu") - vae.vae_dtype = torch.float32 - vae.disable_offload = True - vae.extra_1d_channel = None - vae.upscale_ratio = vae.downscale_ratio = 8 - vae.upscale_index_formula = vae.downscale_index_formula = None - vae.output_channels = 3 - vae.latent_channels = latent_channels - vae.latent_dim = latent_dim - vae.vae_output_dtype = lambda: torch.float32 - vae.spacial_compression_decode = lambda: 8 - vae.process_input = lambda x: x - vae.process_output = lambda x: x - vae.throw_exception_if_invalid = lambda: None - vae.memory_used_decode = lambda *a, **k: 1 - return vae - - -def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode): - mm = sd_mod.model_management - with ExitStack() as stack: - stack.enter_context(patch.object(mm, "raise_non_oom", lambda e: None)) - stack.enter_context(patch.object(mm, "load_models_gpu", lambda *a, **k: None)) - stack.enter_context(patch.object(mm, "soft_empty_cache", lambda: None)) - stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_seedvr2", seedvr2_call)) - stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_", generic_call)) - if patch_wrapper_decode: - stack.enter_context(patch.object( - seedvr_vae_mod.VideoAutoencoderKLWrapper, "decode", - side_effect=_force_oom)) - vae.decode(samples) - - -def test_4d_seedvr2_latent_routes_to_decode_tiled_seedvr2(): - wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( - seedvr_vae_mod.VideoAutoencoderKLWrapper) - vae = _make_vae(wrapper, latent_channels=16, latent_dim=3) - seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) - generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) - _dispatch(vae, torch.zeros(1, 16 * 3, 8, 8), seedvr2_call, generic_call, True) - assert seedvr2_call.call_count == 1 - assert generic_call.call_count == 0 - - -def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled(): - first_stage = MagicMock() - first_stage.decode = MagicMock(side_effect=_force_oom) - vae = _make_vae(first_stage, latent_channels=4, latent_dim=2) - seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) - generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) - _dispatch(vae, torch.zeros(1, 4, 8, 8), seedvr2_call, generic_call, False) - assert generic_call.call_count == 1 - assert seedvr2_call.call_count == 0 - - -# --------------------------------------------------------------------------- -# From test_vae_encode_tiled_fallback_dispatcher_seedvr2.py -# --------------------------------------------------------------------------- - - -def _populate_common_vae_attrs_fallback(vae): - vae.patcher = MagicMock() - vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) - vae.device = torch.device("cpu") - vae.output_device = torch.device("cpu") - vae.vae_dtype = torch.float32 - vae.disable_offload = True - vae.extra_1d_channel = None - vae.upscale_ratio = 8 - vae.upscale_index_formula = None - vae.output_channels = 3 - vae.latent_channels = 16 - vae.latent_dim = 3 - vae.downscale_ratio = 8 - vae.downscale_index_formula = None - vae.not_video = False - vae.crop_input = False - vae.pad_channel_value = None - - vae.vae_output_dtype = lambda: torch.float32 - vae.spacial_compression_encode = lambda: 8 - vae.process_input = lambda x: x - vae.process_output = lambda x: x - vae.throw_exception_if_invalid = lambda: None - vae.memory_used_encode = lambda *a, **k: 1 - - -def _make_seedvr2_vae_fallback(): - vae = sd_mod.VAE.__new__(sd_mod.VAE) - wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( - seedvr_vae_mod.VideoAutoencoderKLWrapper - ) - vae.first_stage_model = wrapper - _populate_common_vae_attrs_fallback(vae) - return vae - - -def _make_non_seedvr2_vae_fallback(): - vae = sd_mod.VAE.__new__(sd_mod.VAE) - vae.first_stage_model = MagicMock() - _populate_common_vae_attrs_fallback(vae) - return vae - - -def _force_regular_encode_oom(*args, **kwargs): - raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") - - -def test_seedvr2_3d_routes_to_encode_tiled_seedvr2_on_oom(): - vae = _make_seedvr2_vae_fallback() - pixel_samples = torch.zeros((1, 8, 64, 64, 3)) - - seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) - generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) - - with patch.object(sd_mod.model_management, "raise_non_oom", - lambda e: None), \ - patch.object(sd_mod.model_management, "load_models_gpu", - lambda *a, **k: None), \ - patch.object(sd_mod.model_management, "soft_empty_cache", - lambda: None), \ - patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode", - side_effect=_force_regular_encode_oom), \ - patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call, - create=True), \ - patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): - vae.encode(pixel_samples) - - assert seedvr2_call.call_count == 1, ( - f"Expected encode_tiled_seedvr2 to be called once for a SeedVR2 3D " - f"input under OOM fallback; got {seedvr2_call.call_count} calls." - ) - assert generic_call.call_count == 0, ( - f"encode_tiled_3d must NOT be called for a SeedVR2 input; got " - f"{generic_call.call_count} calls." - ) - - -def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete(): - vae = _make_non_seedvr2_vae_fallback() - vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8) - vae.upscale_ratio = (lambda a: a * 4, 8, 8) - generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) - pixel_samples = torch.zeros((1, 8, 64, 64, 3)) - - with patch.object(sd_mod.model_management, "load_models_gpu", - lambda *a, **k: None), \ - patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): - vae.encode_tiled(pixel_samples) - - assert generic_call.call_args.kwargs["overlap"] == (1, 64, 64) diff --git a/tests-unit/comfy_test/test_seedvr_progressive_sampler.py b/tests-unit/comfy_test/test_seedvr_progressive_sampler.py deleted file mode 100644 index 05291989e..000000000 --- a/tests-unit/comfy_test/test_seedvr_progressive_sampler.py +++ /dev/null @@ -1,126 +0,0 @@ -"""Unit tests for ``comfy_extras.nodes_seedvr.SeedVR2ProgressiveSampler``.""" - -from unittest.mock import patch - -import pytest -import torch - -from comfy.cli_args import args as cli_args - -if not torch.cuda.is_available(): - cli_args.cpu = True - -import comfy.sample # noqa: E402 -import comfy_extras.nodes_seedvr as nodes_seedvr_mod # noqa: E402 -from comfy_extras.nodes_seedvr import SeedVR2ProgressiveSampler # noqa: E402 - -_LAT_C = 16 -_COND_C = 17 - - -def _make_inputs(B: int = 1, T: int = 5, H: int = 8, W: int = 8): - """Build minimal SeedVR2-shaped sampling inputs.""" - samples_5d = torch.arange( - B * _LAT_C * T * H * W, dtype=torch.float32 - ).reshape(B, _LAT_C, T, H, W) - samples = samples_5d.reshape(B, _LAT_C * T, H, W).contiguous() - - cond_5d = torch.arange( - B * _COND_C * T * H * W, dtype=torch.float32 - ).reshape(B, _COND_C, T, H, W) + 10000.0 - cond = cond_5d.reshape(B, _COND_C * T, H, W).contiguous() - - text_pos = torch.zeros(1, 4, 32) - text_neg = torch.zeros(1, 4, 32) - positive = [[text_pos, {"condition": cond.clone()}]] - negative = [[text_neg, {"condition": cond.clone()}]] - latent_image = {"samples": samples} - return latent_image, positive, negative, samples_5d, cond_5d - - -def _identity_fix_empty(model, latent_image, downscale_ratio_spacial=None): - return latent_image - - -def _fingerprinted_prepare_noise(latent_image, seed, batch_inds=None): - """Return a tensor whose values encode ``(seed, position)``.""" - base = torch.arange( - latent_image.numel(), dtype=torch.float32 - ).reshape(latent_image.shape) - return base + float(seed) * 1e6 - - -def test_progressive_sampler_schema_exposes_manual_default_auto_chunking(): - schema = SeedVR2ProgressiveSampler.define_schema() - inputs = {item.id: item for item in schema.inputs} - - assert inputs["chunking_mode"].options == ["manual", "auto"] - assert inputs["chunking_mode"].default == "manual" - - -def test_auto_chunking_walks_two_three_four_chunk_ladder(): - """Auto mode must walk 2-, 3-, then 4-chunk geometries on OOM.""" - latent, pos, neg, _, _ = _make_inputs(T=17) - calls = [] - - def _oom_until_four_chunks(model, noise, steps, cfg, sampler_name, - scheduler, positive, negative, - latent_image, denoise=1.0, - noise_mask=None, seed=None): - calls.append(tuple(latent_image.shape)) - if latent_image.shape[1] > _LAT_C * 5: - raise torch.cuda.OutOfMemoryError("chunk too large") - return latent_image.clone() - - with patch.object(comfy.sample, "sample", - side_effect=_oom_until_four_chunks), \ - patch.object(comfy.sample, "fix_empty_latent_channels", - side_effect=_identity_fix_empty), \ - patch.object(comfy.sample, "prepare_noise", - side_effect=_fingerprinted_prepare_noise), \ - patch.object(nodes_seedvr_mod.comfy.model_management, - "soft_empty_cache") as soft_empty: - out = SeedVR2ProgressiveSampler.execute( - model=None, seed=0, steps=2, cfg=1.0, - sampler_name="euler", scheduler="simple", - positive=pos, negative=neg, latent=latent, - denoise=1.0, frames_per_chunk=65, temporal_overlap=0, - chunking_mode="auto", - ) - - assert calls[:4] == [ - (1, _LAT_C * 17, 8, 8), - (1, _LAT_C * 9, 8, 8), - (1, _LAT_C * 6, 8, 8), - (1, _LAT_C * 5, 8, 8), - ] - assert torch.equal(out.result[0]["samples"], latent["samples"]) - assert soft_empty.call_count == 3 - - -@pytest.mark.parametrize("bad_chunk", [0, -1, 2]) -def test_t3_invalid_frames_per_chunk_raises_value_error(bad_chunk): - """``frames_per_chunk`` violating 4n+1 (or <1) must raise ``ValueError`` before any model invocation.""" - latent, pos, neg, _, _ = _make_inputs(T=5) - - sampler_called = {"n": 0} - - def _should_not_be_called(*args, **kwargs): - sampler_called["n"] += 1 - return torch.zeros(1) - - with patch.object(comfy.sample, "sample", - side_effect=_should_not_be_called), \ - patch.object(comfy.sample, "fix_empty_latent_channels", - side_effect=_identity_fix_empty), \ - patch.object(comfy.sample, "prepare_noise", - side_effect=_fingerprinted_prepare_noise): - with pytest.raises(ValueError) as excinfo: - SeedVR2ProgressiveSampler.execute( - model=None, seed=0, steps=2, cfg=1.0, - sampler_name="euler", scheduler="simple", - positive=pos, negative=neg, latent=latent, - denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0, - ) - assert str(bad_chunk) in str(excinfo.value) - assert sampler_called["n"] == 0