This commit is contained in:
John Pollock 2026-07-03 12:36:43 +08:00 committed by GitHub
commit cd26e58f2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 5451 additions and 28 deletions

View File

@ -779,6 +779,10 @@ class ACEAudio(LatentFormat):
latent_channels = 8 latent_channels = 8
latent_dimensions = 2 latent_dimensions = 2
class SeedVR2(LatentFormat):
latent_channels = 16
latent_dimensions = 3
class ACEAudio15(LatentFormat): class ACEAudio15(LatentFormat):
latent_channels = 64 latent_channels = 64
latent_dimensions = 1 latent_dimensions = 1

View File

@ -22,7 +22,7 @@ def torch_cat_if_needed(xl, dim):
else: else:
return None return None
def get_timestep_embedding(timesteps, embedding_dim): def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1):
""" """
This matches the implementation in Denoising Diffusion Probabilistic Models: This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq. From Fairseq.
@ -33,11 +33,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
assert len(timesteps.shape) == 1 assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2 half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1) emb = math.log(10000) / (half_dim - downscale_freq_shift)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device) emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :] emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
if embedding_dim % 2 == 1: # zero pad if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0)) emb = torch.nn.functional.pad(emb, (0,1,0,0))
return emb return emb

View File

@ -0,0 +1,51 @@
import torch
from comfy.ldm.modules import attention as _attention
def _var_attention_qkv(q, k, v, heads, skip_reshape):
if skip_reshape:
return q, k, v, q.shape[-1]
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
return (
q.view(total_tokens, heads, head_dim),
k.view(k.shape[0], heads, head_dim),
v.view(v.shape[0], heads, head_dim),
head_dim,
)
def _var_attention_output(out, heads, head_dim, skip_output_reshape):
if skip_output_reshape:
return out
return out.reshape(-1, heads * head_dim)
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
q_split_indices = cu_seqlens_q[1:-1]
k_split_indices = cu_seqlens_k[1:-1]
if k.shape[0] != v.shape[0]:
raise ValueError("cu_seqlens_k does not match v token count")
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits):
raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count")
out = []
for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits):
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
out_i = _attention.optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
out.append(out_i.squeeze(0).permute(1, 0, 2))
out = torch.cat(out, dim=0)
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
optimized_var_attention = var_attention_optimized_split

View File

@ -0,0 +1,301 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from comfy.ldm.seedvr.constants import (
CIELAB_DELTA,
CIELAB_KAPPA,
D65_WHITE_X,
D65_WHITE_Z,
WAVELET_DECOMP_LEVELS,
)
def wavelet_blur(image: Tensor, radius):
max_safe_radius = max(1, min(image.shape[-2:]) // 8)
if radius > max_safe_radius:
radius = max_safe_radius
num_channels = image.shape[1]
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
image = F.pad(image, (radius, radius, radius, radius), mode='replicate')
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS):
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq.add_(image).sub_(low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
if content_feat.shape != style_feat.shape:
if len(content_feat.shape) >= 3:
style_feat = F.interpolate(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False
)
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq
if content_high_freq.shape != style_low_freq.shape:
style_low_freq = F.interpolate(
style_low_freq,
size=content_high_freq.shape[-2:],
mode='bilinear',
align_corners=False
)
content_high_freq.add_(style_low_freq)
return content_high_freq.clamp_(-1.0, 1.0)
def _histogram_matching_channel(source: Tensor, reference: Tensor) -> Tensor:
original_shape = source.shape
source_flat = source.flatten()
reference_flat = reference.flatten()
source_sorted, source_indices = torch.sort(source_flat)
reference_sorted, _ = torch.sort(reference_flat)
del reference_flat
n_source = len(source_sorted)
n_reference = len(reference_sorted)
if n_source == n_reference:
matched_sorted = reference_sorted
else:
source_quantiles = torch.linspace(0, 1, n_source, device=source.device)
ref_indices = (source_quantiles * (n_reference - 1)).long()
ref_indices.clamp_(0, n_reference - 1)
matched_sorted = reference_sorted[ref_indices]
del source_quantiles, ref_indices, reference_sorted
del source_sorted, source_flat
inverse_indices = torch.argsort(source_indices)
del source_indices
matched_flat = matched_sorted[inverse_indices]
del matched_sorted, inverse_indices
return matched_flat.reshape(original_shape)
def _lab_to_rgb_batch(lab: Tensor, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
fy = (L + 16.0) / 116.0
fx = a.div(500.0).add_(fy)
fz = fy - b / 200.0
del L, a, b
x = torch.where(
fx > epsilon,
torch.pow(fx, 3.0),
fx.mul(116.0).sub_(16.0).div_(kappa)
)
y = torch.where(
fy > epsilon,
torch.pow(fy, 3.0),
fy.mul(116.0).sub_(16.0).div_(kappa)
)
z = torch.where(
fz > epsilon,
torch.pow(fz, 3.0),
fz.mul(116.0).sub_(16.0).div_(kappa)
)
del fx, fy, fz
x.mul_(D65_WHITE_X)
z.mul_(D65_WHITE_Z)
xyz = torch.stack([x, y, z], dim=1)
del x, y, z
B, _, H, W = xyz.shape
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
del xyz
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
del xyz_flat
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
del rgb_linear_flat
mask = rgb_linear > 0.0031308
rgb = torch.where(
mask,
torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055),
rgb_linear * 12.92
)
del mask, rgb_linear
return torch.clamp(rgb, 0.0, 1.0)
def _rgb_to_lab_batch(rgb: Tensor, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
mask = rgb > 0.04045
rgb_linear = torch.where(
mask,
torch.pow((rgb + 0.055) / 1.055, 2.4),
rgb / 12.92
)
del mask
B, _, H, W = rgb_linear.shape
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
del rgb_linear
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
xyz_flat = torch.matmul(rgb_flat, matrix.T)
del rgb_flat
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
del xyz_flat
xyz[:, 0].div_(D65_WHITE_X)
xyz[:, 2].div_(D65_WHITE_Z)
epsilon_cubed = epsilon ** 3
mask = xyz > epsilon_cubed
f_xyz = torch.where(
mask,
torch.pow(xyz, 1.0 / 3.0),
xyz.mul(kappa).add_(16.0).div_(116.0)
)
del xyz, mask
L = f_xyz[:, 1].mul(116.0).sub_(16.0)
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0)
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0)
del f_xyz
return torch.stack([L, a, b], dim=1)
def lab_color_transfer(
content_feat: Tensor,
style_feat: Tensor,
luminance_weight: float = 0.8
) -> Tensor:
content_feat = wavelet_reconstruction(content_feat, style_feat)
if content_feat.shape != style_feat.shape:
style_feat = F.interpolate(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False
)
device = content_feat.device
original_dtype = content_feat.dtype
content_feat = content_feat.float()
style_feat = style_feat.float()
rgb_to_xyz_matrix = torch.tensor([
[0.4124564, 0.3575761, 0.1804375],
[0.2126729, 0.7151522, 0.0721750],
[0.0193339, 0.1191920, 0.9503041]
], dtype=torch.float32, device=device)
xyz_to_rgb_matrix = torch.tensor([
[ 3.2404542, -1.5371385, -0.4985314],
[-0.9692660, 1.8760108, 0.0415560],
[ 0.0556434, -0.2040259, 1.0572252]
], dtype=torch.float32, device=device)
epsilon = CIELAB_DELTA
kappa = CIELAB_KAPPA
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
content_lab = _rgb_to_lab_batch(content_feat, rgb_to_xyz_matrix, epsilon, kappa)
del content_feat
style_lab = _rgb_to_lab_batch(style_feat, rgb_to_xyz_matrix, epsilon, kappa)
del style_feat, rgb_to_xyz_matrix
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1])
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2])
if luminance_weight < 1.0:
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0])
result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight))
del matched_L
else:
result_L = content_lab[:, 0]
del content_lab, style_lab
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
del result_L, matched_a, matched_b
result_rgb = _lab_to_rgb_batch(result_lab, xyz_to_rgb_matrix, epsilon, kappa)
del result_lab, xyz_to_rgb_matrix
result = result_rgb.mul_(2.0).sub_(1.0)
del result_rgb
result = result.to(original_dtype)
return result
def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor:
return wavelet_reconstruction(content_feat, style_feat)
def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor:
if content_feat.shape != style_feat.shape:
style_feat = F.interpolate(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False,
)
original_dtype = content_feat.dtype
content_feat = content_feat.float()
style_feat = style_feat.float()
b, c = content_feat.shape[:2]
content_flat = content_feat.reshape(b, c, -1)
style_flat = style_feat.reshape(b, c, -1)
content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1)
content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1)
style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
del content_flat, style_flat
normalized = (content_feat - content_mean) / content_std
del content_mean, content_std
result = normalized * style_std + style_mean
del normalized, style_mean, style_std
result = result.clamp_(-1.0, 1.0)
if result.dtype != original_dtype:
result = result.to(original_dtype)
return result

View File

@ -0,0 +1,38 @@
"""SeedVR2 constants."""
SEEDVR2_7B_VID_DIM = 3072
SEEDVR2_OOM_BACKOFF_DIVISOR = 2
SEEDVR2_DTYPE_BYTES_FLOOR = 4
SEEDVR2_7B_MLP_CHUNK = 8192
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk.
SEEDVR2_LATENT_CHANNELS = 16
SEEDVR2_COLOR_MEM_HEADROOM = 0.75
SEEDVR2_LAB_SCALE_MULTIPLIER = 13
SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path.
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57.
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0
BYTEDANCE_VAE_CONV_MEM_GIB = 0.5
BYTEDANCE_VAE_NORM_MEM_GIB = 0.5
BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28.
BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28.
BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16).
BYTEDANCE_GN_CHUNKS_FP32 = 2 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp32).
BYTEDANCE_BLOCK_OUT_CHANNELS = (128, 256, 512, 512) # s8_c16_t4_inflation_sd3.yaml:7-11.
BYTEDANCE_SLICING_SAMPLE_MIN = 4 # s8_c16_t4_inflation_sd3.yaml:22 (slicing_sample_min_size).
BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE = 4 # infer.py:230 (temporal_downsample_factor); the 4n+1 factor.
BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE = 8 # infer.py:231 (spatial_downsample_factor).
BYTEDANCE_720P_REF_AREA = 45 * 80 # dit_v2/window.py:32 (720p reference area for window scaling).
BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal window frames).
BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency).
BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim).
ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864.
CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta).
CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa).
D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1).
D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn.
WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR).

1357
comfy/ldm/seedvr/model.py Normal file

File diff suppressed because it is too large Load Diff

1613
comfy/ldm/seedvr/vae.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -55,6 +55,7 @@ import comfy.ldm.pixeldit.model
import comfy.ldm.pixeldit.pid import comfy.ldm.pixeldit.pid
import comfy.ldm.ace.model import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2 import comfy.ldm.omnigen.omnigen2
import comfy.ldm.seedvr.model
import comfy.ldm.boogu.model import comfy.ldm.boogu.model
import comfy.ldm.qwen_image.model import comfy.ldm.qwen_image.model
import comfy.ldm.ideogram4.model import comfy.ldm.ideogram4.model
@ -932,6 +933,17 @@ class HunyuanDiT(BaseModel):
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
return out return out
class SeedVR2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.seedvr.model.NaDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
condition = kwargs.get("condition", None)
if condition is not None:
out["condition"] = comfy.conds.CONDRegular(condition)
return out
class PixArt(BaseModel): class PixArt(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None): def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS) super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)

View File

@ -598,6 +598,44 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config return dit_config
seedvr2_7b_separate_key = "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)
if seedvr2_7b_separate_key in state_dict_keys and state_dict[seedvr2_7b_separate_key].shape[1] == 3072: # seedvr2 7b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 3072
dit_config["heads"] = 24
dit_config["num_layers"] = 36
# This checkpoint uses separate vid/txt MMModule keys in every block.
dit_config["mm_layers"] = 36
dit_config["norm_eps"] = 1e-5
dit_config["rope_type"] = "rope3d"
dit_config["rope_dim"] = 64
dit_config["mlp_type"] = "normal"
return dit_config
if "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 3072
dit_config["heads"] = 24
dit_config["num_layers"] = 36
# This checkpoint uses shared all.* MMModule keys after the initial blocks.
dit_config["mm_layers"] = 10
dit_config["norm_eps"] = 1e-5
dit_config["rope_type"] = "rope3d"
dit_config["rope_dim"] = 64
dit_config["mlp_type"] = "swiglu"
return dit_config
if "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 2560
dit_config["heads"] = 20
dit_config["num_layers"] = 32
dit_config["norm_eps"] = 1.0e-05
dit_config["mlp_type"] = "swiglu"
dit_config["vid_out_norm"] = True
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {} dit_config = {}
dit_config["image_model"] = "wan2.1" dit_config["image_model"] = "wan2.1"
@ -1118,10 +1156,24 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
unet_config["heatmap_head"] = True unet_config["heatmap_head"] = True
return unet_config return unet_config
def normalize_seedvr2_unet_config(unet_config):
if unet_config.get("image_model") != "seedvr2" or "num_heads" not in unet_config:
return unet_config
def model_config_from_unet_config(unet_config, state_dict=None): unet_config = dict(unet_config)
num_heads = unet_config.pop("num_heads")
if "heads" in unet_config and unet_config["heads"] != num_heads:
raise ValueError(
f"SeedVR2 config has conflicting heads={unet_config['heads']} and num_heads={num_heads}."
)
unet_config["heads"] = num_heads
return unet_config
def model_config_from_unet_config(unet_config, state_dict=None, unet_key_prefix=""):
unet_config = normalize_seedvr2_unet_config(unet_config)
for model_config in comfy.supported_models.models: for model_config in comfy.supported_models.models:
if model_config.matches(unet_config, state_dict): if model_config.matches(unet_config, state_dict, unet_key_prefix=unet_key_prefix):
return model_config(unet_config) return model_config(unet_config)
logging.error("no match {}".format(unet_config)) logging.error("no match {}".format(unet_config))
@ -1131,7 +1183,7 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata) unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
if unet_config is None: if unet_config is None:
return None return None
model_config = model_config_from_unet_config(unet_config, state_dict) model_config = model_config_from_unet_config(unet_config, state_dict, unet_key_prefix)
if model_config is None and use_base_if_no_match: if model_config is None and use_base_if_no_match:
model_config = comfy.supported_models_base.BASE(unet_config) model_config = comfy.supported_models_base.BASE(unet_config)

View File

@ -16,6 +16,7 @@ import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2 import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae import comfy.ldm.hunyuan3d.vae
import comfy.ldm.seedvr.vae
import comfy.ldm.triposplat.vae import comfy.ldm.triposplat.vae
import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.cogvideo.vae import comfy.ldm.cogvideo.vae
@ -470,7 +471,8 @@ class CLIP:
class VAE: class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd
if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd) sd = diffusers_convert.convert_vae_state_dict(sd)
if model_management.is_amd(): if model_management.is_amd():
@ -497,6 +499,8 @@ class VAE:
self.upscale_index_formula = None self.upscale_index_formula = None
self.extra_1d_channel = None self.extra_1d_channel = None
self.crop_input = True self.crop_input = True
self.handles_tiling = False
self.format_encoded = None
self.audio_sample_rate = 44100 self.audio_sample_rate = 44100
@ -543,6 +547,22 @@ class VAE:
self.first_stage_model = StageC_coder() self.first_stage_model = StageC_coder()
self.downscale_ratio = 32 self.downscale_ratio = 32
self.latent_channels = 16 self.latent_channels = 16
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
self.latent_channels = comfy.ldm.seedvr.vae.SEEDVR2_LATENT_CHANNELS
self.latent_dim = 3
self.disable_offload = True
self.memory_used_decode = lambda shape, dtype: self.first_stage_model.comfy_memory_used_decode(shape)
self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.handles_tiling = True
self.format_encoded = self.first_stage_model.comfy_format_encoded
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.process_input = lambda image: image * 2.0 - 1.0
self.crop_input = False
elif "decoder.conv_in.weight" in sd: elif "decoder.conv_in.weight" in sd:
if sd['decoder.conv_in.weight'].shape[1] == 64: if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
@ -1009,6 +1029,10 @@ class VAE:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def _decode_tiled_owned(self, samples, **kwargs):
out = self.first_stage_model.decode_tiled(samples.to(self.vae_dtype).to(self.device), **kwargs)
return self.process_output(out.to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
@ -1045,6 +1069,25 @@ class VAE:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def _encode_tiled_owned(self, pixel_samples, **kwargs):
x = self.process_input(pixel_samples).to(self.vae_dtype).to(self.device)
out = self.first_stage_model.encode_tiled(x, **kwargs)
return out.to(device=self.output_device, dtype=self.vae_output_dtype())
def _owned_tiled_args(self, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
args = {}
if tile_x is not None:
args["tile_x"] = tile_x
if tile_y is not None:
args["tile_y"] = tile_y
if overlap is not None:
args["overlap"] = overlap
if tile_t is not None:
args["tile_t"] = tile_t
if overlap_t is not None:
args["overlap_t"] = overlap_t
return args
def decode(self, samples_in, vae_options={}): def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid() self.throw_exception_if_invalid()
pixel_samples = None pixel_samples = None
@ -1092,11 +1135,19 @@ class VAE:
if dims == 1 or self.extra_1d_channel is not None: if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in) pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2: elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in) if self.handles_tiling:
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self._decode_tiled_owned(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
else:
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3: elif dims == 3:
tile = 256 // self.spacial_compression_decode() tile = 256 // self.spacial_compression_decode()
overlap = tile // 4 overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) if self.handles_tiling:
pixel_samples = self._decode_tiled_owned(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
else:
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples return pixel_samples
@ -1115,7 +1166,9 @@ class VAE:
args["overlap"] = overlap args["overlap"] = overlap
with model_management.cuda_device_context(self.device): with model_management.cuda_device_context(self.device):
if dims == 1 or self.extra_1d_channel is not None: if self.handles_tiling and dims in (2, 3):
output = self._decode_tiled_owned(samples, **self._owned_tiled_args(tile_x, tile_y, overlap, tile_t, overlap_t))
elif dims == 1 or self.extra_1d_channel is not None:
args.pop("tile_y") args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args) output = self.decode_tiled_1d(samples, **args)
elif dims == 2: elif dims == 2:
@ -1176,12 +1229,17 @@ class VAE:
if self.latent_dim == 3: if self.latent_dim == 3:
tile = 256 tile = 256
overlap = tile // 4 overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) if self.handles_tiling:
samples = self._encode_tiled_owned(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap)
else:
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1 or self.extra_1d_channel is not None: elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples) samples = self.encode_tiled_1d(pixel_samples)
else: else:
samples = self.encode_tiled_(pixel_samples) samples = self.encode_tiled_(pixel_samples)
if self.format_encoded is not None:
samples = self.format_encoded(samples)
return samples return samples
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
@ -1189,7 +1247,7 @@ class VAE:
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
dims = self.latent_dim dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1) pixel_samples = pixel_samples.movedim(-1, 1)
if dims == 3: if dims == 3 and pixel_samples.ndim < 5:
if not self.not_video: if not self.not_video:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
else: else:
@ -1213,21 +1271,27 @@ class VAE:
elif dims == 2: elif dims == 2:
samples = self.encode_tiled_(pixel_samples, **args) samples = self.encode_tiled_(pixel_samples, **args)
elif dims == 3: elif dims == 3:
if tile_t is not None: if self.handles_tiling:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) samples = self._encode_tiled_owned(pixel_samples, **self._owned_tiled_args(tile_x, tile_y, overlap, tile_t, overlap_t))
else: else:
tile_t_latent = 9999 if tile_t is not None:
args["tile_t"] = self.upscale_ratio[0](tile_t_latent) tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
else:
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
if overlap_t is None: spatial_overlap = overlap if overlap is not None else 64
args["overlap"] = (1, overlap, overlap) if overlap_t is None:
else: args["overlap"] = (1, spatial_overlap, spatial_overlap)
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) else:
maximum = pixel_samples.shape[2] args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), spatial_overlap, spatial_overlap)
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
if self.format_encoded is not None:
samples = self.format_encoded(samples)
return samples return samples
def get_sd(self): def get_sd(self):
@ -1890,7 +1954,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else: else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype, device=load_device)
if model_config.clip_vision_prefix is not None: if model_config.clip_vision_prefix is not None:
if output_clipvision: if output_clipvision:
@ -2031,7 +2095,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else: else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype, device=load_device)
if custom_operations is not None: if custom_operations is not None:
model_config.custom_operations = custom_operations model_config.custom_operations = custom_operations

View File

@ -1685,6 +1685,40 @@ class Chroma(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
class SeedVR2(supported_models_base.BASE):
unet_config = {
"image_model": "seedvr2"
}
unet_extra_config = {}
required_keys = {
"{}positive_conditioning",
"{}negative_conditioning",
}
latent_format = comfy.latent_formats.SeedVR2
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
sampling_settings = {
"shift": 1.0,
}
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
if (
dtype == torch.float16
and manual_cast_dtype is None
and comfy.model_management.should_use_bf16(device)
):
manual_cast_dtype = torch.bfloat16
super().set_inference_dtype(dtype, manual_cast_dtype, device=device)
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SeedVR2(self, device=device)
return out
def clip_target(self, state_dict={}):
return None
class ChromaRadiance(Chroma): class ChromaRadiance(Chroma):
unet_config = { unet_config = {
"image_model": "chroma_radiance", "image_model": "chroma_radiance",
@ -2348,6 +2382,7 @@ models = [
HiDream, HiDream,
HiDreamO1, HiDreamO1,
Chroma, Chroma,
SeedVR2,
ChromaRadiance, ChromaRadiance,
ACEStep, ACEStep,
ACEStep15, ACEStep15,

View File

@ -54,13 +54,13 @@ class BASE:
optimizations = {"fp8": False} optimizations = {"fp8": False}
@classmethod @classmethod
def matches(s, unet_config, state_dict=None): def matches(s, unet_config, state_dict=None, unet_key_prefix=""):
for k in s.unet_config: for k in s.unet_config:
if k not in unet_config or s.unet_config[k] != unet_config[k]: if k not in unet_config or s.unet_config[k] != unet_config[k]:
return False return False
if state_dict is not None: if state_dict is not None:
for k in s.required_keys: for k in s.required_keys:
if k not in state_dict: if k.format(unet_key_prefix) not in state_dict:
return False return False
return True return True
@ -115,7 +115,7 @@ class BASE:
replace_prefix = {"": self.vae_key_prefix[0]} replace_prefix = {"": self.vae_key_prefix[0]}
return utils.state_dict_prefix_replace(state_dict, replace_prefix) return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def set_inference_dtype(self, dtype, manual_cast_dtype): def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
self.unet_config['dtype'] = dtype self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype self.manual_cast_dtype = manual_cast_dtype

View File

@ -0,0 +1,419 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import torch
import comfy.model_management
from comfy.ldm.seedvr.color_fix import (
adain_color_transfer,
lab_color_transfer,
wavelet_color_transfer,
)
from comfy.ldm.seedvr.constants import (
SEEDVR2_ADAIN_SCALE_MULTIPLIER,
SEEDVR2_COLOR_MEM_HEADROOM,
SEEDVR2_DTYPE_BYTES_FLOOR,
SEEDVR2_LAB_SCALE_MULTIPLIER,
SEEDVR2_LATENT_CHANNELS,
SEEDVR2_OOM_BACKOFF_DIVISOR,
SEEDVR2_WAVELET_SCALE_MULTIPLIER,
)
from torchvision.transforms import functional as TVF
from torchvision.transforms.functional import InterpolationMode
_SEEDVR2_INVALID_MODEL_MSG_PREFIX = "SeedVR2Conditioning: model object does not match expected SeedVR2 structure"
_ATTR_MISSING = object()
def _resolve_seedvr2_diffusion_model(model):
inner = getattr(model, "model", _ATTR_MISSING)
if inner is _ATTR_MISSING:
raise RuntimeError(
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input has no 'model' attribute "
f"(got type {type(model).__name__})."
)
if inner is None:
raise RuntimeError(
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input.model is None "
f"(input type {type(model).__name__})."
)
diffusion_model = getattr(inner, "diffusion_model", _ATTR_MISSING)
if diffusion_model is _ATTR_MISSING:
raise RuntimeError(
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model' has no "
f"'diffusion_model' attribute (got type {type(inner).__name__})."
)
if diffusion_model is None:
raise RuntimeError(
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model.diffusion_model' "
f"is None (model.model type {type(inner).__name__})."
)
return diffusion_model
def div_pad(image, factor):
height_factor, width_factor = factor
height, width = image.shape[-2:]
pad_height = (height_factor - (height % height_factor)) % height_factor
pad_width = (width_factor - (width % width_factor)) % width_factor
if pad_height == 0 and pad_width == 0:
return image
padding = (0, pad_width, 0, pad_height)
return torch.nn.functional.pad(image, padding, mode='constant', value=0.0)
def cut_videos(videos):
t = videos.size(1)
if t < 1:
raise ValueError("SeedVR2Preprocess expected at least one frame.")
if t == 1:
return videos
if t <= 4:
padding = videos[:, -1:].repeat(1, 4 - t + 1, 1, 1, 1)
return torch.cat([videos, padding], dim=1)
if (t - 1) % 4 == 0:
return videos
padding = videos[:, -1:].repeat(1, 4 - ((t - 1) % 4), 1, 1, 1)
videos = torch.cat([videos, padding], dim=1)
if (videos.size(1) - 1) % 4 != 0:
raise ValueError(f"SeedVR2Preprocess failed to pad video length to 4n+1; got {videos.size(1)} frames.")
return videos
def _seedvr2_input_shorter_edge(images, node_name):
if images.dim() == 4:
return min(images.shape[1], images.shape[2])
if images.dim() == 5:
return min(images.shape[2], images.shape[3])
raise ValueError(
f"{node_name}: expected 4-D or 5-D IMAGE tensor, "
f"got shape {tuple(images.shape)}"
)
def _seedvr2_pad(images, upscaled_shorter_edge, node_name):
if upscaled_shorter_edge < 2:
raise ValueError(
f"{node_name}: input shorter edge must be at least 2 pixels; "
f"got {upscaled_shorter_edge}."
)
if images.shape[-1] > 3:
images = images[..., :3]
if images.dim() == 4:
# Comfy video components arrive as a 4-D IMAGE frame sequence:
# (frames, H, W, C). SeedVR2 consumes that as one video.
images = images.unsqueeze(0)
elif images.dim() != 5:
raise ValueError(
f"{node_name}: expected 4-D or 5-D IMAGE tensor, "
f"got shape {tuple(images.shape)}"
)
images = images.permute(0, 1, 4, 2, 3)
b, t, c, h, w = images.shape
images = images.reshape(b * t, c, h, w)
images = torch.clamp(images, 0.0, 1.0)
images = div_pad(images, (16, 16))
_, _, new_h, new_w = images.shape
images = images.reshape(b, t, c, new_h, new_w)
images = cut_videos(images)
images_bthwc = images.permute(0, 1, 3, 4, 2).contiguous()
return io.NodeOutput(images_bthwc)
class SeedVR2Preprocess(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SeedVR2Preprocess",
display_name="Pre-Process SeedVR2 Input",
category="image/upscaling",
description="Pad a resized image for SeedVR2 model. Alpha channel is dropped. The node Post-Process SeedVR2 Output re-applies it from the original resized image.",
search_aliases=["seedvr2", "upscale", "video upscale", "pad", "preprocess"],
inputs=[
io.Image.Input("resized_images", tooltip="The resized image to process."),
],
outputs=[
io.Image.Output("images", tooltip="The padded image for VAE encoding."),
]
)
@classmethod
def execute(cls, resized_images):
upscaled_shorter_edge = _seedvr2_input_shorter_edge(resized_images, "SeedVR2Preprocess")
return _seedvr2_pad(
resized_images, upscaled_shorter_edge, "SeedVR2Preprocess",
)
class SeedVR2PostProcessing(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SeedVR2PostProcessing",
display_name="Post-Process SeedVR2 Output",
category="image/upscaling",
description="Align the generated image with the original resized image and apply color correction.",
search_aliases=["seedvr2", "upscale", "color correction", "color match", "postprocess"],
inputs=[
io.Image.Input("images", tooltip="The generated image to process."),
io.Image.Input("original_resized_images", tooltip="The original resized image before pre-processing, used as reference."),
io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab", tooltip="Method to match the generated image colors to the original image. lab: transfer color in CIELAB space, preserving detail (most faithful). wavelet: transfer low-frequency color, keeping upscaled high-frequency detail. adain: match per-channel mean/std (fastest, global tint). none: skip color transfer (geometry alignment only)."),
],
outputs=[io.Image.Output(display_name="images", tooltip="The aligned, color-corrected image.")],
)
@classmethod
def execute(cls, images, original_resized_images, color_correction_method):
alpha_input = None
if original_resized_images.shape[-1] == 4:
alpha_input = original_resized_images[..., 3:4]
original_resized_images = original_resized_images[..., :3]
decoded_5d, decoded_was_4d = cls._as_bthwc(images)
reference_full, _ = cls._as_bthwc(original_resized_images)
decoded_5d = cls._restore_reference_batch_time(decoded_5d, reference_full)
b = min(decoded_5d.shape[0], reference_full.shape[0])
t = min(decoded_5d.shape[1], reference_full.shape[1])
reference_h = reference_full.shape[2]
reference_w = reference_full.shape[3]
decoded_5d = decoded_5d[:b, :t, :, :, :]
target_h = min(decoded_5d.shape[2], reference_h)
target_w = min(decoded_5d.shape[3], reference_w)
decoded_5d = decoded_5d[:, :, :target_h, :target_w, :]
if color_correction_method in ("lab", "wavelet", "adain"):
reference_5d = reference_full[:b, :t, :, :, :]
reference_5d = cls._resize_reference(reference_5d, target_h, target_w)
output_device = decoded_5d.device
decoded_raw = cls._to_seedvr2_raw(decoded_5d)
reference_raw = cls._to_seedvr2_raw(reference_5d)
decoded_flat = decoded_raw.permute(0, 1, 4, 2, 3).reshape(b * t, decoded_raw.shape[4], target_h, target_w)
reference_flat = reference_raw.permute(0, 1, 4, 2, 3).reshape(b * t, reference_raw.shape[4], target_h, target_w)
output = cls._color_transfer_chunked(
decoded_flat, reference_flat, output_device, color_correction_method,
)
output = output.reshape(b, t, output.shape[1], output.shape[2], output.shape[3]).permute(0, 1, 3, 4, 2)
output = output.add(1.0).div(2.0).clamp(0.0, 1.0)
elif color_correction_method == "none":
output = decoded_5d
else:
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
if alpha_input is not None:
alpha_5d, _ = cls._as_bthwc(alpha_input)
alpha_5d = alpha_5d[:output.shape[0], :output.shape[1], :output.shape[2], :output.shape[3], :]
output = torch.cat([output, alpha_5d.to(dtype=output.dtype, device=output.device)], dim=-1)
h2 = output.shape[-3] - (output.shape[-3] % 2)
w2 = output.shape[-2] - (output.shape[-2] % 2)
output = output[:, :, :h2, :w2, :]
if decoded_was_4d:
output = output.reshape(-1, output.shape[-3], output.shape[-2], output.shape[-1])
return io.NodeOutput(output)
@staticmethod
def _as_bthwc(images):
if images.ndim == 4:
return images.unsqueeze(0), True
if images.ndim == 5:
return images, False
raise ValueError(
f"SeedVR2PostProcessing: expected 4-D or 5-D IMAGE tensor, got shape {tuple(images.shape)}"
)
@staticmethod
def _restore_reference_batch_time(decoded, reference):
if decoded.shape[0] != 1:
return decoded
ref_b, ref_t = reference.shape[:2]
if ref_b < 1 or decoded.shape[1] % ref_b != 0:
return decoded
decoded_t = decoded.shape[1] // ref_b
if decoded_t < ref_t:
return decoded
return decoded.reshape(ref_b, decoded_t, decoded.shape[2], decoded.shape[3], decoded.shape[4])
@staticmethod
def _to_seedvr2_raw(images):
return images.mul(2.0).sub(1.0)
@staticmethod
def _color_transfer_on_vae_device(decoded_flat, reference_flat, output_device, transfer_fn):
color_device = comfy.model_management.vae_device()
decoded_flat = decoded_flat.to(device=color_device)
reference_flat = reference_flat.to(device=color_device)
output = transfer_fn(decoded_flat, reference_flat)
return output.to(device=output_device)
@staticmethod
def _lab_color_transfer_on_vae_device(decoded_flat, reference_flat, output_device):
color_device = comfy.model_management.vae_device()
result = None
for start in range(decoded_flat.shape[0]):
decoded_frame = decoded_flat[start:start + 1].to(device=color_device).clone()
reference_frame = reference_flat[start:start + 1].to(device=color_device).clone()
output = lab_color_transfer(decoded_frame, reference_frame).to(device=output_device)
if result is None:
result = torch.empty(
(decoded_flat.shape[0],) + tuple(output.shape[1:]),
device=output_device,
dtype=output.dtype,
)
result[start:start + 1].copy_(output)
if result is None:
raise ValueError("SeedVR2PostProcessing: LAB color correction requires at least one frame.")
return result
@classmethod
def _color_transfer_chunked(cls, decoded_flat, reference_flat, output_device, color_correction_method):
chunk_size = cls._estimate_color_correction_chunk_size(decoded_flat, color_correction_method)
while True:
try:
return cls._run_color_transfer_chunks(
decoded_flat, reference_flat, output_device, color_correction_method, chunk_size,
)
except Exception as e:
comfy.model_management.raise_non_oom(e)
if chunk_size <= 1:
raise RuntimeError(
"SeedVR2PostProcessing: color correction OOM at one frame; "
f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}."
) from e
chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR)
@classmethod
def _run_color_transfer_chunks(cls, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size):
result = None
for start in range(0, decoded_flat.shape[0], chunk_size):
end = min(start + chunk_size, decoded_flat.shape[0])
decoded_chunk = decoded_flat[start:end]
reference_chunk = reference_flat[start:end]
if color_correction_method == "lab":
output = cls._lab_color_transfer_on_vae_device(decoded_chunk, reference_chunk, output_device)
elif color_correction_method == "wavelet":
output = cls._color_transfer_on_vae_device(
decoded_chunk, reference_chunk, output_device, wavelet_color_transfer,
)
else:
output = cls._color_transfer_on_vae_device(
decoded_chunk, reference_chunk, output_device, adain_color_transfer,
)
if result is None:
result = torch.empty(
(decoded_flat.shape[0],) + tuple(output.shape[1:]),
device=output_device,
dtype=output.dtype,
)
result[start:end].copy_(output)
if result is None:
raise ValueError("SeedVR2PostProcessing: color correction requires at least one frame.")
return result
@classmethod
def _estimate_color_correction_chunk_size(cls, decoded_flat, color_correction_method):
multiplier = cls._color_correction_memory_multiplier(color_correction_method)
frames = decoded_flat.shape[0]
_, channels, height, width = decoded_flat.shape
dtype_bytes = max(decoded_flat.element_size(), SEEDVR2_DTYPE_BYTES_FLOOR)
bytes_per_frame = height * width * channels * dtype_bytes * multiplier
if bytes_per_frame <= 0:
return frames
color_device = comfy.model_management.vae_device()
free_memory = comfy.model_management.get_free_memory(color_device)
chunk_size = int((free_memory * SEEDVR2_COLOR_MEM_HEADROOM) // bytes_per_frame)
return max(1, min(frames, chunk_size))
@staticmethod
def _color_correction_memory_multiplier(color_correction_method):
if color_correction_method == "lab":
return SEEDVR2_LAB_SCALE_MULTIPLIER
if color_correction_method == "wavelet":
return SEEDVR2_WAVELET_SCALE_MULTIPLIER
if color_correction_method == "adain":
return SEEDVR2_ADAIN_SCALE_MULTIPLIER
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
@staticmethod
def _resize_reference(reference, height, width):
if reference.shape[2] == height and reference.shape[3] == width:
return reference
b, t = reference.shape[:2]
reference_flat = reference.permute(0, 1, 4, 2, 3).reshape(b * t, reference.shape[4], reference.shape[2], reference.shape[3])
resized = TVF.resize(
reference_flat,
size=(height, width),
interpolation=InterpolationMode.BICUBIC,
antialias=not (isinstance(reference_flat, torch.Tensor) and reference_flat.device.type == "mps"),
)
return resized.reshape(b, t, resized.shape[1], height, width).permute(0, 1, 3, 4, 2)
class SeedVR2Conditioning(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SeedVR2Conditioning",
display_name="Apply SeedVR2 Conditioning",
category="conditioning",
description="Build SeedVR2 positive/negative conditioning from a VAE latent.",
search_aliases=["seedvr2", "upscale", "conditioning"],
inputs=[
io.Model.Input("model", tooltip="The SeedVR2 model."),
io.Latent.Input("vae_conditioning", display_name="latent"),
],
outputs=[
io.Conditioning.Output(display_name="positive", tooltip="The positive conditioning for sampling."),
io.Conditioning.Output(display_name="negative", tooltip="The negative conditioning for sampling."),
],
)
@classmethod
def execute(cls, model, vae_conditioning) -> io.NodeOutput:
vae_conditioning = vae_conditioning["samples"]
if vae_conditioning.ndim != 5:
raise ValueError(
"SeedVR2Conditioning expects a 5-D VAE latent in Comfy "
f"channel-first layout; got shape {tuple(vae_conditioning.shape)}."
)
if vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS:
if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS:
raise ValueError(
"SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy "
f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); "
f"got channel-last shape {tuple(vae_conditioning.shape)}."
)
raise ValueError(
"SeedVR2Conditioning expects SeedVR2 VAE latents with "
f"{SEEDVR2_LATENT_CHANNELS} channels; got shape {tuple(vae_conditioning.shape)}."
)
vae_conditioning = vae_conditioning.movedim(1, -1).contiguous()
model = _resolve_seedvr2_diffusion_model(model)
pos_cond = model.positive_conditioning
neg_cond = model.negative_conditioning
mask = vae_conditioning.new_ones(vae_conditioning.shape[:-1] + (1,))
condition = torch.cat((vae_conditioning, mask), dim=-1)
condition = condition.movedim(-1, 1)
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
return io.NodeOutput(positive, negative)
class SeedVRExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
SeedVR2Conditioning,
SeedVR2Preprocess,
SeedVR2PostProcessing,
]
async def comfy_entrypoint() -> SeedVRExtension:
return SeedVRExtension()

View File

@ -2458,6 +2458,7 @@ async def init_builtin_extra_nodes():
"nodes_camera_trajectory.py", "nodes_camera_trajectory.py",
"nodes_edit_model.py", "nodes_edit_model.py",
"nodes_tcfg.py", "nodes_tcfg.py",
"nodes_seedvr.py",
"nodes_context_windows.py", "nodes_context_windows.py",
"nodes_qwen.py", "nodes_qwen.py",
"nodes_boogu.py", "nodes_boogu.py",

View File

@ -0,0 +1,186 @@
"""SeedVR2 conditioning node regression tests."""
import importlib
import sys
from unittest.mock import MagicMock
import pytest
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
from comfy.ldm.seedvr.constants import SEEDVR2_LATENT_CHANNELS
if not torch.cuda.is_available():
cli_args.cpu = True
_SENTINEL = object()
_TARGETS = (
("comfy.model_management", "comfy"),
("comfy_extras.nodes_seedvr", "comfy_extras"),
)
def _import_nodes_seedvr_isolated():
"""Import comfy_extras.nodes_seedvr with comfy.model_management mocked."""
priors = []
for mod_name, parent_name in _TARGETS:
prior_mod = sys.modules.get(mod_name, _SENTINEL)
parent = sys.modules.get(parent_name)
attr = mod_name.split(".")[-1]
prior_attr = (
getattr(parent, attr, _SENTINEL) if parent is not None else _SENTINEL
)
priors.append((mod_name, parent_name, attr, prior_mod, prior_attr))
mock_mm = MagicMock()
for fn in (
"xformers_enabled", "xformers_enabled_vae",
"pytorch_attention_enabled", "pytorch_attention_enabled_vae",
"sage_attention_enabled", "flash_attention_enabled",
"is_intel_xpu",
):
getattr(mock_mm, fn).return_value = False
tv = torch.version.__version__.split(".")
mock_mm.torch_version_numeric = (int(tv[0]), int(tv[1]))
mock_mm.WINDOWS = False
sys.modules["comfy.model_management"] = mock_mm
if sys.modules.get("comfy") is None:
importlib.import_module("comfy")
comfy_pkg = sys.modules.get("comfy")
if comfy_pkg is not None:
setattr(comfy_pkg, "model_management", mock_mm)
nodes_seedvr = sys.modules.get("comfy_extras.nodes_seedvr") or (
importlib.import_module("comfy_extras.nodes_seedvr")
)
def _restore():
for mod_name, parent_name, attr, prior_mod, prior_attr in priors:
if prior_mod is _SENTINEL:
sys.modules.pop(mod_name, None)
else:
sys.modules[mod_name] = prior_mod
parent = sys.modules.get(parent_name)
if parent is None:
continue
if prior_attr is _SENTINEL:
if hasattr(parent, attr):
delattr(parent, attr)
else:
setattr(parent, attr, prior_attr)
return nodes_seedvr, _restore
class _Rope(nn.Module):
def __init__(self):
super().__init__()
self.freqs = nn.Parameter(torch.zeros(4))
class _Block(nn.Module):
def __init__(self):
super().__init__()
self.rope = _Rope()
class _DiffusionModel(nn.Module):
def __init__(self, n_blocks=3, conditioning_dtype=torch.float32):
super().__init__()
self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)])
self.register_buffer("positive_conditioning", torch.ones((2, 4), dtype=conditioning_dtype))
self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype))
class _ModelInner:
def __init__(self, diffusion_model):
self.diffusion_model = diffusion_model
class _ModelPatcher:
def __init__(self, diffusion_model):
self.model = _ModelInner(diffusion_model)
def test_seedvr2_conditioning_schema_exposes_conditioning_outputs():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
schema = nodes_seedvr.SeedVR2Conditioning.define_schema()
assert [input_item.id for input_item in schema.inputs] == [
"model",
"vae_conditioning",
]
assert schema.inputs[1].display_name == "latent"
assert [output.display_name for output in schema.outputs] == [
"positive",
"negative",
]
finally:
restore()
def test_seedvr2_conditioning_rejects_wrong_latent_channels():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
patcher = _ModelPatcher(_DiffusionModel())
vae_conditioning = {"samples": torch.zeros(1, 8, 2, 2, 2)}
with pytest.raises(ValueError, match=f"{SEEDVR2_LATENT_CHANNELS} channels"):
nodes_seedvr.SeedVR2Conditioning.execute(patcher, vae_conditioning)
finally:
restore()
def test_seedvr2_conditioning_returns_conditioning_deterministically():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
diffusion_model = _DiffusionModel()
patcher = _ModelPatcher(diffusion_model)
samples = torch.arange(
1,
1 + SEEDVR2_LATENT_CHANNELS * 3 * 2 * 2,
dtype=torch.float32,
).reshape(1, SEEDVR2_LATENT_CHANNELS, 3, 2, 2)
vae_conditioning = {"samples": samples}
first_positive, first_negative = (
nodes_seedvr.SeedVR2Conditioning.execute(
patcher,
vae_conditioning,
)
)
second_positive, second_negative = (
nodes_seedvr.SeedVR2Conditioning.execute(
patcher,
vae_conditioning,
)
)
channel_last = samples.movedim(1, -1).contiguous()
expected_condition = torch.cat(
[
channel_last,
torch.ones((*channel_last.shape[:-1], 1)),
],
dim=-1,
).movedim(-1, 1)
assert torch.equal(
first_positive[0][1]["condition"],
expected_condition,
)
assert torch.equal(
second_positive[0][1]["condition"],
expected_condition,
)
assert torch.equal(
first_negative[0][1]["condition"],
expected_condition,
)
assert torch.equal(
second_negative[0][1]["condition"],
expected_condition,
)
finally:
restore()

View File

@ -0,0 +1,55 @@
import importlib
import inspect
import sys
from unittest.mock import MagicMock, patch
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
def test_seedvr_node_signature_matches_schema():
mock_mm = MagicMock()
mock_mm.xformers_enabled.return_value = False
mock_mm.xformers_enabled_vae.return_value = False
mock_mm.sage_attention_enabled.return_value = False
mock_mm.flash_attention_enabled.return_value = False
sentinel = object()
prior_cpu = cli_args.cpu
cli_args.cpu = True
prior_module = sys.modules.get("comfy_extras.nodes_seedvr", sentinel)
comfy_pkg = sys.modules.get("comfy")
prior_mm_attr = getattr(comfy_pkg, "model_management", sentinel) if comfy_pkg else sentinel
with patch.dict(sys.modules, {"comfy.model_management": mock_mm}):
if comfy_pkg is not None:
setattr(comfy_pkg, "model_management", mock_mm)
sys.modules.pop("comfy_extras.nodes_seedvr", None)
try:
nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr")
for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning):
schema_ids = [i.id for i in node_cls.define_schema().inputs]
exec_params = [
p for p in inspect.signature(node_cls.execute).parameters.keys()
if p != "cls"
]
assert schema_ids == exec_params, (
f"{node_cls.__name__} schema/execute drift: "
f"schema_ids={schema_ids}, exec_params={exec_params}"
)
finally:
cli_args.cpu = prior_cpu
if prior_module is sentinel:
sys.modules.pop("comfy_extras.nodes_seedvr", None)
else:
sys.modules["comfy_extras.nodes_seedvr"] = prior_module
if comfy_pkg is not None:
if prior_mm_attr is sentinel:
if hasattr(comfy_pkg, "model_management"):
delattr(comfy_pkg, "model_management")
else:
setattr(comfy_pkg, "model_management", prior_mm_attr)

View File

@ -0,0 +1,51 @@
from unittest.mock import patch
import pytest
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
from comfy_extras import nodes_seedvr # noqa: E402
def _schema_ids(items):
return [item.id for item in items]
def test_seedvr2_post_processing_schema():
schema = nodes_seedvr.SeedVR2PostProcessing.define_schema()
assert _schema_ids(schema.inputs) == ["images", "original_resized_images", "color_correction_method"]
assert schema.inputs[2].options == ["lab", "wavelet", "adain", "none"]
assert schema.inputs[2].default == "lab"
assert schema.outputs[0].get_io_type() == "IMAGE"
def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch):
decoded = torch.full((1, 3, 4, 4), 0.25)
reference = torch.full((1, 3, 4, 4), 0.75)
def _lab(content, style):
raise torch.cuda.OutOfMemoryError("CUDA out of memory")
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
with pytest.raises(RuntimeError) as excinfo:
nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked(
decoded, reference, torch.device("cpu"), "lab",
)
assert "color_correction_method=lab" in str(excinfo.value)
assert " method=lab" not in str(excinfo.value)
def test_seedvr2_post_processing_unknown_color_correction_method_raises():
decoded = torch.zeros(1, 2, 4, 4, 3)
original = torch.zeros(1, 2, 4, 4, 3)
with pytest.raises(ValueError) as excinfo:
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus")
assert "color_correction_method" in str(excinfo.value)

View File

@ -2,7 +2,7 @@ from collections import defaultdict
import torch import torch
from comfy.model_detection import detect_unet_config, model_config_from_unet_config from comfy.model_detection import detect_unet_config, model_config_from_unet, model_config_from_unet_config
import comfy.supported_models import comfy.supported_models
@ -73,6 +73,34 @@ def _make_flux_schnell_comfyui_sd():
return sd return sd
def _make_seedvr2_7b_separate_mm_sd():
return {
"blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072),
"positive_conditioning": torch.empty(58, 5120),
"negative_conditioning": torch.empty(64, 5120),
}
def _make_seedvr2_7b_shared_mm_sd():
return {
"blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
"positive_conditioning": torch.empty(58, 5120),
"negative_conditioning": torch.empty(64, 5120),
}
def _make_seedvr2_3b_shared_mm_sd():
return {
"blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
"positive_conditioning": torch.empty(58, 5120),
"negative_conditioning": torch.empty(64, 5120),
}
def _add_model_diffusion_prefix(sd):
return {f"model.diffusion_model.{k}": v for k, v in sd.items()}
class TestModelDetection: class TestModelDetection:
"""Verify that first-match model detection selects the correct model """Verify that first-match model detection selects the correct model
based on list ordering and unet_config specificity.""" based on list ordering and unet_config specificity."""
@ -125,6 +153,70 @@ class TestModelDetection:
assert model_config is not None assert model_config is not None
assert type(model_config).__name__ == "FluxSchnell" assert type(model_config).__name__ == "FluxSchnell"
def test_seedvr2_7b_separate_mm_detection_config(self):
sd = _make_seedvr2_7b_separate_mm_sd()
unet_config = detect_unet_config(sd, "")
assert unet_config is not None
assert unet_config["image_model"] == "seedvr2"
assert unet_config["vid_dim"] == 3072
assert unet_config["heads"] == 24
assert unet_config["num_layers"] == 36
assert unet_config["mm_layers"] == 36
assert unet_config["mlp_type"] == "normal"
assert unet_config["rope_type"] == "rope3d"
assert unet_config["rope_dim"] == 64
def test_seedvr2_7b_shared_mm_detection_config(self):
sd = _make_seedvr2_7b_shared_mm_sd()
unet_config = detect_unet_config(sd, "")
assert unet_config is not None
assert unet_config["image_model"] == "seedvr2"
assert unet_config["vid_dim"] == 3072
assert unet_config["heads"] == 24
assert unet_config["num_layers"] == 36
assert unet_config["mm_layers"] == 10
assert unet_config["mlp_type"] == "swiglu"
assert unet_config["rope_type"] == "rope3d"
assert unet_config["rope_dim"] == 64
def test_seedvr2_3b_shared_mm_detection_config(self):
sd = _make_seedvr2_3b_shared_mm_sd()
unet_config = detect_unet_config(sd, "")
assert unet_config is not None
assert unet_config["image_model"] == "seedvr2"
assert unet_config["vid_dim"] == 2560
assert unet_config["heads"] == 20
assert unet_config["num_layers"] == 32
assert unet_config["mlp_type"] == "swiglu"
def test_seedvr2_model_match_requires_conditioning_tensors(self):
sd = _make_seedvr2_7b_shared_mm_sd()
unet_config = detect_unet_config(sd, "")
assert type(model_config_from_unet_config(unet_config, sd)).__name__ == "SeedVR2"
del sd["positive_conditioning"]
assert model_config_from_unet_config(unet_config, sd) is None
def test_seedvr2_model_match_normalizes_num_heads(self):
sd = _make_seedvr2_7b_shared_mm_sd()
unet_config = detect_unet_config(sd, "")
unet_config["num_heads"] = unet_config.pop("heads")
model_config = model_config_from_unet_config(unet_config, sd)
assert type(model_config).__name__ == "SeedVR2"
assert model_config.unet_config["heads"] == 24
assert "num_heads" not in model_config.unet_config
def test_seedvr2_model_match_accepts_full_checkpoint_prefix(self):
sd = _add_model_diffusion_prefix(_make_seedvr2_7b_shared_mm_sd())
assert type(model_config_from_unet(sd, "model.diffusion_model.")).__name__ == "SeedVR2"
def test_unet_config_and_required_keys_combination_is_unique(self): def test_unet_config_and_required_keys_combination_is_unique(self):
"""Each model in the registry must have a unique combination of """Each model in the registry must have a unique combination of
``unet_config`` and ``required_keys``. If two models share the same ``unet_config`` and ``required_keys``. If two models share the same

View File

@ -0,0 +1,74 @@
"""Regression tests for the SeedVR2 VAE forward return contract."""
import pytest
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
from comfy.ldm.seedvr.vae import SEEDVR2_LATENT_CHANNELS, VideoAutoencoderKL # noqa: E402
_LATENT_SHAPE = (1, SEEDVR2_LATENT_CHANNELS, 2, 2, 2)
_DECODED_SHAPE = (1, 3, 5, 16, 16)
_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16)
_INPUT_DECODE_SHAPE = _LATENT_SHAPE
class _StubVAE(VideoAutoencoderKL):
def __init__(self):
nn.Module.__init__(self)
self._encode_out = torch.zeros(*_LATENT_SHAPE)
self._decode_out = torch.zeros(*_DECODED_SHAPE)
def encode(self, x, return_dict=True):
return self._encode_out
def decode_(self, z, return_dict=True):
return self._decode_out
def test_forward_encode_returns_tensor():
vae = _StubVAE()
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
result = vae.forward(x, mode="encode")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_LATENT_SHAPE)
def test_forward_decode_returns_tensor():
vae = _StubVAE()
z = torch.zeros(*_INPUT_DECODE_SHAPE)
result = vae.forward(z, mode="decode")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_DECODED_SHAPE)
class _TupleReturningStubVAE(VideoAutoencoderKL):
def __init__(self):
nn.Module.__init__(self)
self._encode_tensor = torch.zeros(*_LATENT_SHAPE)
self._decode_tensor = torch.zeros(*_DECODED_SHAPE)
def encode(self, x, return_dict=True):
return (self._encode_tensor,)
def decode_(self, z, return_dict=True):
return (self._decode_tensor,)
def test_forward_all_unwraps_one_tuple_at_each_step():
vae = _TupleReturningStubVAE()
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
result = vae.forward(x, mode="all")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_DECODED_SHAPE)
def test_forward_rejects_unknown_mode():
vae = _StubVAE()
with pytest.raises(ValueError, match="Unknown SeedVR2 VAE forward mode"):
vae.forward(torch.zeros(*_INPUT_ENCODE_SHAPE), mode="bogus")

View File

@ -0,0 +1,50 @@
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.sd
import comfy.supported_models
import comfy.ldm.seedvr.model as seedvr_model
import comfy.ldm.seedvr.vae as seedvr_vae
def test_seedvr2_fp16_manual_cast_only_for_bf16_device(monkeypatch):
bf16_device = object()
fp16_device = object()
monkeypatch.setattr(
comfy.supported_models.comfy.model_management,
"should_use_bf16",
lambda device=None: device is bf16_device,
)
bf16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"})
bf16_config.set_inference_dtype(torch.float16, None, device=bf16_device)
assert bf16_config.manual_cast_dtype is torch.bfloat16
fp16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"})
fp16_config.set_inference_dtype(torch.float16, None, device=fp16_device)
assert fp16_config.manual_cast_dtype is None
def test_seedvr2_text_conditioning_accepts_cfg1_single_branch():
context = torch.arange(6, dtype=torch.float32).reshape(1, 3, 2)
txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0])
torch.testing.assert_close(txt, context.squeeze(0))
torch.testing.assert_close(txt_shape, torch.tensor([[3]], device=context.device))
def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer():
wrapper = seedvr_vae.VideoAutoencoderKLWrapper.__new__(seedvr_vae.VideoAutoencoderKLWrapper)
latent_channels = seedvr_vae.SEEDVR2_LATENT_CHANNELS
estimate = wrapper.comfy_memory_used_decode((1, latent_channels, 26, 120, 160))
old_estimate = latent_channels * 120 * 160 * (4 * 8 * 8) * 2
assert estimate == 101 * 960 * 1280 * 160
assert estimate > 15 * 1024 ** 3
assert estimate > old_estimate * 100

View File

@ -0,0 +1,169 @@
"""SeedVR2 internals regression tests."""
from __future__ import annotations
from unittest.mock import patch
import pytest
import torch
from comfy.cli_args import args
if not torch.cuda.is_available():
args.cpu = True
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
import comfy.ldm.modules.attention as attention # noqa: E402
import comfy.ops as comfy_ops # noqa: E402
from comfy.ldm.seedvr.vae import ( # noqa: E402
causal_norm_wrapper,
set_norm_limit,
)
from comfy.ldm.seedvr.attention import var_attention_optimized_split # noqa: E402
_NUM_CHANNELS = 8
_NUM_GROUPS = 4
_TENSOR_SHAPE = (1, 8, 2, 4, 4)
_GROUPNORM_SUBCLASSES = [
pytest.param(comfy_ops.disable_weight_init.GroupNorm, id="disable_weight_init"),
pytest.param(comfy_ops.manual_cast.GroupNorm, id="manual_cast"),
]
@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES)
def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls):
real_group_norm = vae_mod.F.group_norm
set_norm_limit(1e-9)
try:
gn = groupnorm_cls(num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS)
gn.eval()
forward_hook_calls = []
def _hook(module, inputs, output):
forward_hook_calls.append(tuple(inputs[0].shape))
spy_calls = []
def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs):
spy_calls.append({"num_groups": int(num_groups_arg)})
return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs)
handle = gn.register_forward_hook(_hook)
try:
with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy):
out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE))
finally:
handle.remove()
full_calls = len(forward_hook_calls)
chunked_calls = sum(1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS)
assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE
assert full_calls == 0, (
f"low-limit GroupNorm gate must NOT take the full-forward path; got full_calls={full_calls}"
)
assert chunked_calls > 0, (
f"low-limit GroupNorm gate must take the chunked path; got chunked_calls={chunked_calls}"
)
finally:
set_norm_limit(None)
def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch):
dim = 8
heads = 2
head_dim = 4
attn = seedvr_model.NaSwinAttention(
vid_dim=dim,
txt_dim=dim,
heads=heads,
head_dim=head_dim,
qk_bias=False,
qk_norm=comfy_ops.disable_weight_init.RMSNorm,
qk_norm_eps=1e-6,
rope_type=None,
rope_dim=head_dim,
shared_weights=False,
window=(2, 1, 1),
window_method="720pwin_by_size_bysize",
version=True,
device="cpu",
dtype=torch.float32,
operations=comfy_ops.disable_weight_init,
)
generator = torch.Generator(device="cpu").manual_seed(11)
vid = torch.randn(8, dim, generator=generator)
txt = torch.randn(3, dim, generator=generator)
vid_shape = torch.tensor([[2, 2, 2]], dtype=torch.long)
txt_shape = torch.tensor([[3]], dtype=torch.long)
calls = []
def fake_optimized_var_attention(**kwargs):
calls.append(kwargs)
return kwargs["q"]
monkeypatch.setattr(seedvr_model, "optimized_var_attention", fake_optimized_var_attention)
vid_out, txt_out = attn(vid, txt, vid_shape, txt_shape, seedvr_model.Cache(disable=True))
assert tuple(vid_out.shape) == (8, dim)
assert tuple(txt_out.shape) == (3, dim)
assert len(calls) == 1
call = calls[0]
assert tuple(call["q"].shape) == (14, heads, head_dim)
assert tuple(call["k"].shape) == (14, heads, head_dim)
assert tuple(call["v"].shape) == (14, heads, head_dim)
assert call["heads"] == heads
assert call["skip_reshape"] is True
assert call["skip_output_reshape"] is True
assert call["cu_seqlens_q"] == [0, 7, 14]
assert call["cu_seqlens_k"] == [0, 7, 14]
def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch):
heads = 2
head_dim = 3
q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim)
k = q + 100
v = q + 200
cu = [0, 2, 5]
calls = []
def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs):
calls.append(
{
"q_shape": tuple(q_arg.shape),
"k_shape": tuple(k_arg.shape),
"v_shape": tuple(v_arg.shape),
"heads": heads_arg,
"kwargs": kwargs,
}
)
return q_arg + v_arg
monkeypatch.setattr(attention, "optimized_attention", fake_optimized_attention)
out = var_attention_optimized_split(
q,
k,
v,
heads,
cu,
cu,
skip_reshape=True,
skip_output_reshape=True,
)
assert tuple(out.shape) == (5, heads, head_dim)
assert len(calls) == 2
assert calls[0]["q_shape"] == (1, heads, 2, head_dim)
assert calls[1]["q_shape"] == (1, heads, 3, head_dim)
assert all(call["heads"] == heads for call in calls)
assert all(call["kwargs"]["skip_reshape"] is True for call in calls)
assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls)
torch.testing.assert_close(out, q + v, rtol=0, atol=0)

View File

@ -0,0 +1,321 @@
"""SeedVR2 model, latent-format, and VAE graph regression tests."""
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
import torch
from torch import nn
from comfy.cli_args import args
if not torch.cuda.is_available():
args.cpu = True
import comfy # noqa: E402
import comfy.latent_formats # noqa: E402
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
import comfy.model_management # noqa: E402
import comfy.ops as comfy_ops # noqa: E402
import comfy.sample # noqa: E402
import comfy.sd as sd_mod # noqa: E402
import nodes as nodes_mod # noqa: E402
from comfy.ldm.seedvr.model import NaDiT # noqa: E402
_LATENT_CHANNELS = seedvr_vae_mod.SEEDVR2_LATENT_CHANNELS
def _make_standin(positive_conditioning):
class _StandIn(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"positive_conditioning", positive_conditioning
)
_resolve_text_conditioning = NaDiT._resolve_text_conditioning
return _StandIn()
class _StubModule(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]:
flags = []
class _Block(_StubModule):
def __init__(self, *args, **kwargs):
flags.append(kwargs["is_last_layer"])
super().__init__()
monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule)
monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule)
monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule)
monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block)
seedvr_model.NaDiT(
norm_eps=1e-5,
num_layers=4,
mlp_type="normal",
vid_dim=vid_dim,
txt_in_dim=txt_in_dim,
heads=24,
mm_layers=3,
operations=comfy_ops.disable_weight_init,
)
return flags
class _Model:
def __init__(self, latent_format):
self._latent_format = latent_format
def get_model_object(self, name):
assert name == "latent_format"
return self._latent_format
class _Patcher:
def get_free_memory(self, device):
return 1024 * 1024 * 1024
class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
def __init__(self, encoded):
nn.Module.__init__(self)
self.encoded = encoded
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.seen = []
def encode(self, x):
self.seen.append(tuple(x.shape))
return self.encoded.to(device=x.device, dtype=x.dtype)
class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
def __init__(self):
nn.Module.__init__(self)
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.calls = []
def decode(self, z, seedvr2_tiling=None):
self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling})
if z.ndim == 4:
b, tc, h, w = z.shape
t = tc // _LATENT_CHANNELS
else:
b, _, t, h, w = z.shape
return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
def test_seedvr2_wrapper_public_encode_returns_tensor(monkeypatch):
raw_latent = torch.full((1, _LATENT_CHANNELS, 1, 4, 5), 2.0)
seen_shapes = []
def base_encode(self, x):
seen_shapes.append(tuple(x.shape))
return raw_latent.to(device=x.device, dtype=x.dtype)
monkeypatch.setattr(seedvr_vae_mod.VideoAutoencoderKL, "encode", base_encode)
vae = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(seedvr_vae_mod.VideoAutoencoderKLWrapper)
nn.Module.__init__(vae)
vae._dummy = nn.Parameter(torch.zeros((), dtype=torch.float32))
latent = vae.encode(torch.zeros(1, 3, 32, 40))
assert type(latent) is torch.Tensor
assert tuple(latent.shape) == (1, _LATENT_CHANNELS, 4, 5)
assert seen_shapes == [(1, 3, 1, 32, 40)]
def test_seedvr2_wrapper_private_encode_helper_keeps_raw_latent(monkeypatch):
raw_latent = torch.full((1, _LATENT_CHANNELS, 1, 4, 5), 3.0)
def base_encode(self, x):
return raw_latent.to(device=x.device, dtype=x.dtype)
monkeypatch.setattr(seedvr_vae_mod.VideoAutoencoderKL, "encode", base_encode)
vae = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(seedvr_vae_mod.VideoAutoencoderKLWrapper)
nn.Module.__init__(vae)
vae._dummy = nn.Parameter(torch.zeros((), dtype=torch.float32))
latent, raw = vae._encode_with_raw_latent(torch.zeros(1, 3, 32, 40))
assert tuple(latent.shape) == (1, _LATENT_CHANNELS, 4, 5)
assert tuple(raw.shape) == (1, _LATENT_CHANNELS, 1, 4, 5)
assert torch.equal(raw, raw_latent)
def _make_vae(wrapper):
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = wrapper
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.latent_channels = _LATENT_CHANNELS
vae.latent_dim = 3
vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8)
vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
vae.output_channels = 3
vae.disable_offload = True
vae.extra_1d_channel = None
vae.crop_input = False
vae.not_video = False
vae.handles_tiling = isinstance(wrapper, seedvr_vae_mod.VideoAutoencoderKLWrapper)
vae.format_encoded = wrapper.comfy_format_encoded
vae.patcher = _Patcher()
vae.process_input = lambda image: image
vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0)
vae.vae_output_dtype = lambda: torch.float32
vae.memory_used_encode = lambda shape, dtype: 1
vae.memory_used_decode = lambda shape, dtype: 1
vae.throw_exception_if_invalid = lambda: None
vae.vae_encode_crop_pixels = lambda pixels: pixels
vae.spacial_compression_decode = lambda: 8
vae.temporal_compression_decode = lambda: 4
return vae
def test_missing_context_falls_back_to_positive_buffer():
pos_buffer = torch.full((58, 5120), 7.0)
standin = _make_standin(pos_buffer)
txt, txt_shape = standin._resolve_text_conditioning(None)
assert txt.shape == (58, 5120)
assert (txt == 7.0).all(), (
"fallback path must use the positive_conditioning buffer "
"verbatim, not a zero tensor"
)
assert txt_shape.shape == (1, 1)
assert txt_shape[0, 0].item() == 58
def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch):
assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [
False,
False,
False,
False,
]
def test_seedvr2_7b_rope3d_matches_wrapper_oracle():
rope = seedvr_model.get_na_rope("rope3d", dim=64)
generator = torch.Generator(device="cpu").manual_seed(0)
q = torch.randn(4, 2, 128, generator=generator)
k = torch.randn(4, 2, 128, generator=generator)
shape = torch.tensor([[1, 2, 2]], dtype=torch.long)
freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1)
expected_q = seedvr_model._apply_seedvr2_rotary_emb(
freqs,
q.permute(1, 0, 2).float(),
).to(q.dtype).permute(1, 0, 2)
expected_k = seedvr_model._apply_seedvr2_rotary_emb(
freqs,
k.permute(1, 0, 2).float(),
).to(k.dtype).permute(1, 0, 2)
actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True))
torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0)
torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0)
def test_seedvr2_forward_requires_conditioning_latents():
model = NaDiT.__new__(NaDiT)
x = torch.zeros(1, _LATENT_CHANNELS, 1, 4, 5)
with pytest.raises(ValueError, match="requires conditioning latents"):
NaDiT.forward(model, x, timestep=torch.tensor([1.0]), context=None)
def test_seedvr2_latent_format_uses_native_video_latent_shape():
latent_format = comfy.latent_formats.SeedVR2()
latent_image = torch.zeros(1, 1, 4, 5)
fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image)
assert latent_format.latent_channels == _LATENT_CHANNELS
assert latent_format.latent_dimensions == 3
assert fixed.shape == (1, _LATENT_CHANNELS, 1, 4, 5)
def test_seedvr2_model_requires_native_5d_latent():
latent = torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5)
assert NaDiT._check_seedvr2_video_latent(latent, _LATENT_CHANNELS, "latent") is latent
with pytest.raises(ValueError, match="5-D native latent"):
NaDiT._check_seedvr2_video_latent(torch.zeros(1, _LATENT_CHANNELS * 2, 4, 5), _LATENT_CHANNELS, "latent")
def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch):
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
encoded = torch.full((1, _LATENT_CHANNELS, 2, 4, 5), 2.0)
vae = _make_vae(_EncodeWrapper(encoded))
pixels = torch.zeros(1, 5, 32, 40, 3)
node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0]
node_latent = node_output["samples"]
assert set(node_output) == {"samples"}
assert tuple(node_latent.shape) == (1, _LATENT_CHANNELS, 2, 4, 5)
assert node_latent.dtype == torch.float32
assert node_latent.stride()[-1] == 1
assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * seedvr_vae_mod.BYTEDANCE_VAE_SCALING_FACTOR))
tiled = torch.full((1, _LATENT_CHANNELS, 2, 4, 5), 3.0)
monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled))
tiled_output = nodes_mod.VAEEncodeTiled().encode(
vae,
pixels,
tile_size=512,
overlap=64,
temporal_size=16,
temporal_overlap=4,
)[0]
tiled_latent = tiled_output["samples"]
assert set(tiled_output) == {"samples"}
assert tuple(tiled_latent.shape) == (1, _LATENT_CHANNELS, 2, 4, 5)
assert tiled_latent.dtype == torch.float32
assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * seedvr_vae_mod.BYTEDANCE_VAE_SCALING_FACTOR))
def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch):
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
vae = _make_vae(_DecodeWrapper())
nodes_mod.VAEDecodeTiled().decode(
vae,
{"samples": torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5)},
tile_size=512,
overlap=64,
temporal_size=16,
temporal_overlap=4,
)
# Spatial inputs flow through; temporal inputs are discarded — SeedVR2 owns
# temporal via the MemoryState causal cache, so VAEDecodeTiled's temporal
# knobs are no-ops at the wrapper.
assert vae.first_stage_model.calls == [
{
"shape": (1, _LATENT_CHANNELS, 2, 4, 5),
"seedvr2_tiling": {
"enable_tiling": True,
"tile_size": (512, 512),
"tile_overlap": (64, 64),
"temporal_size": 0,
"temporal_overlap": 0,
},
}
]

View File

@ -0,0 +1,94 @@
from unittest.mock import patch
import pytest
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
from comfy_extras import nodes_seedvr # noqa: E402
_LATENT_CHANNELS = vae_mod.SEEDVR2_LATENT_CHANNELS
def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper:
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
vae_mod.VideoAutoencoderKLWrapper
)
nn.Module.__init__(wrapper)
return wrapper
def _fingerprint_decode_(self, z, return_dict=True):
b = int(z.shape[0])
t = int(z.shape[2])
h = int(z.shape[3])
w = int(z.shape[4])
out = torch.empty(b, 3, t, h * 8, w * 8)
for batch_idx in range(b):
out[batch_idx].fill_(float(batch_idx + 1))
return out
def _decode_with_patches(wrapper, z):
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_):
return wrapper.decode(z)
def test_decode_b2_t3_multi_frame_batch_unchanged():
wrapper = _make_wrapper()
out = _decode_with_patches(wrapper, torch.zeros(2, _LATENT_CHANNELS * 3, 2, 2))
assert tuple(out.shape) == (2, 3, 3, 16, 16)
class _Wrapper(vae_mod.VideoAutoencoderKLWrapper):
def __init__(self):
nn.Module.__init__(self)
self.calls = []
def parameters(self):
return iter([torch.nn.Parameter(torch.zeros(()))])
def _decode_stub(self, latent):
self.calls.append(tuple(latent.shape))
return torch.zeros(latent.shape[0], 3, latent.shape[2], latent.shape[3] * 8, latent.shape[4] * 8)
def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preprocessor_state():
wrapper = _Wrapper()
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
out = wrapper.decode(torch.zeros(1, _LATENT_CHANNELS, 2, 4, 5))
assert tuple(out.shape) == (1, 3, 2, 32, 40)
assert wrapper.calls == [(1, _LATENT_CHANNELS, 2, 4, 5)]
def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents():
wrapper = _Wrapper()
with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"):
wrapper.decode(torch.zeros(1, _LATENT_CHANNELS, 4))
def _t_padded(t_in: int) -> int:
if t_in == 1:
return 1
if t_in <= 4:
return 5
if (t_in - 1) % 4 == 0:
return t_in
return t_in + (4 - ((t_in - 1) % 4))
@pytest.mark.parametrize("t_in", [1, 5, 9])
def test_t_padded_matches_cut_videos(t_in):
dummy = torch.zeros(1, t_in, 1, 1, 1)
assert nodes_seedvr.cut_videos(dummy).shape[1] == _t_padded(t_in)

View File

@ -0,0 +1,382 @@
from contextlib import ExitStack
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
import comfy.sd as sd_mod # noqa: E402
from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402
_LATENT_CHANNELS = seedvr_vae_mod.SEEDVR2_LATENT_CHANNELS
def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
class StubVAEModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.slicing_latent_min_size = 2
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self.use_slicing = True
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
self.decode_min_sizes = []
self.memory_states = []
def decode_(self, t_chunk):
self.decode_min_sizes.append(self.slicing_latent_min_size)
return vae_mod.VideoAutoencoderKL.slicing_decode(self, t_chunk)
def _decode(self, z, memory_state=MemoryState.DISABLED, memory_cache=None):
self.memory_states.append(memory_state)
b, c, d, h, w = z.shape
return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype)
vae = StubVAEModel()
z = torch.zeros((1, _LATENT_CHANNELS, 5, 8, 8), dtype=torch.float32)
tiled_vae(
z,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=0,
temporal_overlap=0,
encode=False,
)
assert vae.decode_min_sizes == [5]
assert vae.memory_states == [MemoryState.DISABLED]
assert vae.slicing_latent_min_size == 2
def test_zero_temporal_size_preserves_min_size_when_encode_raises():
class RaisingVAEModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.slicing_sample_min_size = 4
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
def encode(self, t_chunk):
raise RuntimeError("simulated encode failure")
vae = RaisingVAEModel()
x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32)
with pytest.raises(RuntimeError, match="simulated encode failure"):
tiled_vae(
x,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=0,
temporal_overlap=0,
encode=True,
)
assert vae.slicing_sample_min_size == 4
def test_tiled_vae_encode_uses_tensor_return_without_indexing():
class TensorEncodeVAEModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.slicing_sample_min_size = 4
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
self.calls = []
def encode(self, t_chunk):
self.calls.append(tuple(t_chunk.shape))
b, _, _, h, w = t_chunk.shape
return torch.ones((b, _LATENT_CHANNELS, 1, h // 8, w // 8), dtype=t_chunk.dtype)
vae = TensorEncodeVAEModel()
x = torch.zeros((2, 3, 1, 64, 64), dtype=torch.float32)
out = tiled_vae(
x,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=0,
temporal_overlap=0,
encode=True,
)
assert vae.calls == [(2, 3, 1, 64, 64)]
assert tuple(out.shape) == (2, _LATENT_CHANNELS, 1, 8, 8)
def test_tiled_vae_preserves_input_dtype_on_single_tile():
class FloatOutputVAEModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.slicing_sample_min_size = 4
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
def encode(self, t_chunk):
b, _, _, h, w = t_chunk.shape
return torch.ones((b, _LATENT_CHANNELS, 1, h // 8, w // 8), dtype=torch.float32)
out = tiled_vae(
torch.zeros((1, 3, 1, 64, 64), dtype=torch.float16),
FloatOutputVAEModel(),
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=0,
temporal_overlap=0,
encode=True,
)
assert out.dtype == torch.float16
class _SlicingDecodeVAE(nn.Module):
def __init__(self, slicing_latent_min_size):
super().__init__()
self.slicing_latent_min_size = slicing_latent_min_size
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self.use_slicing = True
self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32))
self.decode_min_sizes = []
self.memory_states = []
def decode_(self, z):
self.decode_min_sizes.append(self.slicing_latent_min_size)
return vae_mod.VideoAutoencoderKL.slicing_decode(self, z)
def _decode(self, z, memory_state=MemoryState.DISABLED, memory_cache=None):
self.memory_states.append(memory_state)
x = z[:, :1].repeat(
1,
3,
1,
self.spatial_downsample_factor,
self.spatial_downsample_factor,
)
return x
def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size():
vae = _SlicingDecodeVAE(slicing_latent_min_size=2)
z = torch.arange(
_LATENT_CHANNELS * 5 * 8 * 8,
dtype=torch.float32,
).reshape(1, _LATENT_CHANNELS, 5, 8, 8)
tiled_vae(
z,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=12,
temporal_overlap=4,
encode=False,
)
assert vae.decode_min_sizes == [2]
assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE]
assert vae.slicing_latent_min_size == 2
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
vae_mod.VideoAutoencoderKLWrapper
)
nn.Module.__init__(wrapper)
seedvr2_tiling = {
"enable_tiling": True,
"tile_size": (64, 64),
"tile_overlap": (0, 0),
"temporal_size": 8,
"temporal_overlap": 7,
}
captured = {}
def _fake_tiled_vae(latent, model, **kwargs):
captured.update(kwargs)
return torch.zeros(1, 3, 1, 16, 16)
with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae):
wrapper.decode(torch.zeros(1, _LATENT_CHANNELS, 2, 2), seedvr2_tiling=seedvr2_tiling)
assert captured["temporal_overlap"] == 7
def _force_oom(*a, **k):
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
def _make_vae(first_stage_model, latent_channels, latent_dim):
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = first_stage_model
vae.patcher = MagicMock()
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
vae.device = vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.disable_offload = True
vae.extra_1d_channel = None
vae.upscale_ratio = vae.downscale_ratio = 8
vae.upscale_index_formula = vae.downscale_index_formula = None
vae.output_channels = 3
vae.latent_channels = latent_channels
vae.latent_dim = latent_dim
vae.vae_output_dtype = lambda: torch.float32
vae.spacial_compression_decode = lambda: 8
vae.handles_tiling = isinstance(first_stage_model, seedvr_vae_mod.VideoAutoencoderKLWrapper)
vae.format_encoded = None
vae.process_input = lambda x: x
vae.process_output = lambda x: x
vae.throw_exception_if_invalid = lambda: None
vae.memory_used_decode = lambda *a, **k: 1
return vae
def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode):
mm = sd_mod.model_management
with ExitStack() as stack:
stack.enter_context(patch.object(mm, "raise_non_oom", lambda e: None))
stack.enter_context(patch.object(mm, "load_models_gpu", lambda *a, **k: None))
stack.enter_context(patch.object(mm, "soft_empty_cache", lambda: None))
stack.enter_context(patch.object(sd_mod.VAE, "_decode_tiled_owned", seedvr2_call))
stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_", generic_call))
if patch_wrapper_decode:
stack.enter_context(patch.object(
seedvr_vae_mod.VideoAutoencoderKLWrapper, "decode",
side_effect=_force_oom))
vae.decode(samples)
def test_4d_seedvr2_latent_routes_to_owned_decode_tiled():
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
seedvr_vae_mod.VideoAutoencoderKLWrapper)
vae = _make_vae(wrapper, latent_channels=_LATENT_CHANNELS, latent_dim=3)
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
_dispatch(vae, torch.zeros(1, _LATENT_CHANNELS * 3, 8, 8), seedvr2_call, generic_call, True)
assert seedvr2_call.call_count == 1
assert generic_call.call_count == 0
def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled():
first_stage = MagicMock()
first_stage.decode = MagicMock(side_effect=_force_oom)
vae = _make_vae(first_stage, latent_channels=4, latent_dim=2)
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
_dispatch(vae, torch.zeros(1, 4, 8, 8), seedvr2_call, generic_call, False)
assert generic_call.call_count == 1
assert seedvr2_call.call_count == 0
def _populate_common_vae_attrs_fallback(vae):
vae.patcher = MagicMock()
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.disable_offload = True
vae.extra_1d_channel = None
vae.upscale_ratio = 8
vae.upscale_index_formula = None
vae.output_channels = 3
vae.latent_channels = _LATENT_CHANNELS
vae.latent_dim = 3
vae.downscale_ratio = 8
vae.downscale_index_formula = None
vae.not_video = False
vae.crop_input = False
vae.pad_channel_value = None
vae.handles_tiling = isinstance(vae.first_stage_model, seedvr_vae_mod.VideoAutoencoderKLWrapper)
vae.format_encoded = None
vae.vae_output_dtype = lambda: torch.float32
vae.spacial_compression_encode = lambda: 8
vae.process_input = lambda x: x
vae.process_output = lambda x: x
vae.throw_exception_if_invalid = lambda: None
vae.memory_used_encode = lambda *a, **k: 1
def _make_seedvr2_vae_fallback():
vae = sd_mod.VAE.__new__(sd_mod.VAE)
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
seedvr_vae_mod.VideoAutoencoderKLWrapper
)
vae.first_stage_model = wrapper
_populate_common_vae_attrs_fallback(vae)
return vae
def _make_non_seedvr2_vae_fallback():
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = MagicMock()
_populate_common_vae_attrs_fallback(vae)
return vae
def _force_regular_encode_oom(*args, **kwargs):
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
def test_seedvr2_3d_routes_to_owned_encode_tiled_on_oom():
vae = _make_seedvr2_vae_fallback()
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
seedvr2_call = MagicMock(return_value=torch.zeros(1, _LATENT_CHANNELS, 2, 8, 8))
generic_call = MagicMock(return_value=torch.zeros(1, _LATENT_CHANNELS, 2, 8, 8))
with patch.object(sd_mod.model_management, "raise_non_oom",
lambda e: None), \
patch.object(sd_mod.model_management, "load_models_gpu",
lambda *a, **k: None), \
patch.object(sd_mod.model_management, "soft_empty_cache",
lambda: None), \
patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode",
side_effect=_force_regular_encode_oom), \
patch.object(sd_mod.VAE, "_encode_tiled_owned", seedvr2_call), \
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
vae.encode(pixel_samples)
assert seedvr2_call.call_count == 1, (
f"Expected _encode_tiled_owned to be called once for a SeedVR2 3D "
f"input under OOM fallback; got {seedvr2_call.call_count} calls."
)
assert generic_call.call_count == 0, (
f"encode_tiled_3d must NOT be called for a SeedVR2 input; got "
f"{generic_call.call_count} calls."
)
def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete():
vae = _make_non_seedvr2_vae_fallback()
vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8)
vae.upscale_ratio = (lambda a: a * 4, 8, 8)
generic_call = MagicMock(return_value=torch.zeros(1, _LATENT_CHANNELS, 2, 8, 8))
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
with patch.object(sd_mod.model_management, "load_models_gpu",
lambda *a, **k: None), \
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
vae.encode_tiled(pixel_samples)
assert generic_call.call_args.kwargs["overlap"] == (1, 64, 64)