Merge branch 'master' into alexis/reorder_inputs
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled

This commit is contained in:
Alexis Rolland 2026-06-09 23:33:33 +08:00 committed by GitHub
commit d977079a7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
41 changed files with 1084 additions and 7602 deletions

View File

@ -33,6 +33,7 @@ from app.assets.services.file_utils import (
verify_file_unchanged,
)
from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash
from app.assets.services.image_dimensions import extract_image_dimensions
from app.assets.services.metadata_extract import extract_file_metadata
from app.assets.services.path_utils import (
compute_relative_filename,
@ -506,6 +507,10 @@ def enrich_asset(
if extract_metadata and metadata:
system_metadata = metadata.to_user_metadata()
if mime_type and mime_type.startswith("image/"):
dims = extract_image_dimensions(file_path, mime_type=mime_type)
if dims:
system_metadata.update(dims)
set_reference_system_metadata(session, reference_id, system_metadata)
if full_hash:

View File

@ -0,0 +1,63 @@
"""Image dimension extraction for asset ingest.
Reads only the image header via Pillow to capture width/height cheaply,
without a full pixel decode. Returns a metadata dict suitable for merging
into ``AssetReference.system_metadata``.
"""
from __future__ import annotations
import logging
from typing import Any
logger = logging.getLogger(__name__)
def extract_image_dimensions(
file_path: str, mime_type: str | None = None
) -> dict[str, Any] | None:
"""Extract image dimensions for the file at ``file_path``.
Args:
file_path: Absolute path to a file on disk.
mime_type: Optional MIME type hint. When provided and not prefixed
with ``image/``, extraction is skipped without touching the file.
Returns:
``{"kind": "image", "width": W, "height": H}`` when the file is a
recognizable image with positive dimensions, otherwise ``None``.
The dict shape is intended to be merged into ``system_metadata`` so the
asset response surfaces ``metadata.kind`` plus dimension fields for image
assets. Forward-compatible: future media kinds (e.g. ``"video"`` with
duration/fps) can extend this shape without schema changes.
"""
if mime_type is not None and not mime_type.startswith("image/"):
return None
try:
from PIL import Image, UnidentifiedImageError
except ImportError:
logger.debug(
"Pillow not available; skipping image dimension extraction for %s",
file_path,
)
return None
try:
with Image.open(file_path) as img:
width, height = img.size
except (OSError, UnidentifiedImageError, ValueError) as exc:
logger.debug(
"Failed to read image dimensions from %s: %s", file_path, exc
)
return None
if (
not isinstance(width, int)
or not isinstance(height, int)
or width <= 0
or height <= 0
):
return None
return {"kind": "image", "width": width, "height": height}

View File

@ -17,9 +17,11 @@ from app.assets.database.queries import (
get_reference_by_file_path,
get_reference_tags,
get_or_create_reference,
list_references_by_asset_id,
reference_exists,
remove_missing_tag_for_asset_id,
set_reference_metadata,
set_reference_system_metadata,
set_reference_tags,
update_asset_hash_and_mime,
upsert_asset,
@ -29,6 +31,7 @@ from app.assets.database.queries import (
from app.assets.helpers import get_utc_now, normalize_tags
from app.assets.services.bulk_ingest import batch_insert_seed_assets
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.image_dimensions import extract_image_dimensions
from app.assets.services.path_utils import (
compute_relative_filename,
get_name_and_tags_from_asset_path,
@ -118,6 +121,14 @@ def _ingest_file_from_path(
user_metadata=user_metadata,
)
_maybe_store_image_dimensions(
session,
reference_id=reference_id,
file_path=locator,
mime_type=mime_type,
current_system_metadata=ref.system_metadata,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
@ -288,6 +299,13 @@ def _register_existing_asset(
user_metadata=new_meta,
)
_backfill_image_dimensions_from_siblings(
session,
asset_id=asset.id,
new_reference_id=ref.id,
current_system_metadata=ref.system_metadata,
)
if tags is not None:
set_reference_tags(
session,
@ -334,6 +352,87 @@ def _update_metadata_with_filename(
)
_IMAGE_DIMENSION_KEYS = ("kind", "width", "height")
def _maybe_store_image_dimensions(
session: Session,
reference_id: str,
file_path: str,
mime_type: str | None,
current_system_metadata: dict | None,
) -> None:
"""Populate ``kind``/``width``/``height`` on system_metadata for image refs.
Non-image MIME types are a no-op. Pre-existing keys (e.g. enricher-written
safetensors metadata, download provenance) are preserved by merge.
"""
if not mime_type or not mime_type.startswith("image/"):
return
dims = extract_image_dimensions(file_path, mime_type=mime_type)
if not dims:
return
current = current_system_metadata or {}
merged = dict(current)
merged.update(dims)
if merged != current:
set_reference_system_metadata(
session,
reference_id=reference_id,
system_metadata=merged,
)
def _backfill_image_dimensions_from_siblings(
session: Session,
asset_id: str,
new_reference_id: str,
current_system_metadata: dict | None,
) -> None:
"""Copy image dimension keys from any sibling reference of the same asset.
The from-hash path doesn't read the file bytes, so dimensions can't be
extracted there directly. When another reference of the same asset already
carries image dimensions, copy them onto the new reference so consumers
see consistent metadata regardless of how the asset was registered.
Best-effort: missing siblings, non-image siblings, or absent dimension
keys leave the target reference unchanged.
"""
current = current_system_metadata or {}
if current.get("kind") == "image" and "width" in current and "height" in current:
return
for sibling in list_references_by_asset_id(session, asset_id):
if sibling.id == new_reference_id:
continue
meta = sibling.system_metadata or {}
if meta.get("kind") != "image":
continue
width = meta.get("width")
height = meta.get("height")
if (
type(width) is not int
or type(height) is not int
or width <= 0
or height <= 0
):
continue
merged = dict(current)
merged["kind"] = "image"
merged["width"] = width
merged["height"] = height
if merged != current:
set_reference_system_metadata(
session,
reference_id=new_reference_id,
system_metadata=merged,
)
return
def _sanitize_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
return n if n else fallback

View File

@ -166,6 +166,8 @@ class PerformanceFeature(enum.Enum):
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--debug-hang", action="store_true", help="Enable stack trace dumps on Ctrl-C for debugging hangs.")
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")

View File

@ -4,7 +4,6 @@ class LatentFormat:
scale_factor = 1.0
latent_channels = 4
latent_dimensions = 2
preserve_empty_channel_multiples = False
latent_rgb_factors = None
latent_rgb_factors_bias = None
latent_rgb_factors_reshape = None
@ -780,10 +779,6 @@ class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2
class SeedVR2(LatentFormat):
latent_channels = 16
preserve_empty_channel_multiples = True
class ACEAudio15(LatentFormat):
latent_channels = 64
latent_dimensions = 1

View File

@ -735,86 +735,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
)
return out
def _var_attention_qkv(q, k, v, heads, skip_reshape):
if skip_reshape:
return q, k, v, q.shape[-1]
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
return (
q.view(total_tokens, heads, head_dim),
k.view(k.shape[0], heads, head_dim),
v.view(v.shape[0], heads, head_dim),
head_dim,
)
def _var_attention_output(out, heads, head_dim, skip_output_reshape):
if skip_output_reshape:
return out
return out.reshape(-1, heads * head_dim)
def _use_blackwell_attention():
device = model_management.get_torch_device()
if device.type != "cuda":
return False
major, minor = torch.cuda.get_device_capability(device)
return (major, minor) >= (12, 0)
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
if cu_seqlens.dtype not in (torch.int32, torch.int64):
raise ValueError(f"{name} must use an integer dtype")
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
if cu_seqlens[0].item() != 0:
raise ValueError(f"{name} must start at 0")
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
raise ValueError(f"{name} must be strictly increasing")
if cu_seqlens[-1].item() != token_count:
raise ValueError(f"{name} does not match token count")
def _split_indices(cu_seqlens):
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0])
if cu_seqlens_k[-1].item() != v.shape[0]:
raise ValueError("cu_seqlens_k does not match v token count")
q_split_indices = _split_indices(cu_seqlens_q)
k_split_indices = _split_indices(cu_seqlens_k)
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits):
raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count")
out = []
for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits):
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
out_dtype = q_i.dtype
if optimized_attention is attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16):
q_i = q_i.to(torch.bfloat16)
k_i = k_i.to(torch.bfloat16)
v_i = v_i.to(torch.bfloat16)
out_i = optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
if out_i.dtype != out_dtype:
out_i = out_i.to(out_dtype)
out.append(out_i.squeeze(0).permute(1, 0, 2))
out = torch.cat(out, dim=0)
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
optimized_var_attention = var_attention_optimized_split
optimized_attention = attention_basic
if model_management.sage_attention_enabled():
@ -837,8 +758,6 @@ else:
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
logging.info("Using optimized_attention split-loop for variable-length attention")
optimized_attention_masked = optimized_attention
@ -854,7 +773,6 @@ if model_management.xformers_enabled():
register_attention_function("pytorch", attention_pytorch)
register_attention_function("sub_quad", attention_sub_quad)
register_attention_function("split", attention_split)
register_attention_function("var_attention_optimized_split", var_attention_optimized_split)
def optimized_attention_for_device(device, mask=False, small_input=False):
@ -1291,3 +1209,5 @@ class SpatialVideoTransformer(SpatialTransformer):
x = self.proj_out(x)
out = x + x_in
return out

View File

@ -13,7 +13,6 @@ if model_management.xformers_enabled_vae():
import xformers
import xformers.ops
def torch_cat_if_needed(xl, dim):
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
if len(xl) > 1:
@ -23,8 +22,7 @@ def torch_cat_if_needed(xl, dim):
else:
return None
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1):
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
@ -35,13 +33,11 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - downscale_freq_shift)
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0))
return emb

View File

@ -51,6 +51,18 @@ class FeedForward(nn.Module):
return hidden_states
# Addin this back because Nunchaku custom nodes rely on it, see comment here:
# https://github.com/Comfy-Org/ComfyUI/pull/14178#issuecomment-4640475161
# TODO: Eventually remove this once we natively support SVDQuants
def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0:
return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape)
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
super().__init__()

View File

@ -1,340 +0,0 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from comfy.ldm.seedvr.model import safe_pad_operation
from comfy.ldm.seedvr.vae import safe_interpolate_operation
from comfy.ldm.seedvr.constants import (
CIELAB_DELTA,
CIELAB_KAPPA,
D65_WHITE_X,
D65_WHITE_Z,
WAVELET_DECOMP_LEVELS,
)
def wavelet_blur(image: Tensor, radius):
max_safe_radius = max(1, min(image.shape[-2:]) // 8)
if radius > max_safe_radius:
radius = max_safe_radius
num_channels = image.shape[1]
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate')
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS):
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq.add_(image).sub_(low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
if content_feat.shape != style_feat.shape:
# Resize style to match content spatial dimensions
if len(content_feat.shape) >= 3:
# safe_interpolate_operation handles FP16 conversion automatically
style_feat = safe_interpolate_operation(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False
)
# Decompose both features into frequency components
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq # Free memory immediately
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq # Free memory immediately
if content_high_freq.shape != style_low_freq.shape:
style_low_freq = safe_interpolate_operation(
style_low_freq,
size=content_high_freq.shape[-2:],
mode='bilinear',
align_corners=False
)
content_high_freq.add_(style_low_freq)
return content_high_freq.clamp_(-1.0, 1.0)
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
original_shape = source.shape
# Flatten
source_flat = source.flatten()
reference_flat = reference.flatten()
# Sort both arrays
source_sorted, source_indices = torch.sort(source_flat)
reference_sorted, _ = torch.sort(reference_flat)
del reference_flat
# Quantile mapping
n_source = len(source_sorted)
n_reference = len(reference_sorted)
if n_source == n_reference:
matched_sorted = reference_sorted
else:
# Interpolate reference to match source quantiles
source_quantiles = torch.linspace(0, 1, n_source, device=device)
ref_indices = (source_quantiles * (n_reference - 1)).long()
ref_indices.clamp_(0, n_reference - 1)
matched_sorted = reference_sorted[ref_indices]
del source_quantiles, ref_indices, reference_sorted
del source_sorted, source_flat
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
inverse_indices = torch.argsort(source_indices)
del source_indices
matched_flat = matched_sorted[inverse_indices]
del matched_sorted, inverse_indices
return matched_flat.reshape(original_shape)
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
"""Convert batch of CIELAB images to RGB color space."""
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
# LAB to XYZ
fy = (L + 16.0) / 116.0
fx = a.div(500.0).add_(fy)
fz = fy - b / 200.0
del L, a, b
# XYZ transformation
x = torch.where(
fx > epsilon,
torch.pow(fx, 3.0),
fx.mul(116.0).sub_(16.0).div_(kappa)
)
y = torch.where(
fy > epsilon,
torch.pow(fy, 3.0),
fy.mul(116.0).sub_(16.0).div_(kappa)
)
z = torch.where(
fz > epsilon,
torch.pow(fz, 3.0),
fz.mul(116.0).sub_(16.0).div_(kappa)
)
del fx, fy, fz
# Apply D65 white point (in-place)
x.mul_(D65_WHITE_X)
# y *= 1.00000 # (no-op, skip)
z.mul_(D65_WHITE_Z)
xyz = torch.stack([x, y, z], dim=1)
del x, y, z
# Matrix multiplication: XYZ -> RGB
B, C, H, W = xyz.shape
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
del xyz
# Ensure dtype consistency for matrix multiplication
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
del xyz_flat
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
del rgb_linear_flat
# Apply inverse gamma correction (delinearize)
mask = rgb_linear > 0.0031308
rgb = torch.where(
mask,
torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055),
rgb_linear * 12.92
)
del mask, rgb_linear
return torch.clamp(rgb, 0.0, 1.0)
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
"""Convert batch of RGB images to CIELAB color space using D65 illuminant."""
# Apply sRGB gamma correction (linearize)
mask = rgb > 0.04045
rgb_linear = torch.where(
mask,
torch.pow((rgb + 0.055) / 1.055, 2.4),
rgb / 12.92
)
del mask
# Matrix multiplication: RGB -> XYZ
B, C, H, W = rgb_linear.shape
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
del rgb_linear
# Ensure dtype consistency for matrix multiplication
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
xyz_flat = torch.matmul(rgb_flat, matrix.T)
del rgb_flat
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
del xyz_flat
# Normalize by D65 white point (in-place)
xyz[:, 0].div_(D65_WHITE_X) # X
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
xyz[:, 2].div_(D65_WHITE_Z) # Z
# XYZ to LAB transformation
epsilon_cubed = epsilon ** 3
mask = xyz > epsilon_cubed
f_xyz = torch.where(
mask,
torch.pow(xyz, 1.0 / 3.0),
xyz.mul(kappa).add_(16.0).div_(116.0)
)
del xyz, mask
# Extract channels and compute LAB
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
del f_xyz
return torch.stack([L, a, b], dim=1)
def lab_color_transfer(
content_feat: Tensor,
style_feat: Tensor,
luminance_weight: float = 0.8
) -> Tensor:
content_feat = wavelet_reconstruction(content_feat, style_feat)
if content_feat.shape != style_feat.shape:
style_feat = safe_interpolate_operation(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False
)
device = content_feat.device
def ensure_float32_precision(c):
orig_dtype = c.dtype
c = c.float()
return c, orig_dtype
content_feat, original_dtype = ensure_float32_precision(content_feat)
style_feat, _ = ensure_float32_precision(style_feat)
rgb_to_xyz_matrix = torch.tensor([
[0.4124564, 0.3575761, 0.1804375],
[0.2126729, 0.7151522, 0.0721750],
[0.0193339, 0.1191920, 0.9503041]
], dtype=torch.float32, device=device)
xyz_to_rgb_matrix = torch.tensor([
[ 3.2404542, -1.5371385, -0.4985314],
[-0.9692660, 1.8760108, 0.0415560],
[ 0.0556434, -0.2040259, 1.0572252]
], dtype=torch.float32, device=device)
epsilon = CIELAB_DELTA
kappa = CIELAB_KAPPA
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
# Convert to LAB color space
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
del content_feat
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
del style_feat, rgb_to_xyz_matrix
# Match chrominance channels (a*, b*) for accurate color transfer
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
# Handle luminance with weighted blending
if luminance_weight < 1.0:
# Partially match luminance for better overall color accuracy
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device)
# Blend: preserve some content L* for detail, adopt some style L* for color
result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight))
del matched_L
else:
# Fully preserve content luminance
result_L = content_lab[:, 0]
del content_lab, style_lab
# Reconstruct LAB with corrected channels
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
del result_L, matched_a, matched_b
# Convert back to RGB
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
del result_lab, xyz_to_rgb_matrix
# Convert back to [-1, 1] range (in-place)
result = result_rgb.mul_(2.0).sub_(1.0)
del result_rgb
result = result.to(original_dtype)
return result
def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor:
return wavelet_reconstruction(content_feat, style_feat)
def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor:
if content_feat.shape != style_feat.shape:
style_feat = safe_interpolate_operation(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False,
)
original_dtype = content_feat.dtype
content_feat = content_feat.float()
style_feat = style_feat.float()
b, c = content_feat.shape[:2]
content_flat = content_feat.reshape(b, c, -1)
style_flat = style_feat.reshape(b, c, -1)
content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1)
content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1)
style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
del content_flat, style_flat
normalized = (content_feat - content_mean) / content_std
del content_mean, content_std
result = normalized * style_std + style_mean
del normalized, style_mean, style_std
result = result.clamp_(-1.0, 1.0)
if result.dtype != original_dtype:
result = result.to(original_dtype)
return result

View File

@ -1,79 +0,0 @@
"""Named constants for the SeedVR2 integration, grouped by provenance.
Provenance prefixes:
- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline.
- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites
the upstream config/source path it was lifted from.
- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature /
ISO / CIE values; cite the standard.
"""
# --------------------------------------------------------------------------------------
# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment)
# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN)
# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070
# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT).
# --------------------------------------------------------------------------------------
SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB)
SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB
# --------------------------------------------------------------------------------------
# B. Fork heuristics (SEEDVR2 - this integration)
# --------------------------------------------------------------------------------------
SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim.
# (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.)
SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry.
SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case).
SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM.
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk.
SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels).
SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16).
SEEDVR2_DEFAULT_TEMPORAL_SIZE = 16 # default VAE temporal tile when unset.
# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing)
SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk.
SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path.
SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path.
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path.
# --------------------------------------------------------------------------------------
# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR)
# --------------------------------------------------------------------------------------
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm.
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift.
BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem).
BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem).
BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28.
BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28.
BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16).
BYTEDANCE_GN_CHUNKS_FP32 = 2 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp32).
BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD = 64 # attn_video_vae.py:308 (force .contiguous() above this b*t).
BYTEDANCE_BLOCK_OUT_CHANNELS = (128, 256, 512, 512) # s8_c16_t4_inflation_sd3.yaml:7-11.
BYTEDANCE_SLICING_SAMPLE_MIN = 4 # s8_c16_t4_inflation_sd3.yaml:22 (slicing_sample_min_size).
BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE = 4 # infer.py:230 (temporal_downsample_factor); the 4n+1 factor.
BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE = 8 # infer.py:231 (spatial_downsample_factor).
BYTEDANCE_SCHEDULE_T = 1000.0 # configs_3b/main.yaml:65 (schedule.T); timestep range.
BYTEDANCE_SPATIAL_DIVISOR = 16 # inference_seedvr2_3b.py:241 (DivisibleCrop((16,16))).
BYTEDANCE_720P_REF_AREA = 45 * 80 # dit_v2/window.py:32 (720p reference area for window scaling).
BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal window frames).
BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency).
BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim).
# Resolution-dependent timestep-shift linear fits: (x1, y1, x2, y2) for get_lin_function.
BYTEDANCE_IMG_SHIFT_FIT = (256 * 256, 1.0, 1024 * 1024, 3.2) # infer.py:242.
BYTEDANCE_VID_SHIFT_FIT = (256 * 256 * 37, 1.0, 1280 * 720 * 145, 5.0) # infer.py:243.
# --------------------------------------------------------------------------------------
# D. Published standards (cite the literature)
# --------------------------------------------------------------------------------------
ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864.
# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65).
CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta).
CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa).
D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1).
D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn.
WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR).
# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and
# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the
# exact existing coefficients move verbatim rather than being retyped here.

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1631,13 +1631,15 @@ class SCAILWanModel(WanModel):
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, ref_mask_latents=None, sam_latents=None, **kwargs):
if reference_latent is not None:
x = torch.cat((reference_latent, x), dim=2)
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
if ref_mask_latents is not None: # SCAIL-2 additive mask stream
x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype)
grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)
@ -1645,6 +1647,8 @@ class SCAILWanModel(WanModel):
scail_pose_seq_len = 0
if pose_latents is not None:
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
if sam_latents is not None: # SCAIL-2 additive mask stream
scail_x = scail_x + self.patch_embedding_mask(sam_latents.float()).to(x.dtype)
scail_x = scail_x.flatten(2).transpose(1, 2)
scail_pose_seq_len = scail_x.shape[1]
x = torch.cat([x, scail_x], dim=1)
@ -1695,7 +1699,36 @@ class SCAILWanModel(WanModel):
return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
# ref_mask_flag is a scalar bool (CONDConstant, SCAIL-2 only). False => replacement mode,
# which places ref/pose via H/W rope shifts instead of the animation-mode temporal offset.
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}):
if ref_mask_flag is not None and not bool(ref_mask_flag):
REF_ROPE_H = 120.0
POSE_ROPE_W = 120.0
ref_t_patches = 0
if reference_latent is not None:
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
main_t_patches = t - ref_t_patches
parts = []
if ref_t_patches > 0:
ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}}
parts.append(super().rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf))
if main_t_patches > 0:
parts.append(super().rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options))
if pose_latents is not None:
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
h_scale = h / H_pose
w_scale = w / W_pose
h_shift = (h_scale - 1) / 2
w_shift = (w_scale - 1) / 2
pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf))
return torch.cat(parts, dim=1)
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
if pose_latents is None:
@ -1719,12 +1752,16 @@ class SCAILWanModel(WanModel):
return torch.cat([main_freqs, pose_freqs], dim=1)
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, ref_mask_latents=None, sam_latents=None, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
if pose_latents is not None:
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
if ref_mask_latents is not None: # SCAIL-2
ref_mask_latents = comfy.ldm.common_dit.pad_to_patch_size(ref_mask_latents, self.patch_size)
if sam_latents is not None: # SCAIL-2
sam_latents = comfy.ldm.common_dit.pad_to_patch_size(sam_latents, self.patch_size)
t_len = t
if time_dim_concat is not None:
@ -1737,5 +1774,15 @@ class SCAILWanModel(WanModel):
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
t_len += reference_latent.shape[2]
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
ref_mask_flag = kwargs.pop("ref_mask_flag", None) # SCAIL-2
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_flag=ref_mask_flag)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_latents=ref_mask_latents, sam_latents=sam_latents, **kwargs)[:, :, :t, :h, :w]
class SCAIL2WanModel(SCAILWanModel):
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
def __init__(self, model_type="scail2", patch_size=(1, 2, 2), in_dim=20, mask_in_dim=28, dim=5120, operations=None, device=None, dtype=None, **kwargs):
super().__init__(model_type=model_type, patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
self.patch_embedding_mask = operations.Conv3d(mask_in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)

View File

@ -357,6 +357,12 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["transformer.{}".format(key_lora)] = k
if isinstance(model, (comfy.model_base.LTXV, comfy.model_base.LTXAV)):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
return key_map

View File

@ -54,8 +54,6 @@ import comfy.ldm.pixeldit.model
import comfy.ldm.pixeldit.pid
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.seedvr.model
import comfy.ldm.qwen_image.model
import comfy.ldm.ideogram4.model
import comfy.ldm.kandinsky5.model
@ -930,16 +928,6 @@ class HunyuanDiT(BaseModel):
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
return out
class SeedVR2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
condition = kwargs.get("condition", None)
if condition is not None:
out["condition"] = comfy.conds.CONDRegular(condition)
return out
class PixArt(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
@ -1766,6 +1754,80 @@ class WAN21_SCAIL(WAN21):
return out
class WAN21_SCAIL2(WAN21_SCAIL):
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAIL2WanModel)
self.memory_usage_factor_conds = ("reference_latent", "pose_latents", "ref_mask_latents", "sam_latents")
self.memory_usage_shape_process = {
"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]],
"sam_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]],
}
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
driving_mask_28ch = kwargs.get("driving_mask_28ch", None)
if driving_mask_28ch is not None:
out['sam_latents'] = comfy.conds.CONDRegular(driving_mask_28ch.movedim(1, 2).contiguous())
ref_mask_28ch = kwargs.get("ref_mask_28ch", None)
if ref_mask_28ch is not None:
out['ref_mask_latents'] = comfy.conds.CONDRegular(ref_mask_28ch.movedim(1, 2).contiguous())
ref_mask_flag = kwargs.get("ref_mask_flag", None)
if ref_mask_flag is not None:
out['ref_mask_flag'] = comfy.conds.CONDConstant(ref_mask_flag)
return out
def extra_conds_shapes(self, **kwargs):
out = super().extra_conds_shapes(**kwargs)
driving_mask_28ch = kwargs.get("driving_mask_28ch", None)
if driving_mask_28ch is not None:
s = driving_mask_28ch.shape
out['sam_latents'] = [s[0], 28, s[1], s[3], s[4]]
ref_mask_28ch = kwargs.get("ref_mask_28ch", None)
if ref_mask_28ch is not None:
s = ref_mask_28ch.shape
out['ref_mask_latents'] = [s[0], 28, s[1], s[3], s[4]]
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key in ("sam_latents", "pose_latents"):
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
def concat_cond(self, **kwargs):
# The 4 extra channels are the history_mask (1 at clean-anchor frames).
noise = kwargs.get("noise", None)
extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1]
if extra_channels != 4:
return super().concat_cond(**kwargs)
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
return torch.zeros_like(noise)[:, :4]
device = kwargs["device"]
if mask.shape[1] != 4:
mask = torch.mean(mask, dim=1, keepdim=True)
mask = 1.0 - mask
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
if mask.shape[1] == 1:
mask = mask.repeat(1, 4, 1, 1, 1)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return mask
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
# Hold anchor constant across all sigmas instead of base sigma*noise + (1-sigma)*latent_image.
return latent_image
class WAN22_WanDancer(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=True, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_wandancer.WanDancerModel)

View File

@ -598,56 +598,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config
if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 3072
dit_config["heads"] = 24
dit_config["num_layers"] = 36
# 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.``
# submodules) at EVERY block — verified by inspecting the 7B
# state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means
# ``MMModule.shared_weights=False``). Native NaDiT computes
# per-block ``shared_weights = not (i < mm_layers)``, so to keep
# every block non-shared we set ``mm_layers = num_layers``.
# Without this, blocks at index >= mm_layers (default 10) try to
# load ``blocks.N.*.all.*`` keys that don't exist in the file,
# silently miss-load → all-black output.
dit_config["mm_layers"] = 36
dit_config["norm_eps"] = 1e-5
dit_config["qk_rope"] = True
dit_config["rope_type"] = "rope3d"
dit_config["rope_dim"] = 64
dit_config["mlp_type"] = "normal"
return dit_config
elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 3072
dit_config["heads"] = 24
dit_config["num_layers"] = 36
# This checkpoint layout carries shared ``all.`` MMModule keys.
# Preserve the historical split: the initial blocks use separate
# vid/txt modules, later blocks use shared modules.
dit_config["mm_layers"] = 10
dit_config["norm_eps"] = 1e-5
dit_config["qk_rope"] = True
dit_config["rope_type"] = "rope3d"
dit_config["rope_dim"] = 64
dit_config["mlp_type"] = "swiglu"
return dit_config
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 2560
dit_config["heads"] = 20
dit_config["num_layers"] = 32
dit_config["norm_eps"] = 1.0e-05
dit_config["qk_rope"] = None
dit_config["mlp_type"] = "swiglu"
dit_config["vid_out_norm"] = True
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {}
dit_config["image_model"] = "wan2.1"
@ -680,6 +630,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "humo"
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "animate"
elif '{}patch_embedding_mask.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "scail2"
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "scail"
elif '{}patch_embedding_global.weight'.format(key_prefix) in state_dict_keys:

View File

@ -44,13 +44,7 @@ def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None,
is_empty = torch.count_nonzero(latent_image) == 0
if is_empty:
if latent_format.latent_channels != latent_image.shape[1]:
preserves_collapsed_channels = (
getattr(latent_format, "preserve_empty_channel_multiples", False)
and latent_image.ndim == 4
and latent_image.shape[1] % latent_format.latent_channels == 0
)
if not preserves_collapsed_channels:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
if downscale_ratio_spacial is not None:
if downscale_ratio_spacial != latent_format.spacial_downscale_ratio:
ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio

View File

@ -1,4 +1,3 @@
import inspect
import json
import torch
from enum import Enum
@ -17,7 +16,6 @@ import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.seedvr.vae
import comfy.ldm.triposplat.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.cogvideo.vae
@ -86,36 +84,6 @@ import comfy.latent_formats
import comfy.ldm.flux.redux
SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL = 160
def _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w):
output_t = max(1, (latent_t - 1) * 4 + 1)
return output_t * latent_h * 8 * latent_w * 8
def _seedvr2_vae_decode_memory_used(shape):
if len(shape) == 5:
candidates = []
if shape[1] == 16:
candidates.append((shape[2], shape[3], shape[4]))
if shape[-1] == 16:
candidates.append((shape[1], shape[2], shape[3]))
if len(candidates) == 0:
candidates.append((shape[2], shape[3], shape[4]))
output_pixels = max(_seedvr2_vae_decode_output_pixels(*candidate) for candidate in candidates)
elif len(shape) == 4:
latent_t = max(1, (shape[1] + 15) // 16)
latent_h, latent_w = shape[2], shape[3]
output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w)
else:
latent_t, latent_h, latent_w = 1, shape[-2], shape[-1]
output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w)
# SeedVR2 decode performs full-frame LAB histogram matching: fp32 channels
# plus int64 sort indices dominate peak memory, not the VAE weight dtype.
return output_pixels * SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None):
key_map = {}
if model is not None:
@ -499,10 +467,8 @@ class CLIP:
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd
if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
if metadata is None or metadata.get("keep_diffusers_format") != "true":
sd = diffusers_convert.convert_vae_state_dict(sd)
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
if model_management.is_amd():
VAE_KL_MEM_RATIO = 2.73
@ -574,20 +540,6 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
self.latent_channels = 16
self.latent_dim = 3
self.disable_offload = True
self.memory_used_decode = lambda shape, dtype: _seedvr2_vae_decode_memory_used(shape)
self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.process_input = lambda image: image * 2.0 - 1.0
self.crop_input = False
elif "decoder.conv_in.weight" in sd:
if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
@ -715,7 +667,6 @@ class VAE:
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
self.downscale_index_formula = (8, 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
@ -1055,40 +1006,6 @@ class VAE:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def decode_tiled_seedvr2(self, samples, tile_x=32, tile_y=32, overlap=8, tile_t=16, overlap_t=4):
sf_s = getattr(self.first_stage_model, "spatial_downsample_factor", 8)
sf_t = getattr(self.first_stage_model, "temporal_downsample_factor", 4)
if tile_t is None:
tile_t = 16
if overlap_t is None:
overlap_t = 4
if tile_t > 0:
temporal_size = tile_t * sf_t
temporal_overlap = max(0, overlap_t) * sf_t
else:
temporal_size = 0
temporal_overlap = 0
args = {
"enable_tiling": True,
"tile_size": (tile_y * sf_s, tile_x * sf_s),
"tile_overlap": (overlap * sf_s, overlap * sf_s),
"temporal_size": temporal_size,
"temporal_overlap": temporal_overlap,
}
output = self.first_stage_model.decode(
samples.to(self.vae_dtype).to(self.device),
seedvr2_tiling=args,
)
return self.process_output(output.to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
def _format_seedvr2_encoded_samples(self, samples):
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
if samples.ndim == 4:
samples = samples.unsqueeze(2)
samples = samples.contiguous()
samples = samples * 0.9152
return samples
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
@ -1125,36 +1042,6 @@ class VAE:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def encode_tiled_seedvr2(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
if tile_y is None:
tile_y = 512
if tile_x is None:
tile_x = 512
if overlap is None:
overlap_y = 64
overlap_x = 64
else:
overlap_y = overlap
overlap_x = overlap
if tile_t is None:
tile_t = 9999
if overlap_t is None:
overlap_t = 0
overlap_y = min(overlap_y, max(0, tile_y - 8))
overlap_x = min(overlap_x, max(0, tile_x - 8))
self.first_stage_model.device = self.device
x = self.process_input(pixel_samples).to(self.vae_dtype).to(self.device)
output = comfy.ldm.seedvr.vae.tiled_vae(
x,
self.first_stage_model,
tile_size=(tile_y, tile_x),
tile_overlap=(overlap_y, overlap_x),
temporal_size=tile_t,
temporal_overlap=overlap_t,
encode=True,
)
return output.to(device=self.output_device, dtype=self.vae_output_dtype())
def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid()
pixel_samples = None
@ -1202,40 +1089,16 @@ class VAE:
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2:
# SeedVR2 latents arrive in 4D collapsed form ``(B, 16*T, H, W)``
# downstream of ``SeedVR2Conditioning`` (which performs the
# ``rearrange(b c t h w -> b (c t) h w)`` collapse). The
# generic ``decode_tiled_`` would treat the channel dim as
# spatial-only and crash on the collapsed (16, T) layout
# under ``tiled_scale``'s mask broadcast; route SeedVR2 4D
# latents to ``decode_tiled_seedvr2`` instead, whose wrapper
# dispatch handles both 4D and 5D inputs.
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
else:
pixel_samples = self.decode_tiled_(samples_in)
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3:
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
else:
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
def decode_tiled(
self,
samples,
tile_x=None,
tile_y=None,
overlap=None,
tile_t=None,
overlap_t=None,
):
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -1249,20 +1112,7 @@ class VAE:
args["overlap"] = overlap
with model_management.cuda_device_context(self.device):
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper) and dims in (2, 3):
seedvr2_args = {}
if tile_x is not None:
seedvr2_args["tile_x"] = tile_x
if tile_y is not None:
seedvr2_args["tile_y"] = tile_y
if overlap is not None:
seedvr2_args["overlap"] = overlap
if tile_t is not None:
seedvr2_args["tile_t"] = tile_t
if overlap_t is not None:
seedvr2_args["overlap_t"] = overlap_t
output = self.decode_tiled_seedvr2(samples, **seedvr2_args)
elif dims == 1 or self.extra_1d_channel is not None:
if dims == 1 or self.extra_1d_channel is not None:
args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args)
elif dims == 2:
@ -1304,8 +1154,6 @@ class VAE:
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
if isinstance(out, tuple):
out = out[0]
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
@ -1325,23 +1173,20 @@ class VAE:
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
samples = self.encode_tiled_seedvr2(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap)
else:
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples)
else:
samples = self.encode_tiled_(pixel_samples)
return self._format_seedvr2_encoded_samples(samples)
return samples
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1)
if dims == 3 and pixel_samples.ndim < 5:
if dims == 3:
if not self.not_video:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
else:
@ -1365,47 +1210,22 @@ class VAE:
elif dims == 2:
samples = self.encode_tiled_(pixel_samples, **args)
elif dims == 3:
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
seedvr2_args = {}
if tile_x is not None:
seedvr2_args["tile_x"] = tile_x
else:
seedvr2_args["tile_x"] = 512
if tile_y is not None:
seedvr2_args["tile_y"] = tile_y
else:
seedvr2_args["tile_y"] = 512
if overlap is not None:
seedvr2_args["overlap"] = overlap
else:
seedvr2_args["overlap"] = 64
if tile_t is not None:
seedvr2_args["tile_t"] = tile_t
else:
seedvr2_args["tile_t"] = 9999
if overlap_t is not None:
seedvr2_args["overlap_t"] = overlap_t
else:
seedvr2_args["overlap_t"] = 0
samples = self.encode_tiled_seedvr2(pixel_samples, **seedvr2_args)
if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
else:
if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
else:
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
spatial_overlap = overlap if overlap is not None else 64
if overlap_t is None:
args["overlap"] = (1, spatial_overlap, spatial_overlap)
else:
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), spatial_overlap, spatial_overlap)
maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
return self._format_seedvr2_encoded_samples(samples)
return samples
def get_sd(self):
return self.first_stage_model.state_dict()
@ -1932,17 +1752,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae)
def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device):
set_dtype = model_config.set_inference_dtype
parameters = inspect.signature(set_dtype).parameters
supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values())
if supports_device:
set_dtype(dtype, manual_cast_dtype, device=device)
else:
set_dtype(dtype, manual_cast_dtype)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
@ -2050,7 +1859,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
if model_config.clip_vision_prefix is not None:
if output_clipvision:
@ -2191,7 +2000,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
if custom_operations is not None:
model_config.custom_operations = custom_operations

View File

@ -1450,6 +1450,17 @@ class WAN21_SCAIL(WAN21_T2V):
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
return out
class WAN21_SCAIL2(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "scail2",
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_SCAIL2(self, image_to_video=False, device=device)
return out
class WAN22_WanDancer(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@ -1672,35 +1683,6 @@ class Chroma(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
class SeedVR2(supported_models_base.BASE):
unet_config = {
"image_model": "seedvr2"
}
latent_format = comfy.latent_formats.SeedVR2
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
sampling_settings = {
"shift": 1.0,
}
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
if (
dtype == torch.float16
and manual_cast_dtype is None
and comfy.model_management.should_use_bf16(device)
):
manual_cast_dtype = torch.bfloat16
super().set_inference_dtype(dtype, manual_cast_dtype, device=device)
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SeedVR2(self, device=device)
return out
def clip_target(self, state_dict={}):
return None
class ChromaRadiance(Chroma):
unet_config = {
"image_model": "chroma_radiance",
@ -2058,6 +2040,7 @@ class LongCatImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
class RT_DETR_v4(supported_models_base.BASE):
unet_config = {
"image_model": "RT_DETR_v4",
@ -2287,6 +2270,7 @@ models = [
WAN22_Animate,
WAN21_FlowRVS,
WAN21_SCAIL,
WAN21_SCAIL2,
WAN22_WanDancer,
Hunyuan3Dv2mini,
Hunyuan3Dv2,
@ -2295,7 +2279,6 @@ models = [
HiDream,
HiDreamO1,
Chroma,
SeedVR2,
ChromaRadiance,
ACEStep,
ACEStep15,

View File

@ -115,7 +115,7 @@ class BASE:
replace_prefix = {"": self.vae_key_prefix[0]}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
def set_inference_dtype(self, dtype, manual_cast_dtype):
self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype

View File

@ -32,7 +32,9 @@ class Ideogram4Tokenizer(sd1_clip.SD1Tokenizer):
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
if text.startswith('<|im_start|>'):
llama_text = text
elif llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)

321
comfy_extras/nodes_scail.py Normal file
View File

@ -0,0 +1,321 @@
"""SCAIL / SCAIL-2 nodes: the WanSCAILToVideo conditioning node and the SAM3
preprocessing that turns video tracks into the bundle the SCAIL-2 model consumes."""
from typing_extensions import override
import torch
import torch.nn.functional as F
import nodes
import node_helpers
import comfy.model_management
import comfy.utils
from comfy_api.latest import ComfyExtension, io
from comfy.ldm.sam3.tracker import unpack_masks
SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
# Model was trained on these exact colors; deviating degrades multi-identity quality.
DEFAULT_PALETTE = [
(0.0, 0.0, 1.0), # Blue
(1.0, 0.0, 0.0), # Red
(0.0, 1.0, 0.0), # Green
(1.0, 0.0, 1.0), # Magenta
(0.0, 1.0, 1.0), # Cyan
(1.0, 1.0, 0.0), # Yellow
]
def _unpack(track_data):
packed = track_data["packed_masks"]
if packed is None or packed.shape[1] == 0:
return None
return unpack_masks(packed)
def _first_frame_cx_area(masks_bool):
first = masks_bool[0].float()
H, W = first.shape[-2], first.shape[-1]
n_pixels = H * W
grid_x = torch.arange(W, device=first.device, dtype=first.dtype).view(1, W)
area = first.sum(dim=(-1, -2)).clamp_(min=1)
cx = (first * grid_x).sum(dim=(-1, -2)) / area
return (cx / W).tolist(), (area / n_pixels).tolist()
def _subset_track_data(track_data, obj_indices):
out = dict(track_data)
packed = track_data["packed_masks"]
if packed is None or not obj_indices:
out["packed_masks"] = None
if "scores" in out:
out["scores"] = []
return out
out["packed_masks"] = packed[:, obj_indices].contiguous()
scores = track_data.get("scores")
if scores is not None:
out["scores"] = [scores[i] for i in obj_indices if i < len(scores)]
return out
def _render_colored_masks(track_data, background="black"):
packed = track_data["packed_masks"]
H, W = track_data["orig_size"]
device = comfy.model_management.intermediate_device()
dtype = comfy.model_management.intermediate_dtype()
bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0)
if packed is None or packed.shape[1] == 0:
T = track_data.get("n_frames", 1) if packed is None else packed.shape[0]
out = torch.empty(T, H, W, 3, device=device, dtype=dtype)
out[..., 0], out[..., 1], out[..., 2] = bg_rgb[0], bg_rgb[1], bg_rgb[2]
return out
T, N_obj = packed.shape[0], packed.shape[1]
colors = torch.tensor(
[DEFAULT_PALETTE[i % len(DEFAULT_PALETTE)] for i in range(N_obj)],
device=device, dtype=dtype,
)
masks_full = unpack_masks(packed.to(device)).float()
Hm, Wm = masks_full.shape[-2], masks_full.shape[-1]
masks_full = F.interpolate(
masks_full.view(T * N_obj, 1, Hm, Wm), size=(H, W), mode="nearest"
).view(T, N_obj, H, W) > 0.5
any_mask = masks_full.any(dim=1)
obj_idx_map = masks_full.to(torch.uint8).argmax(dim=1)
color_overlay = colors[obj_idx_map]
bg_tensor = torch.tensor(bg_rgb, device=device, dtype=color_overlay.dtype).view(1, 1, 1, 3)
return torch.where(any_mask.unsqueeze(-1), color_overlay, bg_tensor.expand_as(color_overlay))
def _extract_mask_to_28ch(rgb_video):
"""Colored RGB mask (T, H, W, 3) in [0, 1] -> SCAIL-2 28-channel binary latent
(1, T_lat, 28, H_lat, W_lat). 7 per-color binary channels (white/r/g/b/y/m/c)
threshold-extracted at 225/255, 8x spatial downsample, 4-frame temporal stacking."""
T, H, W, _ = rgb_video.shape
_ON_THRESH = 225.0 / 255.0
mask = rgb_video.movedim(-1, 1).float()
R = (mask[:, 0:1] > _ON_THRESH).float()
G = (mask[:, 1:2] > _ON_THRESH).float()
B = (mask[:, 2:3] > _ON_THRESH).float()
nR, nG, nB = 1 - R, 1 - G, 1 - B
binary_7ch = torch.cat([
R * G * B, # white
R * nG * nB, # red
nR * G * nB, # green
nR * nG * B, # blue
R * G * nB, # yellow
R * nG * B, # magenta
nR * G * B, # cyan
], dim=1)
H_lat, W_lat = H, W
for _ in range(3):
H_lat = (H_lat + 1) // 2
W_lat = (W_lat + 1) // 2
binary_7ch = torch.nn.functional.interpolate(binary_7ch, size=(H_lat, W_lat), mode='area')
T_latent = (T - 1) // 4 + 1
padded = torch.cat([binary_7ch[:1].repeat(4, 1, 1, 1), binary_7ch[1:]], dim=0)
out = padded.view(T_latent, 28, H_lat, W_lat)
return out.unsqueeze(0)
class WanSCAILToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanSCAILToVideo",
category="model/conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
io.Image.Input("pose_video_mask", optional=True, tooltip="SCAIL-2 only. Colored per-identity SAM3 mask video at the same resolution as pose_video."),
io.Boolean.Input("replacement_mode", default=False, optional=True, tooltip="SCAIL-2 only. False = Animation Mode (pose_video_mask should have black background). True = Replacement Mode (pose_video_mask should have white background)."),
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step of the pose conditioning."),
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step of the pose conditioning."),
io.Image.Input("reference_image", optional=True, tooltip="Reference image, for multiple references composite all on single image."),
io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Colored reference mask at the same resolution as reference_image."),
io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="CLIP vision features for conditioning. Model is trained with stretch resize to aspect ratio."),
io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="Cumulative output frame this chunk begins at. Wire from the previous chunk's video_frame_offset output."),
io.Int.Input("previous_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="Tail frames of previous_frames to anchor. SCAIL-2 trained at 5 (81-frame chunks, 76-frame step)."),
io.Image.Input("previous_frames", optional=True, tooltip="SCAIL-2 only. Full decoded output of the previous chunk. Only the last previous_frame_count are used as the extension anchor."),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
io.Int.Output(display_name="video_frame_offset", tooltip="Adjusted offset + length. Wire into the next chunk."),
],
is_experimental=True,
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end,
video_frame_offset, previous_frame_count, replacement_mode=False, reference_image=None, clip_vision_output=None, pose_video=None,
pose_video_mask=None, reference_image_mask=None, previous_frames=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
noise_mask = None
ref_mask_flag = not replacement_mode
positive = node_helpers.conditioning_set_values(positive, {"ref_mask_flag": ref_mask_flag})
negative = node_helpers.conditioning_set_values(negative, {"ref_mask_flag": ref_mask_flag})
prev_trimmed = None
if previous_frames is not None and previous_frames.shape[0] > 0:
prev_trimmed = previous_frames[-previous_frame_count:]
video_frame_offset -= prev_trimmed.shape[0]
video_frame_offset = max(0, video_frame_offset)
ref_latent = None
if reference_image is not None:
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
# Replacement Mode: composite ref on black bg using reference_image_mask as alpha matte
if replacement_mode and reference_image_mask is not None:
rm = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1)
is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(reference_image.dtype)
reference_image = reference_image * is_char
ref_latent = vae.encode(reference_image[:, :, :, :3])
if ref_latent is not None:
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
if pose_video is not None:
if pose_video.shape[0] <= video_frame_offset:
pose_video = None
else:
pose_video = pose_video[video_frame_offset:]
if pose_video_mask is not None:
if pose_video_mask.shape[0] <= video_frame_offset:
pose_video_mask = None
else:
pose_video_mask = pose_video_mask[video_frame_offset:]
# Truncate pose+mask jointly to the shorter of the two, capped at length.
ts = [v.shape[0] for v in (pose_video, pose_video_mask) if v is not None]
if ts:
T_kept = ((min(min(ts), length) - 1) // 4) * 4 + 1
if pose_video is not None:
pose_video = pose_video[:T_kept]
if pose_video_mask is not None:
pose_video_mask = pose_video_mask[:T_kept]
if pose_video is not None:
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
if pose_video_mask is not None:
mask_video_hw = comfy.utils.common_upscale(pose_video_mask[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
driving_mask_28ch = _extract_mask_to_28ch(mask_video_hw)
positive = node_helpers.conditioning_set_values(positive, {"driving_mask_28ch": driving_mask_28ch})
negative = node_helpers.conditioning_set_values(negative, {"driving_mask_28ch": driving_mask_28ch})
if reference_image_mask is not None:
ref_mask_hw = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw)
zeros = torch.zeros((1, latent.shape[2], 28, ref_mask_1f.shape[-2], ref_mask_1f.shape[-1]), device=ref_mask_1f.device, dtype=ref_mask_1f.dtype)
ref_mask_28ch = torch.cat([ref_mask_1f, zeros], dim=1)
positive = node_helpers.conditioning_set_values(positive, {"ref_mask_28ch": ref_mask_28ch})
negative = node_helpers.conditioning_set_values(negative, {"ref_mask_28ch": ref_mask_28ch})
if prev_trimmed is not None:
pf = comfy.utils.common_upscale(prev_trimmed.movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
prev_latent = vae.encode(pf[:, :, :, :3])
prev_latent_frames = min(prev_latent.shape[2], latent.shape[2])
latent[:, :, :prev_latent_frames] = prev_latent[:, :, :prev_latent_frames].to(latent.dtype)
noise_mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=latent.device, dtype=latent.dtype)
noise_mask[:, :, :prev_latent_frames] = 0.0
out_latent = {"samples": latent}
if noise_mask is not None:
out_latent["noise_mask"] = noise_mask
return io.NodeOutput(positive, negative, out_latent, video_frame_offset + length)
class SCAIL2ColoredMask(io.ComfyNode):
"""Render SAM3 tracks for the driving pose video and (optionally) the reference
image into the two colored masks WanSCAILToVideo consumes. Shared `sort_by`
across both outputs guarantees identity K maps to the same color on both
sides, for multi-person workflow consistency.
reference_image_mask is always rendered black-bg (model convention)
pose_video_mask bg follows replacement_mode: black = Animation Mode, white = Replacement Mode
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SCAIL2ColoredMask",
display_name="Create SCAIL-2 Colored Mask",
category="conditioning/video_models/scail",
inputs=[
SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving pose video. Will be rendered into the pose_video_mask output."),
SAM3TrackData.Input("ref_track_data", optional=True,
tooltip="SAM3 track of the reference image."),
io.String.Input("object_indices", default="",
tooltip="Comma-separated list of person indices to include (e.g. '0,2,3'). Applied to both reference and pose video masks. Empty = all."),
io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right",
tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). left_to_right = leftmost object (by first-frame centroid) gets the first color; area = biggest object (by first-frame mask area) gets the first color; none = keep SAM3's order."),
io.Boolean.Input("replacement_mode", default=False,
tooltip="False = mask_video has black bg (Animation Mode). True = white bg (Replacement Mode). Set the matching replacement_mode on WanSCAILToVideo. reference_image_mask is always black-bg regardless."),
],
outputs=[
io.Image.Output("pose_video_mask"),
io.Image.Output("reference_image_mask"),
],
is_experimental=True,
)
@classmethod
def execute(cls, driving_track_data, object_indices, sort_by, replacement_mode, ref_track_data=None):
def _prep(td):
masks_bool = _unpack(td)
if sort_by != "none" and masks_bool is not None:
cx, area = _first_frame_cx_area(masks_bool)
if sort_by == "left_to_right":
order = sorted(range(len(cx)), key=lambda i: cx[i])
else: # "area"
order = sorted(range(len(area)), key=lambda i: -area[i])
td = _subset_track_data(td, order)
if object_indices.strip():
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
packed = td.get("packed_masks")
n_obj = packed.shape[1] if packed is not None else 0
indices = [i for i in indices if 0 <= i < n_obj]
td = _subset_track_data(td, indices)
return td
drv = _prep(driving_track_data)
mask_video = _render_colored_masks(drv, "white" if replacement_mode else "black")
if ref_track_data is not None:
ref = _prep(ref_track_data)
reference_image_mask = _render_colored_masks(ref, "black")
else:
H, W = drv["orig_size"]
reference_image_mask = torch.zeros(1, H, W, 3, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return io.NodeOutput(mask_video, reference_image_mask)
class SCAILExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
WanSCAILToVideo,
SCAIL2ColoredMask,
]
async def comfy_entrypoint() -> SCAILExtension:
return SCAILExtension()

File diff suppressed because it is too large Load Diff

View File

@ -1456,63 +1456,6 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
class WanSCAILToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanSCAILToVideo",
category="model/conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("reference_image", optional=True),
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."),
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
],
is_experimental=True,
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
ref_latent = None
if reference_image is not None:
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
ref_latent = vae.encode(reference_image[:, :, :, :3])
if ref_latent is not None:
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True)
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
if pose_video is not None:
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class WanExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -1533,7 +1476,6 @@ class WanExtension(ComfyExtension):
WanAnimateToVideo,
Wan22ImageToVideoLatent,
WanInfiniteTalkToVideo,
WanSCAILToVideo,
]
async def comfy_entrypoint() -> WanExtension:

15
main.py
View File

@ -26,6 +26,7 @@ import utils.extra_config
from utils.mime_types import init_mime_types
import faulthandler
import logging
import signal
import sys
from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
@ -37,7 +38,19 @@ if __name__ == "__main__":
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
os.environ['DO_NOT_TRACK'] = '1'
faulthandler.enable(file=sys.stderr, all_threads=False)
faulthandler.enable(file=sys.stderr, all_threads=args.debug_hang)
if __name__ == "__main__" and args.debug_hang:
dumping_traceback = False
def dump_traceback_on_sigint(signum, frame):
global dumping_traceback
if dumping_traceback:
raise KeyboardInterrupt
dumping_traceback = True
faulthandler.dump_traceback(file=sys.stderr, all_threads=True)
raise KeyboardInterrupt
signal.signal(signal.SIGINT, dump_traceback_on_sigint)
import comfy_aimdo.control

View File

@ -47,18 +47,14 @@ import node_helpers
if args.enable_manager:
import comfyui_manager
def before_node_execution():
comfy.model_management.throw_exception_if_processing_interrupted()
def interrupt_processing(value=True):
comfy.model_management.interrupt_current_processing(value)
MAX_RESOLUTION=16384
class CLIPTextEncode(ComfyNodeABC):
@classmethod
def INPUT_TYPES(s) -> InputTypeDict:
@ -327,8 +323,8 @@ class VAEDecodeTiled:
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32, "advanced": True}),
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}),
"temporal_size": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time. SeedVR2 allows 0 to disable temporal slicing.", "advanced": True}),
"temporal_overlap": ("INT", {"default": 8, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}),
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time.", "advanced": True}),
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
@ -338,32 +334,18 @@ class VAEDecodeTiled:
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
if tile_size < overlap * 4:
overlap = tile_size // 4
if temporal_size < temporal_overlap * 2:
temporal_overlap = temporal_overlap // 2
temporal_compression = vae.temporal_compression_decode()
if temporal_compression is not None:
if temporal_size <= 0:
temporal_size = 0
temporal_overlap = 0
else:
requested_temporal_overlap = temporal_overlap
if temporal_size < temporal_overlap * 2:
temporal_overlap = temporal_overlap // 2
temporal_size = max(2, temporal_size // temporal_compression)
temporal_overlap = min(temporal_size // 2, temporal_overlap // temporal_compression)
if requested_temporal_overlap > 0:
temporal_overlap = max(1, temporal_overlap)
temporal_size = max(2, temporal_size // temporal_compression)
temporal_overlap = max(1, min(temporal_size // 2, temporal_overlap // temporal_compression))
else:
temporal_size = None
temporal_overlap = None
compression = vae.spacial_compression_decode()
images = vae.decode_tiled(
samples["samples"],
tile_x=tile_size // compression,
tile_y=tile_size // compression,
overlap=overlap // compression,
tile_t=temporal_size,
overlap_t=temporal_overlap,
)
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap)
if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return (images, )
@ -380,7 +362,7 @@ class VAEEncode:
def encode(self, vae, pixels):
t = vae.encode(pixels)
return ({"samples": t}, )
return ({"samples":t}, )
class VAEEncodeTiled:
@classmethod
@ -388,8 +370,8 @@ class VAEEncodeTiled:
return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64, "advanced": True}),
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}),
"temporal_size": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time. SeedVR2 allows 0 to disable temporal slicing.", "advanced": True}),
"temporal_overlap": ("INT", {"default": 8, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}),
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time.", "advanced": True}),
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}),
}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
@ -397,9 +379,6 @@ class VAEEncodeTiled:
CATEGORY = "experimental"
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
if temporal_size <= 0:
temporal_size = 0
temporal_overlap = 0
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
return ({"samples": t}, )
@ -2439,7 +2418,6 @@ async def init_builtin_extra_nodes():
"nodes_camera_trajectory.py",
"nodes_edit_model.py",
"nodes_tcfg.py",
"nodes_seedvr.py",
"nodes_context_windows.py",
"nodes_qwen.py",
"nodes_chroma_radiance.py",
@ -2472,6 +2450,7 @@ async def init_builtin_extra_nodes():
"nodes_rtdetr.py",
"nodes_frame_interpolation.py",
"nodes_sam3.py",
"nodes_scail.py",
"nodes_void.py",
"nodes_wandancer.py",
"nodes_hidream_o1.py",

View File

@ -3,11 +3,6 @@ components:
Asset:
description: Represents a user-owned asset (image, video, or other generated output).
properties:
asset_hash:
deprecated: true
description: 'Deprecated: use hash instead. Blake3 hash of the asset content.'
pattern: ^blake3:[a-f0-9]{64}$
type: string
created_at:
description: Timestamp when the asset was created
format: date-time
@ -16,8 +11,12 @@ components:
description: Display name of the asset. Mirrors name for backwards compatibility.
nullable: true
type: string
file_path:
description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors")
nullable: true
type: string
hash:
description: Blake3 hash of the asset content. Preferred over asset_hash.
description: Blake3 hash of the asset content.
pattern: ^blake3:[a-f0-9]{64}$
type: string
id:
@ -139,17 +138,16 @@ components:
AssetUpdated:
description: Response returned when an existing asset is successfully updated.
properties:
asset_hash:
deprecated: true
description: 'Deprecated: use hash instead. Blake3 hash of the asset content.'
pattern: ^blake3:[a-f0-9]{64}$
type: string
display_name:
description: Display name of the asset. Mirrors name for backwards compatibility.
nullable: true
type: string
file_path:
description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors")
nullable: true
type: string
hash:
description: Blake3 hash of the asset content. Preferred over asset_hash.
description: Blake3 hash of the asset content.
pattern: ^blake3:[a-f0-9]{64}$
type: string
id:
@ -828,7 +826,11 @@ components:
type: string
type: object
PaginationInfo:
description: Offset/limit-based pagination metadata included in list responses.
description: |
Pagination metadata included in list responses. Supports both legacy
offset/limit pagination and cursor-based pagination. When cursor-based
pagination is used, `next_cursor` is the primary pagination token and
`offset`/`total` may be zero.
properties:
has_more:
description: Whether more items are available beyond this page
@ -837,12 +839,19 @@ components:
description: Items per page
minimum: 1
type: integer
next_cursor:
description: |
Opaque cursor for the next page. Pass this value as the `after`
query parameter on the next request. Empty or absent when there
are no more results.
type: string
offset:
description: Current offset (0-based)
deprecated: true
description: 'Current offset (0-based). Deprecated: use cursor-based pagination.'
minimum: 0
type: integer
total:
description: Total number of items matching filters
description: Total number of items matching filters (may be 0 when using cursor pagination)
minimum: 0
type: integer
required:
@ -1518,17 +1527,11 @@ paths:
schema:
default: true
type: boolean
- description: Filter assets by exact content hash. Preferred over asset_hash.
- description: Filter assets by exact content hash.
in: query
name: hash
schema:
type: string
- deprecated: true
description: 'Deprecated: use hash instead. Filter assets by exact content hash.'
in: query
name: asset_hash
schema:
type: string
- description: |
Opaque cursor for keyset pagination. Pass the `next_cursor` value
from the previous response to fetch the next page. When provided,
@ -1571,42 +1574,12 @@ paths:
- file
post:
description: |
Uploads a new asset to the system with associated metadata.
Supports two upload methods:
1. Direct file upload (multipart/form-data)
2. URL-based upload (application/json with source: "url")
Creates a new asset from a direct file upload (multipart/form-data) with associated metadata.
If an asset with the same hash already exists, returns the existing asset.
operationId: uploadAsset
operationId: createAsset
requestBody:
content:
application/json:
schema:
properties:
name:
description: Display name for the asset (used to determine file extension)
type: string
preview_id:
description: Optional preview asset ID
format: uuid
type: string
tags:
description: Freeform tags for the asset. Common types include "models", "input", "output", and "temp", but any tag can be used in any order.
items:
type: string
type: array
url:
description: HTTP/HTTPS URL to download the asset from
format: uri
type: string
user_metadata:
additionalProperties: true
description: Custom metadata to store with the asset
type: object
required:
- url
- name
type: object
multipart/form-data:
schema:
properties:
@ -1614,6 +1587,10 @@ paths:
description: The asset file to upload
format: binary
type: string
hash:
description: Content hash of the file.
pattern: ^(blake3|sha256):[a-f0-9]{64}$
type: string
id:
description: Optional asset ID for idempotent creation. If provided and asset exists, returns existing asset.
format: uuid
@ -1629,10 +1606,8 @@ paths:
format: uuid
type: string
tags:
description: Freeform tags for the asset. Common types include "models", "input", "output", and "temp", but any tag can be used in any order.
items:
type: string
type: array
description: JSON-encoded array of freeform tag strings, e.g. '["models","checkpoint"]'. Common types include "models", "input", "output", and "temp", but any tag can be used in any order.
type: string
user_metadata:
description: Custom JSON metadata as a string
type: string
@ -1641,36 +1616,32 @@ paths:
type: object
required: true
responses:
"200":
content:
application/json:
schema:
$ref: '#/components/schemas/AssetCreated'
description: |
Asset already existed for this user (deduplicated by content hash); the
existing asset is returned with created_new=false.
"201":
content:
application/json:
schema:
$ref: '#/components/schemas/AssetCreated'
description: Asset created successfully
description: Asset created successfully (created_new=true)
"400":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Invalid request (bad file, invalid URL, invalid content type, etc.)
description: Invalid request (bad file, invalid content type, etc.)
"401":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Unauthorized
"403":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Source URL requires authentication or access denied
"404":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Source URL not found
"413":
content:
application/json:
@ -1683,19 +1654,13 @@ paths:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Unsupported media type
"422":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Download failed due to network error or timeout
"500":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Internal server error
summary: Upload a new asset
summary: Create a new asset
tags:
- file
/api/assets/{id}:
@ -1730,7 +1695,7 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Asset cannot be deleted because it is referenced by another resource (e.g., workflow version)
description: 'Asset cannot be deleted because it is referenced by another resource, e.g. a workflow version (error code: ASSET_IN_USE)'
"500":
content:
application/json:
@ -1783,7 +1748,7 @@ paths:
description: |
Updates an asset's metadata. At least one field must be provided.
Only name, mime_type, preview_id, and user_metadata can be updated.
For tag management, use the dedicated PUT /api/assets/{id}/tags endpoint.
For tag management, use POST (add) and DELETE (remove) /api/assets/{id}/tags.
operationId: updateAsset
parameters:
- description: Asset ID
@ -1982,76 +1947,6 @@ paths:
summary: Add tags to asset
tags:
- file
put:
description: Adds and removes tags from an asset in a single operation
operationId: updateAssetTags
parameters:
- description: Asset ID
in: path
name: id
required: true
schema:
format: uuid
type: string
requestBody:
content:
application/json:
schema:
description: At least one of add or remove must contain items. Empty arrays are allowed when the other array has items.
minProperties: 1
properties:
add:
description: Tags to add to the asset. Can be empty if remove has items.
items:
type: string
type: array
remove:
description: Tags to remove from the asset. Can be empty if add has items.
items:
type: string
type: array
type: object
required: true
responses:
"200":
content:
application/json:
schema:
$ref: '#/components/schemas/TagsModificationResponse'
description: Tags updated successfully
"400":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Invalid request
"401":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Unauthorized
"404":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Asset not found
"422":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Reserved tag validation error
"500":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Internal server error
summary: Update asset tags
tags:
- file
/api/assets/from-hash:
post:
description: |
@ -2090,12 +1985,20 @@ paths:
type: object
required: true
responses:
"200":
content:
application/json:
schema:
$ref: '#/components/schemas/AssetCreated'
description: |
Asset reference already existed for this user (deduplicated by content
hash); the existing asset is returned with created_new=false.
"201":
content:
application/json:
schema:
$ref: '#/components/schemas/AssetCreated'
description: Asset reference created successfully
description: Asset reference created successfully (created_new=true)
"400":
content:
application/json:
@ -2887,7 +2790,21 @@ paths:
- asc
- desc
type: string
- description: Pagination offset (0-based)
- description: |
Opaque cursor for keyset pagination. Pass the `next_cursor` value
from a previous response to fetch the next page.
Cursor pagination is supported only when `sort_by=create_time`
(default). If `sort_by=execution_time`, `after` is ignored and
offset/limit pagination is used.
Cursors are opaque base64url payloads — clients should treat them
as strings and not parse the contents.
example: eyJzIjoiY3JlYXRlX3RpbWUiLCJ2IjoiMTcxNjIwMDAwMDAwMDAwMCIsImlkIjoiYTFiMmMzZDQtZTVmNi03YTg5LWIwYzEtZDJlM2Y0YTViNmM3In0
in: query
name: after
schema:
type: string
- deprecated: true
description: 'Pagination offset (0-based). Deprecated: prefer cursor-based pagination via `after`.'
in: query
name: offset
schema:
@ -2909,6 +2826,12 @@ paths:
schema:
$ref: '#/components/schemas/JobsListResponse'
description: Success - Jobs retrieved
"400":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Bad request (e.g. malformed pagination cursor).
"401":
content:
application/json:

View File

@ -1253,6 +1253,15 @@ class PromptServer():
if verbose:
logging.info("Starting server\n")
if args.debug_hang:
logging.info(
f"{'-' * 80}\n"
"ComfyUI has been started in debug-hang mode. Run your workflow as normal up to\n"
"the point of the hang or freeze, then use ctrl-C in the cmd or controlling\n"
"terminal to dump the python backtraces for debugging. Please attach the extra\n"
"debug info to your bug report.\n"
f"{'-' * 80}"
)
for addr in addresses:
address = addr[0]
port = addr[1]

View File

@ -0,0 +1,86 @@
"""Tests for the image_dimensions service."""
from __future__ import annotations
from pathlib import Path
import pytest
from PIL import Image
from app.assets.services.image_dimensions import extract_image_dimensions
def _make_png(path: Path, size: tuple[int, int]) -> Path:
img = Image.new("RGB", size, color=(123, 45, 67))
img.save(path, format="PNG")
return path
def _make_jpeg(path: Path, size: tuple[int, int]) -> Path:
img = Image.new("RGB", size, color=(10, 20, 30))
img.save(path, format="JPEG", quality=80)
return path
class TestExtractImageDimensions:
def test_extracts_png_dimensions(self, tmp_path: Path):
f = _make_png(tmp_path / "rect.png", (320, 240))
result = extract_image_dimensions(str(f), mime_type="image/png")
assert result == {"kind": "image", "width": 320, "height": 240}
def test_extracts_jpeg_dimensions(self, tmp_path: Path):
f = _make_jpeg(tmp_path / "shot.jpg", (1920, 1080))
result = extract_image_dimensions(str(f), mime_type="image/jpeg")
assert result == {"kind": "image", "width": 1920, "height": 1080}
def test_works_when_mime_type_is_none(self, tmp_path: Path):
f = _make_png(tmp_path / "no_mime.png", (50, 100))
result = extract_image_dimensions(str(f), mime_type=None)
assert result == {"kind": "image", "width": 50, "height": 100}
def test_skips_non_image_mime_without_touching_file(self, tmp_path: Path):
# Path doesn't need to exist — non-image MIME short-circuits.
result = extract_image_dimensions(
str(tmp_path / "model.safetensors"),
mime_type="application/octet-stream",
)
assert result is None
@pytest.mark.parametrize(
"mime",
["application/json", "text/plain", "video/mp4", "audio/mpeg"],
)
def test_skips_all_non_image_mime_types(self, tmp_path: Path, mime: str):
f = tmp_path / "file.bin"
f.write_bytes(b"\x00\x01\x02")
assert extract_image_dimensions(str(f), mime_type=mime) is None
def test_returns_none_for_missing_file(self, tmp_path: Path):
result = extract_image_dimensions(
str(tmp_path / "does_not_exist.png"), mime_type="image/png"
)
assert result is None
def test_returns_none_for_corrupt_image(self, tmp_path: Path):
f = tmp_path / "corrupt.png"
f.write_bytes(b"not actually a png file")
result = extract_image_dimensions(str(f), mime_type="image/png")
assert result is None
def test_returns_none_for_empty_file(self, tmp_path: Path):
f = tmp_path / "empty.png"
f.write_bytes(b"")
result = extract_image_dimensions(str(f), mime_type="image/png")
assert result is None

View File

@ -4,10 +4,12 @@ from pathlib import Path
from unittest.mock import patch
import pytest
from PIL import Image
from sqlalchemy.orm import Session as SASession, Session
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag
from app.assets.database.queries import get_reference_tags
from app.assets.helpers import get_utc_now
from app.assets.services.ingest import (
_ingest_file_from_path,
_register_existing_asset,
@ -15,6 +17,11 @@ from app.assets.services.ingest import (
)
def _make_png(path: Path, size: tuple[int, int]) -> Path:
Image.new("RGB", size, color=(80, 120, 200)).save(path, format="PNG")
return path
class TestIngestFileFromPath:
def test_creates_asset_and_reference(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "test_file.bin"
@ -279,4 +286,203 @@ class TestIngestExistingFileTagFK:
ref_tags = sess.query(AssetReferenceTag).all()
ref_tag_names = {rt.tag_name for rt in ref_tags}
assert "output" in ref_tag_names
assert "my-job" in ref_tag_names
class TestIngestImageDimensions:
"""system_metadata should carry {kind, width, height} for image assets."""
def test_image_asset_emits_dimensions(
self, mock_create_session, temp_dir: Path, session: Session
):
f = _make_png(temp_dir / "shot.png", (640, 480))
result = _ingest_file_from_path(
abs_path=str(f),
asset_hash="blake3:img1",
size_bytes=f.stat().st_size,
mtime_ns=1234567890000000000,
mime_type="image/png",
)
ref = session.query(AssetReference).filter_by(id=result.reference_id).first()
assert ref.system_metadata == {
"kind": "image",
"width": 640,
"height": 480,
}
def test_non_image_asset_leaves_system_metadata_empty(
self, mock_create_session, temp_dir: Path, session: Session
):
f = temp_dir / "model.safetensors"
f.write_bytes(b"not an image")
result = _ingest_file_from_path(
abs_path=str(f),
asset_hash="blake3:safetensors1",
size_bytes=f.stat().st_size,
mtime_ns=1234567890000000000,
mime_type="application/octet-stream",
)
ref = session.query(AssetReference).filter_by(id=result.reference_id).first()
assert ref.system_metadata in (None, {})
def test_preserves_existing_system_metadata_keys(
self, mock_create_session, temp_dir: Path, session: Session
):
f = _make_png(temp_dir / "annotated.png", (100, 200))
# First pass populates a sentinel system_metadata key (simulating prior
# enricher write).
result = _ingest_file_from_path(
abs_path=str(f),
asset_hash="blake3:img-merge",
size_bytes=f.stat().st_size,
mtime_ns=1234567890000000000,
mime_type="image/png",
)
ref = session.query(AssetReference).filter_by(id=result.reference_id).first()
ref.system_metadata = {**(ref.system_metadata or {}), "source_url": "https://example/x.png"}
session.commit()
# Second pass with the same path triggers the merge code path again.
_ingest_file_from_path(
abs_path=str(f),
asset_hash="blake3:img-merge",
size_bytes=f.stat().st_size,
mtime_ns=1234567890000000001,
mime_type="image/png",
)
session.refresh(ref)
assert ref.system_metadata["kind"] == "image"
assert ref.system_metadata["width"] == 100
assert ref.system_metadata["height"] == 200
assert ref.system_metadata["source_url"] == "https://example/x.png"
class TestRegisterExistingAssetBackfill:
"""The from-hash path back-fills dimensions from a sibling reference."""
def _add_reference(
self,
session: Session,
asset: Asset,
name: str,
system_metadata: dict | None = None,
) -> AssetReference:
now = get_utc_now()
ref = AssetReference(
asset_id=asset.id,
name=name,
owner_id="",
created_at=now,
updated_at=now,
last_access_time=now,
system_metadata=system_metadata or {},
)
session.add(ref)
session.flush()
return ref
def test_backfills_dimensions_from_sibling_image_reference(
self, mock_create_session, session: Session
):
asset = Asset(hash="blake3:shared", size_bytes=2048, mime_type="image/png")
session.add(asset)
session.flush()
self._add_reference(
session,
asset,
name="original.png",
system_metadata={"kind": "image", "width": 800, "height": 600},
)
session.commit()
result = _register_existing_asset(
asset_hash="blake3:shared",
name="from_hash.png",
owner_id="user-x",
)
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
assert ref.system_metadata.get("kind") == "image"
assert ref.system_metadata.get("width") == 800
assert ref.system_metadata.get("height") == 600
def test_no_backfill_when_sibling_has_no_image_metadata(
self, mock_create_session, session: Session
):
asset = Asset(hash="blake3:nodims", size_bytes=2048, mime_type="image/png")
session.add(asset)
session.flush()
self._add_reference(
session,
asset,
name="original.png",
system_metadata={"base_model": "flux"}, # no kind=image
)
session.commit()
result = _register_existing_asset(
asset_hash="blake3:nodims",
name="from_hash.png",
owner_id="user-x",
)
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
meta = ref.system_metadata or {}
assert "kind" not in meta
assert "width" not in meta
assert "height" not in meta
def test_no_backfill_when_no_sibling_exists(
self, mock_create_session, session: Session
):
asset = Asset(hash="blake3:lonely", size_bytes=1024, mime_type="image/png")
session.add(asset)
session.commit()
result = _register_existing_asset(
asset_hash="blake3:lonely",
name="solo.png",
owner_id="user-x",
)
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
assert ref.system_metadata in (None, {})
def test_backfill_preserves_caller_supplied_keys(
self, mock_create_session, session: Session
):
asset = Asset(hash="blake3:preserve", size_bytes=2048, mime_type="image/png")
session.add(asset)
session.flush()
self._add_reference(
session,
asset,
name="original.png",
system_metadata={"kind": "image", "width": 1024, "height": 768},
)
session.commit()
# Simulate a from-hash path where the new reference already carries
# some system_metadata (e.g. a download-provenance source_url written
# by an earlier step). The back-fill must merge dim keys without
# clobbering existing keys.
result = _register_existing_asset(
asset_hash="blake3:preserve",
name="from_hash.png",
owner_id="user-x",
)
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
# Seed a sentinel key and re-run back-fill via a second register call
# to exercise the merge path with pre-existing data.
ref.system_metadata = {**(ref.system_metadata or {}), "source_url": "https://example/p"}
session.commit()
assert ref.system_metadata.get("source_url") == "https://example/p"
assert ref.system_metadata.get("kind") == "image"
assert ref.system_metadata.get("width") == 1024
assert ref.system_metadata.get("height") == 768

View File

@ -1,213 +0,0 @@
"""Consolidated SeedVR2 conditioning and refactor regression tests.
Merges the prior test_seedvr2_refactor_nodes.py and
test_seedvr_conditioning_hardening.py modules. Refactor tests use the
top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests
use _import_nodes_seedvr_isolated() for sys.modules isolation when
mocking comfy.model_management.
"""
import importlib
import sys
from unittest.mock import MagicMock
import pytest
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
_SENTINEL = object()
_TARGETS = (
("comfy.model_management", "comfy"),
("comfy_extras.nodes_seedvr", "comfy_extras"),
)
def _import_nodes_seedvr_isolated():
"""Import comfy_extras.nodes_seedvr with comfy.model_management mocked."""
priors = []
for mod_name, parent_name in _TARGETS:
prior_mod = sys.modules.get(mod_name, _SENTINEL)
parent = sys.modules.get(parent_name)
attr = mod_name.split(".")[-1]
prior_attr = (
getattr(parent, attr, _SENTINEL) if parent is not None else _SENTINEL
)
priors.append((mod_name, parent_name, attr, prior_mod, prior_attr))
mock_mm = MagicMock()
for fn in (
"xformers_enabled", "xformers_enabled_vae",
"pytorch_attention_enabled", "pytorch_attention_enabled_vae",
"sage_attention_enabled", "flash_attention_enabled",
"is_intel_xpu",
):
getattr(mock_mm, fn).return_value = False
tv = torch.version.__version__.split(".")
mock_mm.torch_version_numeric = (int(tv[0]), int(tv[1]))
mock_mm.WINDOWS = False
sys.modules["comfy.model_management"] = mock_mm
if sys.modules.get("comfy") is None:
import comfy as _comfy_pkg # noqa: F401
comfy_pkg = sys.modules.get("comfy")
if comfy_pkg is not None:
setattr(comfy_pkg, "model_management", mock_mm)
nodes_seedvr = sys.modules.get("comfy_extras.nodes_seedvr") or (
importlib.import_module("comfy_extras.nodes_seedvr")
)
def _restore():
for mod_name, parent_name, attr, prior_mod, prior_attr in priors:
if prior_mod is _SENTINEL:
sys.modules.pop(mod_name, None)
else:
sys.modules[mod_name] = prior_mod
parent = sys.modules.get(parent_name)
if parent is None:
continue
if prior_attr is _SENTINEL:
if hasattr(parent, attr):
delattr(parent, attr)
else:
setattr(parent, attr, prior_attr)
return nodes_seedvr, _restore
class _Rope(nn.Module):
"""Minimal RoPE stub exposing a `freqs` parameter."""
def __init__(self):
super().__init__()
self.freqs = nn.Parameter(torch.zeros(4))
class _Block(nn.Module):
"""Minimal transformer block stub holding a `_Rope`."""
def __init__(self):
super().__init__()
self.rope = _Rope()
class _DiffusionModel(nn.Module):
"""Stub diffusion model with N blocks and pos/neg conditioning buffers."""
def __init__(self, n_blocks=3, zero_conditioning=False, conditioning_dtype=torch.float32):
super().__init__()
self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)])
pos = torch.zeros if zero_conditioning else torch.ones
self.register_buffer("positive_conditioning", pos((2, 4), dtype=conditioning_dtype))
self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype))
class _ModelInner:
"""Inner model wrapper exposing `.diffusion_model`."""
def __init__(self, diffusion_model):
self.diffusion_model = diffusion_model
class _ModelPatcher:
"""ModelPatcher stub exposing `.model._ModelInner`."""
def __init__(self, diffusion_model):
self.model = _ModelInner(diffusion_model)
def test_seedvr2_conditioning_schema_exposes_model_passthrough_output():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
schema = nodes_seedvr.SeedVR2Conditioning.define_schema()
assert [input_item.id for input_item in schema.inputs] == [
"model",
"vae_conditioning",
]
assert schema.inputs[1].display_name == "latent"
assert [output.display_name for output in schema.outputs] == [
"model",
"positive",
"negative",
"latent",
]
finally:
restore()
def test_seedvr2_conditioning_returns_packed_input_latent_deterministically():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
diffusion_model = _DiffusionModel()
patcher = _ModelPatcher(diffusion_model)
samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2)
vae_conditioning = {"samples": samples}
_, first_positive, first_negative, first_latent = (
nodes_seedvr.SeedVR2Conditioning.execute(
patcher,
vae_conditioning,
)
)
_, second_positive, second_negative, second_latent = (
nodes_seedvr.SeedVR2Conditioning.execute(
patcher,
vae_conditioning,
)
)
expected_latent = samples.reshape(1, 6, 2, 2)
channel_last = samples.movedim(1, -1).contiguous()
expected_condition = torch.cat(
[
channel_last,
torch.ones((*channel_last.shape[:-1], 1)),
],
dim=-1,
).movedim(-1, 1).reshape(1, 9, 2, 2)
assert torch.equal(first_latent["samples"], expected_latent)
assert torch.equal(second_latent["samples"], expected_latent)
assert torch.equal(
first_positive[0][1]["condition"],
expected_condition,
)
assert torch.equal(
second_positive[0][1]["condition"],
expected_condition,
)
assert torch.equal(
first_negative[0][1]["condition"],
expected_condition,
)
assert torch.equal(
second_negative[0][1]["condition"],
expected_condition,
)
finally:
restore()
def test_seedvr2_conditioning_fails_loud_on_zero_buffers():
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
try:
diffusion_model = _DiffusionModel(zero_conditioning=True)
patcher = _ModelPatcher(diffusion_model)
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
with pytest.raises(RuntimeError) as excinfo:
nodes_seedvr.SeedVR2Conditioning.execute(
patcher, vae_conditioning,
)
message = str(excinfo.value)
assert message.startswith(
nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX
), (
"Fail-loud message must use the standard "
"_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers "
f"can match it. Got: {message!r}"
)
assert "positive_conditioning" in message
assert "negative_conditioning" in message
finally:
restore()

View File

@ -1,55 +0,0 @@
import importlib
import inspect
import sys
from unittest.mock import MagicMock, patch
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
def test_seedvr_node_signature_matches_schema():
mock_mm = MagicMock()
mock_mm.xformers_enabled.return_value = False
mock_mm.xformers_enabled_vae.return_value = False
mock_mm.sage_attention_enabled.return_value = False
mock_mm.flash_attention_enabled.return_value = False
sentinel = object()
prior_cpu = cli_args.cpu
cli_args.cpu = True
prior_module = sys.modules.get("comfy_extras.nodes_seedvr", sentinel)
comfy_pkg = sys.modules.get("comfy")
prior_mm_attr = getattr(comfy_pkg, "model_management", sentinel) if comfy_pkg else sentinel
with patch.dict(sys.modules, {"comfy.model_management": mock_mm}):
if comfy_pkg is not None:
setattr(comfy_pkg, "model_management", mock_mm)
sys.modules.pop("comfy_extras.nodes_seedvr", None)
try:
nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr")
for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning, nodes_seedvr.SeedVR2ProgressiveSampler):
schema_ids = [i.id for i in node_cls.define_schema().inputs]
exec_params = [
p for p in inspect.signature(node_cls.execute).parameters.keys()
if p != "cls"
]
assert schema_ids == exec_params, (
f"{node_cls.__name__} schema/execute drift: "
f"schema_ids={schema_ids}, exec_params={exec_params}"
)
finally:
cli_args.cpu = prior_cpu
if prior_module is sentinel:
sys.modules.pop("comfy_extras.nodes_seedvr", None)
else:
sys.modules["comfy_extras.nodes_seedvr"] = prior_module
if comfy_pkg is not None:
if prior_mm_attr is sentinel:
if hasattr(comfy_pkg, "model_management"):
delattr(comfy_pkg, "model_management")
else:
setattr(comfy_pkg, "model_management", prior_mm_attr)

View File

@ -1,57 +0,0 @@
from unittest.mock import patch
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
from comfy_extras import nodes_seedvr # noqa: E402
def _schema_ids(items):
return [item.id for item in items]
def test_seedvr2_post_processing_schema():
schema = nodes_seedvr.SeedVR2PostProcessing.define_schema()
assert _schema_ids(schema.inputs) == ["images", "original_resized_images", "color_correction_method"]
assert schema.inputs[2].options == ["lab", "wavelet", "adain", "none"]
assert schema.inputs[2].default == "lab"
assert schema.outputs[0].get_io_type() == "IMAGE"
def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch):
decoded = torch.full((1, 3, 4, 4), 0.25)
reference = torch.full((1, 3, 4, 4), 0.75)
def _lab(content, style):
raise torch.cuda.OutOfMemoryError("CUDA out of memory")
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None)
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
try:
nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked(
decoded, reference, torch.device("cpu"), "lab",
)
except RuntimeError as exc:
assert "color_correction_method=lab" in str(exc)
assert " method=lab" not in str(exc)
else:
raise AssertionError("expected RuntimeError for one-frame LAB OOM")
def test_seedvr2_post_processing_unknown_color_correction_method_raises():
decoded = torch.zeros(1, 2, 4, 4, 3)
original = torch.zeros(1, 2, 4, 4, 3)
try:
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus")
except ValueError as exc:
assert "color_correction_method" in str(exc)
else:
raise AssertionError("expected ValueError for unknown color_correction_method")

View File

@ -73,24 +73,6 @@ def _make_flux_schnell_comfyui_sd():
return sd
def _make_seedvr2_7b_separate_mm_sd():
return {
"blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072),
}
def _make_seedvr2_7b_shared_mm_sd():
return {
"blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
}
def _make_seedvr2_3b_shared_mm_sd():
return {
"blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
}
class TestModelDetection:
"""Verify that first-match model detection selects the correct model
based on list ordering and unet_config specificity."""
@ -143,48 +125,6 @@ class TestModelDetection:
assert model_config is not None
assert type(model_config).__name__ == "FluxSchnell"
def test_seedvr2_7b_separate_mm_detection_config(self):
sd = _make_seedvr2_7b_separate_mm_sd()
unet_config = detect_unet_config(sd, "")
assert unet_config is not None
assert unet_config["image_model"] == "seedvr2"
assert unet_config["vid_dim"] == 3072
assert unet_config["heads"] == 24
assert unet_config["num_layers"] == 36
assert unet_config["mm_layers"] == 36
assert unet_config["mlp_type"] == "normal"
assert unet_config["qk_rope"] is True
assert unet_config["rope_type"] == "rope3d"
assert unet_config["rope_dim"] == 64
def test_seedvr2_7b_shared_mm_detection_config(self):
sd = _make_seedvr2_7b_shared_mm_sd()
unet_config = detect_unet_config(sd, "")
assert unet_config is not None
assert unet_config["image_model"] == "seedvr2"
assert unet_config["vid_dim"] == 3072
assert unet_config["heads"] == 24
assert unet_config["num_layers"] == 36
assert unet_config["mm_layers"] == 10
assert unet_config["mlp_type"] == "swiglu"
assert unet_config["qk_rope"] is True
assert unet_config["rope_type"] == "rope3d"
assert unet_config["rope_dim"] == 64
def test_seedvr2_3b_shared_mm_detection_config(self):
sd = _make_seedvr2_3b_shared_mm_sd()
unet_config = detect_unet_config(sd, "")
assert unet_config is not None
assert unet_config["image_model"] == "seedvr2"
assert unet_config["vid_dim"] == 2560
assert unet_config["heads"] == 20
assert unet_config["num_layers"] == 32
assert unet_config["mlp_type"] == "swiglu"
assert unet_config["qk_rope"] is None
def test_unet_config_and_required_keys_combination_is_unique(self):
"""Each model in the registry must have a unique combination of
``unet_config`` and ``required_keys``. If two models share the same

View File

@ -1,90 +0,0 @@
"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must
honor the actual tensor/tuple return contract of ``encode()`` and
``decode_()`` and must NOT dereference diffusers-style ``.latent_dist``
or ``.sample`` attributes on those returns.
The pre-fix body raised ``AttributeError: 'Tensor' object has no
attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and
``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'``
for ``mode == "decode"`` (the class only defines ``decode_`` with a
trailing underscore). The post-fix body unwraps the optional one-element
tuple shape that ``return_dict=False`` produces and returns the tensor
directly.
Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses
the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and
overrides ``encode``/``decode_`` with known tensors so the contract can
be probed without loading any real VAE weights.
"""
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402
_LATENT_SHAPE = (1, 16, 2, 2, 2)
_DECODED_SHAPE = (1, 3, 5, 16, 16)
_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16)
_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2)
class _StubVAE(VideoAutoencoderKL):
def __init__(self):
nn.Module.__init__(self)
self._encode_out = torch.zeros(*_LATENT_SHAPE)
self._decode_out = torch.zeros(*_DECODED_SHAPE)
def encode(self, x, return_dict=True):
return self._encode_out
def decode_(self, z, return_dict=True):
return self._decode_out
def test_forward_encode_returns_tensor():
vae = _StubVAE()
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
result = vae.forward(x, mode="encode")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_LATENT_SHAPE)
def test_forward_decode_returns_tensor():
vae = _StubVAE()
z = torch.zeros(*_INPUT_DECODE_SHAPE)
result = vae.forward(z, mode="decode")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_DECODED_SHAPE)
class _TupleReturningStubVAE(VideoAutoencoderKL):
"""Stub variant whose ``encode``/``decode_`` return the
``(tensor,)`` one-element tuple shape ``return_dict=False`` produces
in the parent class. Exercises the unwrap branch of
``VideoAutoencoderKL.forward``.
"""
def __init__(self):
nn.Module.__init__(self)
self._encode_tensor = torch.zeros(*_LATENT_SHAPE)
self._decode_tensor = torch.zeros(*_DECODED_SHAPE)
def encode(self, x, return_dict=True):
return (self._encode_tensor,)
def decode_(self, z, return_dict=True):
return (self._decode_tensor,)
def test_forward_all_unwraps_one_tuple_at_each_step():
vae = _TupleReturningStubVAE()
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
result = vae.forward(x, mode="all")
assert type(result) is torch.Tensor
assert result.shape == torch.Size(_DECODED_SHAPE)

View File

@ -1,47 +0,0 @@
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.sd
import comfy.supported_models
import comfy.ldm.seedvr.model as seedvr_model
def test_seedvr2_fp16_manual_cast_only_for_bf16_device(monkeypatch):
bf16_device = object()
fp16_device = object()
monkeypatch.setattr(
comfy.supported_models.comfy.model_management,
"should_use_bf16",
lambda device=None: device is bf16_device,
)
bf16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"})
bf16_config.set_inference_dtype(torch.float16, None, device=bf16_device)
assert bf16_config.manual_cast_dtype is torch.bfloat16
fp16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"})
fp16_config.set_inference_dtype(torch.float16, None, device=fp16_device)
assert fp16_config.manual_cast_dtype is None
def test_seedvr2_text_conditioning_accepts_cfg1_single_branch():
context = torch.arange(6, dtype=torch.float32).reshape(1, 3, 2)
txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0])
torch.testing.assert_close(txt, context.squeeze(0))
torch.testing.assert_close(txt_shape, torch.tensor([[3]], device=context.device))
def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer():
estimate = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 26, 120, 160))
old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2
assert estimate == 101 * 960 * 1280 * 160
assert estimate > 15 * 1024 ** 3
assert estimate > old_estimate * 100

View File

@ -1,341 +0,0 @@
"""Consolidated SeedVR2 internals regression tests.
Sources (all merged verbatim, helper names disambiguated where colliding):
* RoPE rewrite NaMMRotaryEmbedding3d.forward must match the legacy
apply_rotary_emb wrapper oracle at fp32.
* GroupNorm limit gate causal_norm_wrapper at vae.py:509 must compare
memory_occupy against get_norm_limit(), not float('inf').
* SeedVR2 variable-length attention split-loop contract.
Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and
comfy.ldm.modules.attention transitively pull in comfy.model_management,
which probes torch.cuda.current_device() at import time unless args.cpu is
set first.
"""
from __future__ import annotations
from unittest.mock import patch
import pytest
import torch
from comfy.cli_args import args
if not torch.cuda.is_available():
args.cpu = True
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
import comfy.ldm.modules.attention as attention # noqa: E402
import comfy.ops as comfy_ops # noqa: E402
from comfy.ldm.seedvr.model import ( # noqa: E402
Cache,
NaMMRotaryEmbedding3d,
)
from comfy.ldm.seedvr.vae import ( # noqa: E402
causal_norm_wrapper,
set_norm_limit,
)
from comfy.ldm.modules.attention import var_attention_optimized_split # noqa: E402
# ---------------------------------------------------------------------------
# RoPE rewrite tests (test_seedvr_rope_rewrite.py)
# ---------------------------------------------------------------------------
# Test rig dimensions. dim=192 → per-axis rope dim = 64 (even, lucidrains
# requirement). vid_shape=(2,4,4) → L_vid = 32. txt_shape=(8,) → L_txt = 8.
_DIM = 192
_HEADS = 4
_VID_T, _VID_H, _VID_W = 2, 4, 4
_TXT_L = 8
_L_VID = _VID_T * _VID_H * _VID_W
_SEED = 0
def _make_inputs(dtype=torch.float32, device="cpu"):
"""Construct the 6 forward inputs + cache. Deterministic via local
Generator so global RNG state is not mutated.
"""
g = torch.Generator(device=device).manual_seed(_SEED)
vid_q = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
vid_k = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
txt_q = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
txt_k = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long, device=device)
txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long, device=device)
cache = Cache(disable=True)
return vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache
def _legacy_get_freqs(rope: NaMMRotaryEmbedding3d, vid_shape, txt_shape):
"""Reproduce the pre-rewrite ``get_freqs`` body verbatim against
``self.get_axial_freqs`` (parent ``RotaryEmbeddingBase`` method,
unchanged by the rewrite).
"""
max_temporal = 0
max_height = 0
max_width = 0
max_txt_len = 0
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
max_temporal = max(max_temporal, l + f)
max_height = max(max_height, h)
max_width = max(max_width, w)
max_txt_len = max(max_txt_len, l)
with torch.amp.autocast(device_type="cuda", enabled=False):
vid_freqs_full = rope.get_axial_freqs(
min(max_temporal + 16, 1024),
min(max_height + 4, 128),
min(max_width + 4, 128),
).float()
txt_freqs_full = rope.get_axial_freqs(min(max_txt_len + 16, 1024))
vid_freq_list, txt_freq_list = [], []
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
vid_freq = vid_freqs_full[l : l + f, :h, :w].reshape(-1, vid_freqs_full.size(-1))
txt_freq = txt_freqs_full[:l].repeat(1, 3).reshape(-1, vid_freqs_full.size(-1))
vid_freq_list.append(vid_freq)
txt_freq_list.append(txt_freq)
return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0)
def _legacy_forward(rope: NaMMRotaryEmbedding3d, vid_q, vid_k, vid_shape,
txt_q, txt_k, txt_shape):
"""Compute expected forward output via the unchanged
``apply_rotary_emb`` wrapper fed with legacy-shape freqs. This is the
oracle. The wrapper itself is out of scope for the rewrite (Shape B).
"""
vid_freqs, txt_freqs = _legacy_get_freqs(rope, vid_shape, txt_shape)
vid_freqs = vid_freqs.to(vid_q.device)
txt_freqs = txt_freqs.to(txt_q.device)
from einops import rearrange
vid_q = rearrange(vid_q, "L h d -> h L d")
vid_k = rearrange(vid_k, "L h d -> h L d")
vid_q_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype)
vid_k_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype)
vid_q_out = rearrange(vid_q_out, "h L d -> L h d")
vid_k_out = rearrange(vid_k_out, "h L d -> L h d")
txt_q = rearrange(txt_q, "L h d -> h L d")
txt_k = rearrange(txt_k, "L h d -> h L d")
txt_q_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype)
txt_k_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype)
txt_q_out = rearrange(txt_q_out, "h L d -> L h d")
txt_k_out = rearrange(txt_k_out, "h L d -> L h d")
return vid_q_out, vid_k_out, txt_q_out, txt_k_out
def test_namm_forward_output_tensor_equal_against_legacy_oracle():
rope = NaMMRotaryEmbedding3d(dim=_DIM)
vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache = _make_inputs()
expected_vid_q, expected_vid_k, expected_txt_q, expected_txt_k = _legacy_forward(
rope,
vid_q.clone(), vid_k.clone(), vid_shape,
txt_q.clone(), txt_k.clone(), txt_shape,
)
actual_vid_q, actual_vid_k, actual_txt_q, actual_txt_k = rope.forward(
vid_q.clone(), vid_k.clone(), vid_shape,
txt_q.clone(), txt_k.clone(), txt_shape, cache,
)
torch.testing.assert_close(actual_vid_q, expected_vid_q, rtol=0, atol=0,
msg="vid_q output diverges from wrapper oracle")
torch.testing.assert_close(actual_vid_k, expected_vid_k, rtol=0, atol=0,
msg="vid_k output diverges from wrapper oracle")
torch.testing.assert_close(actual_txt_q, expected_txt_q, rtol=0, atol=0,
msg="txt_q output diverges from wrapper oracle")
torch.testing.assert_close(actual_txt_k, expected_txt_k, rtol=0, atol=0,
msg="txt_k output diverges from wrapper oracle")
# ---------------------------------------------------------------------------
# GroupNorm limit tests (test_seedvr_groupnorm_limit.py)
# ---------------------------------------------------------------------------
_NUM_CHANNELS = 8
_NUM_GROUPS = 4
_TENSOR_SHAPE = (1, 8, 2, 4, 4)
_GROUPNORM_SUBCLASSES = [
pytest.param(comfy_ops.disable_weight_init.GroupNorm, id="disable_weight_init"),
pytest.param(comfy_ops.manual_cast.GroupNorm, id="manual_cast"),
]
@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES)
def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls):
real_group_norm = vae_mod.F.group_norm
set_norm_limit(1e-9)
try:
gn = groupnorm_cls(num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS)
gn.eval()
forward_hook_calls = []
def _hook(module, inputs, output):
forward_hook_calls.append(tuple(inputs[0].shape))
spy_calls = []
def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs):
spy_calls.append({"num_groups": int(num_groups_arg)})
return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs)
handle = gn.register_forward_hook(_hook)
try:
with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy):
out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE))
finally:
handle.remove()
full_calls = len(forward_hook_calls)
chunked_calls = sum(1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS)
assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE
assert full_calls == 0, (
f"low-limit GroupNorm gate must NOT take the full-forward path; got full_calls={full_calls}"
)
assert chunked_calls > 0, (
f"low-limit GroupNorm gate must take the chunked path; got chunked_calls={chunked_calls}"
)
finally:
set_norm_limit(None)
# ---------------------------------------------------------------------------
# SeedVR2 var_attention split-loop tests
# ---------------------------------------------------------------------------
def test_var_attention_registry_contains_always_available_entries():
assert (
attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_optimized_split"]
is attention.var_attention_optimized_split
)
def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch):
dim = 8
heads = 2
head_dim = 4
attn = seedvr_model.NaSwinAttention(
vid_dim=dim,
txt_dim=dim,
heads=heads,
head_dim=head_dim,
qk_bias=False,
qk_norm=seedvr_model.CustomRMSNorm,
qk_norm_eps=1e-6,
rope_type=None,
rope_dim=head_dim,
shared_weights=False,
window=(2, 1, 1),
window_method="720pwin_by_size_bysize",
version=True,
device="cpu",
dtype=torch.float32,
operations=comfy_ops.disable_weight_init,
)
generator = torch.Generator(device="cpu").manual_seed(11)
vid = torch.randn(8, dim, generator=generator)
txt = torch.randn(3, dim, generator=generator)
vid_shape = torch.tensor([[2, 2, 2]], dtype=torch.long)
txt_shape = torch.tensor([[3]], dtype=torch.long)
calls = []
def fake_optimized_var_attention(**kwargs):
calls.append(kwargs)
return kwargs["q"]
monkeypatch.setattr(seedvr_model, "optimized_var_attention", fake_optimized_var_attention)
vid_out, txt_out = attn(vid, txt, vid_shape, txt_shape, seedvr_model.Cache(disable=True))
assert tuple(vid_out.shape) == (8, dim)
assert tuple(txt_out.shape) == (3, dim)
assert len(calls) == 1
call = calls[0]
assert tuple(call["q"].shape) == (14, heads, head_dim)
assert tuple(call["k"].shape) == (14, heads, head_dim)
assert tuple(call["v"].shape) == (14, heads, head_dim)
assert call["heads"] == heads
assert call["skip_reshape"] is True
assert call["skip_output_reshape"] is True
torch.testing.assert_close(
call["cu_seqlens_q"],
torch.tensor([0, 7, 14], dtype=torch.int32),
rtol=0,
atol=0,
)
torch.testing.assert_close(
call["cu_seqlens_k"],
torch.tensor([0, 7, 14], dtype=torch.int32),
rtol=0,
atol=0,
)
def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch):
heads = 2
head_dim = 3
q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim)
k = q + 100
v = q + 200
cu = torch.tensor([0, 2, 5], dtype=torch.int32)
calls = []
def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs):
calls.append(
{
"q_shape": tuple(q_arg.shape),
"k_shape": tuple(k_arg.shape),
"v_shape": tuple(v_arg.shape),
"heads": heads_arg,
"kwargs": kwargs,
}
)
return q_arg + v_arg
monkeypatch.setattr(attention, "optimized_attention", fake_optimized_attention)
out = var_attention_optimized_split(
q,
k,
v,
heads,
cu,
cu,
skip_reshape=True,
skip_output_reshape=True,
)
assert tuple(out.shape) == (5, heads, head_dim)
assert len(calls) == 2
assert calls[0]["q_shape"] == (1, heads, 2, head_dim)
assert calls[1]["q_shape"] == (1, heads, 3, head_dim)
assert all(call["heads"] == heads for call in calls)
assert all(call["kwargs"]["skip_reshape"] is True for call in calls)
assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls)
torch.testing.assert_close(out, q + v, rtol=0, atol=0)
def test_var_attention_optimized_split_rejects_bad_offsets():
q = torch.randn(5, 2, 3)
cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32)
cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32)
with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"):
var_attention_optimized_split(
q,
q,
q,
2,
cu_bad,
cu_ok,
skip_reshape=True,
skip_output_reshape=True,
)

View File

@ -1,308 +0,0 @@
"""Consolidated SeedVR2 model/graph/forward regression tests.
Merged from:
- seedvr_model_test.py
- test_seedvr_7b_final_block_text_path.py
- test_seedvr_forward_no_device_cast.py
- test_seedvr_latent_format.py
- test_seedvr2_vae_graph_boundaries.py
"""
from __future__ import annotations
from unittest.mock import MagicMock
import torch
from torch import nn
from comfy.cli_args import args
if not torch.cuda.is_available():
args.cpu = True
import comfy # noqa: E402
import comfy.latent_formats # noqa: E402
import comfy.ldm.seedvr.model # noqa: E402
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
import comfy.model_management # noqa: E402
import comfy.sample # noqa: E402
import comfy.sd as sd_mod # noqa: E402
import nodes as nodes_mod # noqa: E402
from comfy.ldm.seedvr.model import NaDiT # noqa: E402
# ---------------------------------------------------------------------------
# Helpers from seedvr_model_test.py
# ---------------------------------------------------------------------------
def _make_standin(positive_conditioning):
class _StandIn(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer(
"positive_conditioning", positive_conditioning
)
_resolve_text_conditioning = NaDiT._resolve_text_conditioning
return _StandIn()
# ---------------------------------------------------------------------------
# Helpers from test_seedvr_7b_final_block_text_path.py
# ---------------------------------------------------------------------------
class _StubModule(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]:
flags = []
class _Block(_StubModule):
def __init__(self, *args, **kwargs):
flags.append(kwargs["is_last_layer"])
super().__init__()
monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule)
monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule)
monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule)
monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block)
seedvr_model.NaDiT(
norm_eps=1e-5,
qk_rope=None,
num_layers=4,
mlp_type="normal",
vid_dim=vid_dim,
txt_in_dim=txt_in_dim,
heads=24,
mm_layers=3,
)
return flags
# ---------------------------------------------------------------------------
# Helpers from test_seedvr_latent_format.py
# ---------------------------------------------------------------------------
class _Model:
def __init__(self, latent_format):
self._latent_format = latent_format
def get_model_object(self, name):
assert name == "latent_format"
return self._latent_format
# ---------------------------------------------------------------------------
# Helpers from test_seedvr2_vae_graph_boundaries.py
# ---------------------------------------------------------------------------
class _Patcher:
def get_free_memory(self, device):
return 1024 * 1024 * 1024
class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
def __init__(self, encoded):
nn.Module.__init__(self)
self.encoded = encoded
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.seen = []
def encode(self, x):
self.seen.append(tuple(x.shape))
return self.encoded.to(device=x.device, dtype=x.dtype)
class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
def __init__(self):
nn.Module.__init__(self)
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.calls = []
def decode(self, z, seedvr2_tiling=None):
self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling})
if z.ndim == 4:
b, tc, h, w = z.shape
t = tc // 16
else:
b, _, t, h, w = z.shape
return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
def _make_vae(wrapper):
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = wrapper
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.latent_channels = 16
vae.latent_dim = 3
vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8)
vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
vae.output_channels = 3
vae.disable_offload = True
vae.extra_1d_channel = None
vae.crop_input = False
vae.not_video = False
vae.patcher = _Patcher()
vae.process_input = lambda image: image
vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0)
vae.vae_output_dtype = lambda: torch.float32
vae.memory_used_encode = lambda shape, dtype: 1
vae.memory_used_decode = lambda shape, dtype: 1
vae.throw_exception_if_invalid = lambda: None
vae.vae_encode_crop_pixels = lambda pixels: pixels
vae.spacial_compression_decode = lambda: 8
vae.temporal_compression_decode = lambda: 4
return vae
# ---------------------------------------------------------------------------
# Tests from seedvr_model_test.py
# ---------------------------------------------------------------------------
def test_missing_context_falls_back_to_positive_buffer():
"""AC: ``context is None`` falls back to the registered
``positive_conditioning`` buffer and runs to completion no
silent zero substitution, no raised exception.
"""
pos_buffer = torch.full((58, 5120), 7.0)
standin = _make_standin(pos_buffer)
txt, txt_shape = standin._resolve_text_conditioning(None)
assert txt.shape == (58, 5120)
assert (txt == 7.0).all(), (
"fallback path must use the positive_conditioning buffer "
"verbatim, not a zero tensor"
)
assert txt_shape.shape == (1, 1)
assert txt_shape[0, 0].item() == 58
# ---------------------------------------------------------------------------
# Tests from test_seedvr_7b_final_block_text_path.py
# ---------------------------------------------------------------------------
def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch):
assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [
False,
False,
False,
False,
]
def test_seedvr2_7b_rope3d_matches_wrapper_oracle():
rope = seedvr_model.get_na_rope("rope3d", dim=64)
generator = torch.Generator(device="cpu").manual_seed(0)
q = torch.randn(4, 2, 128, generator=generator)
k = torch.randn(4, 2, 128, generator=generator)
shape = torch.tensor([[1, 2, 2]], dtype=torch.long)
freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1)
expected_q = seedvr_model._apply_seedvr2_rotary_emb(
freqs,
q.permute(1, 0, 2).float(),
).to(q.dtype).permute(1, 0, 2)
expected_k = seedvr_model._apply_seedvr2_rotary_emb(
freqs,
k.permute(1, 0, 2).float(),
).to(k.dtype).permute(1, 0, 2)
actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True))
torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0)
torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0)
# ---------------------------------------------------------------------------
# Tests from test_seedvr_latent_format.py
# ---------------------------------------------------------------------------
def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion():
latent_format = comfy.latent_formats.SeedVR2()
latent_image = torch.zeros(1, 1, 4, 5)
fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image)
assert latent_format.latent_channels == 16
assert latent_format.latent_dimensions == 2
assert fixed.shape == (1, 16, 4, 5)
# ---------------------------------------------------------------------------
# Tests from test_seedvr2_vae_graph_boundaries.py
# ---------------------------------------------------------------------------
def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch):
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
encoded = torch.full((1, 16, 2, 4, 5), 2.0)
vae = _make_vae(_EncodeWrapper(encoded))
pixels = torch.zeros(1, 5, 32, 40, 3)
node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0]
node_latent = node_output["samples"]
assert set(node_output) == {"samples"}
assert tuple(node_latent.shape) == (1, 16, 2, 4, 5)
assert node_latent.dtype == torch.float32
assert node_latent.stride()[-1] == 1
assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152))
tiled = torch.full((1, 16, 2, 4, 5), 3.0)
monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled))
tiled_output = nodes_mod.VAEEncodeTiled().encode(
vae,
pixels,
tile_size=512,
overlap=64,
temporal_size=16,
temporal_overlap=4,
)[0]
tiled_latent = tiled_output["samples"]
assert set(tiled_output) == {"samples"}
assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5)
assert tiled_latent.dtype == torch.float32
assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152))
def test_vaedecode_tiled_visible_inputs_are_seedvr2_decode_tiling_authority(monkeypatch):
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
vae = _make_vae(_DecodeWrapper())
nodes_mod.VAEDecodeTiled().decode(
vae,
{"samples": torch.zeros(1, 16, 2, 4, 5)},
tile_size=512,
overlap=64,
temporal_size=16,
temporal_overlap=4,
)
assert vae.first_stage_model.calls == [
{
"shape": (1, 16, 2, 4, 5),
"seedvr2_tiling": {
"enable_tiling": True,
"tile_size": (512, 512),
"tile_overlap": (64, 64),
"temporal_size": 16,
"temporal_overlap": 4,
},
}
]

View File

@ -1,91 +0,0 @@
from unittest.mock import patch
import pytest
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
from comfy_extras import nodes_seedvr # noqa: E402
def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper:
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
vae_mod.VideoAutoencoderKLWrapper
)
nn.Module.__init__(wrapper)
return wrapper
def _fingerprint_decode_(self, z, return_dict=True):
b = int(z.shape[0])
t = int(z.shape[2])
h = int(z.shape[3])
w = int(z.shape[4])
out = torch.empty(b, 3, t, h * 8, w * 8)
for batch_idx in range(b):
out[batch_idx].fill_(float(batch_idx + 1))
return out
def _decode_with_patches(wrapper, z):
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_):
return wrapper.decode(z)
def test_decode_b2_t3_multi_frame_batch_unchanged():
wrapper = _make_wrapper()
out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2))
assert tuple(out.shape) == (2, 3, 3, 16, 16)
class _Wrapper(vae_mod.VideoAutoencoderKLWrapper):
def __init__(self):
nn.Module.__init__(self)
self.calls = []
def parameters(self):
return iter([torch.nn.Parameter(torch.zeros(()))])
def _decode_stub(self, latent):
self.calls.append(tuple(latent.shape))
return torch.zeros(latent.shape[0], 3, latent.shape[2], latent.shape[3] * 8, latent.shape[4] * 8)
def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preprocessor_state():
wrapper = _Wrapper()
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5))
assert tuple(out.shape) == (1, 3, 2, 32, 40)
assert wrapper.calls == [(1, 16, 2, 4, 5)]
def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents():
wrapper = _Wrapper()
with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"):
wrapper.decode(torch.zeros(1, 16, 4))
def _t_padded(t_in: int) -> int:
if t_in == 1:
return 1
if t_in <= 4:
return 5
if (t_in - 1) % 4 == 0:
return t_in
return t_in + (4 - ((t_in - 1) % 4))
@pytest.mark.parametrize("t_in", [1, 5, 9])
def test_t_padded_matches_cut_videos(t_in):
dummy = torch.zeros(1, t_in, 1, 1, 1)
assert nodes_seedvr.cut_videos(dummy).shape[1] == _t_padded(t_in)

View File

@ -1,347 +0,0 @@
from contextlib import ExitStack
from unittest.mock import MagicMock, patch
import torch
import torch.nn as nn
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
import comfy.sd as sd_mod # noqa: E402
from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402
# ---------------------------------------------------------------------------
# From test_seedvr_vae_tiled_decode_latent_min_size_override.py
# ---------------------------------------------------------------------------
def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae
class StubVAEModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.slicing_latent_min_size = 2
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self.use_slicing = True
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
self.decode_min_sizes = []
self.memory_states = []
def decode_(self, t_chunk):
self.decode_min_sizes.append(self.slicing_latent_min_size)
return VideoAutoencoderKL.slicing_decode(self, t_chunk)
def _decode(self, z, memory_state=MemoryState.DISABLED):
self.memory_states.append(memory_state)
b, c, d, h, w = z.shape
return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype)
vae = StubVAEModel()
z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32)
tiled_vae(
z,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=0,
temporal_overlap=0,
encode=False,
)
assert vae.decode_min_sizes == [5]
assert vae.memory_states == [MemoryState.DISABLED]
assert vae.slicing_latent_min_size == 2
# ---------------------------------------------------------------------------
# From test_seedvr_vae_tiled_encode_runt_slice_override.py
# ---------------------------------------------------------------------------
def test_zero_temporal_size_preserves_min_size_when_encode_raises():
from comfy.ldm.seedvr.vae import tiled_vae
class RaisingVAEModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.slicing_sample_min_size = 4
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
def encode(self, t_chunk):
raise RuntimeError("simulated encode failure")
vae = RaisingVAEModel()
x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32)
raised = False
try:
tiled_vae(
x,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=0,
temporal_overlap=0,
encode=True,
)
except RuntimeError as exc:
if "simulated encode failure" not in str(exc):
raise
raised = True
assert raised
assert vae.slicing_sample_min_size == 4
# ---------------------------------------------------------------------------
# From test_seedvr_vae_tiled_temporal_slicing.py
# ---------------------------------------------------------------------------
class _SlicingDecodeVAE(nn.Module):
def __init__(self, slicing_latent_min_size):
super().__init__()
self.slicing_latent_min_size = slicing_latent_min_size
self.spatial_downsample_factor = 8
self.temporal_downsample_factor = 4
self.device = torch.device("cpu")
self.use_slicing = True
self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32))
self.decode_min_sizes = []
self.memory_states = []
def decode_(self, z):
self.decode_min_sizes.append(self.slicing_latent_min_size)
return vae_mod.VideoAutoencoderKL.slicing_decode(self, z)
def _decode(self, z, memory_state=MemoryState.DISABLED):
self.memory_states.append(memory_state)
x = z[:, :1].repeat(
1,
3,
1,
self.spatial_downsample_factor,
self.spatial_downsample_factor,
)
return x
def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size():
vae = _SlicingDecodeVAE(slicing_latent_min_size=2)
z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8)
tiled_vae(
z,
vae,
tile_size=(64, 64),
tile_overlap=(0, 0),
temporal_size=12,
temporal_overlap=4,
encode=False,
)
assert vae.decode_min_sizes == [2]
assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE]
assert vae.slicing_latent_min_size == 2
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
vae_mod.VideoAutoencoderKLWrapper
)
nn.Module.__init__(wrapper)
seedvr2_tiling = {
"enable_tiling": True,
"tile_size": (64, 64),
"tile_overlap": (0, 0),
"temporal_size": 8,
"temporal_overlap": 7,
}
captured = {}
def _fake_tiled_vae(latent, model, **kwargs):
captured.update(kwargs)
return torch.zeros(1, 3, 1, 16, 16)
with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae):
wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling)
assert captured["temporal_overlap"] == 7
# ---------------------------------------------------------------------------
# From test_vae_decode_tiled_dispatcher_seedvr2_4d.py
# ---------------------------------------------------------------------------
def _force_oom(*a, **k):
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
def _make_vae(first_stage_model, latent_channels, latent_dim):
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = first_stage_model
vae.patcher = MagicMock()
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
vae.device = vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.disable_offload = True
vae.extra_1d_channel = None
vae.upscale_ratio = vae.downscale_ratio = 8
vae.upscale_index_formula = vae.downscale_index_formula = None
vae.output_channels = 3
vae.latent_channels = latent_channels
vae.latent_dim = latent_dim
vae.vae_output_dtype = lambda: torch.float32
vae.spacial_compression_decode = lambda: 8
vae.process_input = lambda x: x
vae.process_output = lambda x: x
vae.throw_exception_if_invalid = lambda: None
vae.memory_used_decode = lambda *a, **k: 1
return vae
def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode):
mm = sd_mod.model_management
with ExitStack() as stack:
stack.enter_context(patch.object(mm, "raise_non_oom", lambda e: None))
stack.enter_context(patch.object(mm, "load_models_gpu", lambda *a, **k: None))
stack.enter_context(patch.object(mm, "soft_empty_cache", lambda: None))
stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_seedvr2", seedvr2_call))
stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_", generic_call))
if patch_wrapper_decode:
stack.enter_context(patch.object(
seedvr_vae_mod.VideoAutoencoderKLWrapper, "decode",
side_effect=_force_oom))
vae.decode(samples)
def test_4d_seedvr2_latent_routes_to_decode_tiled_seedvr2():
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
seedvr_vae_mod.VideoAutoencoderKLWrapper)
vae = _make_vae(wrapper, latent_channels=16, latent_dim=3)
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
_dispatch(vae, torch.zeros(1, 16 * 3, 8, 8), seedvr2_call, generic_call, True)
assert seedvr2_call.call_count == 1
assert generic_call.call_count == 0
def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled():
first_stage = MagicMock()
first_stage.decode = MagicMock(side_effect=_force_oom)
vae = _make_vae(first_stage, latent_channels=4, latent_dim=2)
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
_dispatch(vae, torch.zeros(1, 4, 8, 8), seedvr2_call, generic_call, False)
assert generic_call.call_count == 1
assert seedvr2_call.call_count == 0
# ---------------------------------------------------------------------------
# From test_vae_encode_tiled_fallback_dispatcher_seedvr2.py
# ---------------------------------------------------------------------------
def _populate_common_vae_attrs_fallback(vae):
vae.patcher = MagicMock()
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
vae.device = torch.device("cpu")
vae.output_device = torch.device("cpu")
vae.vae_dtype = torch.float32
vae.disable_offload = True
vae.extra_1d_channel = None
vae.upscale_ratio = 8
vae.upscale_index_formula = None
vae.output_channels = 3
vae.latent_channels = 16
vae.latent_dim = 3
vae.downscale_ratio = 8
vae.downscale_index_formula = None
vae.not_video = False
vae.crop_input = False
vae.pad_channel_value = None
vae.vae_output_dtype = lambda: torch.float32
vae.spacial_compression_encode = lambda: 8
vae.process_input = lambda x: x
vae.process_output = lambda x: x
vae.throw_exception_if_invalid = lambda: None
vae.memory_used_encode = lambda *a, **k: 1
def _make_seedvr2_vae_fallback():
vae = sd_mod.VAE.__new__(sd_mod.VAE)
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
seedvr_vae_mod.VideoAutoencoderKLWrapper
)
vae.first_stage_model = wrapper
_populate_common_vae_attrs_fallback(vae)
return vae
def _make_non_seedvr2_vae_fallback():
vae = sd_mod.VAE.__new__(sd_mod.VAE)
vae.first_stage_model = MagicMock()
_populate_common_vae_attrs_fallback(vae)
return vae
def _force_regular_encode_oom(*args, **kwargs):
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
def test_seedvr2_3d_routes_to_encode_tiled_seedvr2_on_oom():
vae = _make_seedvr2_vae_fallback()
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
with patch.object(sd_mod.model_management, "raise_non_oom",
lambda e: None), \
patch.object(sd_mod.model_management, "load_models_gpu",
lambda *a, **k: None), \
patch.object(sd_mod.model_management, "soft_empty_cache",
lambda: None), \
patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode",
side_effect=_force_regular_encode_oom), \
patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call,
create=True), \
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
vae.encode(pixel_samples)
assert seedvr2_call.call_count == 1, (
f"Expected encode_tiled_seedvr2 to be called once for a SeedVR2 3D "
f"input under OOM fallback; got {seedvr2_call.call_count} calls."
)
assert generic_call.call_count == 0, (
f"encode_tiled_3d must NOT be called for a SeedVR2 input; got "
f"{generic_call.call_count} calls."
)
def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete():
vae = _make_non_seedvr2_vae_fallback()
vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8)
vae.upscale_ratio = (lambda a: a * 4, 8, 8)
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
with patch.object(sd_mod.model_management, "load_models_gpu",
lambda *a, **k: None), \
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
vae.encode_tiled(pixel_samples)
assert generic_call.call_args.kwargs["overlap"] == (1, 64, 64)

View File

@ -1,126 +0,0 @@
"""Unit tests for ``comfy_extras.nodes_seedvr.SeedVR2ProgressiveSampler``."""
from unittest.mock import patch
import pytest
import torch
from comfy.cli_args import args as cli_args
if not torch.cuda.is_available():
cli_args.cpu = True
import comfy.sample # noqa: E402
import comfy_extras.nodes_seedvr as nodes_seedvr_mod # noqa: E402
from comfy_extras.nodes_seedvr import SeedVR2ProgressiveSampler # noqa: E402
_LAT_C = 16
_COND_C = 17
def _make_inputs(B: int = 1, T: int = 5, H: int = 8, W: int = 8):
"""Build minimal SeedVR2-shaped sampling inputs."""
samples_5d = torch.arange(
B * _LAT_C * T * H * W, dtype=torch.float32
).reshape(B, _LAT_C, T, H, W)
samples = samples_5d.reshape(B, _LAT_C * T, H, W).contiguous()
cond_5d = torch.arange(
B * _COND_C * T * H * W, dtype=torch.float32
).reshape(B, _COND_C, T, H, W) + 10000.0
cond = cond_5d.reshape(B, _COND_C * T, H, W).contiguous()
text_pos = torch.zeros(1, 4, 32)
text_neg = torch.zeros(1, 4, 32)
positive = [[text_pos, {"condition": cond.clone()}]]
negative = [[text_neg, {"condition": cond.clone()}]]
latent_image = {"samples": samples}
return latent_image, positive, negative, samples_5d, cond_5d
def _identity_fix_empty(model, latent_image, downscale_ratio_spacial=None):
return latent_image
def _fingerprinted_prepare_noise(latent_image, seed, batch_inds=None):
"""Return a tensor whose values encode ``(seed, position)``."""
base = torch.arange(
latent_image.numel(), dtype=torch.float32
).reshape(latent_image.shape)
return base + float(seed) * 1e6
def test_progressive_sampler_schema_exposes_manual_default_auto_chunking():
schema = SeedVR2ProgressiveSampler.define_schema()
inputs = {item.id: item for item in schema.inputs}
assert inputs["chunking_mode"].options == ["manual", "auto"]
assert inputs["chunking_mode"].default == "manual"
def test_auto_chunking_walks_two_three_four_chunk_ladder():
"""Auto mode must walk 2-, 3-, then 4-chunk geometries on OOM."""
latent, pos, neg, _, _ = _make_inputs(T=17)
calls = []
def _oom_until_four_chunks(model, noise, steps, cfg, sampler_name,
scheduler, positive, negative,
latent_image, denoise=1.0,
noise_mask=None, seed=None):
calls.append(tuple(latent_image.shape))
if latent_image.shape[1] > _LAT_C * 5:
raise torch.cuda.OutOfMemoryError("chunk too large")
return latent_image.clone()
with patch.object(comfy.sample, "sample",
side_effect=_oom_until_four_chunks), \
patch.object(comfy.sample, "fix_empty_latent_channels",
side_effect=_identity_fix_empty), \
patch.object(comfy.sample, "prepare_noise",
side_effect=_fingerprinted_prepare_noise), \
patch.object(nodes_seedvr_mod.comfy.model_management,
"soft_empty_cache") as soft_empty:
out = SeedVR2ProgressiveSampler.execute(
model=None, seed=0, steps=2, cfg=1.0,
sampler_name="euler", scheduler="simple",
positive=pos, negative=neg, latent=latent,
denoise=1.0, frames_per_chunk=65, temporal_overlap=0,
chunking_mode="auto",
)
assert calls[:4] == [
(1, _LAT_C * 17, 8, 8),
(1, _LAT_C * 9, 8, 8),
(1, _LAT_C * 6, 8, 8),
(1, _LAT_C * 5, 8, 8),
]
assert torch.equal(out.result[0]["samples"], latent["samples"])
assert soft_empty.call_count == 3
@pytest.mark.parametrize("bad_chunk", [0, -1, 2])
def test_t3_invalid_frames_per_chunk_raises_value_error(bad_chunk):
"""``frames_per_chunk`` violating 4n+1 (or <1) must raise ``ValueError`` before any model invocation."""
latent, pos, neg, _, _ = _make_inputs(T=5)
sampler_called = {"n": 0}
def _should_not_be_called(*args, **kwargs):
sampler_called["n"] += 1
return torch.zeros(1)
with patch.object(comfy.sample, "sample",
side_effect=_should_not_be_called), \
patch.object(comfy.sample, "fix_empty_latent_channels",
side_effect=_identity_fix_empty), \
patch.object(comfy.sample, "prepare_noise",
side_effect=_fingerprinted_prepare_noise):
with pytest.raises(ValueError) as excinfo:
SeedVR2ProgressiveSampler.execute(
model=None, seed=0, steps=2, cfg=1.0,
sampler_name="euler", scheduler="simple",
positive=pos, negative=neg, latent=latent,
denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0,
)
assert str(bad_chunk) in str(excinfo.value)
assert sampler_called["n"] == 0