Revert "Add SeedVR2 support (CORE-6) (#14110)" (#14359)
Some checks are pending
Detect Unreviewed Merge / detect (push) Waiting to run
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run

This reverts commit 7863cf0e53.
This commit is contained in:
comfyanonymous 2026-06-08 15:00:20 -07:00 committed by GitHub
parent a0a055bc4e
commit 00b633f368
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 40 additions and 7383 deletions

View File

@ -4,7 +4,6 @@ class LatentFormat:
scale_factor = 1.0 scale_factor = 1.0
latent_channels = 4 latent_channels = 4
latent_dimensions = 2 latent_dimensions = 2
preserve_empty_channel_multiples = False
latent_rgb_factors = None latent_rgb_factors = None
latent_rgb_factors_bias = None latent_rgb_factors_bias = None
latent_rgb_factors_reshape = None latent_rgb_factors_reshape = None
@ -780,10 +779,6 @@ class ACEAudio(LatentFormat):
latent_channels = 8 latent_channels = 8
latent_dimensions = 2 latent_dimensions = 2
class SeedVR2(LatentFormat):
latent_channels = 16
preserve_empty_channel_multiples = True
class ACEAudio15(LatentFormat): class ACEAudio15(LatentFormat):
latent_channels = 64 latent_channels = 64
latent_dimensions = 1 latent_dimensions = 1

View File

@ -735,86 +735,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
) )
return out return out
def _var_attention_qkv(q, k, v, heads, skip_reshape):
if skip_reshape:
return q, k, v, q.shape[-1]
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
return (
q.view(total_tokens, heads, head_dim),
k.view(k.shape[0], heads, head_dim),
v.view(v.shape[0], heads, head_dim),
head_dim,
)
def _var_attention_output(out, heads, head_dim, skip_output_reshape):
if skip_output_reshape:
return out
return out.reshape(-1, heads * head_dim)
def _use_blackwell_attention():
device = model_management.get_torch_device()
if device.type != "cuda":
return False
major, minor = torch.cuda.get_device_capability(device)
return (major, minor) >= (12, 0)
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
if cu_seqlens.dtype not in (torch.int32, torch.int64):
raise ValueError(f"{name} must use an integer dtype")
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
if cu_seqlens[0].item() != 0:
raise ValueError(f"{name} must start at 0")
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
raise ValueError(f"{name} must be strictly increasing")
if cu_seqlens[-1].item() != token_count:
raise ValueError(f"{name} does not match token count")
def _split_indices(cu_seqlens):
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0])
if cu_seqlens_k[-1].item() != v.shape[0]:
raise ValueError("cu_seqlens_k does not match v token count")
q_split_indices = _split_indices(cu_seqlens_q)
k_split_indices = _split_indices(cu_seqlens_k)
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits):
raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count")
out = []
for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits):
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
out_dtype = q_i.dtype
if optimized_attention is attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16):
q_i = q_i.to(torch.bfloat16)
k_i = k_i.to(torch.bfloat16)
v_i = v_i.to(torch.bfloat16)
out_i = optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
if out_i.dtype != out_dtype:
out_i = out_i.to(out_dtype)
out.append(out_i.squeeze(0).permute(1, 0, 2))
out = torch.cat(out, dim=0)
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
optimized_var_attention = var_attention_optimized_split
optimized_attention = attention_basic optimized_attention = attention_basic
if model_management.sage_attention_enabled(): if model_management.sage_attention_enabled():
@ -837,8 +758,6 @@ else:
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention") logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad optimized_attention = attention_sub_quad
logging.info("Using optimized_attention split-loop for variable-length attention")
optimized_attention_masked = optimized_attention optimized_attention_masked = optimized_attention
@ -854,7 +773,6 @@ if model_management.xformers_enabled():
register_attention_function("pytorch", attention_pytorch) register_attention_function("pytorch", attention_pytorch)
register_attention_function("sub_quad", attention_sub_quad) register_attention_function("sub_quad", attention_sub_quad)
register_attention_function("split", attention_split) register_attention_function("split", attention_split)
register_attention_function("var_attention_optimized_split", var_attention_optimized_split)
def optimized_attention_for_device(device, mask=False, small_input=False): def optimized_attention_for_device(device, mask=False, small_input=False):
@ -1291,3 +1209,5 @@ class SpatialVideoTransformer(SpatialTransformer):
x = self.proj_out(x) x = self.proj_out(x)
out = x + x_in out = x + x_in
return out return out

View File

@ -13,7 +13,6 @@ if model_management.xformers_enabled_vae():
import xformers import xformers
import xformers.ops import xformers.ops
def torch_cat_if_needed(xl, dim): def torch_cat_if_needed(xl, dim):
xl = [x for x in xl if x is not None and x.shape[dim] > 0] xl = [x for x in xl if x is not None and x.shape[dim] > 0]
if len(xl) > 1: if len(xl) > 1:
@ -23,8 +22,7 @@ def torch_cat_if_needed(xl, dim):
else: else:
return None return None
def get_timestep_embedding(timesteps, embedding_dim):
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1):
""" """
This matches the implementation in Denoising Diffusion Probabilistic Models: This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq. From Fairseq.
@ -35,13 +33,11 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down
assert len(timesteps.shape) == 1 assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2 half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - downscale_freq_shift) emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device) emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :] emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
if embedding_dim % 2 == 1: # zero pad if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0)) emb = torch.nn.functional.pad(emb, (0,1,0,0))
return emb return emb

View File

@ -1,340 +0,0 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from comfy.ldm.seedvr.model import safe_pad_operation
from comfy.ldm.seedvr.vae import safe_interpolate_operation
from comfy.ldm.seedvr.constants import (
CIELAB_DELTA,
CIELAB_KAPPA,
D65_WHITE_X,
D65_WHITE_Z,
WAVELET_DECOMP_LEVELS,
)
def wavelet_blur(image: Tensor, radius):
max_safe_radius = max(1, min(image.shape[-2:]) // 8)
if radius > max_safe_radius:
radius = max_safe_radius
num_channels = image.shape[1]
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate')
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS):
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq.add_(image).sub_(low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
if content_feat.shape != style_feat.shape:
# Resize style to match content spatial dimensions
if len(content_feat.shape) >= 3:
# safe_interpolate_operation handles FP16 conversion automatically
style_feat = safe_interpolate_operation(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False
)
# Decompose both features into frequency components
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq # Free memory immediately
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq # Free memory immediately
if content_high_freq.shape != style_low_freq.shape:
style_low_freq = safe_interpolate_operation(
style_low_freq,
size=content_high_freq.shape[-2:],
mode='bilinear',
align_corners=False
)
content_high_freq.add_(style_low_freq)
return content_high_freq.clamp_(-1.0, 1.0)
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
original_shape = source.shape
# Flatten
source_flat = source.flatten()
reference_flat = reference.flatten()
# Sort both arrays
source_sorted, source_indices = torch.sort(source_flat)
reference_sorted, _ = torch.sort(reference_flat)
del reference_flat
# Quantile mapping
n_source = len(source_sorted)
n_reference = len(reference_sorted)
if n_source == n_reference:
matched_sorted = reference_sorted
else:
# Interpolate reference to match source quantiles
source_quantiles = torch.linspace(0, 1, n_source, device=device)
ref_indices = (source_quantiles * (n_reference - 1)).long()
ref_indices.clamp_(0, n_reference - 1)
matched_sorted = reference_sorted[ref_indices]
del source_quantiles, ref_indices, reference_sorted
del source_sorted, source_flat
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
inverse_indices = torch.argsort(source_indices)
del source_indices
matched_flat = matched_sorted[inverse_indices]
del matched_sorted, inverse_indices
return matched_flat.reshape(original_shape)
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
"""Convert batch of CIELAB images to RGB color space."""
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
# LAB to XYZ
fy = (L + 16.0) / 116.0
fx = a.div(500.0).add_(fy)
fz = fy - b / 200.0
del L, a, b
# XYZ transformation
x = torch.where(
fx > epsilon,
torch.pow(fx, 3.0),
fx.mul(116.0).sub_(16.0).div_(kappa)
)
y = torch.where(
fy > epsilon,
torch.pow(fy, 3.0),
fy.mul(116.0).sub_(16.0).div_(kappa)
)
z = torch.where(
fz > epsilon,
torch.pow(fz, 3.0),
fz.mul(116.0).sub_(16.0).div_(kappa)
)
del fx, fy, fz
# Apply D65 white point (in-place)
x.mul_(D65_WHITE_X)
# y *= 1.00000 # (no-op, skip)
z.mul_(D65_WHITE_Z)
xyz = torch.stack([x, y, z], dim=1)
del x, y, z
# Matrix multiplication: XYZ -> RGB
B, C, H, W = xyz.shape
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
del xyz
# Ensure dtype consistency for matrix multiplication
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
del xyz_flat
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
del rgb_linear_flat
# Apply inverse gamma correction (delinearize)
mask = rgb_linear > 0.0031308
rgb = torch.where(
mask,
torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055),
rgb_linear * 12.92
)
del mask, rgb_linear
return torch.clamp(rgb, 0.0, 1.0)
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
"""Convert batch of RGB images to CIELAB color space using D65 illuminant."""
# Apply sRGB gamma correction (linearize)
mask = rgb > 0.04045
rgb_linear = torch.where(
mask,
torch.pow((rgb + 0.055) / 1.055, 2.4),
rgb / 12.92
)
del mask
# Matrix multiplication: RGB -> XYZ
B, C, H, W = rgb_linear.shape
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
del rgb_linear
# Ensure dtype consistency for matrix multiplication
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
xyz_flat = torch.matmul(rgb_flat, matrix.T)
del rgb_flat
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
del xyz_flat
# Normalize by D65 white point (in-place)
xyz[:, 0].div_(D65_WHITE_X) # X
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
xyz[:, 2].div_(D65_WHITE_Z) # Z
# XYZ to LAB transformation
epsilon_cubed = epsilon ** 3
mask = xyz > epsilon_cubed
f_xyz = torch.where(
mask,
torch.pow(xyz, 1.0 / 3.0),
xyz.mul(kappa).add_(16.0).div_(116.0)
)
del xyz, mask
# Extract channels and compute LAB
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
del f_xyz
return torch.stack([L, a, b], dim=1)
def lab_color_transfer(
content_feat: Tensor,
style_feat: Tensor,
luminance_weight: float = 0.8
) -> Tensor:
content_feat = wavelet_reconstruction(content_feat, style_feat)
if content_feat.shape != style_feat.shape:
style_feat = safe_interpolate_operation(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False
)
device = content_feat.device
def ensure_float32_precision(c):
orig_dtype = c.dtype
c = c.float()
return c, orig_dtype
content_feat, original_dtype = ensure_float32_precision(content_feat)
style_feat, _ = ensure_float32_precision(style_feat)
rgb_to_xyz_matrix = torch.tensor([
[0.4124564, 0.3575761, 0.1804375],
[0.2126729, 0.7151522, 0.0721750],
[0.0193339, 0.1191920, 0.9503041]
], dtype=torch.float32, device=device)
xyz_to_rgb_matrix = torch.tensor([
[ 3.2404542, -1.5371385, -0.4985314],
[-0.9692660, 1.8760108, 0.0415560],
[ 0.0556434, -0.2040259, 1.0572252]
], dtype=torch.float32, device=device)
epsilon = CIELAB_DELTA
kappa = CIELAB_KAPPA
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
# Convert to LAB color space
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
del content_feat
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
del style_feat, rgb_to_xyz_matrix
# Match chrominance channels (a*, b*) for accurate color transfer
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
# Handle luminance with weighted blending
if luminance_weight < 1.0:
# Partially match luminance for better overall color accuracy
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device)
# Blend: preserve some content L* for detail, adopt some style L* for color
result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight))
del matched_L
else:
# Fully preserve content luminance
result_L = content_lab[:, 0]
del content_lab, style_lab
# Reconstruct LAB with corrected channels
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
del result_L, matched_a, matched_b
# Convert back to RGB
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
del result_lab, xyz_to_rgb_matrix
# Convert back to [-1, 1] range (in-place)
result = result_rgb.mul_(2.0).sub_(1.0)
del result_rgb
result = result.to(original_dtype)
return result
def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor:
return wavelet_reconstruction(content_feat, style_feat)
def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor:
if content_feat.shape != style_feat.shape:
style_feat = safe_interpolate_operation(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False,
)
original_dtype = content_feat.dtype
content_feat = content_feat.float()
style_feat = style_feat.float()
b, c = content_feat.shape[:2]
content_flat = content_feat.reshape(b, c, -1)
style_flat = style_feat.reshape(b, c, -1)
content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1)
content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1)
style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
del content_flat, style_flat
normalized = (content_feat - content_mean) / content_std
del content_mean, content_std
result = normalized * style_std + style_mean
del normalized, style_mean, style_std
result = result.clamp_(-1.0, 1.0)
if result.dtype != original_dtype:
result = result.to(original_dtype)
return result

View File

@ -1,79 +0,0 @@
"""Named constants for the SeedVR2 integration, grouped by provenance.
Provenance prefixes:
- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline.
- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites
the upstream config/source path it was lifted from.
- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature /
ISO / CIE values; cite the standard.
"""
# --------------------------------------------------------------------------------------
# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment)
# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN)
# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070
# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT).
# --------------------------------------------------------------------------------------
SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB)
SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB
# --------------------------------------------------------------------------------------
# B. Fork heuristics (SEEDVR2 - this integration)
# --------------------------------------------------------------------------------------
SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim.
# (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.)
SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry.
SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case).
SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM.
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk.
SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels).
SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16).
SEEDVR2_DEFAULT_TEMPORAL_SIZE = 16 # default VAE temporal tile when unset.
# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing)
SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk.
SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path.
SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path.
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path.
# --------------------------------------------------------------------------------------
# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR)
# --------------------------------------------------------------------------------------
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm.
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift.
BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem).
BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem).
BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28.
BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28.
BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16).
BYTEDANCE_GN_CHUNKS_FP32 = 2 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp32).
BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD = 64 # attn_video_vae.py:308 (force .contiguous() above this b*t).
BYTEDANCE_BLOCK_OUT_CHANNELS = (128, 256, 512, 512) # s8_c16_t4_inflation_sd3.yaml:7-11.
BYTEDANCE_SLICING_SAMPLE_MIN = 4 # s8_c16_t4_inflation_sd3.yaml:22 (slicing_sample_min_size).
BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE = 4 # infer.py:230 (temporal_downsample_factor); the 4n+1 factor.
BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE = 8 # infer.py:231 (spatial_downsample_factor).
BYTEDANCE_SCHEDULE_T = 1000.0 # configs_3b/main.yaml:65 (schedule.T); timestep range.
BYTEDANCE_SPATIAL_DIVISOR = 16 # inference_seedvr2_3b.py:241 (DivisibleCrop((16,16))).
BYTEDANCE_720P_REF_AREA = 45 * 80 # dit_v2/window.py:32 (720p reference area for window scaling).
BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal window frames).
BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency).
BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim).
# Resolution-dependent timestep-shift linear fits: (x1, y1, x2, y2) for get_lin_function.
BYTEDANCE_IMG_SHIFT_FIT = (256 * 256, 1.0, 1024 * 1024, 3.2) # infer.py:242.
BYTEDANCE_VID_SHIFT_FIT = (256 * 256 * 37, 1.0, 1280 * 720 * 145, 5.0) # infer.py:243.
# --------------------------------------------------------------------------------------
# D. Published standards (cite the literature)
# --------------------------------------------------------------------------------------
ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864.
# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65).
CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta).
CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa).
D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1).
D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn.
WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR).
# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and
# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the
# exact existing coefficients move verbatim rather than being retyped here.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -54,8 +54,6 @@ import comfy.ldm.pixeldit.model
import comfy.ldm.pixeldit.pid import comfy.ldm.pixeldit.pid
import comfy.ldm.ace.model import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2 import comfy.ldm.omnigen.omnigen2
import comfy.ldm.seedvr.model
import comfy.ldm.qwen_image.model import comfy.ldm.qwen_image.model
import comfy.ldm.ideogram4.model import comfy.ldm.ideogram4.model
import comfy.ldm.kandinsky5.model import comfy.ldm.kandinsky5.model
@ -930,16 +928,6 @@ class HunyuanDiT(BaseModel):
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
return out return out
class SeedVR2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
condition = kwargs.get("condition", None)
if condition is not None:
out["condition"] = comfy.conds.CONDRegular(condition)
return out
class PixArt(BaseModel): class PixArt(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)

View File

@ -598,56 +598,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config return dit_config
if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 3072
dit_config["heads"] = 24
dit_config["num_layers"] = 36
# 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.``
# submodules) at EVERY block — verified by inspecting the 7B
# state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means
# ``MMModule.shared_weights=False``). Native NaDiT computes
# per-block ``shared_weights = not (i < mm_layers)``, so to keep
# every block non-shared we set ``mm_layers = num_layers``.
# Without this, blocks at index >= mm_layers (default 10) try to
# load ``blocks.N.*.all.*`` keys that don't exist in the file,
# silently miss-load → all-black output.
dit_config["mm_layers"] = 36
dit_config["norm_eps"] = 1e-5
dit_config["qk_rope"] = True
dit_config["rope_type"] = "rope3d"
dit_config["rope_dim"] = 64
dit_config["mlp_type"] = "normal"
return dit_config
elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 3072
dit_config["heads"] = 24
dit_config["num_layers"] = 36
# This checkpoint layout carries shared ``all.`` MMModule keys.
# Preserve the historical split: the initial blocks use separate
# vid/txt modules, later blocks use shared modules.
dit_config["mm_layers"] = 10
dit_config["norm_eps"] = 1e-5
dit_config["qk_rope"] = True
dit_config["rope_type"] = "rope3d"
dit_config["rope_dim"] = 64
dit_config["mlp_type"] = "swiglu"
return dit_config
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 2560
dit_config["heads"] = 20
dit_config["num_layers"] = 32
dit_config["norm_eps"] = 1.0e-05
dit_config["qk_rope"] = None
dit_config["mlp_type"] = "swiglu"
dit_config["vid_out_norm"] = True
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {} dit_config = {}
dit_config["image_model"] = "wan2.1" dit_config["image_model"] = "wan2.1"

View File

@ -44,13 +44,7 @@ def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None,
is_empty = torch.count_nonzero(latent_image) == 0 is_empty = torch.count_nonzero(latent_image) == 0
if is_empty: if is_empty:
if latent_format.latent_channels != latent_image.shape[1]: if latent_format.latent_channels != latent_image.shape[1]:
preserves_collapsed_channels = ( latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
getattr(latent_format, "preserve_empty_channel_multiples", False)
and latent_image.ndim == 4
and latent_image.shape[1] % latent_format.latent_channels == 0
)
if not preserves_collapsed_channels:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
if downscale_ratio_spacial is not None: if downscale_ratio_spacial is not None:
if downscale_ratio_spacial != latent_format.spacial_downscale_ratio: if downscale_ratio_spacial != latent_format.spacial_downscale_ratio:
ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio

View File

@ -1,4 +1,3 @@
import inspect
import json import json
import torch import torch
from enum import Enum from enum import Enum
@ -17,7 +16,6 @@ import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2 import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae import comfy.ldm.hunyuan3d.vae
import comfy.ldm.seedvr.vae
import comfy.ldm.triposplat.vae import comfy.ldm.triposplat.vae
import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.cogvideo.vae import comfy.ldm.cogvideo.vae
@ -86,36 +84,6 @@ import comfy.latent_formats
import comfy.ldm.flux.redux import comfy.ldm.flux.redux
SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL = 160
def _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w):
output_t = max(1, (latent_t - 1) * 4 + 1)
return output_t * latent_h * 8 * latent_w * 8
def _seedvr2_vae_decode_memory_used(shape):
if len(shape) == 5:
candidates = []
if shape[1] == 16:
candidates.append((shape[2], shape[3], shape[4]))
if shape[-1] == 16:
candidates.append((shape[1], shape[2], shape[3]))
if len(candidates) == 0:
candidates.append((shape[2], shape[3], shape[4]))
output_pixels = max(_seedvr2_vae_decode_output_pixels(*candidate) for candidate in candidates)
elif len(shape) == 4:
latent_t = max(1, (shape[1] + 15) // 16)
latent_h, latent_w = shape[2], shape[3]
output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w)
else:
latent_t, latent_h, latent_w = 1, shape[-2], shape[-1]
output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w)
# SeedVR2 decode performs full-frame LAB histogram matching: fp32 channels
# plus int64 sort indices dominate peak memory, not the VAE weight dtype.
return output_pixels * SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None): def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None):
key_map = {} key_map = {}
if model is not None: if model is not None:
@ -499,10 +467,8 @@ class CLIP:
class VAE: class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format sd = diffusers_convert.convert_vae_state_dict(sd)
if metadata is None or metadata.get("keep_diffusers_format") != "true":
sd = diffusers_convert.convert_vae_state_dict(sd)
if model_management.is_amd(): if model_management.is_amd():
VAE_KL_MEM_RATIO = 2.73 VAE_KL_MEM_RATIO = 2.73
@ -574,20 +540,6 @@ class VAE:
self.first_stage_model = StageC_coder() self.first_stage_model = StageC_coder()
self.downscale_ratio = 32 self.downscale_ratio = 32
self.latent_channels = 16 self.latent_channels = 16
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
self.latent_channels = 16
self.latent_dim = 3
self.disable_offload = True
self.memory_used_decode = lambda shape, dtype: _seedvr2_vae_decode_memory_used(shape)
self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.process_input = lambda image: image * 2.0 - 1.0
self.crop_input = False
elif "decoder.conv_in.weight" in sd: elif "decoder.conv_in.weight" in sd:
if sd['decoder.conv_in.weight'].shape[1] == 64: if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
@ -715,7 +667,6 @@ class VAE:
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
self.downscale_index_formula = (8, 32, 32) self.downscale_index_formula = (8, 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
@ -1055,40 +1006,6 @@ class VAE:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def decode_tiled_seedvr2(self, samples, tile_x=32, tile_y=32, overlap=8, tile_t=16, overlap_t=4):
sf_s = getattr(self.first_stage_model, "spatial_downsample_factor", 8)
sf_t = getattr(self.first_stage_model, "temporal_downsample_factor", 4)
if tile_t is None:
tile_t = 16
if overlap_t is None:
overlap_t = 4
if tile_t > 0:
temporal_size = tile_t * sf_t
temporal_overlap = max(0, overlap_t) * sf_t
else:
temporal_size = 0
temporal_overlap = 0
args = {
"enable_tiling": True,
"tile_size": (tile_y * sf_s, tile_x * sf_s),
"tile_overlap": (overlap * sf_s, overlap * sf_s),
"temporal_size": temporal_size,
"temporal_overlap": temporal_overlap,
}
output = self.first_stage_model.decode(
samples.to(self.vae_dtype).to(self.device),
seedvr2_tiling=args,
)
return self.process_output(output.to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
def _format_seedvr2_encoded_samples(self, samples):
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
if samples.ndim == 4:
samples = samples.unsqueeze(2)
samples = samples.contiguous()
samples = samples * 0.9152
return samples
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
@ -1125,36 +1042,6 @@ class VAE:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def encode_tiled_seedvr2(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
if tile_y is None:
tile_y = 512
if tile_x is None:
tile_x = 512
if overlap is None:
overlap_y = 64
overlap_x = 64
else:
overlap_y = overlap
overlap_x = overlap
if tile_t is None:
tile_t = 9999
if overlap_t is None:
overlap_t = 0
overlap_y = min(overlap_y, max(0, tile_y - 8))
overlap_x = min(overlap_x, max(0, tile_x - 8))
self.first_stage_model.device = self.device
x = self.process_input(pixel_samples).to(self.vae_dtype).to(self.device)
output = comfy.ldm.seedvr.vae.tiled_vae(
x,
self.first_stage_model,
tile_size=(tile_y, tile_x),
tile_overlap=(overlap_y, overlap_x),
temporal_size=tile_t,
temporal_overlap=overlap_t,
encode=True,
)
return output.to(device=self.output_device, dtype=self.vae_output_dtype())
def decode(self, samples_in, vae_options={}): def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid() self.throw_exception_if_invalid()
pixel_samples = None pixel_samples = None
@ -1202,40 +1089,16 @@ class VAE:
if dims == 1 or self.extra_1d_channel is not None: if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in) pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2: elif dims == 2:
# SeedVR2 latents arrive in 4D collapsed form ``(B, 16*T, H, W)`` pixel_samples = self.decode_tiled_(samples_in)
# downstream of ``SeedVR2Conditioning`` (which performs the
# ``rearrange(b c t h w -> b (c t) h w)`` collapse). The
# generic ``decode_tiled_`` would treat the channel dim as
# spatial-only and crash on the collapsed (16, T) layout
# under ``tiled_scale``'s mask broadcast; route SeedVR2 4D
# latents to ``decode_tiled_seedvr2`` instead, whose wrapper
# dispatch handles both 4D and 5D inputs.
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
else:
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3: elif dims == 3:
tile = 256 // self.spacial_compression_decode() tile = 256 // self.spacial_compression_decode()
overlap = tile // 4 overlap = tile // 4
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
else:
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples return pixel_samples
def decode_tiled( def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self,
samples,
tile_x=None,
tile_y=None,
overlap=None,
tile_t=None,
overlap_t=None,
):
self.throw_exception_if_invalid() self.throw_exception_if_invalid()
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -1249,20 +1112,7 @@ class VAE:
args["overlap"] = overlap args["overlap"] = overlap
with model_management.cuda_device_context(self.device): with model_management.cuda_device_context(self.device):
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper) and dims in (2, 3): if dims == 1 or self.extra_1d_channel is not None:
seedvr2_args = {}
if tile_x is not None:
seedvr2_args["tile_x"] = tile_x
if tile_y is not None:
seedvr2_args["tile_y"] = tile_y
if overlap is not None:
seedvr2_args["overlap"] = overlap
if tile_t is not None:
seedvr2_args["tile_t"] = tile_t
if overlap_t is not None:
seedvr2_args["overlap_t"] = overlap_t
output = self.decode_tiled_seedvr2(samples, **seedvr2_args)
elif dims == 1 or self.extra_1d_channel is not None:
args.pop("tile_y") args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args) output = self.decode_tiled_1d(samples, **args)
elif dims == 2: elif dims == 2:
@ -1304,8 +1154,6 @@ class VAE:
else: else:
pixels_in = pixels_in.to(self.device) pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in) out = self.first_stage_model.encode(pixels_in)
if isinstance(out, tuple):
out = out[0]
out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None: if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
@ -1325,23 +1173,20 @@ class VAE:
if self.latent_dim == 3: if self.latent_dim == 3:
tile = 256 tile = 256
overlap = tile // 4 overlap = tile // 4
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
samples = self.encode_tiled_seedvr2(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap)
else:
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1 or self.extra_1d_channel is not None: elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples) samples = self.encode_tiled_1d(pixel_samples)
else: else:
samples = self.encode_tiled_(pixel_samples) samples = self.encode_tiled_(pixel_samples)
return self._format_seedvr2_encoded_samples(samples) return samples
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid() self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
dims = self.latent_dim dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1) pixel_samples = pixel_samples.movedim(-1, 1)
if dims == 3 and pixel_samples.ndim < 5: if dims == 3:
if not self.not_video: if not self.not_video:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
else: else:
@ -1365,47 +1210,22 @@ class VAE:
elif dims == 2: elif dims == 2:
samples = self.encode_tiled_(pixel_samples, **args) samples = self.encode_tiled_(pixel_samples, **args)
elif dims == 3: elif dims == 3:
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): if tile_t is not None:
seedvr2_args = {} tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
if tile_x is not None:
seedvr2_args["tile_x"] = tile_x
else:
seedvr2_args["tile_x"] = 512
if tile_y is not None:
seedvr2_args["tile_y"] = tile_y
else:
seedvr2_args["tile_y"] = 512
if overlap is not None:
seedvr2_args["overlap"] = overlap
else:
seedvr2_args["overlap"] = 64
if tile_t is not None:
seedvr2_args["tile_t"] = tile_t
else:
seedvr2_args["tile_t"] = 9999
if overlap_t is not None:
seedvr2_args["overlap_t"] = overlap_t
else:
seedvr2_args["overlap_t"] = 0
samples = self.encode_tiled_seedvr2(pixel_samples, **seedvr2_args)
else: else:
if tile_t is not None: tile_t_latent = 9999
tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
else:
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
spatial_overlap = overlap if overlap is not None else 64 if overlap_t is None:
if overlap_t is None: args["overlap"] = (1, overlap, overlap)
args["overlap"] = (1, spatial_overlap, spatial_overlap) else:
else: args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), spatial_overlap, spatial_overlap) maximum = pixel_samples.shape[2]
maximum = pixel_samples.shape[2] maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
return self._format_seedvr2_encoded_samples(samples) return samples
def get_sd(self): def get_sd(self):
return self.first_stage_model.state_dict() return self.first_stage_model.state_dict()
@ -1932,17 +1752,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae) return (model, clip, vae)
def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device):
set_dtype = model_config.set_inference_dtype
parameters = inspect.signature(set_dtype).parameters
supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values())
if supports_device:
set_dtype(dtype, manual_cast_dtype, device=device)
else:
set_dtype(dtype, manual_cast_dtype)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
@ -2050,7 +1859,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else: else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
if model_config.clip_vision_prefix is not None: if model_config.clip_vision_prefix is not None:
if output_clipvision: if output_clipvision:
@ -2191,7 +2000,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else: else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
if custom_operations is not None: if custom_operations is not None:
model_config.custom_operations = custom_operations model_config.custom_operations = custom_operations

View File

@ -1672,35 +1672,6 @@ class Chroma(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
class SeedVR2(supported_models_base.BASE):
unet_config = {
"image_model": "seedvr2"
}
latent_format = comfy.latent_formats.SeedVR2
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
sampling_settings = {
"shift": 1.0,
}
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
if (
dtype == torch.float16
and manual_cast_dtype is None
and comfy.model_management.should_use_bf16(device)
):
manual_cast_dtype = torch.bfloat16
super().set_inference_dtype(dtype, manual_cast_dtype, device=device)
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SeedVR2(self, device=device)
return out
def clip_target(self, state_dict={}):
return None
class ChromaRadiance(Chroma): class ChromaRadiance(Chroma):
unet_config = { unet_config = {
"image_model": "chroma_radiance", "image_model": "chroma_radiance",
@ -2058,6 +2029,7 @@ class LongCatImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
class RT_DETR_v4(supported_models_base.BASE): class RT_DETR_v4(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "RT_DETR_v4", "image_model": "RT_DETR_v4",
@ -2295,7 +2267,6 @@ models = [
HiDream, HiDream,
HiDreamO1, HiDreamO1,
Chroma, Chroma,
SeedVR2,
ChromaRadiance, ChromaRadiance,
ACEStep, ACEStep,
ACEStep15, ACEStep15,

View File

@ -115,7 +115,7 @@ class BASE:
replace_prefix = {"": self.vae_key_prefix[0]} replace_prefix = {"": self.vae_key_prefix[0]}
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None): def set_inference_dtype(self, dtype, manual_cast_dtype):
self.unet_config['dtype'] = dtype self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype self.manual_cast_dtype = manual_cast_dtype

File diff suppressed because it is too large Load Diff

View File

@ -47,18 +47,14 @@ import node_helpers
if args.enable_manager: if args.enable_manager:
import comfyui_manager import comfyui_manager
def before_node_execution(): def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted() comfy.model_management.throw_exception_if_processing_interrupted()
def interrupt_processing(value=True): def interrupt_processing(value=True):
comfy.model_management.interrupt_current_processing(value) comfy.model_management.interrupt_current_processing(value)
MAX_RESOLUTION=16384 MAX_RESOLUTION=16384
class CLIPTextEncode(ComfyNodeABC): class CLIPTextEncode(ComfyNodeABC):
@classmethod @classmethod
def INPUT_TYPES(s) -> InputTypeDict: def INPUT_TYPES(s) -> InputTypeDict:
@ -327,8 +323,8 @@ class VAEDecodeTiled:
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ), return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32, "advanced": True}), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32, "advanced": True}),
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}),
"temporal_size": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time. SeedVR2 allows 0 to disable temporal slicing.", "advanced": True}), "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time.", "advanced": True}),
"temporal_overlap": ("INT", {"default": 8, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}),
}} }}
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode" FUNCTION = "decode"
@ -338,32 +334,18 @@ class VAEDecodeTiled:
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8): def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
if tile_size < overlap * 4: if tile_size < overlap * 4:
overlap = tile_size // 4 overlap = tile_size // 4
if temporal_size < temporal_overlap * 2:
temporal_overlap = temporal_overlap // 2
temporal_compression = vae.temporal_compression_decode() temporal_compression = vae.temporal_compression_decode()
if temporal_compression is not None: if temporal_compression is not None:
if temporal_size <= 0: temporal_size = max(2, temporal_size // temporal_compression)
temporal_size = 0 temporal_overlap = max(1, min(temporal_size // 2, temporal_overlap // temporal_compression))
temporal_overlap = 0
else:
requested_temporal_overlap = temporal_overlap
if temporal_size < temporal_overlap * 2:
temporal_overlap = temporal_overlap // 2
temporal_size = max(2, temporal_size // temporal_compression)
temporal_overlap = min(temporal_size // 2, temporal_overlap // temporal_compression)
if requested_temporal_overlap > 0:
temporal_overlap = max(1, temporal_overlap)
else: else:
temporal_size = None temporal_size = None
temporal_overlap = None temporal_overlap = None
compression = vae.spacial_compression_decode() compression = vae.spacial_compression_decode()
images = vae.decode_tiled( images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap)
samples["samples"],
tile_x=tile_size // compression,
tile_y=tile_size // compression,
overlap=overlap // compression,
tile_t=temporal_size,
overlap_t=temporal_overlap,
)
if len(images.shape) == 5: #Combine batches if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return (images, ) return (images, )
@ -380,7 +362,7 @@ class VAEEncode:
def encode(self, vae, pixels): def encode(self, vae, pixels):
t = vae.encode(pixels) t = vae.encode(pixels)
return ({"samples": t}, ) return ({"samples":t}, )
class VAEEncodeTiled: class VAEEncodeTiled:
@classmethod @classmethod
@ -388,8 +370,8 @@ class VAEEncodeTiled:
return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ), return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64, "advanced": True}), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64, "advanced": True}),
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}),
"temporal_size": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time. SeedVR2 allows 0 to disable temporal slicing.", "advanced": True}), "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time.", "advanced": True}),
"temporal_overlap": ("INT", {"default": 8, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}),
}} }}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "encode" FUNCTION = "encode"
@ -397,9 +379,6 @@ class VAEEncodeTiled:
CATEGORY = "experimental" CATEGORY = "experimental"
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
if temporal_size <= 0:
temporal_size = 0
temporal_overlap = 0
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
return ({"samples": t}, ) return ({"samples": t}, )
@ -2439,7 +2418,6 @@ async def init_builtin_extra_nodes():
"nodes_camera_trajectory.py", "nodes_camera_trajectory.py",
"nodes_edit_model.py", "nodes_edit_model.py",
"nodes_tcfg.py", "nodes_tcfg.py",
"nodes_seedvr.py",
"nodes_context_windows.py", "nodes_context_windows.py",
"nodes_qwen.py", "nodes_qwen.py",
"nodes_chroma_radiance.py", "nodes_chroma_radiance.py",

View File

@ -1,213 +0,0 @@
"""Consolidated SeedVR2 conditioning and refactor regression tests.
Merges the prior test_seedvr2_refactor_nodes.py and
test_seedvr_conditioning_hardening.py modules. Refactor tests use the
top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests
use _import_nodes_seedvr_isolated() for sys.modules isolation when
mocking comfy.model_management.
"""
import importlib
import sys
from unittest.mock import MagicMock
import pytest
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
_SENTINEL = object()
_TARGETS = (
("comfy.model_management", "comfy"),
("comfy_extras.nodes_seedvr", "comfy_extras"),
)
def _import_nodes_seedvr_isolated():
"""Import comfy_extras.nodes_seedvr with comfy.model_management mocked."""
priors = []
for mod_name, parent_name in _TARGETS:
prior_mod = sys.modules.get(mod_name, _SENTINEL)
parent = sys.modules.get(parent_name)
attr = mod_name.split(".")[-1]
prior_attr = (
getattr(parent, attr, _SENTINEL) if parent is not None else _SENTINEL
)
priors.append((mod_name, parent_name, attr, prior_mod, prior_attr))
mock_mm = MagicMock()
for fn in (
"xformers_enabled", "xformers_enabled_vae",
"pytorch_attention_enabled", "pytorch_attention_enabled_vae",
"sage_attention_enabled", "flash_attention_enabled",
"is_intel_xpu",
):
getattr(mock_mm, fn).return_value = False
tv = torch.version.__version__.split(".")
mock_mm.torch_version_numeric = (int(tv[0]), int(tv[1]))
mock_mm.WINDOWS = False
sys.modules["comfy.model_management"] = mock_mm
if sys.modules.get("comfy") is None:
import comfy as _comfy_pkg # noqa: F401
comfy_pkg = sys.modules.get("comfy")
if comfy_pkg is not None:
setattr(comfy_pkg, "model_management", mock_mm)
nodes_seedvr = sys.modules.get("comfy_extras.nodes_seedvr") or (
importlib.import_module("comfy_extras.nodes_seedvr")
)
def _restore():
for mod_name, parent_name, attr, prior_mod, prior_attr in priors:
if prior_mod is _SENTINEL:
sys.modules.pop(mod_name, None)
else:
sys.modules[mod_name] = prior_mod
parent = sys.modules.get(parent_name)
if parent is None:
continue
if prior_attr is _SENTINEL:
if hasattr(parent, attr):
delattr(parent, attr)
else:
setattr(parent, attr, prior_attr)
return nodes_seedvr, _restore
class _Rope(nn.Module):
"""Minimal RoPE stub exposing a `freqs` parameter."""
def __init__(self):
super().__init__()
self.freqs = nn.Parameter(torch.zeros(4))
class _Block(nn.Module):
"""Minimal transformer block stub holding a `_Rope`."""
def __init__(self):
super().__init__()
self.rope = _Rope()
class _DiffusionModel(nn.Module):
"""Stub diffusion model with N blocks and pos/neg conditioning buffers."""
def __init__(self, n_blocks=3, zero_conditioning=False, conditioning_dtype=torch.float32):
super().__init__()
self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)])
pos = torch.zeros if zero_conditioning else torch.ones
self.register_buffer("positive_conditioning", pos((2, 4), dtype=conditioning_dtype))
self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype))
class _ModelInner:
"""Inner model wrapper exposing `.diffusion_model`."""
def __init__(self, diffusion_model):
self.diffusion_model = diffusion_model
class _ModelPatcher:
"""ModelPatcher stub exposing `.model._ModelInner`."""
def __init__(self, diffusion_model):
self.model = _ModelInner(diffusion_model)
def test_seedvr2_conditioning_schema_exposes_model_passthrough_output():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
schema = nodes_seedvr.SeedVR2Conditioning.define_schema()
assert [input_item.id for input_item in schema.inputs] == [
"model",
"vae_conditioning",
]
assert schema.inputs[1].display_name == "latent"
assert [output.display_name for output in schema.outputs] == [
"model",
"positive",
"negative",
"latent",
]
finally:
restore()
def test_seedvr2_conditioning_returns_packed_input_latent_deterministically():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
diffusion_model = _DiffusionModel()
patcher = _ModelPatcher(diffusion_model)
samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2)
vae_conditioning = {"samples": samples}
_, first_positive, first_negative, first_latent = (
nodes_seedvr.SeedVR2Conditioning.execute(
patcher,
vae_conditioning,
)
)
_, second_positive, second_negative, second_latent = (
nodes_seedvr.SeedVR2Conditioning.execute(
patcher,
vae_conditioning,
)
)
expected_latent = samples.reshape(1, 6, 2, 2)
channel_last = samples.movedim(1, -1).contiguous()
expected_condition = torch.cat(
[
channel_last,
torch.ones((*channel_last.shape[:-1], 1)),
],
dim=-1,
).movedim(-1, 1).reshape(1, 9, 2, 2)
assert torch.equal(first_latent["samples"], expected_latent)
assert torch.equal(second_latent["samples"], expected_latent)
assert torch.equal(
first_positive[0][1]["condition"],
expected_condition,
)
assert torch.equal(
second_positive[0][1]["condition"],
expected_condition,
)
assert torch.equal(
first_negative[0][1]["condition"],
expected_condition,
)
assert torch.equal(
second_negative[0][1]["condition"],
expected_condition,
)
finally:
restore()
def test_seedvr2_conditioning_fails_loud_on_zero_buffers():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
diffusion_model = _DiffusionModel(zero_conditioning=True)
patcher = _ModelPatcher(diffusion_model)
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
with pytest.raises(RuntimeError) as excinfo:
nodes_seedvr.SeedVR2Conditioning.execute(
patcher, vae_conditioning,
)
message = str(excinfo.value)
assert message.startswith(
nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX
), (
"Fail-loud message must use the standard "
"_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers "
f"can match it. Got: {message!r}"
)
assert "positive_conditioning" in message
assert "negative_conditioning" in message
finally:
restore()

View File

@ -1,55 +0,0 @@
import importlib
import inspect
import sys
from unittest.mock import MagicMock, patch
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
def test_seedvr_node_signature_matches_schema():
mock_mm = MagicMock()
mock_mm.xformers_enabled.return_value = False
mock_mm.xformers_enabled_vae.return_value = False
mock_mm.sage_attention_enabled.return_value = False
mock_mm.flash_attention_enabled.return_value = False
sentinel = object()
prior_cpu = cli_args.cpu
cli_args.cpu = True
prior_module = sys.modules.get("comfy_extras.nodes_seedvr", sentinel)
comfy_pkg = sys.modules.get("comfy")
prior_mm_attr = getattr(comfy_pkg, "model_management", sentinel) if comfy_pkg else sentinel
with patch.dict(sys.modules, {"comfy.model_management": mock_mm}):
if comfy_pkg is not None:
setattr(comfy_pkg, "model_management", mock_mm)
sys.modules.pop("comfy_extras.nodes_seedvr", None)
try:
nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr")
for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning, nodes_seedvr.SeedVR2ProgressiveSampler):
schema_ids = [i.id for i in node_cls.define_schema().inputs]
exec_params = [
p for p in inspect.signature(node_cls.execute).parameters.keys()
if p != "cls"
]
assert schema_ids == exec_params, (
f"{node_cls.__name__} schema/execute drift: "
f"schema_ids={schema_ids}, exec_params={exec_params}"
)
finally:
cli_args.cpu = prior_cpu
if prior_module is sentinel:
sys.modules.pop("comfy_extras.nodes_seedvr", None)
else:
sys.modules["comfy_extras.nodes_seedvr"] = prior_module
if comfy_pkg is not None:
if prior_mm_attr is sentinel:
if hasattr(comfy_pkg, "model_management"):
delattr(comfy_pkg, "model_management")
else:
setattr(comfy_pkg, "model_management", prior_mm_attr)

View File

@ -1,57 +0,0 @@
from unittest.mock import patch
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
from comfy_extras import nodes_seedvr # noqa: E402
def _schema_ids(items):
return [item.id for item in items]
def test_seedvr2_post_processing_schema():
schema = nodes_seedvr.SeedVR2PostProcessing.define_schema()
assert _schema_ids(schema.inputs) == ["images", "original_resized_images", "color_correction_method"]
assert schema.inputs[2].options == ["lab", "wavelet", "adain", "none"]
assert schema.inputs[2].default == "lab"
assert schema.outputs[0].get_io_type() == "IMAGE"
def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch):
decoded = torch.full((1, 3, 4, 4), 0.25)
reference = torch.full((1, 3, 4, 4), 0.75)
def _lab(content, style):
raise torch.cuda.OutOfMemoryError("CUDA out of memory")
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None)
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
try:
nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked(
decoded, reference, torch.device("cpu"), "lab",
)
except RuntimeError as exc:
assert "color_correction_method=lab" in str(exc)
assert " method=lab" not in str(exc)
else:
raise AssertionError("expected RuntimeError for one-frame LAB OOM")
def test_seedvr2_post_processing_unknown_color_correction_method_raises():
decoded = torch.zeros(1, 2, 4, 4, 3)
original = torch.zeros(1, 2, 4, 4, 3)
try:
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus")
except ValueError as exc:
assert "color_correction_method" in str(exc)
else:
raise AssertionError("expected ValueError for unknown color_correction_method")

View File

@ -73,24 +73,6 @@ def _make_flux_schnell_comfyui_sd():
return sd return sd
def _make_seedvr2_7b_separate_mm_sd():
return {
"blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072),
}
def _make_seedvr2_7b_shared_mm_sd():
return {
"blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
}
def _make_seedvr2_3b_shared_mm_sd():
return {
"blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
}
class TestModelDetection: class TestModelDetection:
"""Verify that first-match model detection selects the correct model """Verify that first-match model detection selects the correct model
based on list ordering and unet_config specificity.""" based on list ordering and unet_config specificity."""
@ -143,48 +125,6 @@ class TestModelDetection:
assert model_config is not None assert model_config is not None
assert type(model_config).__name__ == "FluxSchnell" assert type(model_config).__name__ == "FluxSchnell"
def test_seedvr2_7b_separate_mm_detection_config(self):
sd = _make_seedvr2_7b_separate_mm_sd()
unet_config = detect_unet_config(sd, "")
assert unet_config is not None
assert unet_config["image_model"] == "seedvr2"
assert unet_config["vid_dim"] == 3072
assert unet_config["heads"] == 24
assert unet_config["num_layers"] == 36
assert unet_config["mm_layers"] == 36
assert unet_config["mlp_type"] == "normal"
assert unet_config["qk_rope"] is True
assert unet_config["rope_type"] == "rope3d"
assert unet_config["rope_dim"] == 64
def test_seedvr2_7b_shared_mm_detection_config(self):
sd = _make_seedvr2_7b_shared_mm_sd()
unet_config = detect_unet_config(sd, "")
assert unet_config is not None
assert unet_config["image_model"] == "seedvr2"
assert unet_config["vid_dim"] == 3072
assert unet_config["heads"] == 24
assert unet_config["num_layers"] == 36
assert unet_config["mm_layers"] == 10
assert unet_config["mlp_type"] == "swiglu"
assert unet_config["qk_rope"] is True
assert unet_config["rope_type"] == "rope3d"
assert unet_config["rope_dim"] == 64
def test_seedvr2_3b_shared_mm_detection_config(self):
sd = _make_seedvr2_3b_shared_mm_sd()
unet_config = detect_unet_config(sd, "")
assert unet_config is not None
assert unet_config["image_model"] == "seedvr2"
assert unet_config["vid_dim"] == 2560
assert unet_config["heads"] == 20
assert unet_config["num_layers"] == 32
assert unet_config["mlp_type"] == "swiglu"
assert unet_config["qk_rope"] is None
def test_unet_config_and_required_keys_combination_is_unique(self): def test_unet_config_and_required_keys_combination_is_unique(self):
"""Each model in the registry must have a unique combination of """Each model in the registry must have a unique combination of
``unet_config`` and ``required_keys``. If two models share the same ``unet_config`` and ``required_keys``. If two models share the same

View File

@ -1,90 +0,0 @@
"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must
honor the actual tensor/tuple return contract of ``encode()`` and
``decode_()`` and must NOT dereference diffusers-style ``.latent_dist``
or ``.sample`` attributes on those returns.
The pre-fix body raised ``AttributeError: 'Tensor' object has no
attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and
``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'``
for ``mode == "decode"`` (the class only defines ``decode_`` with a
trailing underscore). The post-fix body unwraps the optional one-element
tuple shape that ``return_dict=False`` produces and returns the tensor
directly.
Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses
the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and
overrides ``encode``/``decode_`` with known tensors so the contract can
be probed without loading any real VAE weights.
"""
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402
_LATENT_SHAPE = (1, 16, 2, 2, 2)
_DECODED_SHAPE = (1, 3, 5, 16, 16)
_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16)
_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2)
class _StubVAE(VideoAutoencoderKL):
def __init__(self):
nn.Module.__init__(self)
self._encode_out = torch.zeros(*_LATENT_SHAPE)
self._decode_out = torch.zeros(*_DECODED_SHAPE)
def encode(self, x, return_dict=True):
return self._encode_out
def decode_(self, z, return_dict=True):
return self._decode_out
def test_forward_encode_returns_tensor():
vae = _StubVAE()
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
result = vae.forward(x, mode="encode")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_LATENT_SHAPE)
def test_forward_decode_returns_tensor():
vae = _StubVAE()
z = torch.zeros(*_INPUT_DECODE_SHAPE)
result = vae.forward(z, mode="decode")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_DECODED_SHAPE)
class _TupleReturningStubVAE(VideoAutoencoderKL):
"""Stub variant whose ``encode``/``decode_`` return the
``(tensor,)`` one-element tuple shape ``return_dict=False`` produces
in the parent class. Exercises the unwrap branch of
``VideoAutoencoderKL.forward``.
"""
def __init__(self):
nn.Module.__init__(self)
self._encode_tensor = torch.zeros(*_LATENT_SHAPE)
self._decode_tensor = torch.zeros(*_DECODED_SHAPE)
def encode(self, x, return_dict=True):
return (self._encode_tensor,)
def decode_(self, z, return_dict=True):
return (self._decode_tensor,)
def test_forward_all_unwraps_one_tuple_at_each_step():
vae = _TupleReturningStubVAE()
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
result = vae.forward(x, mode="all")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_DECODED_SHAPE)

View File

@ -1,47 +0,0 @@
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.sd
import comfy.supported_models
import comfy.ldm.seedvr.model as seedvr_model
def test_seedvr2_fp16_manual_cast_only_for_bf16_device(monkeypatch):
bf16_device = object()
fp16_device = object()
monkeypatch.setattr(
comfy.supported_models.comfy.model_management,
"should_use_bf16",
lambda device=None: device is bf16_device,
)
bf16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"})
bf16_config.set_inference_dtype(torch.float16, None, device=bf16_device)
assert bf16_config.manual_cast_dtype is torch.bfloat16
fp16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"})
fp16_config.set_inference_dtype(torch.float16, None, device=fp16_device)
assert fp16_config.manual_cast_dtype is None
def test_seedvr2_text_conditioning_accepts_cfg1_single_branch():
context = torch.arange(6, dtype=torch.float32).reshape(1, 3, 2)
txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0])
torch.testing.assert_close(txt, context.squeeze(0))
torch.testing.assert_close(txt_shape, torch.tensor([[3]], device=context.device))
def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer():
estimate = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 26, 120, 160))
old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2
assert estimate == 101 * 960 * 1280 * 160
assert estimate > 15 * 1024 ** 3
assert estimate > old_estimate * 100

View File

@ -1,341 +0,0 @@
"""Consolidated SeedVR2 internals regression tests.
Sources (all merged verbatim, helper names disambiguated where colliding):
* RoPE rewrite NaMMRotaryEmbedding3d.forward must match the legacy
apply_rotary_emb wrapper oracle at fp32.
* GroupNorm limit gate causal_norm_wrapper at vae.py:509 must compare
memory_occupy against get_norm_limit(), not float('inf').
* SeedVR2 variable-length attention split-loop contract.
Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and
comfy.ldm.modules.attention transitively pull in comfy.model_management,
which probes torch.cuda.current_device() at import time unless args.cpu is
set first.
"""
from __future__ import annotations
from unittest.mock import patch
import pytest
import torch
from comfy.cli_args import args
if not torch.cuda.is_available():
args.cpu = True
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
import comfy.ldm.modules.attention as attention # noqa: E402
import comfy.ops as comfy_ops # noqa: E402
from comfy.ldm.seedvr.model import ( # noqa: E402
Cache,
NaMMRotaryEmbedding3d,
)
from comfy.ldm.seedvr.vae import ( # noqa: E402
causal_norm_wrapper,
set_norm_limit,
)
from comfy.ldm.modules.attention import var_attention_optimized_split # noqa: E402
# ---------------------------------------------------------------------------
# RoPE rewrite tests (test_seedvr_rope_rewrite.py)
# ---------------------------------------------------------------------------
# Test rig dimensions. dim=192 → per-axis rope dim = 64 (even, lucidrains
# requirement). vid_shape=(2,4,4) → L_vid = 32. txt_shape=(8,) → L_txt = 8.
_DIM = 192
_HEADS = 4
_VID_T, _VID_H, _VID_W = 2, 4, 4
_TXT_L = 8
_L_VID = _VID_T * _VID_H * _VID_W
_SEED = 0
def _make_inputs(dtype=torch.float32, device="cpu"):
"""Construct the 6 forward inputs + cache. Deterministic via local
Generator so global RNG state is not mutated.
"""
g = torch.Generator(device=device).manual_seed(_SEED)
vid_q = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
vid_k = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
txt_q = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
txt_k = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long, device=device)
txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long, device=device)
cache = Cache(disable=True)
return vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache
def _legacy_get_freqs(rope: NaMMRotaryEmbedding3d, vid_shape, txt_shape):
"""Reproduce the pre-rewrite ``get_freqs`` body verbatim against
``self.get_axial_freqs`` (parent ``RotaryEmbeddingBase`` method,
unchanged by the rewrite).
"""
max_temporal = 0
max_height = 0
max_width = 0
max_txt_len = 0
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
max_temporal = max(max_temporal, l + f)
max_height = max(max_height, h)
max_width = max(max_width, w)
max_txt_len = max(max_txt_len, l)
with torch.amp.autocast(device_type="cuda", enabled=False):
vid_freqs_full = rope.get_axial_freqs(
min(max_temporal + 16, 1024),
min(max_height + 4, 128),
min(max_width + 4, 128),
).float()
txt_freqs_full = rope.get_axial_freqs(min(max_txt_len + 16, 1024))
vid_freq_list, txt_freq_list = [], []
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
vid_freq = vid_freqs_full[l : l + f, :h, :w].reshape(-1, vid_freqs_full.size(-1))
txt_freq = txt_freqs_full[:l].repeat(1, 3).reshape(-1, vid_freqs_full.size(-1))
vid_freq_list.append(vid_freq)
txt_freq_list.append(txt_freq)
return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0)
def _legacy_forward(rope: NaMMRotaryEmbedding3d, vid_q, vid_k, vid_shape,
txt_q, txt_k, txt_shape):
"""Compute expected forward output via the unchanged
``apply_rotary_emb`` wrapper fed with legacy-shape freqs. This is the
oracle. The wrapper itself is out of scope for the rewrite (Shape B).
"""
vid_freqs, txt_freqs = _legacy_get_freqs(rope, vid_shape, txt_shape)
vid_freqs = vid_freqs.to(vid_q.device)
txt_freqs = txt_freqs.to(txt_q.device)
from einops import rearrange
vid_q = rearrange(vid_q, "L h d -> h L d")
vid_k = rearrange(vid_k, "L h d -> h L d")
vid_q_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype)
vid_k_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype)
vid_q_out = rearrange(vid_q_out, "h L d -> L h d")
vid_k_out = rearrange(vid_k_out, "h L d -> L h d")
txt_q = rearrange(txt_q, "L h d -> h L d")
txt_k = rearrange(txt_k, "L h d -> h L d")
txt_q_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype)
txt_k_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype)
txt_q_out = rearrange(txt_q_out, "h L d -> L h d")
txt_k_out = rearrange(txt_k_out, "h L d -> L h d")
return vid_q_out, vid_k_out, txt_q_out, txt_k_out
def test_namm_forward_output_tensor_equal_against_legacy_oracle():
rope = NaMMRotaryEmbedding3d(dim=_DIM)
vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache = _make_inputs()
expected_vid_q, expected_vid_k, expected_txt_q, expected_txt_k = _legacy_forward(
rope,
vid_q.clone(), vid_k.clone(), vid_shape,
txt_q.clone(), txt_k.clone(), txt_shape,
)
actual_vid_q, actual_vid_k, actual_txt_q, actual_txt_k = rope.forward(
vid_q.clone(), vid_k.clone(), vid_shape,
txt_q.clone(), txt_k.clone(), txt_shape, cache,
)
torch.testing.assert_close(actual_vid_q, expected_vid_q, rtol=0, atol=0,
msg="vid_q output diverges from wrapper oracle")
torch.testing.assert_close(actual_vid_k, expected_vid_k, rtol=0, atol=0,
msg="vid_k output diverges from wrapper oracle")
torch.testing.assert_close(actual_txt_q, expected_txt_q, rtol=0, atol=0,
msg="txt_q output diverges from wrapper oracle")
torch.testing.assert_close(actual_txt_k, expected_txt_k, rtol=0, atol=0,
msg="txt_k output diverges from wrapper oracle")
# ---------------------------------------------------------------------------
# GroupNorm limit tests (test_seedvr_groupnorm_limit.py)
# ---------------------------------------------------------------------------
_NUM_CHANNELS = 8
_NUM_GROUPS = 4
_TENSOR_SHAPE = (1, 8, 2, 4, 4)
_GROUPNORM_SUBCLASSES = [
pytest.param(comfy_ops.disable_weight_init.GroupNorm, id="disable_weight_init"),
pytest.param(comfy_ops.manual_cast.GroupNorm, id="manual_cast"),
]
@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES)
def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls):
real_group_norm = vae_mod.F.group_norm
set_norm_limit(1e-9)
try:
gn = groupnorm_cls(num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS)
gn.eval()
forward_hook_calls = []
def _hook(module, inputs, output):
forward_hook_calls.append(tuple(inputs[0].shape))
spy_calls = []
def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs):
spy_calls.append({"num_groups": int(num_groups_arg)})
return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs)
handle = gn.register_forward_hook(_hook)
try:
with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy):
out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE))
finally:
handle.remove()
full_calls = len(forward_hook_calls)
chunked_calls = sum(1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS)
assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE
assert full_calls == 0, (
f"low-limit GroupNorm gate must NOT take the full-forward path; got full_calls={full_calls}"
)
assert chunked_calls > 0, (
f"low-limit GroupNorm gate must take the chunked path; got chunked_calls={chunked_calls}"
)
finally:
set_norm_limit(None)
# ---------------------------------------------------------------------------
# SeedVR2 var_attention split-loop tests
# ---------------------------------------------------------------------------
def test_var_attention_registry_contains_always_available_entries():
assert (
attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_optimized_split"]
is attention.var_attention_optimized_split
)
def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch):
dim = 8
heads = 2
head_dim = 4
attn = seedvr_model.NaSwinAttention(
vid_dim=dim,
txt_dim=dim,
heads=heads,
head_dim=head_dim,
qk_bias=False,
qk_norm=seedvr_model.CustomRMSNorm,
qk_norm_eps=1e-6,
rope_type=None,
rope_dim=head_dim,
shared_weights=False,
window=(2, 1, 1),
window_method="720pwin_by_size_bysize",
version=True,
device="cpu",
dtype=torch.float32,
operations=comfy_ops.disable_weight_init,
)
generator = torch.Generator(device="cpu").manual_seed(11)
vid = torch.randn(8, dim, generator=generator)
txt = torch.randn(3, dim, generator=generator)
vid_shape = torch.tensor([[2, 2, 2]], dtype=torch.long)
txt_shape = torch.tensor([[3]], dtype=torch.long)
calls = []
def fake_optimized_var_attention(**kwargs):
calls.append(kwargs)
return kwargs["q"]
monkeypatch.setattr(seedvr_model, "optimized_var_attention", fake_optimized_var_attention)
vid_out, txt_out = attn(vid, txt, vid_shape, txt_shape, seedvr_model.Cache(disable=True))
assert tuple(vid_out.shape) == (8, dim)
assert tuple(txt_out.shape) == (3, dim)
assert len(calls) == 1
call = calls[0]
assert tuple(call["q"].shape) == (14, heads, head_dim)
assert tuple(call["k"].shape) == (14, heads, head_dim)
assert tuple(call["v"].shape) == (14, heads, head_dim)
assert call["heads"] == heads
assert call["skip_reshape"] is True
assert call["skip_output_reshape"] is True
torch.testing.assert_close(
call["cu_seqlens_q"],
torch.tensor([0, 7, 14], dtype=torch.int32),
rtol=0,
atol=0,
)
torch.testing.assert_close(
call["cu_seqlens_k"],
torch.tensor([0, 7, 14], dtype=torch.int32),
rtol=0,
atol=0,
)
def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch):
heads = 2
head_dim = 3
q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim)
k = q + 100
v = q + 200
cu = torch.tensor([0, 2, 5], dtype=torch.int32)
calls = []
def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs):
calls.append(
{
"q_shape": tuple(q_arg.shape),
"k_shape": tuple(k_arg.shape),
"v_shape": tuple(v_arg.shape),
"heads": heads_arg,
"kwargs": kwargs,
}
)
return q_arg + v_arg
monkeypatch.setattr(attention, "optimized_attention", fake_optimized_attention)
out = var_attention_optimized_split(
q,
k,
v,
heads,
cu,
cu,
skip_reshape=True,
skip_output_reshape=True,
)
assert tuple(out.shape) == (5, heads, head_dim)
assert len(calls) == 2
assert calls[0]["q_shape"] == (1, heads, 2, head_dim)
assert calls[1]["q_shape"] == (1, heads, 3, head_dim)
assert all(call["heads"] == heads for call in calls)
assert all(call["kwargs"]["skip_reshape"] is True for call in calls)
assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls)
torch.testing.assert_close(out, q + v, rtol=0, atol=0)
def test_var_attention_optimized_split_rejects_bad_offsets():
q = torch.randn(5, 2, 3)
cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32)
cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32)
with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"):
var_attention_optimized_split(
q,
q,
q,
2,
cu_bad,
cu_ok,
skip_reshape=True,
skip_output_reshape=True,
)

View File

@ -1,308 +0,0 @@
"""Consolidated SeedVR2 model/graph/forward regression tests.
Merged from:
- seedvr_model_test.py
- test_seedvr_7b_final_block_text_path.py
- test_seedvr_forward_no_device_cast.py
- test_seedvr_latent_format.py
- test_seedvr2_vae_graph_boundaries.py
"""
from __future__ import annotations
from unittest.mock import MagicMock
import torch
from torch import nn
from comfy.cli_args import args
if not torch.cuda.is_available():
args.cpu = True
import comfy # noqa: E402
import comfy.latent_formats # noqa: E402
import comfy.ldm.seedvr.model # noqa: E402
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
import comfy.model_management # noqa: E402
import comfy.sample # noqa: E402
import comfy.sd as sd_mod # noqa: E402
import nodes as nodes_mod # noqa: E402
from comfy.ldm.seedvr.model import NaDiT # noqa: E402
# ---------------------------------------------------------------------------
# Helpers from seedvr_model_test.py
# ---------------------------------------------------------------------------
def _make_standin(positive_conditioning):
class _StandIn(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"positive_conditioning", positive_conditioning
)
_resolve_text_conditioning = NaDiT._resolve_text_conditioning
return _StandIn()
# ---------------------------------------------------------------------------
# Helpers from test_seedvr_7b_final_block_text_path.py
# ---------------------------------------------------------------------------
class _StubModule(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]:
flags = []
class _Block(_StubModule):
def __init__(self, *args, **kwargs):
flags.append(kwargs["is_last_layer"])
super().__init__()
monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule)
monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule)
monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule)
monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block)
seedvr_model.NaDiT(
norm_eps=1e-5,
qk_rope=None,
num_layers=4,
mlp_type="normal",
vid_dim=vid_dim,
txt_in_dim=txt_in_dim,
heads=24,
mm_layers=3,
)
return flags
# ---------------------------------------------------------------------------
# Helpers from test_seedvr_latent_format.py
# ---------------------------------------------------------------------------
class _Model:
def __init__(self, latent_format):
self._latent_format = latent_format
def get_model_object(self, name):
assert name == "latent_format"
return self._latent_format
# ---------------------------------------------------------------------------
# Helpers from test_seedvr2_vae_graph_boundaries.py
# ---------------------------------------------------------------------------
class _Patcher:
def get_free_memory(self, device):
return 1024 * 1024 * 1024
class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
def __init__(self, encoded):
nn.Module.__init__(self)
self.encoded = encoded
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.seen = []
def encode(self, x):
self.seen.append(tuple(x.shape))
return self.encoded.to(device=x.device, dtype=x.dtype)
class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
def __init__(self):
nn.Module.__init__(self)
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.calls = []
def decode(self, z, seedvr2_tiling=None):
self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling})
if z.ndim == 4:
b, tc, h, w = z.shape
t = tc // 16
else:
b, _, t, h, w = z.shape
return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
def _make_vae(wrapper):
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = wrapper
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.latent_channels = 16
vae.latent_dim = 3
vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8)
vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
vae.output_channels = 3
vae.disable_offload = True
vae.extra_1d_channel = None
vae.crop_input = False
vae.not_video = False
vae.patcher = _Patcher()
vae.process_input = lambda image: image
vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0)
vae.vae_output_dtype = lambda: torch.float32
vae.memory_used_encode = lambda shape, dtype: 1
vae.memory_used_decode = lambda shape, dtype: 1
vae.throw_exception_if_invalid = lambda: None
vae.vae_encode_crop_pixels = lambda pixels: pixels
vae.spacial_compression_decode = lambda: 8
vae.temporal_compression_decode = lambda: 4
return vae
# ---------------------------------------------------------------------------
# Tests from seedvr_model_test.py
# ---------------------------------------------------------------------------
def test_missing_context_falls_back_to_positive_buffer():
"""AC: ``context is None`` falls back to the registered
``positive_conditioning`` buffer and runs to completion no
silent zero substitution, no raised exception.
"""
pos_buffer = torch.full((58, 5120), 7.0)
standin = _make_standin(pos_buffer)
txt, txt_shape = standin._resolve_text_conditioning(None)
assert txt.shape == (58, 5120)
assert (txt == 7.0).all(), (
"fallback path must use the positive_conditioning buffer "
"verbatim, not a zero tensor"
)
assert txt_shape.shape == (1, 1)
assert txt_shape[0, 0].item() == 58
# ---------------------------------------------------------------------------
# Tests from test_seedvr_7b_final_block_text_path.py
# ---------------------------------------------------------------------------
def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch):
assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [
False,
False,
False,
False,
]
def test_seedvr2_7b_rope3d_matches_wrapper_oracle():
rope = seedvr_model.get_na_rope("rope3d", dim=64)
generator = torch.Generator(device="cpu").manual_seed(0)
q = torch.randn(4, 2, 128, generator=generator)
k = torch.randn(4, 2, 128, generator=generator)
shape = torch.tensor([[1, 2, 2]], dtype=torch.long)
freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1)
expected_q = seedvr_model._apply_seedvr2_rotary_emb(
freqs,
q.permute(1, 0, 2).float(),
).to(q.dtype).permute(1, 0, 2)
expected_k = seedvr_model._apply_seedvr2_rotary_emb(
freqs,
k.permute(1, 0, 2).float(),
).to(k.dtype).permute(1, 0, 2)
actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True))
torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0)
torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0)
# ---------------------------------------------------------------------------
# Tests from test_seedvr_latent_format.py
# ---------------------------------------------------------------------------
def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion():
latent_format = comfy.latent_formats.SeedVR2()
latent_image = torch.zeros(1, 1, 4, 5)
fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image)
assert latent_format.latent_channels == 16
assert latent_format.latent_dimensions == 2
assert fixed.shape == (1, 16, 4, 5)
# ---------------------------------------------------------------------------
# Tests from test_seedvr2_vae_graph_boundaries.py
# ---------------------------------------------------------------------------
def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch):
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
encoded = torch.full((1, 16, 2, 4, 5), 2.0)
vae = _make_vae(_EncodeWrapper(encoded))
pixels = torch.zeros(1, 5, 32, 40, 3)
node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0]
node_latent = node_output["samples"]
assert set(node_output) == {"samples"}
assert tuple(node_latent.shape) == (1, 16, 2, 4, 5)
assert node_latent.dtype == torch.float32
assert node_latent.stride()[-1] == 1
assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152))
tiled = torch.full((1, 16, 2, 4, 5), 3.0)
monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled))
tiled_output = nodes_mod.VAEEncodeTiled().encode(
vae,
pixels,
tile_size=512,
overlap=64,
temporal_size=16,
temporal_overlap=4,
)[0]
tiled_latent = tiled_output["samples"]
assert set(tiled_output) == {"samples"}
assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5)
assert tiled_latent.dtype == torch.float32
assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152))
def test_vaedecode_tiled_visible_inputs_are_seedvr2_decode_tiling_authority(monkeypatch):
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
vae = _make_vae(_DecodeWrapper())
nodes_mod.VAEDecodeTiled().decode(
vae,
{"samples": torch.zeros(1, 16, 2, 4, 5)},
tile_size=512,
overlap=64,
temporal_size=16,
temporal_overlap=4,
)
assert vae.first_stage_model.calls == [
{
"shape": (1, 16, 2, 4, 5),
"seedvr2_tiling": {
"enable_tiling": True,
"tile_size": (512, 512),
"tile_overlap": (64, 64),
"temporal_size": 16,
"temporal_overlap": 4,
},
}
]

View File

@ -1,91 +0,0 @@
from unittest.mock import patch
import pytest
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
from comfy_extras import nodes_seedvr # noqa: E402
def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper:
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
vae_mod.VideoAutoencoderKLWrapper
)
nn.Module.__init__(wrapper)
return wrapper
def _fingerprint_decode_(self, z, return_dict=True):
b = int(z.shape[0])
t = int(z.shape[2])
h = int(z.shape[3])
w = int(z.shape[4])
out = torch.empty(b, 3, t, h * 8, w * 8)
for batch_idx in range(b):
out[batch_idx].fill_(float(batch_idx + 1))
return out
def _decode_with_patches(wrapper, z):
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_):
return wrapper.decode(z)
def test_decode_b2_t3_multi_frame_batch_unchanged():
wrapper = _make_wrapper()
out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2))
assert tuple(out.shape) == (2, 3, 3, 16, 16)
class _Wrapper(vae_mod.VideoAutoencoderKLWrapper):
def __init__(self):
nn.Module.__init__(self)
self.calls = []
def parameters(self):
return iter([torch.nn.Parameter(torch.zeros(()))])
def _decode_stub(self, latent):
self.calls.append(tuple(latent.shape))
return torch.zeros(latent.shape[0], 3, latent.shape[2], latent.shape[3] * 8, latent.shape[4] * 8)
def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preprocessor_state():
wrapper = _Wrapper()
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5))
assert tuple(out.shape) == (1, 3, 2, 32, 40)
assert wrapper.calls == [(1, 16, 2, 4, 5)]
def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents():
wrapper = _Wrapper()
with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"):
wrapper.decode(torch.zeros(1, 16, 4))
def _t_padded(t_in: int) -> int:
if t_in == 1:
return 1
if t_in <= 4:
return 5
if (t_in - 1) % 4 == 0:
return t_in
return t_in + (4 - ((t_in - 1) % 4))
@pytest.mark.parametrize("t_in", [1, 5, 9])
def test_t_padded_matches_cut_videos(t_in):
dummy = torch.zeros(1, t_in, 1, 1, 1)
assert nodes_seedvr.cut_videos(dummy).shape[1] == _t_padded(t_in)

View File

@ -1,347 +0,0 @@
from contextlib import ExitStack
from unittest.mock import MagicMock, patch
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
import comfy.sd as sd_mod # noqa: E402
from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402
# ---------------------------------------------------------------------------
# From test_seedvr_vae_tiled_decode_latent_min_size_override.py
# ---------------------------------------------------------------------------
def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae
class StubVAEModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.slicing_latent_min_size = 2
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self.use_slicing = True
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
self.decode_min_sizes = []
self.memory_states = []
def decode_(self, t_chunk):
self.decode_min_sizes.append(self.slicing_latent_min_size)
return VideoAutoencoderKL.slicing_decode(self, t_chunk)
def _decode(self, z, memory_state=MemoryState.DISABLED):
self.memory_states.append(memory_state)
b, c, d, h, w = z.shape
return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype)
vae = StubVAEModel()
z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32)
tiled_vae(
z,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=0,
temporal_overlap=0,
encode=False,
)
assert vae.decode_min_sizes == [5]
assert vae.memory_states == [MemoryState.DISABLED]
assert vae.slicing_latent_min_size == 2
# ---------------------------------------------------------------------------
# From test_seedvr_vae_tiled_encode_runt_slice_override.py
# ---------------------------------------------------------------------------
def test_zero_temporal_size_preserves_min_size_when_encode_raises():
from comfy.ldm.seedvr.vae import tiled_vae
class RaisingVAEModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.slicing_sample_min_size = 4
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
def encode(self, t_chunk):
raise RuntimeError("simulated encode failure")
vae = RaisingVAEModel()
x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32)
raised = False
try:
tiled_vae(
x,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=0,
temporal_overlap=0,
encode=True,
)
except RuntimeError as exc:
if "simulated encode failure" not in str(exc):
raise
raised = True
assert raised
assert vae.slicing_sample_min_size == 4
# ---------------------------------------------------------------------------
# From test_seedvr_vae_tiled_temporal_slicing.py
# ---------------------------------------------------------------------------
class _SlicingDecodeVAE(nn.Module):
def __init__(self, slicing_latent_min_size):
super().__init__()
self.slicing_latent_min_size = slicing_latent_min_size
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self.use_slicing = True
self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32))
self.decode_min_sizes = []
self.memory_states = []
def decode_(self, z):
self.decode_min_sizes.append(self.slicing_latent_min_size)
return vae_mod.VideoAutoencoderKL.slicing_decode(self, z)
def _decode(self, z, memory_state=MemoryState.DISABLED):
self.memory_states.append(memory_state)
x = z[:, :1].repeat(
1,
3,
1,
self.spatial_downsample_factor,
self.spatial_downsample_factor,
)
return x
def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size():
vae = _SlicingDecodeVAE(slicing_latent_min_size=2)
z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8)
tiled_vae(
z,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=12,
temporal_overlap=4,
encode=False,
)
assert vae.decode_min_sizes == [2]
assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE]
assert vae.slicing_latent_min_size == 2
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
vae_mod.VideoAutoencoderKLWrapper
)
nn.Module.__init__(wrapper)
seedvr2_tiling = {
"enable_tiling": True,
"tile_size": (64, 64),
"tile_overlap": (0, 0),
"temporal_size": 8,
"temporal_overlap": 7,
}
captured = {}
def _fake_tiled_vae(latent, model, **kwargs):
captured.update(kwargs)
return torch.zeros(1, 3, 1, 16, 16)
with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae):
wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling)
assert captured["temporal_overlap"] == 7
# ---------------------------------------------------------------------------
# From test_vae_decode_tiled_dispatcher_seedvr2_4d.py
# ---------------------------------------------------------------------------
def _force_oom(*a, **k):
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
def _make_vae(first_stage_model, latent_channels, latent_dim):
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = first_stage_model
vae.patcher = MagicMock()
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
vae.device = vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.disable_offload = True
vae.extra_1d_channel = None
vae.upscale_ratio = vae.downscale_ratio = 8
vae.upscale_index_formula = vae.downscale_index_formula = None
vae.output_channels = 3
vae.latent_channels = latent_channels
vae.latent_dim = latent_dim
vae.vae_output_dtype = lambda: torch.float32
vae.spacial_compression_decode = lambda: 8
vae.process_input = lambda x: x
vae.process_output = lambda x: x
vae.throw_exception_if_invalid = lambda: None
vae.memory_used_decode = lambda *a, **k: 1
return vae
def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode):
mm = sd_mod.model_management
with ExitStack() as stack:
stack.enter_context(patch.object(mm, "raise_non_oom", lambda e: None))
stack.enter_context(patch.object(mm, "load_models_gpu", lambda *a, **k: None))
stack.enter_context(patch.object(mm, "soft_empty_cache", lambda: None))
stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_seedvr2", seedvr2_call))
stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_", generic_call))
if patch_wrapper_decode:
stack.enter_context(patch.object(
seedvr_vae_mod.VideoAutoencoderKLWrapper, "decode",
side_effect=_force_oom))
vae.decode(samples)
def test_4d_seedvr2_latent_routes_to_decode_tiled_seedvr2():
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
seedvr_vae_mod.VideoAutoencoderKLWrapper)
vae = _make_vae(wrapper, latent_channels=16, latent_dim=3)
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
_dispatch(vae, torch.zeros(1, 16 * 3, 8, 8), seedvr2_call, generic_call, True)
assert seedvr2_call.call_count == 1
assert generic_call.call_count == 0
def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled():
first_stage = MagicMock()
first_stage.decode = MagicMock(side_effect=_force_oom)
vae = _make_vae(first_stage, latent_channels=4, latent_dim=2)
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
_dispatch(vae, torch.zeros(1, 4, 8, 8), seedvr2_call, generic_call, False)
assert generic_call.call_count == 1
assert seedvr2_call.call_count == 0
# ---------------------------------------------------------------------------
# From test_vae_encode_tiled_fallback_dispatcher_seedvr2.py
# ---------------------------------------------------------------------------
def _populate_common_vae_attrs_fallback(vae):
vae.patcher = MagicMock()
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.disable_offload = True
vae.extra_1d_channel = None
vae.upscale_ratio = 8
vae.upscale_index_formula = None
vae.output_channels = 3
vae.latent_channels = 16
vae.latent_dim = 3
vae.downscale_ratio = 8
vae.downscale_index_formula = None
vae.not_video = False
vae.crop_input = False
vae.pad_channel_value = None
vae.vae_output_dtype = lambda: torch.float32
vae.spacial_compression_encode = lambda: 8
vae.process_input = lambda x: x
vae.process_output = lambda x: x
vae.throw_exception_if_invalid = lambda: None
vae.memory_used_encode = lambda *a, **k: 1
def _make_seedvr2_vae_fallback():
vae = sd_mod.VAE.__new__(sd_mod.VAE)
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
seedvr_vae_mod.VideoAutoencoderKLWrapper
)
vae.first_stage_model = wrapper
_populate_common_vae_attrs_fallback(vae)
return vae
def _make_non_seedvr2_vae_fallback():
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = MagicMock()
_populate_common_vae_attrs_fallback(vae)
return vae
def _force_regular_encode_oom(*args, **kwargs):
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
def test_seedvr2_3d_routes_to_encode_tiled_seedvr2_on_oom():
vae = _make_seedvr2_vae_fallback()
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
with patch.object(sd_mod.model_management, "raise_non_oom",
lambda e: None), \
patch.object(sd_mod.model_management, "load_models_gpu",
lambda *a, **k: None), \
patch.object(sd_mod.model_management, "soft_empty_cache",
lambda: None), \
patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode",
side_effect=_force_regular_encode_oom), \
patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call,
create=True), \
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
vae.encode(pixel_samples)
assert seedvr2_call.call_count == 1, (
f"Expected encode_tiled_seedvr2 to be called once for a SeedVR2 3D "
f"input under OOM fallback; got {seedvr2_call.call_count} calls."
)
assert generic_call.call_count == 0, (
f"encode_tiled_3d must NOT be called for a SeedVR2 input; got "
f"{generic_call.call_count} calls."
)
def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete():
vae = _make_non_seedvr2_vae_fallback()
vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8)
vae.upscale_ratio = (lambda a: a * 4, 8, 8)
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
with patch.object(sd_mod.model_management, "load_models_gpu",
lambda *a, **k: None), \
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
vae.encode_tiled(pixel_samples)
assert generic_call.call_args.kwargs["overlap"] == (1, 64, 64)

View File

@ -1,126 +0,0 @@
"""Unit tests for ``comfy_extras.nodes_seedvr.SeedVR2ProgressiveSampler``."""
from unittest.mock import patch
import pytest
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.sample # noqa: E402
import comfy_extras.nodes_seedvr as nodes_seedvr_mod # noqa: E402
from comfy_extras.nodes_seedvr import SeedVR2ProgressiveSampler # noqa: E402
_LAT_C = 16
_COND_C = 17
def _make_inputs(B: int = 1, T: int = 5, H: int = 8, W: int = 8):
"""Build minimal SeedVR2-shaped sampling inputs."""
samples_5d = torch.arange(
B * _LAT_C * T * H * W, dtype=torch.float32
).reshape(B, _LAT_C, T, H, W)
samples = samples_5d.reshape(B, _LAT_C * T, H, W).contiguous()
cond_5d = torch.arange(
B * _COND_C * T * H * W, dtype=torch.float32
).reshape(B, _COND_C, T, H, W) + 10000.0
cond = cond_5d.reshape(B, _COND_C * T, H, W).contiguous()
text_pos = torch.zeros(1, 4, 32)
text_neg = torch.zeros(1, 4, 32)
positive = [[text_pos, {"condition": cond.clone()}]]
negative = [[text_neg, {"condition": cond.clone()}]]
latent_image = {"samples": samples}
return latent_image, positive, negative, samples_5d, cond_5d
def _identity_fix_empty(model, latent_image, downscale_ratio_spacial=None):
return latent_image
def _fingerprinted_prepare_noise(latent_image, seed, batch_inds=None):
"""Return a tensor whose values encode ``(seed, position)``."""
base = torch.arange(
latent_image.numel(), dtype=torch.float32
).reshape(latent_image.shape)
return base + float(seed) * 1e6
def test_progressive_sampler_schema_exposes_manual_default_auto_chunking():
schema = SeedVR2ProgressiveSampler.define_schema()
inputs = {item.id: item for item in schema.inputs}
assert inputs["chunking_mode"].options == ["manual", "auto"]
assert inputs["chunking_mode"].default == "manual"
def test_auto_chunking_walks_two_three_four_chunk_ladder():
"""Auto mode must walk 2-, 3-, then 4-chunk geometries on OOM."""
latent, pos, neg, _, _ = _make_inputs(T=17)
calls = []
def _oom_until_four_chunks(model, noise, steps, cfg, sampler_name,
scheduler, positive, negative,
latent_image, denoise=1.0,
noise_mask=None, seed=None):
calls.append(tuple(latent_image.shape))
if latent_image.shape[1] > _LAT_C * 5:
raise torch.cuda.OutOfMemoryError("chunk too large")
return latent_image.clone()
with patch.object(comfy.sample, "sample",
side_effect=_oom_until_four_chunks), \
patch.object(comfy.sample, "fix_empty_latent_channels",
side_effect=_identity_fix_empty), \
patch.object(comfy.sample, "prepare_noise",
side_effect=_fingerprinted_prepare_noise), \
patch.object(nodes_seedvr_mod.comfy.model_management,
"soft_empty_cache") as soft_empty:
out = SeedVR2ProgressiveSampler.execute(
model=None, seed=0, steps=2, cfg=1.0,
sampler_name="euler", scheduler="simple",
positive=pos, negative=neg, latent=latent,
denoise=1.0, frames_per_chunk=65, temporal_overlap=0,
chunking_mode="auto",
)
assert calls[:4] == [
(1, _LAT_C * 17, 8, 8),
(1, _LAT_C * 9, 8, 8),
(1, _LAT_C * 6, 8, 8),
(1, _LAT_C * 5, 8, 8),
]
assert torch.equal(out.result[0]["samples"], latent["samples"])
assert soft_empty.call_count == 3
@pytest.mark.parametrize("bad_chunk", [0, -1, 2])
def test_t3_invalid_frames_per_chunk_raises_value_error(bad_chunk):
"""``frames_per_chunk`` violating 4n+1 (or <1) must raise ``ValueError`` before any model invocation."""
latent, pos, neg, _, _ = _make_inputs(T=5)
sampler_called = {"n": 0}
def _should_not_be_called(*args, **kwargs):
sampler_called["n"] += 1
return torch.zeros(1)
with patch.object(comfy.sample, "sample",
side_effect=_should_not_be_called), \
patch.object(comfy.sample, "fix_empty_latent_channels",
side_effect=_identity_fix_empty), \
patch.object(comfy.sample, "prepare_noise",
side_effect=_fingerprinted_prepare_noise):
with pytest.raises(ValueError) as excinfo:
SeedVR2ProgressiveSampler.execute(
model=None, seed=0, steps=2, cfg=1.0,
sampler_name="euler", scheduler="simple",
positive=pos, negative=neg, latent=latent,
denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0,
)
assert str(bad_chunk) in str(excinfo.value)
assert sampler_called["n"] == 0