This commit is contained in:
John Pollock 2026-05-30 01:11:55 +09:00 committed by GitHub
commit 32cd663883
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 7814 additions and 40 deletions

View File

@ -4,6 +4,7 @@ class LatentFormat:
scale_factor = 1.0
latent_channels = 4
latent_dimensions = 2
preserve_empty_channel_multiples = False
latent_rgb_factors = None
latent_rgb_factors_bias = None
latent_rgb_factors_reshape = None
@ -769,6 +770,10 @@ class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2
class SeedVR2(LatentFormat):
latent_channels = 16
preserve_empty_channel_multiples = True
class ACEAudio15(LatentFormat):
latent_channels = 64
latent_dimensions = 1

View File

@ -32,6 +32,14 @@ except ImportError as e:
raise e
exit(-1)
SAGE_ATTENTION_VARLEN_IS_AVAILABLE = False
try:
from sageattention import sageattn_varlen
SAGE_ATTENTION_VARLEN_IS_AVAILABLE = True
except ImportError:
if model_management.sage_attention_enabled():
logging.warning("SageAttention variable-length attention is unavailable, using pytorch var-len attention instead.")
SAGE_ATTENTION3_IS_AVAILABLE = False
try:
from sageattn3 import sageattn3_blackwell
@ -40,6 +48,7 @@ except ImportError:
pass
FLASH_ATTENTION_IS_AVAILABLE = False
FLASH_ATTENTION_VARLEN_IS_AVAILABLE = False
try:
from flash_attn import flash_attn_func
FLASH_ATTENTION_IS_AVAILABLE = True
@ -48,6 +57,20 @@ except ImportError:
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
try:
from flash_attn import flash_attn_varlen_func
FLASH_ATTENTION_VARLEN_IS_AVAILABLE = True
except ImportError:
if model_management.flash_attention_enabled() and FLASH_ATTENTION_IS_AVAILABLE:
logging.warning("Flash Attention variable-length attention is unavailable, using pytorch var-len attention instead.")
FLASH_ATTENTION3_IS_AVAILABLE = False
try:
from flash_attn_interface import flash_attn_varlen_func as flash_attn3_varlen_func
FLASH_ATTENTION3_IS_AVAILABLE = True
except ImportError:
pass
REGISTERED_ATTENTION_FUNCTIONS = {}
def register_attention_function(name: str, func: Callable):
# avoid replacing existing functions
@ -735,28 +758,434 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
)
return out
_VAR_ATTENTION_NESTED_API_NAME = "nested_tensor_from_jagged"
_VAR_ATTENTION_GUARD_MESSAGE = (
"SeedVR2 var_attention_pytorch: torch.nested.nested_tensor_from_jagged "
"is required by this attention path; the installed PyTorch build "
"does not provide it"
)
def _var_attention_max_seqlen(cu_seqlens):
return int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item())
def _var_attention_qkv(q, k, v, heads, skip_reshape):
if skip_reshape:
return q, k, v, q.shape[-1]
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
return (
q.view(total_tokens, heads, head_dim),
k.view(k.shape[0], heads, head_dim),
v.view(v.shape[0], heads, head_dim),
head_dim,
)
def _var_attention_output(out, heads, head_dim, skip_output_reshape):
if skip_output_reshape:
return out
return out.reshape(-1, heads * head_dim)
def _use_blackwell_attention():
device = model_management.get_torch_device()
if device.type != "cuda":
return False
major, minor = torch.cuda.get_device_capability(device)
return (major, minor) >= (12, 0)
def var_attention_pytorch(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
_nested = getattr(torch, "nested", None)
if _nested is None or not hasattr(_nested, _VAR_ATTENTION_NESTED_API_NAME):
raise RuntimeError(_VAR_ATTENTION_GUARD_MESSAGE)
if not skip_reshape:
# assumes 2D q, k,v [total_tokens, embed_dim]
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
q = q.view(total_tokens, heads, head_dim)
k = k.view(k.shape[0], heads, head_dim)
v = v.view(v.shape[0], heads, head_dim)
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(1, 2)
if not skip_output_reshape:
return out.values().reshape(-1, heads * (q.shape[-1]))
return out.values()
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
if cu_seqlens.dtype not in (torch.int32, torch.int64):
raise ValueError(f"{name} must use an integer dtype")
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
if cu_seqlens[0].item() != 0:
raise ValueError(f"{name} must start at 0")
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
raise ValueError(f"{name} must be strictly increasing")
if cu_seqlens[-1].item() != token_count:
raise ValueError(f"{name} does not match token count")
def _split_indices(cu_seqlens):
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
def var_attention_pytorch_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, skip_reshape=False, skip_output_reshape=False):
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0])
if cu_seqlens_k[-1].item() != v.shape[0]:
raise ValueError("cu_seqlens_k does not match v token count")
q_split_indices = _split_indices(cu_seqlens_q)
k_split_indices = _split_indices(cu_seqlens_k)
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits):
raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count")
out = []
for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits):
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
out_i = comfy.ops.scaled_dot_product_attention(q_i, k_i, v_i, attn_mask=None, dropout_p=0.0, is_causal=False)
out.append(out_i.squeeze(0).permute(1, 0, 2))
out = torch.cat(out, dim=0)
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
@torch._dynamo.disable
def var_attention_sage(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
if not SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
)
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
out_dtype = q.dtype
if not (q.dtype == k.dtype == v.dtype):
k = k.to(q.dtype)
v = v.to(q.dtype)
fallback_q, fallback_k, fallback_v = q, k, v
if q.dtype not in (torch.float16, torch.bfloat16):
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)
v = v.to(torch.bfloat16)
sm_scale = kwargs.get("softmax_scale")
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(head_dim)
try:
out = sageattn_varlen(
q,
k,
v,
cu_seqlens_q.int(),
cu_seqlens_k.int(),
_var_attention_max_seqlen(cu_seqlens_q),
_var_attention_max_seqlen(cu_seqlens_k),
kwargs.get("causal", False),
sm_scale,
)
except Exception as e:
logging.error("Error running sage var-len attention: %s, using pytorch var-len attention instead.", e)
out = var_attention_pytorch(
fallback_q,
fallback_k,
fallback_v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
)
if out.dtype != out_dtype:
out = out.to(out_dtype)
return out
if out.dtype != out_dtype:
out = out.to(out_dtype)
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
@torch._dynamo.disable
def var_attention_sage3(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
if not SAGE_ATTENTION3_IS_AVAILABLE:
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
return var_attention_sage(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs,
)
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
)
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
seq_lens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
seq_lens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
uniform_q = bool((seq_lens_q == seq_lens_q[0]).all().item())
uniform_k = bool((seq_lens_k == seq_lens_k[0]).all().item())
if not (uniform_q and uniform_k and seq_lens_q[0] == seq_lens_k[0]):
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
return var_attention_sage(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
**kwargs,
)
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
)
out_dtype = q.dtype
if not (q.dtype == k.dtype == v.dtype):
k = k.to(q.dtype)
v = v.to(q.dtype)
fallback_q, fallback_k, fallback_v = q, k, v
if q.dtype not in (torch.float16, torch.bfloat16):
q = q.to(torch.bfloat16)
k = k.to(torch.bfloat16)
v = v.to(torch.bfloat16)
batch_size = len(cu_seqlens_q) - 1
seq_len = int(seq_lens_q[0].item())
q = q.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, heads, head_dim).transpose(1, 2)
try:
out = sageattn3_blackwell(q, k, v, is_causal=kwargs.get("causal", False))
except Exception as e:
logging.error("Error running SageAttention3 var-len attention: %s, using fallback var-len attention instead.", e)
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
return var_attention_sage(
fallback_q,
fallback_k,
fallback_v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
**kwargs,
)
return var_attention_pytorch(
fallback_q,
fallback_k,
fallback_v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
)
out = out.transpose(1, 2).reshape(-1, heads, head_dim).contiguous()
if out.dtype != out_dtype:
out = out.to(out_dtype)
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
@torch._dynamo.disable
def var_attention_flash(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
if not FLASH_ATTENTION_VARLEN_IS_AVAILABLE:
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
)
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
max_seqlen_q = _var_attention_max_seqlen(cu_seqlens_q)
max_seqlen_k = _var_attention_max_seqlen(cu_seqlens_k)
try:
out = flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q.int(),
cu_seqlens_k=cu_seqlens_k.int(),
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=kwargs.get("dropout_p", 0.0),
causal=kwargs.get("causal", False),
deterministic=torch.are_deterministic_algorithms_enabled(),
)
except Exception as e:
logging.error("Error running Flash Attention var-len attention: %s, using pytorch var-len attention instead.", e)
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
)
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
@torch._dynamo.disable
def var_attention_flash3(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
if not FLASH_ATTENTION3_IS_AVAILABLE:
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
)
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
max_seqlen_q = _var_attention_max_seqlen(cu_seqlens_q)
max_seqlen_k = _var_attention_max_seqlen(cu_seqlens_k)
try:
out = flash_attn3_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q.int(),
cu_seqlens_k=cu_seqlens_k.int(),
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
seqused_q=None,
seqused_k=None,
softmax_scale=kwargs.get("softmax_scale"),
causal=kwargs.get("causal", False),
deterministic=torch.are_deterministic_algorithms_enabled(),
)
except Exception as e:
logging.error("Error running Flash Attention 3 var-len attention: %s, using pytorch var-len attention instead.", e)
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
)
if isinstance(out, tuple):
out = out[0]
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
@torch._dynamo.disable
def var_attention_sub_quad(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
return var_attention_pytorch(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
)
@torch._dynamo.disable
def var_attention_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
return var_attention_pytorch_split(
q,
k,
v,
heads,
cu_seqlens_q,
cu_seqlens_k,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
)
optimized_var_attention = var_attention_pytorch
optimized_attention = attention_basic
if model_management.sage_attention_enabled():
logging.info("Using sage attention")
optimized_attention = attention_sage
if SAGE_ATTENTION3_IS_AVAILABLE and _use_blackwell_attention():
logging.info("Using SageAttention3 for variable-length attention")
optimized_var_attention = var_attention_sage3
elif SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
logging.info("Using SageAttention for variable-length attention")
optimized_var_attention = var_attention_sage
else:
logging.info("Using pytorch attention for variable-length attention")
optimized_var_attention = var_attention_pytorch
elif model_management.flash_attention_enabled():
logging.info("Using Flash Attention")
optimized_attention = attention_flash
if FLASH_ATTENTION_VARLEN_IS_AVAILABLE and model_management.get_torch_device().type == "cuda":
logging.info("Using Flash Attention 2 for variable-length attention")
optimized_var_attention = var_attention_flash
else:
logging.info("Using pytorch attention for variable-length attention")
optimized_var_attention = var_attention_pytorch
elif model_management.xformers_enabled():
logging.info("Using xformers attention")
optimized_attention = attention_xformers
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention")
optimized_attention = attention_pytorch
optimized_var_attention = var_attention_pytorch
else:
if args.use_split_cross_attention:
logging.info("Using split optimization for attention")
optimized_attention = attention_split
optimized_var_attention = var_attention_split
else:
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
optimized_var_attention = var_attention_sub_quad
optimized_attention_masked = optimized_attention
@ -764,15 +1193,25 @@ optimized_attention_masked = optimized_attention
# register core-supported attention functions
if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage)
if SAGE_ATTENTION_VARLEN_IS_AVAILABLE:
register_attention_function("var_attention_sage", var_attention_sage)
if SAGE_ATTENTION3_IS_AVAILABLE:
register_attention_function("sage3", attention3_sage)
register_attention_function("var_attention_sage3", var_attention_sage3)
if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash)
if FLASH_ATTENTION_VARLEN_IS_AVAILABLE:
register_attention_function("var_attention_flash", var_attention_flash)
if FLASH_ATTENTION3_IS_AVAILABLE:
register_attention_function("var_attention_flash3", var_attention_flash3)
if model_management.xformers_enabled():
register_attention_function("xformers", attention_xformers)
register_attention_function("pytorch", attention_pytorch)
register_attention_function("var_attention_pytorch", var_attention_pytorch)
register_attention_function("sub_quad", attention_sub_quad)
register_attention_function("var_attention_sub_quad", var_attention_sub_quad)
register_attention_function("split", attention_split)
register_attention_function("var_attention_split", var_attention_split)
def optimized_attention_for_device(device, mask=False, small_input=False):
@ -1209,5 +1648,3 @@ class SpatialVideoTransformer(SpatialTransformer):
x = self.proj_out(x)
out = x + x_in
return out

View File

@ -13,6 +13,7 @@ if model_management.xformers_enabled_vae():
import xformers
import xformers.ops
def torch_cat_if_needed(xl, dim):
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
if len(xl) > 1:
@ -22,7 +23,8 @@ def torch_cat_if_needed(xl, dim):
else:
return None
def get_timestep_embedding(timesteps, embedding_dim):
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
@ -33,11 +35,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = math.log(10000) / (half_dim - downscale_freq_shift)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0))
return emb

1629
comfy/ldm/seedvr/model.py Normal file

File diff suppressed because it is too large Load Diff

2421
comfy/ldm/seedvr/vae.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -53,6 +53,8 @@ import comfy.ldm.pixeldit.model
import comfy.ldm.pixeldit.pid
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.seedvr.model
import comfy.ldm.qwen_image.model
import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
@ -926,6 +928,16 @@ class HunyuanDiT(BaseModel):
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
return out
class SeedVR2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
condition = kwargs.get("condition", None)
if condition is not None:
out["condition"] = comfy.conds.CONDRegular(condition)
return out
class PixArt(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)

View File

@ -594,6 +594,56 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config
if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 3072
dit_config["heads"] = 24
dit_config["num_layers"] = 36
# 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.``
# submodules) at EVERY block — verified by inspecting the 7B
# state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means
# ``MMModule.shared_weights=False``). Native NaDiT computes
# per-block ``shared_weights = not (i < mm_layers)``, so to keep
# every block non-shared we set ``mm_layers = num_layers``.
# Without this, blocks at index >= mm_layers (default 10) try to
# load ``blocks.N.*.all.*`` keys that don't exist in the file,
# silently miss-load → all-black output.
dit_config["mm_layers"] = 36
dit_config["norm_eps"] = 1e-5
dit_config["qk_rope"] = True
dit_config["rope_type"] = "rope3d"
dit_config["rope_dim"] = 64
dit_config["mlp_type"] = "normal"
return dit_config
elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 3072
dit_config["heads"] = 24
dit_config["num_layers"] = 36
# This checkpoint layout carries shared ``all.`` MMModule keys.
# Preserve the historical split: the initial blocks use separate
# vid/txt modules, later blocks use shared modules.
dit_config["mm_layers"] = 10
dit_config["norm_eps"] = 1e-5
dit_config["qk_rope"] = True
dit_config["rope_type"] = "rope3d"
dit_config["rope_dim"] = 64
dit_config["mlp_type"] = "swiglu"
return dit_config
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 2560
dit_config["heads"] = 20
dit_config["num_layers"] = 32
dit_config["norm_eps"] = 1.0e-05
dit_config["qk_rope"] = None
dit_config["mlp_type"] = "swiglu"
dit_config["vid_out_norm"] = True
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {}
dit_config["image_model"] = "wan2.1"

View File

@ -44,7 +44,13 @@ def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None,
is_empty = torch.count_nonzero(latent_image) == 0
if is_empty:
if latent_format.latent_channels != latent_image.shape[1]:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
preserves_collapsed_channels = (
getattr(latent_format, "preserve_empty_channel_multiples", False)
and latent_image.ndim == 4
and latent_image.shape[1] % latent_format.latent_channels == 0
)
if not preserves_collapsed_channels:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
if downscale_ratio_spacial is not None:
if downscale_ratio_spacial != latent_format.spacial_downscale_ratio:
ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio

View File

@ -1,3 +1,4 @@
import inspect
import json
import torch
from enum import Enum
@ -16,6 +17,7 @@ import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.seedvr.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.cogvideo.vae
import comfy.ldm.hunyuan_video.vae
@ -82,6 +84,36 @@ import comfy.latent_formats
import comfy.ldm.flux.redux
SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL = 160
def _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w):
output_t = max(1, (latent_t - 1) * 4 + 1)
return output_t * latent_h * 8 * latent_w * 8
def _seedvr2_vae_decode_memory_used(shape):
if len(shape) == 5:
candidates = []
if shape[1] == 16:
candidates.append((shape[2], shape[3], shape[4]))
if shape[-1] == 16:
candidates.append((shape[1], shape[2], shape[3]))
if len(candidates) == 0:
candidates.append((shape[2], shape[3], shape[4]))
output_pixels = max(_seedvr2_vae_decode_output_pixels(*candidate) for candidate in candidates)
elif len(shape) == 4:
latent_t = max(1, (shape[1] + 15) // 16)
latent_h, latent_w = shape[2], shape[3]
output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w)
else:
latent_t, latent_h, latent_w = 1, shape[-2], shape[-1]
output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w)
# SeedVR2 decode performs full-frame LAB histogram matching: fp32 channels
# plus int64 sort indices dominate peak memory, not the VAE weight dtype.
return output_pixels * SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None):
key_map = {}
if model is not None:
@ -465,8 +497,10 @@ class CLIP:
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd
if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
if metadata is None or metadata.get("keep_diffusers_format") != "true":
sd = diffusers_convert.convert_vae_state_dict(sd)
if model_management.is_amd():
VAE_KL_MEM_RATIO = 2.73
@ -538,6 +572,20 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
self.latent_channels = 16
self.latent_dim = 3
self.disable_offload = True
self.memory_used_decode = lambda shape, dtype: _seedvr2_vae_decode_memory_used(shape)
self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.process_input = lambda image: image * 2.0 - 1.0
self.crop_input = False
elif "decoder.conv_in.weight" in sd:
if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
@ -665,6 +713,7 @@ class VAE:
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
self.downscale_index_formula = (8, 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
@ -994,6 +1043,40 @@ class VAE:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def decode_tiled_seedvr2(self, samples, tile_x=32, tile_y=32, overlap=8, tile_t=16, overlap_t=4):
sf_s = getattr(self.first_stage_model, "spatial_downsample_factor", 8)
sf_t = getattr(self.first_stage_model, "temporal_downsample_factor", 4)
if tile_t is None:
tile_t = 16
if overlap_t is None:
overlap_t = 4
if tile_t > 0:
temporal_size = tile_t * sf_t
temporal_overlap = max(0, overlap_t) * sf_t
else:
temporal_size = 0
temporal_overlap = 0
args = {
"enable_tiling": True,
"tile_size": (tile_y * sf_s, tile_x * sf_s),
"tile_overlap": (overlap * sf_s, overlap * sf_s),
"temporal_size": temporal_size,
"temporal_overlap": temporal_overlap,
}
output = self.first_stage_model.decode(
samples.to(self.vae_dtype).to(self.device),
seedvr2_tiling=args,
)
return self.process_output(output.to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
def _format_seedvr2_encoded_samples(self, samples):
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
if samples.ndim == 4:
samples = samples.unsqueeze(2)
samples = samples.contiguous()
samples = samples * 0.9152
return samples
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
@ -1030,6 +1113,36 @@ class VAE:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def encode_tiled_seedvr2(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
if tile_y is None:
tile_y = 512
if tile_x is None:
tile_x = 512
if overlap is None:
overlap_y = 64
overlap_x = 64
else:
overlap_y = overlap
overlap_x = overlap
if tile_t is None:
tile_t = 9999
if overlap_t is None:
overlap_t = 0
overlap_y = min(overlap_y, max(0, tile_y - 8))
overlap_x = min(overlap_x, max(0, tile_x - 8))
self.first_stage_model.device = self.device
x = self.process_input(pixel_samples).to(self.vae_dtype).to(self.device)
output = comfy.ldm.seedvr.vae.tiled_vae(
x,
self.first_stage_model,
tile_size=(tile_y, tile_x),
tile_overlap=(overlap_y, overlap_x),
temporal_size=tile_t,
temporal_overlap=overlap_t,
encode=True,
)
return output.to(device=self.output_device, dtype=self.vae_output_dtype())
def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid()
pixel_samples = None
@ -1077,16 +1190,40 @@ class VAE:
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
# SeedVR2 latents arrive in 4D collapsed form ``(B, 16*T, H, W)``
# downstream of ``SeedVR2Conditioning`` (which performs the
# ``rearrange(b c t h w -> b (c t) h w)`` collapse). The
# generic ``decode_tiled_`` would treat the channel dim as
# spatial-only and crash on the collapsed (16, T) layout
# under ``tiled_scale``'s mask broadcast; route SeedVR2 4D
# latents to ``decode_tiled_seedvr2`` instead, whose wrapper
# dispatch handles both 4D and 5D inputs.
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
else:
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3:
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
else:
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
def decode_tiled(
self,
samples,
tile_x=None,
tile_y=None,
overlap=None,
tile_t=None,
overlap_t=None,
):
self.throw_exception_if_invalid()
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -1100,7 +1237,20 @@ class VAE:
args["overlap"] = overlap
with model_management.cuda_device_context(self.device):
if dims == 1 or self.extra_1d_channel is not None:
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper) and dims in (2, 3):
seedvr2_args = {}
if tile_x is not None:
seedvr2_args["tile_x"] = tile_x
if tile_y is not None:
seedvr2_args["tile_y"] = tile_y
if overlap is not None:
seedvr2_args["overlap"] = overlap
if tile_t is not None:
seedvr2_args["tile_t"] = tile_t
if overlap_t is not None:
seedvr2_args["overlap_t"] = overlap_t
output = self.decode_tiled_seedvr2(samples, **seedvr2_args)
elif dims == 1 or self.extra_1d_channel is not None:
args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args)
elif dims == 2:
@ -1142,6 +1292,8 @@ class VAE:
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
if isinstance(out, tuple):
out = out[0]
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
@ -1161,20 +1313,23 @@ class VAE:
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
samples = self.encode_tiled_seedvr2(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap)
else:
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples)
else:
samples = self.encode_tiled_(pixel_samples)
return samples
return self._format_seedvr2_encoded_samples(samples)
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1)
if dims == 3:
if dims == 3 and pixel_samples.ndim < 5:
if not self.not_video:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
else:
@ -1198,22 +1353,47 @@ class VAE:
elif dims == 2:
samples = self.encode_tiled_(pixel_samples, **args)
elif dims == 3:
if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
seedvr2_args = {}
if tile_x is not None:
seedvr2_args["tile_x"] = tile_x
else:
seedvr2_args["tile_x"] = 512
if tile_y is not None:
seedvr2_args["tile_y"] = tile_y
else:
seedvr2_args["tile_y"] = 512
if overlap is not None:
seedvr2_args["overlap"] = overlap
else:
seedvr2_args["overlap"] = 64
if tile_t is not None:
seedvr2_args["tile_t"] = tile_t
else:
seedvr2_args["tile_t"] = 9999
if overlap_t is not None:
seedvr2_args["overlap_t"] = overlap_t
else:
seedvr2_args["overlap_t"] = 0
samples = self.encode_tiled_seedvr2(pixel_samples, **seedvr2_args)
else:
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
else:
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
spatial_overlap = overlap if overlap is not None else 64
if overlap_t is None:
args["overlap"] = (1, spatial_overlap, spatial_overlap)
else:
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), spatial_overlap, spatial_overlap)
maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
return samples
return self._format_seedvr2_encoded_samples(samples)
def get_sd(self):
return self.first_stage_model.state_dict()
@ -1735,6 +1915,17 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae)
def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device):
set_dtype = model_config.set_inference_dtype
parameters = inspect.signature(set_dtype).parameters
supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values())
if supports_device:
set_dtype(dtype, manual_cast_dtype, device=device)
else:
set_dtype(dtype, manual_cast_dtype)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
@ -1842,7 +2033,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
if model_config.clip_vision_prefix is not None:
if output_clipvision:
@ -1983,7 +2174,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
if custom_operations is not None:
model_config.custom_operations = custom_operations

View File

@ -1647,6 +1647,35 @@ class Chroma(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
class SeedVR2(supported_models_base.BASE):
unet_config = {
"image_model": "seedvr2"
}
latent_format = comfy.latent_formats.SeedVR2
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
sampling_settings = {
"shift": 1.0,
}
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
if (
dtype == torch.float16
and manual_cast_dtype is None
and comfy.model_management.should_use_bf16(device)
):
manual_cast_dtype = torch.bfloat16
super().set_inference_dtype(dtype, manual_cast_dtype, device=device)
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SeedVR2(self, device=device)
return out
def clip_target(self, state_dict={}):
return None
class ChromaRadiance(Chroma):
unet_config = {
"image_model": "chroma_radiance",
@ -1966,7 +1995,6 @@ class LongCatImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
class RT_DETR_v4(supported_models_base.BASE):
unet_config = {
"image_model": "RT_DETR_v4",
@ -2203,6 +2231,7 @@ models = [
HiDream,
HiDreamO1,
Chroma,
SeedVR2,
ChromaRadiance,
ACEStep,
ACEStep15,

View File

@ -115,7 +115,7 @@ class BASE:
replace_prefix = {"": self.vae_key_prefix[0]}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def set_inference_dtype(self, dtype, manual_cast_dtype):
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype

1164
comfy_extras/nodes_seedvr.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -47,14 +47,18 @@ import node_helpers
if args.enable_manager:
import comfyui_manager
def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted()
def interrupt_processing(value=True):
comfy.model_management.interrupt_current_processing(value)
MAX_RESOLUTION=16384
class CLIPTextEncode(ComfyNodeABC):
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
@ -323,8 +327,8 @@ class VAEDecodeTiled:
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32, "advanced": True}),
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}),
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time.", "advanced": True}),
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}),
"temporal_size": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time. SeedVR2 allows 0 to disable temporal slicing.", "advanced": True}),
"temporal_overlap": ("INT", {"default": 8, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
@ -334,18 +338,32 @@ class VAEDecodeTiled:
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
if tile_size < overlap * 4:
overlap = tile_size // 4
if temporal_size < temporal_overlap * 2:
temporal_overlap = temporal_overlap // 2
temporal_compression = vae.temporal_compression_decode()
if temporal_compression is not None:
temporal_size = max(2, temporal_size // temporal_compression)
temporal_overlap = max(1, min(temporal_size // 2, temporal_overlap // temporal_compression))
if temporal_size <= 0:
temporal_size = 0
temporal_overlap = 0
else:
requested_temporal_overlap = temporal_overlap
if temporal_size < temporal_overlap * 2:
temporal_overlap = temporal_overlap // 2
temporal_size = max(2, temporal_size // temporal_compression)
temporal_overlap = min(temporal_size // 2, temporal_overlap // temporal_compression)
if requested_temporal_overlap > 0:
temporal_overlap = max(1, temporal_overlap)
else:
temporal_size = None
temporal_overlap = None
compression = vae.spacial_compression_decode()
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap)
images = vae.decode_tiled(
samples["samples"],
tile_x=tile_size // compression,
tile_y=tile_size // compression,
overlap=overlap // compression,
tile_t=temporal_size,
overlap_t=temporal_overlap,
)
if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return (images, )
@ -362,7 +380,7 @@ class VAEEncode:
def encode(self, vae, pixels):
t = vae.encode(pixels)
return ({"samples":t}, )
return ({"samples": t}, )
class VAEEncodeTiled:
@classmethod
@ -370,8 +388,8 @@ class VAEEncodeTiled:
return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64, "advanced": True}),
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}),
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time.", "advanced": True}),
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}),
"temporal_size": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time. SeedVR2 allows 0 to disable temporal slicing.", "advanced": True}),
"temporal_overlap": ("INT", {"default": 8, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
@ -379,6 +397,9 @@ class VAEEncodeTiled:
CATEGORY = "experimental"
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
if temporal_size <= 0:
temporal_size = 0
temporal_overlap = 0
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
return ({"samples": t}, )
@ -2417,6 +2438,7 @@ async def init_builtin_extra_nodes():
"nodes_camera_trajectory.py",
"nodes_edit_model.py",
"nodes_tcfg.py",
"nodes_seedvr.py",
"nodes_context_windows.py",
"nodes_qwen.py",
"nodes_chroma_radiance.py",

View File

@ -0,0 +1,213 @@
"""Consolidated SeedVR2 conditioning and refactor regression tests.
Merges the prior test_seedvr2_refactor_nodes.py and
test_seedvr_conditioning_hardening.py modules. Refactor tests use the
top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests
use _import_nodes_seedvr_isolated() for sys.modules isolation when
mocking comfy.model_management.
"""
import importlib
import sys
from unittest.mock import MagicMock
import pytest
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
_SENTINEL = object()
_TARGETS = (
("comfy.model_management", "comfy"),
("comfy_extras.nodes_seedvr", "comfy_extras"),
)
def _import_nodes_seedvr_isolated():
"""Import comfy_extras.nodes_seedvr with comfy.model_management mocked."""
priors = []
for mod_name, parent_name in _TARGETS:
prior_mod = sys.modules.get(mod_name, _SENTINEL)
parent = sys.modules.get(parent_name)
attr = mod_name.split(".")[-1]
prior_attr = (
getattr(parent, attr, _SENTINEL) if parent is not None else _SENTINEL
)
priors.append((mod_name, parent_name, attr, prior_mod, prior_attr))
mock_mm = MagicMock()
for fn in (
"xformers_enabled", "xformers_enabled_vae",
"pytorch_attention_enabled", "pytorch_attention_enabled_vae",
"sage_attention_enabled", "flash_attention_enabled",
"is_intel_xpu",
):
getattr(mock_mm, fn).return_value = False
tv = torch.version.__version__.split(".")
mock_mm.torch_version_numeric = (int(tv[0]), int(tv[1]))
mock_mm.WINDOWS = False
sys.modules["comfy.model_management"] = mock_mm
if sys.modules.get("comfy") is None:
import comfy as _comfy_pkg # noqa: F401
comfy_pkg = sys.modules.get("comfy")
if comfy_pkg is not None:
setattr(comfy_pkg, "model_management", mock_mm)
nodes_seedvr = sys.modules.get("comfy_extras.nodes_seedvr") or (
importlib.import_module("comfy_extras.nodes_seedvr")
)
def _restore():
for mod_name, parent_name, attr, prior_mod, prior_attr in priors:
if prior_mod is _SENTINEL:
sys.modules.pop(mod_name, None)
else:
sys.modules[mod_name] = prior_mod
parent = sys.modules.get(parent_name)
if parent is None:
continue
if prior_attr is _SENTINEL:
if hasattr(parent, attr):
delattr(parent, attr)
else:
setattr(parent, attr, prior_attr)
return nodes_seedvr, _restore
class _Rope(nn.Module):
"""Minimal RoPE stub exposing a `freqs` parameter."""
def __init__(self):
super().__init__()
self.freqs = nn.Parameter(torch.zeros(4))
class _Block(nn.Module):
"""Minimal transformer block stub holding a `_Rope`."""
def __init__(self):
super().__init__()
self.rope = _Rope()
class _DiffusionModel(nn.Module):
"""Stub diffusion model with N blocks and pos/neg conditioning buffers."""
def __init__(self, n_blocks=3, zero_conditioning=False, conditioning_dtype=torch.float32):
super().__init__()
self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)])
pos = torch.zeros if zero_conditioning else torch.ones
self.register_buffer("positive_conditioning", pos((2, 4), dtype=conditioning_dtype))
self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype))
class _ModelInner:
"""Inner model wrapper exposing `.diffusion_model`."""
def __init__(self, diffusion_model):
self.diffusion_model = diffusion_model
class _ModelPatcher:
"""ModelPatcher stub exposing `.model._ModelInner`."""
def __init__(self, diffusion_model):
self.model = _ModelInner(diffusion_model)
def test_seedvr2_conditioning_schema_exposes_model_passthrough_output():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
schema = nodes_seedvr.SeedVR2Conditioning.define_schema()
assert [input_item.id for input_item in schema.inputs] == [
"model",
"vae_conditioning",
]
assert schema.inputs[1].display_name == "LATENT"
assert [output.display_name for output in schema.outputs] == [
"model",
"positive",
"negative",
"latent",
]
finally:
restore()
def test_seedvr2_conditioning_returns_packed_input_latent_deterministically():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
diffusion_model = _DiffusionModel()
patcher = _ModelPatcher(diffusion_model)
samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2)
vae_conditioning = {"samples": samples}
_, first_positive, first_negative, first_latent = (
nodes_seedvr.SeedVR2Conditioning.execute(
patcher,
vae_conditioning,
)
)
_, second_positive, second_negative, second_latent = (
nodes_seedvr.SeedVR2Conditioning.execute(
patcher,
vae_conditioning,
)
)
expected_latent = samples.reshape(1, 6, 2, 2)
channel_last = samples.movedim(1, -1).contiguous()
expected_condition = torch.cat(
[
channel_last,
torch.ones((*channel_last.shape[:-1], 1)),
],
dim=-1,
).movedim(-1, 1).reshape(1, 9, 2, 2)
assert torch.equal(first_latent["samples"], expected_latent)
assert torch.equal(second_latent["samples"], expected_latent)
assert torch.equal(
first_positive[0][1]["condition"],
expected_condition,
)
assert torch.equal(
second_positive[0][1]["condition"],
expected_condition,
)
assert torch.equal(
first_negative[0][1]["condition"],
expected_condition,
)
assert torch.equal(
second_negative[0][1]["condition"],
expected_condition,
)
finally:
restore()
def test_seedvr2_conditioning_fails_loud_on_zero_buffers():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
diffusion_model = _DiffusionModel(zero_conditioning=True)
patcher = _ModelPatcher(diffusion_model)
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
with pytest.raises(RuntimeError) as excinfo:
nodes_seedvr.SeedVR2Conditioning.execute(
patcher, vae_conditioning,
)
message = str(excinfo.value)
assert message.startswith(
nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX
), (
"Fail-loud message must use the standard "
"_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers "
f"can match it. Got: {message!r}"
)
assert "positive_conditioning" in message
assert "negative_conditioning" in message
finally:
restore()

View File

@ -0,0 +1,70 @@
import importlib
import inspect
import sys
from unittest.mock import MagicMock, patch
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy_extras.nodes_seedvr as nodes_seedvr # noqa: E402
def test_resize_simple_multiplier_resolves_upscaled_shorter_edge():
images = torch.zeros(1, 3, 16, 20, 3)
output = nodes_seedvr.SeedVR2Resize.execute(images, 4.0)
input_pixels, original_image, upscaled_shorter_edge = output.result
assert tuple(input_pixels.shape) == (1, 5, 64, 80, 3)
assert input_pixels.min().item() == 0.0
assert input_pixels.max().item() == 0.0
assert original_image is images
assert upscaled_shorter_edge == 64
def test_seedvr_node_signature_matches_schema():
mock_mm = MagicMock()
mock_mm.xformers_enabled.return_value = False
mock_mm.xformers_enabled_vae.return_value = False
mock_mm.sage_attention_enabled.return_value = False
mock_mm.flash_attention_enabled.return_value = False
sentinel = object()
prior_cpu = cli_args.cpu
cli_args.cpu = True
prior_module = sys.modules.get("comfy_extras.nodes_seedvr", sentinel)
comfy_pkg = sys.modules.get("comfy")
prior_mm_attr = getattr(comfy_pkg, "model_management", sentinel) if comfy_pkg else sentinel
with patch.dict(sys.modules, {"comfy.model_management": mock_mm}):
if comfy_pkg is not None:
setattr(comfy_pkg, "model_management", mock_mm)
sys.modules.pop("comfy_extras.nodes_seedvr", None)
try:
nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr")
for node_cls in (nodes_seedvr.SeedVR2Resize, nodes_seedvr.SeedVR2ResizeAdvanced):
schema_ids = [i.id for i in node_cls.define_schema().inputs]
exec_params = [
p for p in inspect.signature(node_cls.execute).parameters.keys()
if p != "cls"
]
assert schema_ids == exec_params, (
f"{node_cls.__name__} schema/execute drift: "
f"schema_ids={schema_ids}, exec_params={exec_params}"
)
finally:
cli_args.cpu = prior_cpu
if prior_module is sentinel:
sys.modules.pop("comfy_extras.nodes_seedvr", None)
else:
sys.modules["comfy_extras.nodes_seedvr"] = prior_module
if comfy_pkg is not None:
if prior_mm_attr is sentinel:
if hasattr(comfy_pkg, "model_management"):
delattr(comfy_pkg, "model_management")
else:
setattr(comfy_pkg, "model_management", prior_mm_attr)

View File

@ -0,0 +1,60 @@
from unittest.mock import patch
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
from comfy_extras import nodes_seedvr # noqa: E402
def _schema_ids(items):
return [item.id for item in items]
def test_seedvr2_post_processing_schema():
schema = nodes_seedvr.SeedVR2PostProcessing.define_schema()
assert _schema_ids(schema.inputs) == ["decoded", "original_image", "upscaled_shorter_edge", "color_correction_method"]
assert schema.inputs[2].default is None
assert schema.inputs[2].min == 2
assert schema.inputs[2].force_input is True
assert schema.inputs[3].options == ["lab", "wavelet", "adain", "none"]
assert schema.inputs[3].default == "lab"
assert schema.outputs[0].get_io_type() == "IMAGE"
def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch):
decoded = torch.full((1, 3, 4, 4), 0.25)
reference = torch.full((1, 3, 4, 4), 0.75)
def _lab(content, style):
raise torch.cuda.OutOfMemoryError("CUDA out of memory")
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None)
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
try:
nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked(
decoded, reference, torch.device("cpu"), "lab",
)
except RuntimeError as exc:
assert "color_correction_method=lab" in str(exc)
assert " method=lab" not in str(exc)
else:
raise AssertionError("expected RuntimeError for one-frame LAB OOM")
def test_seedvr2_post_processing_unknown_color_correction_method_raises():
decoded = torch.zeros(1, 2, 4, 4, 3)
original = torch.zeros(1, 2, 4, 4, 3)
try:
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, 4, "bogus")
except ValueError as exc:
assert "color_correction_method" in str(exc)
else:
raise AssertionError("expected ValueError for unknown color_correction_method")

View File

@ -73,6 +73,24 @@ def _make_flux_schnell_comfyui_sd():
return sd
def _make_seedvr2_7b_separate_mm_sd():
return {
"blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072),
}
def _make_seedvr2_7b_shared_mm_sd():
return {
"blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
}
def _make_seedvr2_3b_shared_mm_sd():
return {
"blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
}
class TestModelDetection:
"""Verify that first-match model detection selects the correct model
based on list ordering and unet_config specificity."""
@ -125,6 +143,48 @@ class TestModelDetection:
assert model_config is not None
assert type(model_config).__name__ == "FluxSchnell"
def test_seedvr2_7b_separate_mm_detection_config(self):
sd = _make_seedvr2_7b_separate_mm_sd()
unet_config = detect_unet_config(sd, "")
assert unet_config is not None
assert unet_config["image_model"] == "seedvr2"
assert unet_config["vid_dim"] == 3072
assert unet_config["heads"] == 24
assert unet_config["num_layers"] == 36
assert unet_config["mm_layers"] == 36
assert unet_config["mlp_type"] == "normal"
assert unet_config["qk_rope"] is True
assert unet_config["rope_type"] == "rope3d"
assert unet_config["rope_dim"] == 64
def test_seedvr2_7b_shared_mm_detection_config(self):
sd = _make_seedvr2_7b_shared_mm_sd()
unet_config = detect_unet_config(sd, "")
assert unet_config is not None
assert unet_config["image_model"] == "seedvr2"
assert unet_config["vid_dim"] == 3072
assert unet_config["heads"] == 24
assert unet_config["num_layers"] == 36
assert unet_config["mm_layers"] == 10
assert unet_config["mlp_type"] == "swiglu"
assert unet_config["qk_rope"] is True
assert unet_config["rope_type"] == "rope3d"
assert unet_config["rope_dim"] == 64
def test_seedvr2_3b_shared_mm_detection_config(self):
sd = _make_seedvr2_3b_shared_mm_sd()
unet_config = detect_unet_config(sd, "")
assert unet_config is not None
assert unet_config["image_model"] == "seedvr2"
assert unet_config["vid_dim"] == 2560
assert unet_config["heads"] == 20
assert unet_config["num_layers"] == 32
assert unet_config["mlp_type"] == "swiglu"
assert unet_config["qk_rope"] is None
def test_unet_config_and_required_keys_combination_is_unique(self):
"""Each model in the registry must have a unique combination of
``unet_config`` and ``required_keys``. If two models share the same

View File

@ -0,0 +1,90 @@
"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must
honor the actual tensor/tuple return contract of ``encode()`` and
``decode_()`` and must NOT dereference diffusers-style ``.latent_dist``
or ``.sample`` attributes on those returns.
The pre-fix body raised ``AttributeError: 'Tensor' object has no
attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and
``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'``
for ``mode == "decode"`` (the class only defines ``decode_`` with a
trailing underscore). The post-fix body unwraps the optional one-element
tuple shape that ``return_dict=False`` produces and returns the tensor
directly.
Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses
the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and
overrides ``encode``/``decode_`` with known tensors so the contract can
be probed without loading any real VAE weights.
"""
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402
_LATENT_SHAPE = (1, 16, 2, 2, 2)
_DECODED_SHAPE = (1, 3, 5, 16, 16)
_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16)
_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2)
class _StubVAE(VideoAutoencoderKL):
def __init__(self):
nn.Module.__init__(self)
self._encode_out = torch.zeros(*_LATENT_SHAPE)
self._decode_out = torch.zeros(*_DECODED_SHAPE)
def encode(self, x, return_dict=True):
return self._encode_out
def decode_(self, z, return_dict=True):
return self._decode_out
def test_forward_encode_returns_tensor():
vae = _StubVAE()
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
result = vae.forward(x, mode="encode")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_LATENT_SHAPE)
def test_forward_decode_returns_tensor():
vae = _StubVAE()
z = torch.zeros(*_INPUT_DECODE_SHAPE)
result = vae.forward(z, mode="decode")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_DECODED_SHAPE)
class _TupleReturningStubVAE(VideoAutoencoderKL):
"""Stub variant whose ``encode``/``decode_`` return the
``(tensor,)`` one-element tuple shape ``return_dict=False`` produces
in the parent class. Exercises the unwrap branch of
``VideoAutoencoderKL.forward``.
"""
def __init__(self):
nn.Module.__init__(self)
self._encode_tensor = torch.zeros(*_LATENT_SHAPE)
self._decode_tensor = torch.zeros(*_DECODED_SHAPE)
def encode(self, x, return_dict=True):
return (self._encode_tensor,)
def decode_(self, z, return_dict=True):
return (self._decode_tensor,)
def test_forward_all_unwraps_one_tuple_at_each_step():
vae = _TupleReturningStubVAE()
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
result = vae.forward(x, mode="all")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_DECODED_SHAPE)

View File

@ -0,0 +1,47 @@
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.sd
import comfy.supported_models
import comfy.ldm.seedvr.model as seedvr_model
def test_seedvr2_fp16_manual_cast_only_for_bf16_device(monkeypatch):
bf16_device = object()
fp16_device = object()
monkeypatch.setattr(
comfy.supported_models.comfy.model_management,
"should_use_bf16",
lambda device=None: device is bf16_device,
)
bf16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"})
bf16_config.set_inference_dtype(torch.float16, None, device=bf16_device)
assert bf16_config.manual_cast_dtype is torch.bfloat16
fp16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"})
fp16_config.set_inference_dtype(torch.float16, None, device=fp16_device)
assert fp16_config.manual_cast_dtype is None
def test_seedvr2_text_conditioning_accepts_cfg1_single_branch():
context = torch.arange(6, dtype=torch.float32).reshape(1, 3, 2)
txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0])
torch.testing.assert_close(txt, context.squeeze(0))
torch.testing.assert_close(txt_shape, torch.tensor([[3]], device=context.device))
def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer():
estimate = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 26, 120, 160))
old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2
assert estimate == 101 * 960 * 1280 * 160
assert estimate > 15 * 1024 ** 3
assert estimate > old_estimate * 100

View File

@ -0,0 +1,389 @@
"""Consolidated SeedVR2 internals regression tests.
Sources (all merged verbatim, helper names disambiguated where colliding):
* RoPE rewrite NaMMRotaryEmbedding3d.forward must match the legacy
apply_rotary_emb wrapper oracle at fp32.
* GroupNorm limit gate causal_norm_wrapper at vae.py:509 must compare
memory_occupy against get_norm_limit(), not float('inf').
* var_attention backend registry.
* var_attention_pytorch SeedVR2-named guard present-API shape contract
with AST-level pinning of the guard ordering.
Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and
comfy.ldm.modules.attention transitively pull in comfy.model_management,
which probes torch.cuda.current_device() at import time unless args.cpu is
set first.
"""
from __future__ import annotations
import ast
import inspect
import logging
import textwrap
import warnings
from unittest.mock import patch
import pytest
import torch
from comfy.cli_args import args
if not torch.cuda.is_available():
args.cpu = True
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
import comfy.ldm.modules.attention as attention # noqa: E402
import comfy.ops as comfy_ops # noqa: E402
from comfy.ldm.seedvr.model import ( # noqa: E402
Cache,
NaMMRotaryEmbedding3d,
)
from comfy.ldm.seedvr.vae import ( # noqa: E402
causal_norm_wrapper,
set_norm_limit,
)
from comfy.ldm.modules.attention import var_attention_pytorch # noqa: E402
# ---------------------------------------------------------------------------
# RoPE rewrite tests (test_seedvr_rope_rewrite.py)
# ---------------------------------------------------------------------------
# Test rig dimensions. dim=192 → per-axis rope dim = 64 (even, lucidrains
# requirement). vid_shape=(2,4,4) → L_vid = 32. txt_shape=(8,) → L_txt = 8.
_DIM = 192
_HEADS = 4
_VID_T, _VID_H, _VID_W = 2, 4, 4
_TXT_L = 8
_L_VID = _VID_T * _VID_H * _VID_W
_SEED = 0
def _make_inputs(dtype=torch.float32, device="cpu"):
"""Construct the 6 forward inputs + cache. Deterministic via local
Generator so global RNG state is not mutated.
"""
g = torch.Generator(device=device).manual_seed(_SEED)
vid_q = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
vid_k = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
txt_q = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
txt_k = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long, device=device)
txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long, device=device)
cache = Cache(disable=True)
return vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache
def _legacy_get_freqs(rope: NaMMRotaryEmbedding3d, vid_shape, txt_shape):
"""Reproduce the pre-rewrite ``get_freqs`` body verbatim against
``self.get_axial_freqs`` (parent ``RotaryEmbeddingBase`` method,
unchanged by the rewrite).
"""
max_temporal = 0
max_height = 0
max_width = 0
max_txt_len = 0
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
max_temporal = max(max_temporal, l + f)
max_height = max(max_height, h)
max_width = max(max_width, w)
max_txt_len = max(max_txt_len, l)
with torch.amp.autocast(device_type="cuda", enabled=False):
vid_freqs_full = rope.get_axial_freqs(
min(max_temporal + 16, 1024),
min(max_height + 4, 128),
min(max_width + 4, 128),
).float()
txt_freqs_full = rope.get_axial_freqs(min(max_txt_len + 16, 1024))
vid_freq_list, txt_freq_list = [], []
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
vid_freq = vid_freqs_full[l : l + f, :h, :w].reshape(-1, vid_freqs_full.size(-1))
txt_freq = txt_freqs_full[:l].repeat(1, 3).reshape(-1, vid_freqs_full.size(-1))
vid_freq_list.append(vid_freq)
txt_freq_list.append(txt_freq)
return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0)
def _legacy_forward(rope: NaMMRotaryEmbedding3d, vid_q, vid_k, vid_shape,
txt_q, txt_k, txt_shape):
"""Compute expected forward output via the unchanged
``apply_rotary_emb`` wrapper fed with legacy-shape freqs. This is the
oracle. The wrapper itself is out of scope for the rewrite (Shape B).
"""
vid_freqs, txt_freqs = _legacy_get_freqs(rope, vid_shape, txt_shape)
vid_freqs = vid_freqs.to(vid_q.device)
txt_freqs = txt_freqs.to(txt_q.device)
from einops import rearrange
vid_q = rearrange(vid_q, "L h d -> h L d")
vid_k = rearrange(vid_k, "L h d -> h L d")
vid_q_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype)
vid_k_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype)
vid_q_out = rearrange(vid_q_out, "h L d -> L h d")
vid_k_out = rearrange(vid_k_out, "h L d -> L h d")
txt_q = rearrange(txt_q, "L h d -> h L d")
txt_k = rearrange(txt_k, "L h d -> h L d")
txt_q_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype)
txt_k_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype)
txt_q_out = rearrange(txt_q_out, "h L d -> L h d")
txt_k_out = rearrange(txt_k_out, "h L d -> L h d")
return vid_q_out, vid_k_out, txt_q_out, txt_k_out
def test_namm_forward_output_tensor_equal_against_legacy_oracle():
rope = NaMMRotaryEmbedding3d(dim=_DIM)
vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache = _make_inputs()
expected_vid_q, expected_vid_k, expected_txt_q, expected_txt_k = _legacy_forward(
rope,
vid_q.clone(), vid_k.clone(), vid_shape,
txt_q.clone(), txt_k.clone(), txt_shape,
)
actual_vid_q, actual_vid_k, actual_txt_q, actual_txt_k = rope.forward(
vid_q.clone(), vid_k.clone(), vid_shape,
txt_q.clone(), txt_k.clone(), txt_shape, cache,
)
torch.testing.assert_close(actual_vid_q, expected_vid_q, rtol=0, atol=0,
msg="vid_q output diverges from wrapper oracle")
torch.testing.assert_close(actual_vid_k, expected_vid_k, rtol=0, atol=0,
msg="vid_k output diverges from wrapper oracle")
torch.testing.assert_close(actual_txt_q, expected_txt_q, rtol=0, atol=0,
msg="txt_q output diverges from wrapper oracle")
torch.testing.assert_close(actual_txt_k, expected_txt_k, rtol=0, atol=0,
msg="txt_k output diverges from wrapper oracle")
# ---------------------------------------------------------------------------
# GroupNorm limit tests (test_seedvr_groupnorm_limit.py)
# ---------------------------------------------------------------------------
_NUM_CHANNELS = 8
_NUM_GROUPS = 4
_TENSOR_SHAPE = (1, 8, 2, 4, 4)
_GROUPNORM_SUBCLASSES = [
pytest.param(comfy_ops.disable_weight_init.GroupNorm, id="disable_weight_init"),
pytest.param(comfy_ops.manual_cast.GroupNorm, id="manual_cast"),
]
@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES)
def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls):
real_group_norm = vae_mod.F.group_norm
set_norm_limit(1e-9)
try:
gn = groupnorm_cls(num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS)
gn.eval()
forward_hook_calls = []
def _hook(module, inputs, output):
forward_hook_calls.append(tuple(inputs[0].shape))
spy_calls = []
def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs):
spy_calls.append({"num_groups": int(num_groups_arg)})
return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs)
handle = gn.register_forward_hook(_hook)
try:
with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy):
out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE))
finally:
handle.remove()
full_calls = len(forward_hook_calls)
chunked_calls = sum(1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS)
assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE
assert full_calls == 0, (
f"low-limit GroupNorm gate must NOT take the full-forward path; got full_calls={full_calls}"
)
assert chunked_calls > 0, (
f"low-limit GroupNorm gate must take the chunked path; got chunked_calls={chunked_calls}"
)
finally:
set_norm_limit(None)
# ---------------------------------------------------------------------------
# var_attention backend tests (test_seedvr_var_attention_backends.py)
# ---------------------------------------------------------------------------
def test_var_attention_registry_contains_always_available_entries():
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_pytorch"] is attention.var_attention_pytorch
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_sub_quad"] is attention.var_attention_sub_quad
assert attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_split"] is attention.var_attention_split
def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch):
dim = 8
heads = 2
head_dim = 4
attn = seedvr_model.NaSwinAttention(
vid_dim=dim,
txt_dim=dim,
heads=heads,
head_dim=head_dim,
qk_bias=False,
qk_norm=seedvr_model.CustomRMSNorm,
qk_norm_eps=1e-6,
rope_type=None,
rope_dim=head_dim,
shared_weights=False,
window=(2, 1, 1),
window_method="720pwin_by_size_bysize",
version=True,
device="cpu",
dtype=torch.float32,
operations=comfy_ops.disable_weight_init,
)
generator = torch.Generator(device="cpu").manual_seed(11)
vid = torch.randn(8, dim, generator=generator)
txt = torch.randn(3, dim, generator=generator)
vid_shape = torch.tensor([[2, 2, 2]], dtype=torch.long)
txt_shape = torch.tensor([[3]], dtype=torch.long)
calls = []
def fake_optimized_var_attention(**kwargs):
calls.append(kwargs)
return kwargs["q"]
monkeypatch.setattr(seedvr_model, "optimized_var_attention", fake_optimized_var_attention)
vid_out, txt_out = attn(vid, txt, vid_shape, txt_shape, seedvr_model.Cache(disable=True))
assert tuple(vid_out.shape) == (8, dim)
assert tuple(txt_out.shape) == (3, dim)
assert len(calls) == 1
call = calls[0]
assert tuple(call["q"].shape) == (14, heads, head_dim)
assert tuple(call["k"].shape) == (14, heads, head_dim)
assert tuple(call["v"].shape) == (14, heads, head_dim)
assert call["heads"] == heads
assert call["skip_reshape"] is True
assert call["skip_output_reshape"] is True
torch.testing.assert_close(
call["cu_seqlens_q"],
torch.tensor([0, 7, 14], dtype=torch.int32),
rtol=0,
atol=0,
)
torch.testing.assert_close(
call["cu_seqlens_k"],
torch.tensor([0, 7, 14], dtype=torch.int32),
rtol=0,
atol=0,
)
# ---------------------------------------------------------------------------
# var_attention_pytorch SeedVR2 guard tests
# (test_var_attention_pytorch_seedvr2_guard.py)
# ---------------------------------------------------------------------------
def _pytorch_guard_inputs():
heads, head_dim, total_tokens = 2, 8, 6
embed_dim = heads * head_dim
q = torch.randn(total_tokens, embed_dim)
k = torch.randn(total_tokens, embed_dim)
v = torch.randn(total_tokens, embed_dim)
cu = torch.tensor([0, 3, 6], dtype=torch.int32)
return q, k, v, heads, cu, cu, total_tokens, embed_dim
def _assert_guard_source_pin():
src = textwrap.dedent(inspect.getsource(var_attention_pytorch))
tree = ast.parse(src)
raise_lines = []
nested_lines = []
for node in ast.walk(tree):
if isinstance(node, ast.Raise) and isinstance(node.exc, ast.Call):
func = node.exc.func
if isinstance(func, ast.Name) and func.id == "RuntimeError":
raise_lines.append(node.lineno)
if isinstance(node, ast.Attribute) and node.attr == "nested_tensor_from_jagged":
nested_lines.append(node.lineno)
assert raise_lines, (
"var_attention_pytorch has no `raise RuntimeError(...)` AST node; "
f"the SeedVR2-named guard is missing.\n--- source ---\n{src}"
)
assert nested_lines, (
"var_attention_pytorch source has no `nested_tensor_from_jagged` "
f"attribute access; cannot pin guard ordering.\n"
f"--- source ---\n{src}"
)
first_raise = min(raise_lines)
first_nested = min(nested_lines)
assert first_raise < first_nested, (
f"`raise RuntimeError(...)` first appears at line {first_raise}, "
f"but `torch.nested.nested_tensor_from_jagged` is referenced first "
f"at line {first_nested}; the guard must precede the lookup.\n"
f"--- source ---\n{src}"
)
def test_missing_api_raises_seedvr2_runtime_error(monkeypatch):
monkeypatch.delattr(torch.nested, "nested_tensor_from_jagged", raising=False)
q, k, v, heads, cu_q, cu_k, _, _ = _pytorch_guard_inputs()
with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"):
var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
_assert_guard_source_pin()
def test_missing_namespace_raises_seedvr2_runtime_error(monkeypatch):
monkeypatch.delattr(torch, "nested", raising=False)
q, k, v, heads, cu_q, cu_k, _, _ = _pytorch_guard_inputs()
with pytest.raises(RuntimeError, match=r"SeedVR2.*nested_tensor_from_jagged"):
var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
_assert_guard_source_pin()
def test_present_api_returns_expected_shape():
q, k, v, heads, cu_q, cu_k, total_tokens, embed_dim = _pytorch_guard_inputs()
torch_fx_logger = logging.getLogger("torch.fx._symbolic_trace")
old_torch_fx_level = torch_fx_logger.level
torch_fx_logger.setLevel(logging.ERROR)
try:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="The PyTorch API of nested tensors is in prototype stage.*",
category=UserWarning,
)
out = var_attention_pytorch(q, k, v, heads, cu_q, cu_k)
finally:
torch_fx_logger.setLevel(old_torch_fx_level)
assert tuple(out.shape) == (total_tokens, embed_dim), (
f"expected ({total_tokens}, {embed_dim}); got {tuple(out.shape)}"
)
_assert_guard_source_pin()
def test_malformed_offsets_propagates_torch_runtime_error():
q, k, v, heads, _, _, _, _ = _pytorch_guard_inputs()
cu_q_bad = torch.tensor([0, 3, 7], dtype=torch.int32)
cu_k_ok = torch.tensor([0, 3, 6], dtype=torch.int32)
with pytest.raises(RuntimeError) as exc_info:
var_attention_pytorch(q, k, v, heads, cu_q_bad, cu_k_ok)
msg = str(exc_info.value)
assert "SeedVR2" not in msg
_assert_guard_source_pin()

View File

@ -0,0 +1,308 @@
"""Consolidated SeedVR2 model/graph/forward regression tests.
Merged from:
- seedvr_model_test.py
- test_seedvr_7b_final_block_text_path.py
- test_seedvr_forward_no_device_cast.py
- test_seedvr_latent_format.py
- test_seedvr2_vae_graph_boundaries.py
"""
from __future__ import annotations
from unittest.mock import MagicMock
import torch
from torch import nn
from comfy.cli_args import args
if not torch.cuda.is_available():
args.cpu = True
import comfy # noqa: E402
import comfy.latent_formats # noqa: E402
import comfy.ldm.seedvr.model # noqa: E402
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
import comfy.model_management # noqa: E402
import comfy.sample # noqa: E402
import comfy.sd as sd_mod # noqa: E402
import nodes as nodes_mod # noqa: E402
from comfy.ldm.seedvr.model import NaDiT # noqa: E402
# ---------------------------------------------------------------------------
# Helpers from seedvr_model_test.py
# ---------------------------------------------------------------------------
def _make_standin(positive_conditioning):
class _StandIn(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"positive_conditioning", positive_conditioning
)
_resolve_text_conditioning = NaDiT._resolve_text_conditioning
return _StandIn()
# ---------------------------------------------------------------------------
# Helpers from test_seedvr_7b_final_block_text_path.py
# ---------------------------------------------------------------------------
class _StubModule(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]:
flags = []
class _Block(_StubModule):
def __init__(self, *args, **kwargs):
flags.append(kwargs["is_last_layer"])
super().__init__()
monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule)
monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule)
monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule)
monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block)
seedvr_model.NaDiT(
norm_eps=1e-5,
qk_rope=None,
num_layers=4,
mlp_type="normal",
vid_dim=vid_dim,
txt_in_dim=txt_in_dim,
heads=24,
mm_layers=3,
)
return flags
# ---------------------------------------------------------------------------
# Helpers from test_seedvr_latent_format.py
# ---------------------------------------------------------------------------
class _Model:
def __init__(self, latent_format):
self._latent_format = latent_format
def get_model_object(self, name):
assert name == "latent_format"
return self._latent_format
# ---------------------------------------------------------------------------
# Helpers from test_seedvr2_vae_graph_boundaries.py
# ---------------------------------------------------------------------------
class _Patcher:
def get_free_memory(self, device):
return 1024 * 1024 * 1024
class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
def __init__(self, encoded):
nn.Module.__init__(self)
self.encoded = encoded
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.seen = []
def encode(self, x):
self.seen.append(tuple(x.shape))
return self.encoded.to(device=x.device, dtype=x.dtype)
class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
def __init__(self):
nn.Module.__init__(self)
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.calls = []
def decode(self, z, seedvr2_tiling=None):
self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling})
if z.ndim == 4:
b, tc, h, w = z.shape
t = tc // 16
else:
b, _, t, h, w = z.shape
return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
def _make_vae(wrapper):
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = wrapper
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.latent_channels = 16
vae.latent_dim = 3
vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8)
vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
vae.output_channels = 3
vae.disable_offload = True
vae.extra_1d_channel = None
vae.crop_input = False
vae.not_video = False
vae.patcher = _Patcher()
vae.process_input = lambda image: image
vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0)
vae.vae_output_dtype = lambda: torch.float32
vae.memory_used_encode = lambda shape, dtype: 1
vae.memory_used_decode = lambda shape, dtype: 1
vae.throw_exception_if_invalid = lambda: None
vae.vae_encode_crop_pixels = lambda pixels: pixels
vae.spacial_compression_decode = lambda: 8
vae.temporal_compression_decode = lambda: 4
return vae
# ---------------------------------------------------------------------------
# Tests from seedvr_model_test.py
# ---------------------------------------------------------------------------
def test_missing_context_falls_back_to_positive_buffer():
"""AC: ``context is None`` falls back to the registered
``positive_conditioning`` buffer and runs to completion no
silent zero substitution, no raised exception.
"""
pos_buffer = torch.full((58, 5120), 7.0)
standin = _make_standin(pos_buffer)
txt, txt_shape = standin._resolve_text_conditioning(None)
assert txt.shape == (58, 5120)
assert (txt == 7.0).all(), (
"fallback path must use the positive_conditioning buffer "
"verbatim, not a zero tensor"
)
assert txt_shape.shape == (1, 1)
assert txt_shape[0, 0].item() == 58
# ---------------------------------------------------------------------------
# Tests from test_seedvr_7b_final_block_text_path.py
# ---------------------------------------------------------------------------
def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch):
assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [
False,
False,
False,
False,
]
def test_seedvr2_7b_rope3d_matches_wrapper_oracle():
rope = seedvr_model.get_na_rope("rope3d", dim=64)
generator = torch.Generator(device="cpu").manual_seed(0)
q = torch.randn(4, 2, 128, generator=generator)
k = torch.randn(4, 2, 128, generator=generator)
shape = torch.tensor([[1, 2, 2]], dtype=torch.long)
freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1)
expected_q = seedvr_model.apply_rotary_emb(
freqs,
q.permute(1, 0, 2).float(),
).to(q.dtype).permute(1, 0, 2)
expected_k = seedvr_model.apply_rotary_emb(
freqs,
k.permute(1, 0, 2).float(),
).to(k.dtype).permute(1, 0, 2)
actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True))
torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0)
torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0)
# ---------------------------------------------------------------------------
# Tests from test_seedvr_latent_format.py
# ---------------------------------------------------------------------------
def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion():
latent_format = comfy.latent_formats.SeedVR2()
latent_image = torch.zeros(1, 1, 4, 5)
fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image)
assert latent_format.latent_channels == 16
assert latent_format.latent_dimensions == 2
assert fixed.shape == (1, 16, 4, 5)
# ---------------------------------------------------------------------------
# Tests from test_seedvr2_vae_graph_boundaries.py
# ---------------------------------------------------------------------------
def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch):
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
encoded = torch.full((1, 16, 2, 4, 5), 2.0)
vae = _make_vae(_EncodeWrapper(encoded))
pixels = torch.zeros(1, 5, 32, 40, 3)
node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0]
node_latent = node_output["samples"]
assert set(node_output) == {"samples"}
assert tuple(node_latent.shape) == (1, 16, 2, 4, 5)
assert node_latent.dtype == torch.float32
assert node_latent.stride()[-1] == 1
assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152))
tiled = torch.full((1, 16, 2, 4, 5), 3.0)
monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled))
tiled_output = nodes_mod.VAEEncodeTiled().encode(
vae,
pixels,
tile_size=512,
overlap=64,
temporal_size=16,
temporal_overlap=4,
)[0]
tiled_latent = tiled_output["samples"]
assert set(tiled_output) == {"samples"}
assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5)
assert tiled_latent.dtype == torch.float32
assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152))
def test_vaedecode_tiled_visible_inputs_are_seedvr2_decode_tiling_authority(monkeypatch):
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
vae = _make_vae(_DecodeWrapper())
nodes_mod.VAEDecodeTiled().decode(
vae,
{"samples": torch.zeros(1, 16, 2, 4, 5)},
tile_size=512,
overlap=64,
temporal_size=16,
temporal_overlap=4,
)
assert vae.first_stage_model.calls == [
{
"shape": (1, 16, 2, 4, 5),
"seedvr2_tiling": {
"enable_tiling": True,
"tile_size": (512, 512),
"tile_overlap": (64, 64),
"temporal_size": 16,
"temporal_overlap": 4,
},
}
]

View File

@ -0,0 +1,91 @@
from unittest.mock import patch
import pytest
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
from comfy_extras import nodes_seedvr # noqa: E402
def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper:
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
vae_mod.VideoAutoencoderKLWrapper
)
nn.Module.__init__(wrapper)
return wrapper
def _fingerprint_decode_(self, z, return_dict=True):
b = int(z.shape[0])
t = int(z.shape[2])
h = int(z.shape[3])
w = int(z.shape[4])
out = torch.empty(b, 3, t, h * 8, w * 8)
for batch_idx in range(b):
out[batch_idx].fill_(float(batch_idx + 1))
return out
def _decode_with_patches(wrapper, z):
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_):
return wrapper.decode(z)
def test_decode_b2_t3_multi_frame_batch_unchanged():
wrapper = _make_wrapper()
out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2))
assert tuple(out.shape) == (2, 3, 3, 16, 16)
class _Wrapper(vae_mod.VideoAutoencoderKLWrapper):
def __init__(self):
nn.Module.__init__(self)
self.calls = []
def parameters(self):
return iter([torch.nn.Parameter(torch.zeros(()))])
def _decode_stub(self, latent):
self.calls.append(tuple(latent.shape))
return torch.zeros(latent.shape[0], 3, latent.shape[2], latent.shape[3] * 8, latent.shape[4] * 8)
def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preprocessor_state():
wrapper = _Wrapper()
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5))
assert tuple(out.shape) == (1, 3, 2, 32, 40)
assert wrapper.calls == [(1, 16, 2, 4, 5)]
def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents():
wrapper = _Wrapper()
with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"):
wrapper.decode(torch.zeros(1, 16, 4))
def _t_padded(t_in: int) -> int:
if t_in == 1:
return 1
if t_in <= 4:
return 5
if (t_in - 1) % 4 == 0:
return t_in
return t_in + (4 - ((t_in - 1) % 4))
@pytest.mark.parametrize("t_in", [1, 5, 9])
def test_t_padded_matches_cut_videos(t_in):
dummy = torch.zeros(1, t_in, 1, 1, 1)
assert nodes_seedvr.cut_videos(dummy).shape[1] == _t_padded(t_in)

View File

@ -0,0 +1,350 @@
from contextlib import ExitStack
from unittest.mock import MagicMock, patch
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
import comfy.sd as sd_mod # noqa: E402
from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402
# ---------------------------------------------------------------------------
# From test_seedvr_vae_tiled_decode_latent_min_size_override.py
# ---------------------------------------------------------------------------
def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae
class StubVAEModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.slicing_latent_min_size = 2
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self.use_slicing = True
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
self.decode_min_sizes = []
self.memory_states = []
def decode_(self, t_chunk):
self.decode_min_sizes.append(self.slicing_latent_min_size)
return VideoAutoencoderKL.slicing_decode(self, t_chunk)
def _decode(self, z, memory_state=MemoryState.DISABLED):
self.memory_states.append(memory_state)
b, c, d, h, w = z.shape
return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype)
vae = StubVAEModel()
z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32)
tiled_vae(
z,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=0,
temporal_overlap=0,
encode=False,
)
assert vae.decode_min_sizes == [5]
assert vae.memory_states == [MemoryState.DISABLED]
assert vae.slicing_latent_min_size == 2
# ---------------------------------------------------------------------------
# From test_seedvr_vae_tiled_encode_runt_slice_override.py
# ---------------------------------------------------------------------------
def test_zero_temporal_size_preserves_min_size_when_encode_raises():
from comfy.ldm.seedvr.vae import tiled_vae
class RaisingVAEModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.slicing_sample_min_size = 4
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
def encode(self, t_chunk):
raise RuntimeError("simulated encode failure")
vae = RaisingVAEModel()
x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32)
raised = False
try:
tiled_vae(
x,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=0,
temporal_overlap=0,
encode=True,
)
except RuntimeError as exc:
if "simulated encode failure" not in str(exc):
raise
raised = True
assert raised
assert vae.slicing_sample_min_size == 4
# ---------------------------------------------------------------------------
# From test_seedvr_vae_tiled_temporal_slicing.py
# ---------------------------------------------------------------------------
class _SlicingDecodeVAE(nn.Module):
def __init__(self, slicing_latent_min_size):
super().__init__()
self.slicing_latent_min_size = slicing_latent_min_size
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self.use_slicing = True
self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32))
self.decode_min_sizes = []
self.memory_states = []
def decode_(self, z):
self.decode_min_sizes.append(self.slicing_latent_min_size)
return vae_mod.VideoAutoencoderKL.slicing_decode(self, z)
def _decode(self, z, memory_state=MemoryState.DISABLED):
self.memory_states.append(memory_state)
x = z[:, :1].repeat(
1,
3,
1,
self.spatial_downsample_factor,
self.spatial_downsample_factor,
)
return x
def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size():
vae = _SlicingDecodeVAE(slicing_latent_min_size=2)
z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8)
tiled_vae(
z,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=12,
temporal_overlap=4,
encode=False,
)
assert vae.decode_min_sizes == [2]
assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE]
assert vae.slicing_latent_min_size == 2
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
vae_mod.VideoAutoencoderKLWrapper
)
nn.Module.__init__(wrapper)
seedvr2_tiling = {
"enable_tiling": True,
"tile_size": (64, 64),
"tile_overlap": (0, 0),
"temporal_size": 8,
"temporal_overlap": 7,
}
captured = {}
def _fake_tiled_vae(latent, model, **kwargs):
captured.update(kwargs)
return torch.zeros(1, 3, 1, 16, 16)
with (
patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae),
patch.object(vae_mod, "lab_color_transfer", side_effect=lambda content, style: content),
):
wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling)
assert captured["temporal_overlap"] == 7
# ---------------------------------------------------------------------------
# From test_vae_decode_tiled_dispatcher_seedvr2_4d.py
# ---------------------------------------------------------------------------
def _force_oom(*a, **k):
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
def _make_vae(first_stage_model, latent_channels, latent_dim):
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = first_stage_model
vae.patcher = MagicMock()
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
vae.device = vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.disable_offload = True
vae.extra_1d_channel = None
vae.upscale_ratio = vae.downscale_ratio = 8
vae.upscale_index_formula = vae.downscale_index_formula = None
vae.output_channels = 3
vae.latent_channels = latent_channels
vae.latent_dim = latent_dim
vae.vae_output_dtype = lambda: torch.float32
vae.spacial_compression_decode = lambda: 8
vae.process_input = lambda x: x
vae.process_output = lambda x: x
vae.throw_exception_if_invalid = lambda: None
vae.memory_used_decode = lambda *a, **k: 1
return vae
def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode):
mm = sd_mod.model_management
with ExitStack() as stack:
stack.enter_context(patch.object(mm, "raise_non_oom", lambda e: None))
stack.enter_context(patch.object(mm, "load_models_gpu", lambda *a, **k: None))
stack.enter_context(patch.object(mm, "soft_empty_cache", lambda: None))
stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_seedvr2", seedvr2_call))
stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_", generic_call))
if patch_wrapper_decode:
stack.enter_context(patch.object(
seedvr_vae_mod.VideoAutoencoderKLWrapper, "decode",
side_effect=_force_oom))
vae.decode(samples)
def test_4d_seedvr2_latent_routes_to_decode_tiled_seedvr2():
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
seedvr_vae_mod.VideoAutoencoderKLWrapper)
vae = _make_vae(wrapper, latent_channels=16, latent_dim=3)
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
_dispatch(vae, torch.zeros(1, 16 * 3, 8, 8), seedvr2_call, generic_call, True)
assert seedvr2_call.call_count == 1
assert generic_call.call_count == 0
def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled():
first_stage = MagicMock()
first_stage.decode = MagicMock(side_effect=_force_oom)
vae = _make_vae(first_stage, latent_channels=4, latent_dim=2)
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
_dispatch(vae, torch.zeros(1, 4, 8, 8), seedvr2_call, generic_call, False)
assert generic_call.call_count == 1
assert seedvr2_call.call_count == 0
# ---------------------------------------------------------------------------
# From test_vae_encode_tiled_fallback_dispatcher_seedvr2.py
# ---------------------------------------------------------------------------
def _populate_common_vae_attrs_fallback(vae):
vae.patcher = MagicMock()
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.disable_offload = True
vae.extra_1d_channel = None
vae.upscale_ratio = 8
vae.upscale_index_formula = None
vae.output_channels = 3
vae.latent_channels = 16
vae.latent_dim = 3
vae.downscale_ratio = 8
vae.downscale_index_formula = None
vae.not_video = False
vae.crop_input = False
vae.pad_channel_value = None
vae.vae_output_dtype = lambda: torch.float32
vae.spacial_compression_encode = lambda: 8
vae.process_input = lambda x: x
vae.process_output = lambda x: x
vae.throw_exception_if_invalid = lambda: None
vae.memory_used_encode = lambda *a, **k: 1
def _make_seedvr2_vae_fallback():
vae = sd_mod.VAE.__new__(sd_mod.VAE)
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
seedvr_vae_mod.VideoAutoencoderKLWrapper
)
vae.first_stage_model = wrapper
_populate_common_vae_attrs_fallback(vae)
return vae
def _make_non_seedvr2_vae_fallback():
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = MagicMock()
_populate_common_vae_attrs_fallback(vae)
return vae
def _force_regular_encode_oom(*args, **kwargs):
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
def test_seedvr2_3d_routes_to_encode_tiled_seedvr2_on_oom():
vae = _make_seedvr2_vae_fallback()
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
with patch.object(sd_mod.model_management, "raise_non_oom",
lambda e: None), \
patch.object(sd_mod.model_management, "load_models_gpu",
lambda *a, **k: None), \
patch.object(sd_mod.model_management, "soft_empty_cache",
lambda: None), \
patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode",
side_effect=_force_regular_encode_oom), \
patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call,
create=True), \
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
vae.encode(pixel_samples)
assert seedvr2_call.call_count == 1, (
f"Expected encode_tiled_seedvr2 to be called once for a SeedVR2 3D "
f"input under OOM fallback; got {seedvr2_call.call_count} calls."
)
assert generic_call.call_count == 0, (
f"encode_tiled_3d must NOT be called for a SeedVR2 input; got "
f"{generic_call.call_count} calls."
)
def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete():
vae = _make_non_seedvr2_vae_fallback()
vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8)
vae.upscale_ratio = (lambda a: a * 4, 8, 8)
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
with patch.object(sd_mod.model_management, "load_models_gpu",
lambda *a, **k: None), \
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
vae.encode_tiled(pixel_samples)
assert generic_call.call_args.kwargs["overlap"] == (1, 64, 64)

View File

@ -0,0 +1,126 @@
"""Unit tests for ``comfy_extras.nodes_seedvr.SeedVR2ProgressiveSampler``."""
from unittest.mock import patch
import pytest
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.sample # noqa: E402
import comfy_extras.nodes_seedvr as nodes_seedvr_mod # noqa: E402
from comfy_extras.nodes_seedvr import SeedVR2ProgressiveSampler # noqa: E402
_LAT_C = 16
_COND_C = 17
def _make_inputs(B: int = 1, T: int = 5, H: int = 8, W: int = 8):
"""Build minimal SeedVR2-shaped sampling inputs."""
samples_5d = torch.arange(
B * _LAT_C * T * H * W, dtype=torch.float32
).reshape(B, _LAT_C, T, H, W)
samples = samples_5d.reshape(B, _LAT_C * T, H, W).contiguous()
cond_5d = torch.arange(
B * _COND_C * T * H * W, dtype=torch.float32
).reshape(B, _COND_C, T, H, W) + 10000.0
cond = cond_5d.reshape(B, _COND_C * T, H, W).contiguous()
text_pos = torch.zeros(1, 4, 32)
text_neg = torch.zeros(1, 4, 32)
positive = [[text_pos, {"condition": cond.clone()}]]
negative = [[text_neg, {"condition": cond.clone()}]]
latent_image = {"samples": samples}
return latent_image, positive, negative, samples_5d, cond_5d
def _identity_fix_empty(model, latent_image, downscale_ratio_spacial=None):
return latent_image
def _fingerprinted_prepare_noise(latent_image, seed, batch_inds=None):
"""Return a tensor whose values encode ``(seed, position)``."""
base = torch.arange(
latent_image.numel(), dtype=torch.float32
).reshape(latent_image.shape)
return base + float(seed) * 1e6
def test_progressive_sampler_schema_exposes_manual_default_auto_chunking():
schema = SeedVR2ProgressiveSampler.define_schema()
inputs = {item.id: item for item in schema.inputs}
assert inputs["chunking_mode"].options == ["manual", "auto"]
assert inputs["chunking_mode"].default == "manual"
def test_auto_chunking_walks_two_three_four_chunk_ladder():
"""Auto mode must walk 2-, 3-, then 4-chunk geometries on OOM."""
latent, pos, neg, _, _ = _make_inputs(T=17)
calls = []
def _oom_until_four_chunks(model, noise, steps, cfg, sampler_name,
scheduler, positive, negative,
latent_image, denoise=1.0,
noise_mask=None, seed=None):
calls.append(tuple(latent_image.shape))
if latent_image.shape[1] > _LAT_C * 5:
raise torch.cuda.OutOfMemoryError("chunk too large")
return latent_image.clone()
with patch.object(comfy.sample, "sample",
side_effect=_oom_until_four_chunks), \
patch.object(comfy.sample, "fix_empty_latent_channels",
side_effect=_identity_fix_empty), \
patch.object(comfy.sample, "prepare_noise",
side_effect=_fingerprinted_prepare_noise), \
patch.object(nodes_seedvr_mod.comfy.model_management,
"soft_empty_cache") as soft_empty:
out = SeedVR2ProgressiveSampler.execute(
model=None, seed=0, steps=2, cfg=1.0,
sampler_name="euler", scheduler="simple",
positive=pos, negative=neg, latent_image=latent,
denoise=1.0, frames_per_chunk=65, temporal_overlap=0,
chunking_mode="auto",
)
assert calls[:4] == [
(1, _LAT_C * 17, 8, 8),
(1, _LAT_C * 9, 8, 8),
(1, _LAT_C * 6, 8, 8),
(1, _LAT_C * 5, 8, 8),
]
assert torch.equal(out.result[0]["samples"], latent["samples"])
assert soft_empty.call_count == 3
@pytest.mark.parametrize("bad_chunk", [0, -1, 2])
def test_t3_invalid_frames_per_chunk_raises_value_error(bad_chunk):
"""``frames_per_chunk`` violating 4n+1 (or <1) must raise ``ValueError`` before any model invocation."""
latent, pos, neg, _, _ = _make_inputs(T=5)
sampler_called = {"n": 0}
def _should_not_be_called(*args, **kwargs):
sampler_called["n"] += 1
return torch.zeros(1)
with patch.object(comfy.sample, "sample",
side_effect=_should_not_be_called), \
patch.object(comfy.sample, "fix_empty_latent_channels",
side_effect=_identity_fix_empty), \
patch.object(comfy.sample, "prepare_noise",
side_effect=_fingerprinted_prepare_noise):
with pytest.raises(ValueError) as excinfo:
SeedVR2ProgressiveSampler.execute(
model=None, seed=0, steps=2, cfg=1.0,
sampler_name="euler", scheduler="simple",
positive=pos, negative=neg, latent_image=latent,
denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0,
)
assert str(bad_chunk) in str(excinfo.value)
assert sampler_called["n"] == 0