mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-11 00:37:53 +08:00
Merge branch 'master' into alexis/reorder_inputs
This commit is contained in:
commit
d977079a7d
@ -33,6 +33,7 @@ from app.assets.services.file_utils import (
|
||||
verify_file_unchanged,
|
||||
)
|
||||
from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash
|
||||
from app.assets.services.image_dimensions import extract_image_dimensions
|
||||
from app.assets.services.metadata_extract import extract_file_metadata
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
@ -506,6 +507,10 @@ def enrich_asset(
|
||||
|
||||
if extract_metadata and metadata:
|
||||
system_metadata = metadata.to_user_metadata()
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
dims = extract_image_dimensions(file_path, mime_type=mime_type)
|
||||
if dims:
|
||||
system_metadata.update(dims)
|
||||
set_reference_system_metadata(session, reference_id, system_metadata)
|
||||
|
||||
if full_hash:
|
||||
|
||||
63
app/assets/services/image_dimensions.py
Normal file
63
app/assets/services/image_dimensions.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""Image dimension extraction for asset ingest.
|
||||
|
||||
Reads only the image header via Pillow to capture width/height cheaply,
|
||||
without a full pixel decode. Returns a metadata dict suitable for merging
|
||||
into ``AssetReference.system_metadata``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_image_dimensions(
|
||||
file_path: str, mime_type: str | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Extract image dimensions for the file at ``file_path``.
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to a file on disk.
|
||||
mime_type: Optional MIME type hint. When provided and not prefixed
|
||||
with ``image/``, extraction is skipped without touching the file.
|
||||
|
||||
Returns:
|
||||
``{"kind": "image", "width": W, "height": H}`` when the file is a
|
||||
recognizable image with positive dimensions, otherwise ``None``.
|
||||
|
||||
The dict shape is intended to be merged into ``system_metadata`` so the
|
||||
asset response surfaces ``metadata.kind`` plus dimension fields for image
|
||||
assets. Forward-compatible: future media kinds (e.g. ``"video"`` with
|
||||
duration/fps) can extend this shape without schema changes.
|
||||
"""
|
||||
if mime_type is not None and not mime_type.startswith("image/"):
|
||||
return None
|
||||
|
||||
try:
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
"Pillow not available; skipping image dimension extraction for %s",
|
||||
file_path,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
with Image.open(file_path) as img:
|
||||
width, height = img.size
|
||||
except (OSError, UnidentifiedImageError, ValueError) as exc:
|
||||
logger.debug(
|
||||
"Failed to read image dimensions from %s: %s", file_path, exc
|
||||
)
|
||||
return None
|
||||
|
||||
if (
|
||||
not isinstance(width, int)
|
||||
or not isinstance(height, int)
|
||||
or width <= 0
|
||||
or height <= 0
|
||||
):
|
||||
return None
|
||||
|
||||
return {"kind": "image", "width": width, "height": height}
|
||||
@ -17,9 +17,11 @@ from app.assets.database.queries import (
|
||||
get_reference_by_file_path,
|
||||
get_reference_tags,
|
||||
get_or_create_reference,
|
||||
list_references_by_asset_id,
|
||||
reference_exists,
|
||||
remove_missing_tag_for_asset_id,
|
||||
set_reference_metadata,
|
||||
set_reference_system_metadata,
|
||||
set_reference_tags,
|
||||
update_asset_hash_and_mime,
|
||||
upsert_asset,
|
||||
@ -29,6 +31,7 @@ from app.assets.database.queries import (
|
||||
from app.assets.helpers import get_utc_now, normalize_tags
|
||||
from app.assets.services.bulk_ingest import batch_insert_seed_assets
|
||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.image_dimensions import extract_image_dimensions
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
get_name_and_tags_from_asset_path,
|
||||
@ -118,6 +121,14 @@ def _ingest_file_from_path(
|
||||
user_metadata=user_metadata,
|
||||
)
|
||||
|
||||
_maybe_store_image_dimensions(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
file_path=locator,
|
||||
mime_type=mime_type,
|
||||
current_system_metadata=ref.system_metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
@ -288,6 +299,13 @@ def _register_existing_asset(
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
_backfill_image_dimensions_from_siblings(
|
||||
session,
|
||||
asset_id=asset.id,
|
||||
new_reference_id=ref.id,
|
||||
current_system_metadata=ref.system_metadata,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_reference_tags(
|
||||
session,
|
||||
@ -334,6 +352,87 @@ def _update_metadata_with_filename(
|
||||
)
|
||||
|
||||
|
||||
_IMAGE_DIMENSION_KEYS = ("kind", "width", "height")
|
||||
|
||||
|
||||
def _maybe_store_image_dimensions(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
file_path: str,
|
||||
mime_type: str | None,
|
||||
current_system_metadata: dict | None,
|
||||
) -> None:
|
||||
"""Populate ``kind``/``width``/``height`` on system_metadata for image refs.
|
||||
|
||||
Non-image MIME types are a no-op. Pre-existing keys (e.g. enricher-written
|
||||
safetensors metadata, download provenance) are preserved by merge.
|
||||
"""
|
||||
if not mime_type or not mime_type.startswith("image/"):
|
||||
return
|
||||
|
||||
dims = extract_image_dimensions(file_path, mime_type=mime_type)
|
||||
if not dims:
|
||||
return
|
||||
|
||||
current = current_system_metadata or {}
|
||||
merged = dict(current)
|
||||
merged.update(dims)
|
||||
if merged != current:
|
||||
set_reference_system_metadata(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
system_metadata=merged,
|
||||
)
|
||||
|
||||
|
||||
def _backfill_image_dimensions_from_siblings(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
new_reference_id: str,
|
||||
current_system_metadata: dict | None,
|
||||
) -> None:
|
||||
"""Copy image dimension keys from any sibling reference of the same asset.
|
||||
|
||||
The from-hash path doesn't read the file bytes, so dimensions can't be
|
||||
extracted there directly. When another reference of the same asset already
|
||||
carries image dimensions, copy them onto the new reference so consumers
|
||||
see consistent metadata regardless of how the asset was registered.
|
||||
|
||||
Best-effort: missing siblings, non-image siblings, or absent dimension
|
||||
keys leave the target reference unchanged.
|
||||
"""
|
||||
current = current_system_metadata or {}
|
||||
if current.get("kind") == "image" and "width" in current and "height" in current:
|
||||
return
|
||||
|
||||
for sibling in list_references_by_asset_id(session, asset_id):
|
||||
if sibling.id == new_reference_id:
|
||||
continue
|
||||
meta = sibling.system_metadata or {}
|
||||
if meta.get("kind") != "image":
|
||||
continue
|
||||
width = meta.get("width")
|
||||
height = meta.get("height")
|
||||
if (
|
||||
type(width) is not int
|
||||
or type(height) is not int
|
||||
or width <= 0
|
||||
or height <= 0
|
||||
):
|
||||
continue
|
||||
merged = dict(current)
|
||||
merged["kind"] = "image"
|
||||
merged["width"] = width
|
||||
merged["height"] = height
|
||||
if merged != current:
|
||||
set_reference_system_metadata(
|
||||
session,
|
||||
reference_id=new_reference_id,
|
||||
system_metadata=merged,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def _sanitize_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
return n if n else fallback
|
||||
|
||||
@ -166,6 +166,8 @@ class PerformanceFeature(enum.Enum):
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
parser.add_argument("--debug-hang", action="store_true", help="Enable stack trace dumps on Ctrl-C for debugging hangs.")
|
||||
|
||||
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
|
||||
@ -4,7 +4,6 @@ class LatentFormat:
|
||||
scale_factor = 1.0
|
||||
latent_channels = 4
|
||||
latent_dimensions = 2
|
||||
preserve_empty_channel_multiples = False
|
||||
latent_rgb_factors = None
|
||||
latent_rgb_factors_bias = None
|
||||
latent_rgb_factors_reshape = None
|
||||
@ -780,10 +779,6 @@ class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
class SeedVR2(LatentFormat):
|
||||
latent_channels = 16
|
||||
preserve_empty_channel_multiples = True
|
||||
|
||||
class ACEAudio15(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
@ -735,86 +735,7 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
)
|
||||
return out
|
||||
|
||||
def _var_attention_qkv(q, k, v, heads, skip_reshape):
|
||||
if skip_reshape:
|
||||
return q, k, v, q.shape[-1]
|
||||
total_tokens, embed_dim = q.shape
|
||||
head_dim = embed_dim // heads
|
||||
return (
|
||||
q.view(total_tokens, heads, head_dim),
|
||||
k.view(k.shape[0], heads, head_dim),
|
||||
v.view(v.shape[0], heads, head_dim),
|
||||
head_dim,
|
||||
)
|
||||
|
||||
|
||||
def _var_attention_output(out, heads, head_dim, skip_output_reshape):
|
||||
if skip_output_reshape:
|
||||
return out
|
||||
return out.reshape(-1, heads * head_dim)
|
||||
|
||||
|
||||
def _use_blackwell_attention():
|
||||
device = model_management.get_torch_device()
|
||||
if device.type != "cuda":
|
||||
return False
|
||||
major, minor = torch.cuda.get_device_capability(device)
|
||||
return (major, minor) >= (12, 0)
|
||||
|
||||
|
||||
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
|
||||
if cu_seqlens.dtype not in (torch.int32, torch.int64):
|
||||
raise ValueError(f"{name} must use an integer dtype")
|
||||
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
|
||||
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
|
||||
if cu_seqlens[0].item() != 0:
|
||||
raise ValueError(f"{name} must start at 0")
|
||||
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
|
||||
raise ValueError(f"{name} must be strictly increasing")
|
||||
if cu_seqlens[-1].item() != token_count:
|
||||
raise ValueError(f"{name} does not match token count")
|
||||
|
||||
|
||||
def _split_indices(cu_seqlens):
|
||||
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
|
||||
|
||||
|
||||
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||
|
||||
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
|
||||
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0])
|
||||
if cu_seqlens_k[-1].item() != v.shape[0]:
|
||||
raise ValueError("cu_seqlens_k does not match v token count")
|
||||
|
||||
q_split_indices = _split_indices(cu_seqlens_q)
|
||||
k_split_indices = _split_indices(cu_seqlens_k)
|
||||
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
|
||||
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
|
||||
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
|
||||
if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits):
|
||||
raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count")
|
||||
|
||||
out = []
|
||||
for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits):
|
||||
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
|
||||
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
|
||||
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
|
||||
out_dtype = q_i.dtype
|
||||
if optimized_attention is attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16):
|
||||
q_i = q_i.to(torch.bfloat16)
|
||||
k_i = k_i.to(torch.bfloat16)
|
||||
v_i = v_i.to(torch.bfloat16)
|
||||
out_i = optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
|
||||
if out_i.dtype != out_dtype:
|
||||
out_i = out_i.to(out_dtype)
|
||||
out.append(out_i.squeeze(0).permute(1, 0, 2))
|
||||
|
||||
out = torch.cat(out, dim=0)
|
||||
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||
|
||||
|
||||
optimized_var_attention = var_attention_optimized_split
|
||||
optimized_attention = attention_basic
|
||||
|
||||
if model_management.sage_attention_enabled():
|
||||
@ -837,8 +758,6 @@ else:
|
||||
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
optimized_attention = attention_sub_quad
|
||||
|
||||
logging.info("Using optimized_attention split-loop for variable-length attention")
|
||||
|
||||
optimized_attention_masked = optimized_attention
|
||||
|
||||
|
||||
@ -854,7 +773,6 @@ if model_management.xformers_enabled():
|
||||
register_attention_function("pytorch", attention_pytorch)
|
||||
register_attention_function("sub_quad", attention_sub_quad)
|
||||
register_attention_function("split", attention_split)
|
||||
register_attention_function("var_attention_optimized_split", var_attention_optimized_split)
|
||||
|
||||
|
||||
def optimized_attention_for_device(device, mask=False, small_input=False):
|
||||
@ -1291,3 +1209,5 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
x = self.proj_out(x)
|
||||
out = x + x_in
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@ -13,7 +13,6 @@ if model_management.xformers_enabled_vae():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
|
||||
def torch_cat_if_needed(xl, dim):
|
||||
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
|
||||
if len(xl) > 1:
|
||||
@ -23,8 +22,7 @@ def torch_cat_if_needed(xl, dim):
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1):
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
@ -35,13 +33,11 @@ def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, down
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - downscale_freq_shift)
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||
return emb
|
||||
|
||||
@ -51,6 +51,18 @@ class FeedForward(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Addin this back because Nunchaku custom nodes rely on it, see comment here:
|
||||
# https://github.com/Comfy-Org/ComfyUI/pull/14178#issuecomment-4640475161
|
||||
# TODO: Eventually remove this once we natively support SVDQuants
|
||||
def apply_rotary_emb(x, freqs_cis):
|
||||
if x.shape[1] == 0:
|
||||
return x
|
||||
|
||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x.shape)
|
||||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
@ -1,340 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||
from comfy.ldm.seedvr.vae import safe_interpolate_operation
|
||||
from comfy.ldm.seedvr.constants import (
|
||||
CIELAB_DELTA,
|
||||
CIELAB_KAPPA,
|
||||
D65_WHITE_X,
|
||||
D65_WHITE_Z,
|
||||
WAVELET_DECOMP_LEVELS,
|
||||
)
|
||||
|
||||
|
||||
def wavelet_blur(image: Tensor, radius):
|
||||
max_safe_radius = max(1, min(image.shape[-2:]) // 8)
|
||||
if radius > max_safe_radius:
|
||||
radius = max_safe_radius
|
||||
|
||||
num_channels = image.shape[1]
|
||||
|
||||
kernel_vals = [
|
||||
[0.0625, 0.125, 0.0625],
|
||||
[0.125, 0.25, 0.125],
|
||||
[0.0625, 0.125, 0.0625],
|
||||
]
|
||||
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
||||
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
|
||||
|
||||
image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate')
|
||||
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
|
||||
|
||||
return output
|
||||
|
||||
def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS):
|
||||
high_freq = torch.zeros_like(image)
|
||||
|
||||
for i in range(levels):
|
||||
radius = 2 ** i
|
||||
low_freq = wavelet_blur(image, radius)
|
||||
high_freq.add_(image).sub_(low_freq)
|
||||
image = low_freq
|
||||
|
||||
return high_freq, low_freq
|
||||
|
||||
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
|
||||
if content_feat.shape != style_feat.shape:
|
||||
# Resize style to match content spatial dimensions
|
||||
if len(content_feat.shape) >= 3:
|
||||
# safe_interpolate_operation handles FP16 conversion automatically
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Decompose both features into frequency components
|
||||
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
||||
del content_low_freq # Free memory immediately
|
||||
|
||||
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
||||
del style_high_freq # Free memory immediately
|
||||
|
||||
if content_high_freq.shape != style_low_freq.shape:
|
||||
style_low_freq = safe_interpolate_operation(
|
||||
style_low_freq,
|
||||
size=content_high_freq.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
content_high_freq.add_(style_low_freq)
|
||||
|
||||
return content_high_freq.clamp_(-1.0, 1.0)
|
||||
|
||||
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
|
||||
original_shape = source.shape
|
||||
|
||||
# Flatten
|
||||
source_flat = source.flatten()
|
||||
reference_flat = reference.flatten()
|
||||
|
||||
# Sort both arrays
|
||||
source_sorted, source_indices = torch.sort(source_flat)
|
||||
reference_sorted, _ = torch.sort(reference_flat)
|
||||
del reference_flat
|
||||
|
||||
# Quantile mapping
|
||||
n_source = len(source_sorted)
|
||||
n_reference = len(reference_sorted)
|
||||
|
||||
if n_source == n_reference:
|
||||
matched_sorted = reference_sorted
|
||||
else:
|
||||
# Interpolate reference to match source quantiles
|
||||
source_quantiles = torch.linspace(0, 1, n_source, device=device)
|
||||
ref_indices = (source_quantiles * (n_reference - 1)).long()
|
||||
ref_indices.clamp_(0, n_reference - 1)
|
||||
matched_sorted = reference_sorted[ref_indices]
|
||||
del source_quantiles, ref_indices, reference_sorted
|
||||
|
||||
del source_sorted, source_flat
|
||||
|
||||
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
|
||||
inverse_indices = torch.argsort(source_indices)
|
||||
del source_indices
|
||||
matched_flat = matched_sorted[inverse_indices]
|
||||
del matched_sorted, inverse_indices
|
||||
|
||||
return matched_flat.reshape(original_shape)
|
||||
|
||||
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||
"""Convert batch of CIELAB images to RGB color space."""
|
||||
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
|
||||
|
||||
# LAB to XYZ
|
||||
fy = (L + 16.0) / 116.0
|
||||
fx = a.div(500.0).add_(fy)
|
||||
fz = fy - b / 200.0
|
||||
del L, a, b
|
||||
|
||||
# XYZ transformation
|
||||
x = torch.where(
|
||||
fx > epsilon,
|
||||
torch.pow(fx, 3.0),
|
||||
fx.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
y = torch.where(
|
||||
fy > epsilon,
|
||||
torch.pow(fy, 3.0),
|
||||
fy.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
z = torch.where(
|
||||
fz > epsilon,
|
||||
torch.pow(fz, 3.0),
|
||||
fz.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
del fx, fy, fz
|
||||
|
||||
# Apply D65 white point (in-place)
|
||||
x.mul_(D65_WHITE_X)
|
||||
# y *= 1.00000 # (no-op, skip)
|
||||
z.mul_(D65_WHITE_Z)
|
||||
|
||||
xyz = torch.stack([x, y, z], dim=1)
|
||||
del x, y, z
|
||||
|
||||
# Matrix multiplication: XYZ -> RGB
|
||||
B, C, H, W = xyz.shape
|
||||
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||
del xyz
|
||||
|
||||
# Ensure dtype consistency for matrix multiplication
|
||||
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
|
||||
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
|
||||
del xyz_flat
|
||||
|
||||
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||
del rgb_linear_flat
|
||||
|
||||
# Apply inverse gamma correction (delinearize)
|
||||
mask = rgb_linear > 0.0031308
|
||||
rgb = torch.where(
|
||||
mask,
|
||||
torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055),
|
||||
rgb_linear * 12.92
|
||||
)
|
||||
del mask, rgb_linear
|
||||
|
||||
return torch.clamp(rgb, 0.0, 1.0)
|
||||
|
||||
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||
"""Convert batch of RGB images to CIELAB color space using D65 illuminant."""
|
||||
# Apply sRGB gamma correction (linearize)
|
||||
mask = rgb > 0.04045
|
||||
rgb_linear = torch.where(
|
||||
mask,
|
||||
torch.pow((rgb + 0.055) / 1.055, 2.4),
|
||||
rgb / 12.92
|
||||
)
|
||||
del mask
|
||||
|
||||
# Matrix multiplication: RGB -> XYZ
|
||||
B, C, H, W = rgb_linear.shape
|
||||
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||
del rgb_linear
|
||||
|
||||
# Ensure dtype consistency for matrix multiplication
|
||||
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
|
||||
xyz_flat = torch.matmul(rgb_flat, matrix.T)
|
||||
del rgb_flat
|
||||
|
||||
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||
del xyz_flat
|
||||
|
||||
# Normalize by D65 white point (in-place)
|
||||
xyz[:, 0].div_(D65_WHITE_X) # X
|
||||
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
|
||||
xyz[:, 2].div_(D65_WHITE_Z) # Z
|
||||
|
||||
# XYZ to LAB transformation
|
||||
epsilon_cubed = epsilon ** 3
|
||||
mask = xyz > epsilon_cubed
|
||||
f_xyz = torch.where(
|
||||
mask,
|
||||
torch.pow(xyz, 1.0 / 3.0),
|
||||
xyz.mul(kappa).add_(16.0).div_(116.0)
|
||||
)
|
||||
del xyz, mask
|
||||
|
||||
# Extract channels and compute LAB
|
||||
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
|
||||
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
|
||||
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
|
||||
del f_xyz
|
||||
|
||||
return torch.stack([L, a, b], dim=1)
|
||||
|
||||
def lab_color_transfer(
|
||||
content_feat: Tensor,
|
||||
style_feat: Tensor,
|
||||
luminance_weight: float = 0.8
|
||||
) -> Tensor:
|
||||
content_feat = wavelet_reconstruction(content_feat, style_feat)
|
||||
|
||||
if content_feat.shape != style_feat.shape:
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
device = content_feat.device
|
||||
|
||||
def ensure_float32_precision(c):
|
||||
orig_dtype = c.dtype
|
||||
c = c.float()
|
||||
return c, orig_dtype
|
||||
content_feat, original_dtype = ensure_float32_precision(content_feat)
|
||||
style_feat, _ = ensure_float32_precision(style_feat)
|
||||
|
||||
rgb_to_xyz_matrix = torch.tensor([
|
||||
[0.4124564, 0.3575761, 0.1804375],
|
||||
[0.2126729, 0.7151522, 0.0721750],
|
||||
[0.0193339, 0.1191920, 0.9503041]
|
||||
], dtype=torch.float32, device=device)
|
||||
|
||||
xyz_to_rgb_matrix = torch.tensor([
|
||||
[ 3.2404542, -1.5371385, -0.4985314],
|
||||
[-0.9692660, 1.8760108, 0.0415560],
|
||||
[ 0.0556434, -0.2040259, 1.0572252]
|
||||
], dtype=torch.float32, device=device)
|
||||
|
||||
epsilon = CIELAB_DELTA
|
||||
kappa = CIELAB_KAPPA
|
||||
|
||||
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||
|
||||
# Convert to LAB color space
|
||||
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||
del content_feat
|
||||
|
||||
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||
del style_feat, rgb_to_xyz_matrix
|
||||
|
||||
# Match chrominance channels (a*, b*) for accurate color transfer
|
||||
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
|
||||
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
|
||||
|
||||
# Handle luminance with weighted blending
|
||||
if luminance_weight < 1.0:
|
||||
# Partially match luminance for better overall color accuracy
|
||||
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device)
|
||||
# Blend: preserve some content L* for detail, adopt some style L* for color
|
||||
result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight))
|
||||
del matched_L
|
||||
else:
|
||||
# Fully preserve content luminance
|
||||
result_L = content_lab[:, 0]
|
||||
|
||||
del content_lab, style_lab
|
||||
|
||||
# Reconstruct LAB with corrected channels
|
||||
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
|
||||
del result_L, matched_a, matched_b
|
||||
|
||||
# Convert back to RGB
|
||||
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
|
||||
del result_lab, xyz_to_rgb_matrix
|
||||
|
||||
# Convert back to [-1, 1] range (in-place)
|
||||
result = result_rgb.mul_(2.0).sub_(1.0)
|
||||
del result_rgb
|
||||
|
||||
result = result.to(original_dtype)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
return wavelet_reconstruction(content_feat, style_feat)
|
||||
|
||||
|
||||
def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor:
|
||||
if content_feat.shape != style_feat.shape:
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
original_dtype = content_feat.dtype
|
||||
content_feat = content_feat.float()
|
||||
style_feat = style_feat.float()
|
||||
|
||||
b, c = content_feat.shape[:2]
|
||||
content_flat = content_feat.reshape(b, c, -1)
|
||||
style_flat = style_feat.reshape(b, c, -1)
|
||||
|
||||
content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1)
|
||||
content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
|
||||
style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1)
|
||||
style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
|
||||
del content_flat, style_flat
|
||||
|
||||
normalized = (content_feat - content_mean) / content_std
|
||||
del content_mean, content_std
|
||||
result = normalized * style_std + style_mean
|
||||
del normalized, style_mean, style_std
|
||||
|
||||
result = result.clamp_(-1.0, 1.0)
|
||||
if result.dtype != original_dtype:
|
||||
result = result.to(original_dtype)
|
||||
return result
|
||||
@ -1,79 +0,0 @@
|
||||
"""Named constants for the SeedVR2 integration, grouped by provenance.
|
||||
|
||||
Provenance prefixes:
|
||||
- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline.
|
||||
- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites
|
||||
the upstream config/source path it was lifted from.
|
||||
- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature /
|
||||
ISO / CIE values; cite the standard.
|
||||
"""
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment)
|
||||
# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN)
|
||||
# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070
|
||||
# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT).
|
||||
# --------------------------------------------------------------------------------------
|
||||
SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB)
|
||||
SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# B. Fork heuristics (SEEDVR2 - this integration)
|
||||
# --------------------------------------------------------------------------------------
|
||||
SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim.
|
||||
# (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.)
|
||||
SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry.
|
||||
SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case).
|
||||
SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM.
|
||||
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk.
|
||||
SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels).
|
||||
SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16).
|
||||
SEEDVR2_DEFAULT_TEMPORAL_SIZE = 16 # default VAE temporal tile when unset.
|
||||
|
||||
# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing)
|
||||
SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk.
|
||||
SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path.
|
||||
SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path.
|
||||
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path.
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR)
|
||||
# --------------------------------------------------------------------------------------
|
||||
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm.
|
||||
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift.
|
||||
BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem).
|
||||
BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem).
|
||||
BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28.
|
||||
BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28.
|
||||
BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16).
|
||||
BYTEDANCE_GN_CHUNKS_FP32 = 2 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp32).
|
||||
BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD = 64 # attn_video_vae.py:308 (force .contiguous() above this b*t).
|
||||
BYTEDANCE_BLOCK_OUT_CHANNELS = (128, 256, 512, 512) # s8_c16_t4_inflation_sd3.yaml:7-11.
|
||||
BYTEDANCE_SLICING_SAMPLE_MIN = 4 # s8_c16_t4_inflation_sd3.yaml:22 (slicing_sample_min_size).
|
||||
BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE = 4 # infer.py:230 (temporal_downsample_factor); the 4n+1 factor.
|
||||
BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE = 8 # infer.py:231 (spatial_downsample_factor).
|
||||
BYTEDANCE_SCHEDULE_T = 1000.0 # configs_3b/main.yaml:65 (schedule.T); timestep range.
|
||||
BYTEDANCE_SPATIAL_DIVISOR = 16 # inference_seedvr2_3b.py:241 (DivisibleCrop((16,16))).
|
||||
BYTEDANCE_720P_REF_AREA = 45 * 80 # dit_v2/window.py:32 (720p reference area for window scaling).
|
||||
BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal window frames).
|
||||
BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency).
|
||||
BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim).
|
||||
# Resolution-dependent timestep-shift linear fits: (x1, y1, x2, y2) for get_lin_function.
|
||||
BYTEDANCE_IMG_SHIFT_FIT = (256 * 256, 1.0, 1024 * 1024, 3.2) # infer.py:242.
|
||||
BYTEDANCE_VID_SHIFT_FIT = (256 * 256 * 37, 1.0, 1280 * 720 * 145, 5.0) # infer.py:243.
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# D. Published standards (cite the literature)
|
||||
# --------------------------------------------------------------------------------------
|
||||
ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864.
|
||||
|
||||
# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65).
|
||||
CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta).
|
||||
CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa).
|
||||
D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1).
|
||||
D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn.
|
||||
WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR).
|
||||
|
||||
# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and
|
||||
# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the
|
||||
# exact existing coefficients move verbatim rather than being retyped here.
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1631,13 +1631,15 @@ class SCAILWanModel(WanModel):
|
||||
|
||||
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||
|
||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
|
||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, ref_mask_latents=None, sam_latents=None, **kwargs):
|
||||
|
||||
if reference_latent is not None:
|
||||
x = torch.cat((reference_latent, x), dim=2)
|
||||
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
if ref_mask_latents is not None: # SCAIL-2 additive mask stream
|
||||
x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
@ -1645,6 +1647,8 @@ class SCAILWanModel(WanModel):
|
||||
scail_pose_seq_len = 0
|
||||
if pose_latents is not None:
|
||||
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
|
||||
if sam_latents is not None: # SCAIL-2 additive mask stream
|
||||
scail_x = scail_x + self.patch_embedding_mask(sam_latents.float()).to(x.dtype)
|
||||
scail_x = scail_x.flatten(2).transpose(1, 2)
|
||||
scail_pose_seq_len = scail_x.shape[1]
|
||||
x = torch.cat([x, scail_x], dim=1)
|
||||
@ -1695,7 +1699,36 @@ class SCAILWanModel(WanModel):
|
||||
|
||||
return x
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
|
||||
# ref_mask_flag is a scalar bool (CONDConstant, SCAIL-2 only). False => replacement mode,
|
||||
# which places ref/pose via H/W rope shifts instead of the animation-mode temporal offset.
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}):
|
||||
if ref_mask_flag is not None and not bool(ref_mask_flag):
|
||||
REF_ROPE_H = 120.0
|
||||
POSE_ROPE_W = 120.0
|
||||
|
||||
ref_t_patches = 0
|
||||
if reference_latent is not None:
|
||||
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
|
||||
main_t_patches = t - ref_t_patches
|
||||
|
||||
parts = []
|
||||
if ref_t_patches > 0:
|
||||
ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}}
|
||||
parts.append(super().rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf))
|
||||
if main_t_patches > 0:
|
||||
parts.append(super().rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options))
|
||||
|
||||
if pose_latents is not None:
|
||||
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
|
||||
h_scale = h / H_pose
|
||||
w_scale = w / W_pose
|
||||
h_shift = (h_scale - 1) / 2
|
||||
w_shift = (w_scale - 1) / 2
|
||||
pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
|
||||
parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf))
|
||||
|
||||
return torch.cat(parts, dim=1)
|
||||
|
||||
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
|
||||
|
||||
if pose_latents is None:
|
||||
@ -1719,12 +1752,16 @@ class SCAILWanModel(WanModel):
|
||||
|
||||
return torch.cat([main_freqs, pose_freqs], dim=1)
|
||||
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, ref_mask_latents=None, sam_latents=None, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
if pose_latents is not None:
|
||||
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
|
||||
if ref_mask_latents is not None: # SCAIL-2
|
||||
ref_mask_latents = comfy.ldm.common_dit.pad_to_patch_size(ref_mask_latents, self.patch_size)
|
||||
if sam_latents is not None: # SCAIL-2
|
||||
sam_latents = comfy.ldm.common_dit.pad_to_patch_size(sam_latents, self.patch_size)
|
||||
|
||||
t_len = t
|
||||
if time_dim_concat is not None:
|
||||
@ -1737,5 +1774,15 @@ class SCAILWanModel(WanModel):
|
||||
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
|
||||
t_len += reference_latent.shape[2]
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
|
||||
ref_mask_flag = kwargs.pop("ref_mask_flag", None) # SCAIL-2
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_flag=ref_mask_flag)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_latents=ref_mask_latents, sam_latents=sam_latents, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
|
||||
class SCAIL2WanModel(SCAILWanModel):
|
||||
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
|
||||
|
||||
def __init__(self, model_type="scail2", patch_size=(1, 2, 2), in_dim=20, mask_in_dim=28, dim=5120, operations=None, device=None, dtype=None, **kwargs):
|
||||
super().__init__(model_type=model_type, patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
|
||||
self.patch_embedding_mask = operations.Conv3d(mask_in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||
|
||||
@ -357,6 +357,12 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, (comfy.model_base.LTXV, comfy.model_base.LTXAV)):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
@ -54,8 +54,6 @@ import comfy.ldm.pixeldit.model
|
||||
import comfy.ldm.pixeldit.pid
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.seedvr.model
|
||||
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.ideogram4.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
@ -930,16 +928,6 @@ class HunyuanDiT(BaseModel):
|
||||
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
||||
return out
|
||||
|
||||
class SeedVR2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
condition = kwargs.get("condition", None)
|
||||
if condition is not None:
|
||||
out["condition"] = comfy.conds.CONDRegular(condition)
|
||||
return out
|
||||
|
||||
class PixArt(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
|
||||
@ -1766,6 +1754,80 @@ class WAN21_SCAIL(WAN21):
|
||||
|
||||
return out
|
||||
|
||||
class WAN21_SCAIL2(WAN21_SCAIL):
|
||||
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
|
||||
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAIL2WanModel)
|
||||
self.memory_usage_factor_conds = ("reference_latent", "pose_latents", "ref_mask_latents", "sam_latents")
|
||||
self.memory_usage_shape_process = {
|
||||
"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]],
|
||||
"sam_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]],
|
||||
}
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
|
||||
driving_mask_28ch = kwargs.get("driving_mask_28ch", None)
|
||||
if driving_mask_28ch is not None:
|
||||
out['sam_latents'] = comfy.conds.CONDRegular(driving_mask_28ch.movedim(1, 2).contiguous())
|
||||
|
||||
ref_mask_28ch = kwargs.get("ref_mask_28ch", None)
|
||||
if ref_mask_28ch is not None:
|
||||
out['ref_mask_latents'] = comfy.conds.CONDRegular(ref_mask_28ch.movedim(1, 2).contiguous())
|
||||
|
||||
ref_mask_flag = kwargs.get("ref_mask_flag", None)
|
||||
if ref_mask_flag is not None:
|
||||
out['ref_mask_flag'] = comfy.conds.CONDConstant(ref_mask_flag)
|
||||
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = super().extra_conds_shapes(**kwargs)
|
||||
driving_mask_28ch = kwargs.get("driving_mask_28ch", None)
|
||||
if driving_mask_28ch is not None:
|
||||
s = driving_mask_28ch.shape
|
||||
out['sam_latents'] = [s[0], 28, s[1], s[3], s[4]]
|
||||
ref_mask_28ch = kwargs.get("ref_mask_28ch", None)
|
||||
if ref_mask_28ch is not None:
|
||||
s = ref_mask_28ch.shape
|
||||
out['ref_mask_latents'] = [s[0], 28, s[1], s[3], s[4]]
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key in ("sam_latents", "pose_latents"):
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
# The 4 extra channels are the history_mask (1 at clean-anchor frames).
|
||||
noise = kwargs.get("noise", None)
|
||||
extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1]
|
||||
if extra_channels != 4:
|
||||
return super().concat_cond(**kwargs)
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
return torch.zeros_like(noise)[:, :4]
|
||||
|
||||
device = kwargs["device"]
|
||||
if mask.shape[1] != 4:
|
||||
mask = torch.mean(mask, dim=1, keepdim=True)
|
||||
mask = 1.0 - mask
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
if mask.shape[-3] < noise.shape[-3]:
|
||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||
if mask.shape[1] == 1:
|
||||
mask = mask.repeat(1, 4, 1, 1, 1)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
return mask
|
||||
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
# Hold anchor constant across all sigmas instead of base sigma*noise + (1-sigma)*latent_image.
|
||||
return latent_image
|
||||
|
||||
|
||||
class WAN22_WanDancer(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=True, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_wandancer.WanDancerModel)
|
||||
|
||||
@ -598,56 +598,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
return dit_config
|
||||
|
||||
if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
dit_config["vid_dim"] = 3072
|
||||
dit_config["heads"] = 24
|
||||
dit_config["num_layers"] = 36
|
||||
# 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.``
|
||||
# submodules) at EVERY block — verified by inspecting the 7B
|
||||
# state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means
|
||||
# ``MMModule.shared_weights=False``). Native NaDiT computes
|
||||
# per-block ``shared_weights = not (i < mm_layers)``, so to keep
|
||||
# every block non-shared we set ``mm_layers = num_layers``.
|
||||
# Without this, blocks at index >= mm_layers (default 10) try to
|
||||
# load ``blocks.N.*.all.*`` keys that don't exist in the file,
|
||||
# silently miss-load → all-black output.
|
||||
dit_config["mm_layers"] = 36
|
||||
dit_config["norm_eps"] = 1e-5
|
||||
dit_config["qk_rope"] = True
|
||||
dit_config["rope_type"] = "rope3d"
|
||||
dit_config["rope_dim"] = 64
|
||||
dit_config["mlp_type"] = "normal"
|
||||
return dit_config
|
||||
elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
dit_config["vid_dim"] = 3072
|
||||
dit_config["heads"] = 24
|
||||
dit_config["num_layers"] = 36
|
||||
# This checkpoint layout carries shared ``all.`` MMModule keys.
|
||||
# Preserve the historical split: the initial blocks use separate
|
||||
# vid/txt modules, later blocks use shared modules.
|
||||
dit_config["mm_layers"] = 10
|
||||
dit_config["norm_eps"] = 1e-5
|
||||
dit_config["qk_rope"] = True
|
||||
dit_config["rope_type"] = "rope3d"
|
||||
dit_config["rope_dim"] = 64
|
||||
dit_config["mlp_type"] = "swiglu"
|
||||
return dit_config
|
||||
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
dit_config["vid_dim"] = 2560
|
||||
dit_config["heads"] = 20
|
||||
dit_config["num_layers"] = 32
|
||||
dit_config["norm_eps"] = 1.0e-05
|
||||
dit_config["qk_rope"] = None
|
||||
dit_config["mlp_type"] = "swiglu"
|
||||
dit_config["vid_out_norm"] = True
|
||||
return dit_config
|
||||
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "wan2.1"
|
||||
@ -680,6 +630,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "humo"
|
||||
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "animate"
|
||||
elif '{}patch_embedding_mask.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "scail2"
|
||||
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "scail"
|
||||
elif '{}patch_embedding_global.weight'.format(key_prefix) in state_dict_keys:
|
||||
|
||||
@ -44,13 +44,7 @@ def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None,
|
||||
is_empty = torch.count_nonzero(latent_image) == 0
|
||||
if is_empty:
|
||||
if latent_format.latent_channels != latent_image.shape[1]:
|
||||
preserves_collapsed_channels = (
|
||||
getattr(latent_format, "preserve_empty_channel_multiples", False)
|
||||
and latent_image.ndim == 4
|
||||
and latent_image.shape[1] % latent_format.latent_channels == 0
|
||||
)
|
||||
if not preserves_collapsed_channels:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
if downscale_ratio_spacial is not None:
|
||||
if downscale_ratio_spacial != latent_format.spacial_downscale_ratio:
|
||||
ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio
|
||||
|
||||
237
comfy/sd.py
237
comfy/sd.py
@ -1,4 +1,3 @@
|
||||
import inspect
|
||||
import json
|
||||
import torch
|
||||
from enum import Enum
|
||||
@ -17,7 +16,6 @@ import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.seedvr.vae
|
||||
import comfy.ldm.triposplat.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.cogvideo.vae
|
||||
@ -86,36 +84,6 @@ import comfy.latent_formats
|
||||
|
||||
import comfy.ldm.flux.redux
|
||||
|
||||
SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL = 160
|
||||
|
||||
|
||||
def _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w):
|
||||
output_t = max(1, (latent_t - 1) * 4 + 1)
|
||||
return output_t * latent_h * 8 * latent_w * 8
|
||||
|
||||
|
||||
def _seedvr2_vae_decode_memory_used(shape):
|
||||
if len(shape) == 5:
|
||||
candidates = []
|
||||
if shape[1] == 16:
|
||||
candidates.append((shape[2], shape[3], shape[4]))
|
||||
if shape[-1] == 16:
|
||||
candidates.append((shape[1], shape[2], shape[3]))
|
||||
if len(candidates) == 0:
|
||||
candidates.append((shape[2], shape[3], shape[4]))
|
||||
output_pixels = max(_seedvr2_vae_decode_output_pixels(*candidate) for candidate in candidates)
|
||||
elif len(shape) == 4:
|
||||
latent_t = max(1, (shape[1] + 15) // 16)
|
||||
latent_h, latent_w = shape[2], shape[3]
|
||||
output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w)
|
||||
else:
|
||||
latent_t, latent_h, latent_w = 1, shape[-2], shape[-1]
|
||||
output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w)
|
||||
# SeedVR2 decode performs full-frame LAB histogram matching: fp32 channels
|
||||
# plus int64 sort indices dominate peak memory, not the VAE weight dtype.
|
||||
return output_pixels * SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL
|
||||
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
@ -499,10 +467,8 @@ class CLIP:
|
||||
|
||||
class VAE:
|
||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
||||
is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd
|
||||
if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
if metadata is None or metadata.get("keep_diffusers_format") != "true":
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
|
||||
if model_management.is_amd():
|
||||
VAE_KL_MEM_RATIO = 2.73
|
||||
@ -574,20 +540,6 @@ class VAE:
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
|
||||
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
|
||||
self.latent_channels = 16
|
||||
self.latent_dim = 3
|
||||
self.disable_offload = True
|
||||
self.memory_used_decode = lambda shape, dtype: _seedvr2_vae_decode_memory_used(shape)
|
||||
self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.crop_input = False
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
@ -715,7 +667,6 @@ class VAE:
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||
self.downscale_index_formula = (8, 32, 32)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
@ -1055,40 +1006,6 @@ class VAE:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||
|
||||
def decode_tiled_seedvr2(self, samples, tile_x=32, tile_y=32, overlap=8, tile_t=16, overlap_t=4):
|
||||
sf_s = getattr(self.first_stage_model, "spatial_downsample_factor", 8)
|
||||
sf_t = getattr(self.first_stage_model, "temporal_downsample_factor", 4)
|
||||
if tile_t is None:
|
||||
tile_t = 16
|
||||
if overlap_t is None:
|
||||
overlap_t = 4
|
||||
if tile_t > 0:
|
||||
temporal_size = tile_t * sf_t
|
||||
temporal_overlap = max(0, overlap_t) * sf_t
|
||||
else:
|
||||
temporal_size = 0
|
||||
temporal_overlap = 0
|
||||
args = {
|
||||
"enable_tiling": True,
|
||||
"tile_size": (tile_y * sf_s, tile_x * sf_s),
|
||||
"tile_overlap": (overlap * sf_s, overlap * sf_s),
|
||||
"temporal_size": temporal_size,
|
||||
"temporal_overlap": temporal_overlap,
|
||||
}
|
||||
output = self.first_stage_model.decode(
|
||||
samples.to(self.vae_dtype).to(self.device),
|
||||
seedvr2_tiling=args,
|
||||
)
|
||||
return self.process_output(output.to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
|
||||
|
||||
def _format_seedvr2_encoded_samples(self, samples):
|
||||
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
|
||||
if samples.ndim == 4:
|
||||
samples = samples.unsqueeze(2)
|
||||
samples = samples.contiguous()
|
||||
samples = samples * 0.9152
|
||||
return samples
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
@ -1125,36 +1042,6 @@ class VAE:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||
|
||||
def encode_tiled_seedvr2(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
if tile_y is None:
|
||||
tile_y = 512
|
||||
if tile_x is None:
|
||||
tile_x = 512
|
||||
if overlap is None:
|
||||
overlap_y = 64
|
||||
overlap_x = 64
|
||||
else:
|
||||
overlap_y = overlap
|
||||
overlap_x = overlap
|
||||
if tile_t is None:
|
||||
tile_t = 9999
|
||||
if overlap_t is None:
|
||||
overlap_t = 0
|
||||
overlap_y = min(overlap_y, max(0, tile_y - 8))
|
||||
overlap_x = min(overlap_x, max(0, tile_x - 8))
|
||||
self.first_stage_model.device = self.device
|
||||
x = self.process_input(pixel_samples).to(self.vae_dtype).to(self.device)
|
||||
output = comfy.ldm.seedvr.vae.tiled_vae(
|
||||
x,
|
||||
self.first_stage_model,
|
||||
tile_size=(tile_y, tile_x),
|
||||
tile_overlap=(overlap_y, overlap_x),
|
||||
temporal_size=tile_t,
|
||||
temporal_overlap=overlap_t,
|
||||
encode=True,
|
||||
)
|
||||
return output.to(device=self.output_device, dtype=self.vae_output_dtype())
|
||||
|
||||
def decode(self, samples_in, vae_options={}):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = None
|
||||
@ -1202,40 +1089,16 @@ class VAE:
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
elif dims == 2:
|
||||
# SeedVR2 latents arrive in 4D collapsed form ``(B, 16*T, H, W)``
|
||||
# downstream of ``SeedVR2Conditioning`` (which performs the
|
||||
# ``rearrange(b c t h w -> b (c t) h w)`` collapse). The
|
||||
# generic ``decode_tiled_`` would treat the channel dim as
|
||||
# spatial-only and crash on the collapsed (16, T) layout
|
||||
# under ``tiled_scale``'s mask broadcast; route SeedVR2 4D
|
||||
# latents to ``decode_tiled_seedvr2`` instead, whose wrapper
|
||||
# dispatch handles both 4D and 5D inputs.
|
||||
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
else:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
elif dims == 3:
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
|
||||
pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
else:
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(
|
||||
self,
|
||||
samples,
|
||||
tile_x=None,
|
||||
tile_y=None,
|
||||
overlap=None,
|
||||
tile_t=None,
|
||||
overlap_t=None,
|
||||
):
|
||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
self.throw_exception_if_invalid()
|
||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
@ -1249,20 +1112,7 @@ class VAE:
|
||||
args["overlap"] = overlap
|
||||
|
||||
with model_management.cuda_device_context(self.device):
|
||||
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper) and dims in (2, 3):
|
||||
seedvr2_args = {}
|
||||
if tile_x is not None:
|
||||
seedvr2_args["tile_x"] = tile_x
|
||||
if tile_y is not None:
|
||||
seedvr2_args["tile_y"] = tile_y
|
||||
if overlap is not None:
|
||||
seedvr2_args["overlap"] = overlap
|
||||
if tile_t is not None:
|
||||
seedvr2_args["tile_t"] = tile_t
|
||||
if overlap_t is not None:
|
||||
seedvr2_args["overlap_t"] = overlap_t
|
||||
output = self.decode_tiled_seedvr2(samples, **seedvr2_args)
|
||||
elif dims == 1 or self.extra_1d_channel is not None:
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
@ -1304,8 +1154,6 @@ class VAE:
|
||||
else:
|
||||
pixels_in = pixels_in.to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
@ -1325,23 +1173,20 @@ class VAE:
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
|
||||
samples = self.encode_tiled_seedvr2(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
else:
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
|
||||
return self._format_seedvr2_encoded_samples(samples)
|
||||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
dims = self.latent_dim
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
if dims == 3 and pixel_samples.ndim < 5:
|
||||
if dims == 3:
|
||||
if not self.not_video:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
else:
|
||||
@ -1365,47 +1210,22 @@ class VAE:
|
||||
elif dims == 2:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
|
||||
seedvr2_args = {}
|
||||
if tile_x is not None:
|
||||
seedvr2_args["tile_x"] = tile_x
|
||||
else:
|
||||
seedvr2_args["tile_x"] = 512
|
||||
if tile_y is not None:
|
||||
seedvr2_args["tile_y"] = tile_y
|
||||
else:
|
||||
seedvr2_args["tile_y"] = 512
|
||||
if overlap is not None:
|
||||
seedvr2_args["overlap"] = overlap
|
||||
else:
|
||||
seedvr2_args["overlap"] = 64
|
||||
if tile_t is not None:
|
||||
seedvr2_args["tile_t"] = tile_t
|
||||
else:
|
||||
seedvr2_args["tile_t"] = 9999
|
||||
if overlap_t is not None:
|
||||
seedvr2_args["overlap_t"] = overlap_t
|
||||
else:
|
||||
seedvr2_args["overlap_t"] = 0
|
||||
samples = self.encode_tiled_seedvr2(pixel_samples, **seedvr2_args)
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
else:
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
else:
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
|
||||
spatial_overlap = overlap if overlap is not None else 64
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, spatial_overlap, spatial_overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), spatial_overlap, spatial_overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
|
||||
return self._format_seedvr2_encoded_samples(samples)
|
||||
return samples
|
||||
|
||||
def get_sd(self):
|
||||
return self.first_stage_model.state_dict()
|
||||
@ -1932,17 +1752,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
|
||||
return (model, clip, vae)
|
||||
|
||||
|
||||
def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device):
|
||||
set_dtype = model_config.set_inference_dtype
|
||||
parameters = inspect.signature(set_dtype).parameters
|
||||
supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values())
|
||||
if supports_device:
|
||||
set_dtype(dtype, manual_cast_dtype, device=device)
|
||||
else:
|
||||
set_dtype(dtype, manual_cast_dtype)
|
||||
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||
@ -2050,7 +1859,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
if model_config.clip_vision_prefix is not None:
|
||||
if output_clipvision:
|
||||
@ -2191,7 +2000,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
if custom_operations is not None:
|
||||
model_config.custom_operations = custom_operations
|
||||
|
||||
@ -1450,6 +1450,17 @@ class WAN21_SCAIL(WAN21_T2V):
|
||||
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
|
||||
class WAN21_SCAIL2(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "scail2",
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_SCAIL2(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN22_WanDancer(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -1672,35 +1683,6 @@ class Chroma(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||
|
||||
class SeedVR2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "seedvr2"
|
||||
}
|
||||
latent_format = comfy.latent_formats.SeedVR2
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
sampling_settings = {
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
|
||||
if (
|
||||
dtype == torch.float16
|
||||
and manual_cast_dtype is None
|
||||
and comfy.model_management.should_use_bf16(device)
|
||||
):
|
||||
manual_cast_dtype = torch.bfloat16
|
||||
super().set_inference_dtype(dtype, manual_cast_dtype, device=device)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SeedVR2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
unet_config = {
|
||||
"image_model": "chroma_radiance",
|
||||
@ -2058,6 +2040,7 @@ class LongCatImage(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
||||
|
||||
|
||||
class RT_DETR_v4(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "RT_DETR_v4",
|
||||
@ -2287,6 +2270,7 @@ models = [
|
||||
WAN22_Animate,
|
||||
WAN21_FlowRVS,
|
||||
WAN21_SCAIL,
|
||||
WAN21_SCAIL2,
|
||||
WAN22_WanDancer,
|
||||
Hunyuan3Dv2mini,
|
||||
Hunyuan3Dv2,
|
||||
@ -2295,7 +2279,6 @@ models = [
|
||||
HiDream,
|
||||
HiDreamO1,
|
||||
Chroma,
|
||||
SeedVR2,
|
||||
ChromaRadiance,
|
||||
ACEStep,
|
||||
ACEStep15,
|
||||
|
||||
@ -115,7 +115,7 @@ class BASE:
|
||||
replace_prefix = {"": self.vae_key_prefix[0]}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
||||
self.unet_config['dtype'] = dtype
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
|
||||
|
||||
@ -32,7 +32,9 @@ class Ideogram4Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||
if llama_template is None:
|
||||
if text.startswith('<|im_start|>'):
|
||||
llama_text = text
|
||||
elif llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
|
||||
321
comfy_extras/nodes_scail.py
Normal file
321
comfy_extras/nodes_scail.py
Normal file
@ -0,0 +1,321 @@
|
||||
"""SCAIL / SCAIL-2 nodes: the WanSCAILToVideo conditioning node and the SAM3
|
||||
preprocessing that turns video tracks into the bundle the SCAIL-2 model consumes."""
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import nodes
|
||||
import node_helpers
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy.ldm.sam3.tracker import unpack_masks
|
||||
|
||||
SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
|
||||
|
||||
|
||||
# Model was trained on these exact colors; deviating degrades multi-identity quality.
|
||||
DEFAULT_PALETTE = [
|
||||
(0.0, 0.0, 1.0), # Blue
|
||||
(1.0, 0.0, 0.0), # Red
|
||||
(0.0, 1.0, 0.0), # Green
|
||||
(1.0, 0.0, 1.0), # Magenta
|
||||
(0.0, 1.0, 1.0), # Cyan
|
||||
(1.0, 1.0, 0.0), # Yellow
|
||||
]
|
||||
|
||||
|
||||
def _unpack(track_data):
|
||||
packed = track_data["packed_masks"]
|
||||
if packed is None or packed.shape[1] == 0:
|
||||
return None
|
||||
return unpack_masks(packed)
|
||||
|
||||
|
||||
def _first_frame_cx_area(masks_bool):
|
||||
first = masks_bool[0].float()
|
||||
H, W = first.shape[-2], first.shape[-1]
|
||||
n_pixels = H * W
|
||||
grid_x = torch.arange(W, device=first.device, dtype=first.dtype).view(1, W)
|
||||
area = first.sum(dim=(-1, -2)).clamp_(min=1)
|
||||
cx = (first * grid_x).sum(dim=(-1, -2)) / area
|
||||
return (cx / W).tolist(), (area / n_pixels).tolist()
|
||||
|
||||
|
||||
def _subset_track_data(track_data, obj_indices):
|
||||
out = dict(track_data)
|
||||
packed = track_data["packed_masks"]
|
||||
if packed is None or not obj_indices:
|
||||
out["packed_masks"] = None
|
||||
if "scores" in out:
|
||||
out["scores"] = []
|
||||
return out
|
||||
out["packed_masks"] = packed[:, obj_indices].contiguous()
|
||||
scores = track_data.get("scores")
|
||||
if scores is not None:
|
||||
out["scores"] = [scores[i] for i in obj_indices if i < len(scores)]
|
||||
return out
|
||||
|
||||
|
||||
def _render_colored_masks(track_data, background="black"):
|
||||
packed = track_data["packed_masks"]
|
||||
H, W = track_data["orig_size"]
|
||||
device = comfy.model_management.intermediate_device()
|
||||
dtype = comfy.model_management.intermediate_dtype()
|
||||
bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0)
|
||||
if packed is None or packed.shape[1] == 0:
|
||||
T = track_data.get("n_frames", 1) if packed is None else packed.shape[0]
|
||||
out = torch.empty(T, H, W, 3, device=device, dtype=dtype)
|
||||
out[..., 0], out[..., 1], out[..., 2] = bg_rgb[0], bg_rgb[1], bg_rgb[2]
|
||||
return out
|
||||
T, N_obj = packed.shape[0], packed.shape[1]
|
||||
colors = torch.tensor(
|
||||
[DEFAULT_PALETTE[i % len(DEFAULT_PALETTE)] for i in range(N_obj)],
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
masks_full = unpack_masks(packed.to(device)).float()
|
||||
Hm, Wm = masks_full.shape[-2], masks_full.shape[-1]
|
||||
masks_full = F.interpolate(
|
||||
masks_full.view(T * N_obj, 1, Hm, Wm), size=(H, W), mode="nearest"
|
||||
).view(T, N_obj, H, W) > 0.5
|
||||
any_mask = masks_full.any(dim=1)
|
||||
obj_idx_map = masks_full.to(torch.uint8).argmax(dim=1)
|
||||
color_overlay = colors[obj_idx_map]
|
||||
bg_tensor = torch.tensor(bg_rgb, device=device, dtype=color_overlay.dtype).view(1, 1, 1, 3)
|
||||
return torch.where(any_mask.unsqueeze(-1), color_overlay, bg_tensor.expand_as(color_overlay))
|
||||
|
||||
|
||||
def _extract_mask_to_28ch(rgb_video):
|
||||
"""Colored RGB mask (T, H, W, 3) in [0, 1] -> SCAIL-2 28-channel binary latent
|
||||
(1, T_lat, 28, H_lat, W_lat). 7 per-color binary channels (white/r/g/b/y/m/c)
|
||||
threshold-extracted at 225/255, 8x spatial downsample, 4-frame temporal stacking."""
|
||||
T, H, W, _ = rgb_video.shape
|
||||
_ON_THRESH = 225.0 / 255.0
|
||||
mask = rgb_video.movedim(-1, 1).float()
|
||||
R = (mask[:, 0:1] > _ON_THRESH).float()
|
||||
G = (mask[:, 1:2] > _ON_THRESH).float()
|
||||
B = (mask[:, 2:3] > _ON_THRESH).float()
|
||||
nR, nG, nB = 1 - R, 1 - G, 1 - B
|
||||
binary_7ch = torch.cat([
|
||||
R * G * B, # white
|
||||
R * nG * nB, # red
|
||||
nR * G * nB, # green
|
||||
nR * nG * B, # blue
|
||||
R * G * nB, # yellow
|
||||
R * nG * B, # magenta
|
||||
nR * G * B, # cyan
|
||||
], dim=1)
|
||||
H_lat, W_lat = H, W
|
||||
for _ in range(3):
|
||||
H_lat = (H_lat + 1) // 2
|
||||
W_lat = (W_lat + 1) // 2
|
||||
binary_7ch = torch.nn.functional.interpolate(binary_7ch, size=(H_lat, W_lat), mode='area')
|
||||
T_latent = (T - 1) // 4 + 1
|
||||
padded = torch.cat([binary_7ch[:1].repeat(4, 1, 1, 1), binary_7ch[1:]], dim=0)
|
||||
out = padded.view(T_latent, 28, H_lat, W_lat)
|
||||
return out.unsqueeze(0)
|
||||
|
||||
|
||||
class WanSCAILToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanSCAILToVideo",
|
||||
category="model/conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
|
||||
io.Image.Input("pose_video_mask", optional=True, tooltip="SCAIL-2 only. Colored per-identity SAM3 mask video at the same resolution as pose_video."),
|
||||
io.Boolean.Input("replacement_mode", default=False, optional=True, tooltip="SCAIL-2 only. False = Animation Mode (pose_video_mask should have black background). True = Replacement Mode (pose_video_mask should have white background)."),
|
||||
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
|
||||
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step of the pose conditioning."),
|
||||
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step of the pose conditioning."),
|
||||
io.Image.Input("reference_image", optional=True, tooltip="Reference image, for multiple references composite all on single image."),
|
||||
io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Colored reference mask at the same resolution as reference_image."),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="CLIP vision features for conditioning. Model is trained with stretch resize to aspect ratio."),
|
||||
io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="Cumulative output frame this chunk begins at. Wire from the previous chunk's video_frame_offset output."),
|
||||
io.Int.Input("previous_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="Tail frames of previous_frames to anchor. SCAIL-2 trained at 5 (81-frame chunks, 76-frame step)."),
|
||||
io.Image.Input("previous_frames", optional=True, tooltip="SCAIL-2 only. Full decoded output of the previous chunk. Only the last previous_frame_count are used as the extension anchor."),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
|
||||
io.Int.Output(display_name="video_frame_offset", tooltip="Adjusted offset + length. Wire into the next chunk."),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end,
|
||||
video_frame_offset, previous_frame_count, replacement_mode=False, reference_image=None, clip_vision_output=None, pose_video=None,
|
||||
pose_video_mask=None, reference_image_mask=None, previous_frames=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
noise_mask = None
|
||||
|
||||
ref_mask_flag = not replacement_mode
|
||||
positive = node_helpers.conditioning_set_values(positive, {"ref_mask_flag": ref_mask_flag})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"ref_mask_flag": ref_mask_flag})
|
||||
|
||||
prev_trimmed = None
|
||||
if previous_frames is not None and previous_frames.shape[0] > 0:
|
||||
prev_trimmed = previous_frames[-previous_frame_count:]
|
||||
video_frame_offset -= prev_trimmed.shape[0]
|
||||
video_frame_offset = max(0, video_frame_offset)
|
||||
|
||||
ref_latent = None
|
||||
if reference_image is not None:
|
||||
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
|
||||
# Replacement Mode: composite ref on black bg using reference_image_mask as alpha matte
|
||||
if replacement_mode and reference_image_mask is not None:
|
||||
rm = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1)
|
||||
is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(reference_image.dtype)
|
||||
reference_image = reference_image * is_char
|
||||
ref_latent = vae.encode(reference_image[:, :, :, :3])
|
||||
|
||||
if ref_latent is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
if pose_video is not None:
|
||||
if pose_video.shape[0] <= video_frame_offset:
|
||||
pose_video = None
|
||||
else:
|
||||
pose_video = pose_video[video_frame_offset:]
|
||||
if pose_video_mask is not None:
|
||||
if pose_video_mask.shape[0] <= video_frame_offset:
|
||||
pose_video_mask = None
|
||||
else:
|
||||
pose_video_mask = pose_video_mask[video_frame_offset:]
|
||||
|
||||
# Truncate pose+mask jointly to the shorter of the two, capped at length.
|
||||
ts = [v.shape[0] for v in (pose_video, pose_video_mask) if v is not None]
|
||||
if ts:
|
||||
T_kept = ((min(min(ts), length) - 1) // 4) * 4 + 1
|
||||
if pose_video is not None:
|
||||
pose_video = pose_video[:T_kept]
|
||||
if pose_video_mask is not None:
|
||||
pose_video_mask = pose_video_mask[:T_kept]
|
||||
|
||||
if pose_video is not None:
|
||||
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
|
||||
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
|
||||
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||
|
||||
if pose_video_mask is not None:
|
||||
mask_video_hw = comfy.utils.common_upscale(pose_video_mask[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
|
||||
driving_mask_28ch = _extract_mask_to_28ch(mask_video_hw)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"driving_mask_28ch": driving_mask_28ch})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"driving_mask_28ch": driving_mask_28ch})
|
||||
|
||||
if reference_image_mask is not None:
|
||||
ref_mask_hw = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
|
||||
ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw)
|
||||
zeros = torch.zeros((1, latent.shape[2], 28, ref_mask_1f.shape[-2], ref_mask_1f.shape[-1]), device=ref_mask_1f.device, dtype=ref_mask_1f.dtype)
|
||||
ref_mask_28ch = torch.cat([ref_mask_1f, zeros], dim=1)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"ref_mask_28ch": ref_mask_28ch})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"ref_mask_28ch": ref_mask_28ch})
|
||||
|
||||
if prev_trimmed is not None:
|
||||
pf = comfy.utils.common_upscale(prev_trimmed.movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
|
||||
prev_latent = vae.encode(pf[:, :, :, :3])
|
||||
prev_latent_frames = min(prev_latent.shape[2], latent.shape[2])
|
||||
latent[:, :, :prev_latent_frames] = prev_latent[:, :, :prev_latent_frames].to(latent.dtype)
|
||||
noise_mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=latent.device, dtype=latent.dtype)
|
||||
noise_mask[:, :, :prev_latent_frames] = 0.0
|
||||
|
||||
out_latent = {"samples": latent}
|
||||
if noise_mask is not None:
|
||||
out_latent["noise_mask"] = noise_mask
|
||||
return io.NodeOutput(positive, negative, out_latent, video_frame_offset + length)
|
||||
|
||||
|
||||
class SCAIL2ColoredMask(io.ComfyNode):
|
||||
"""Render SAM3 tracks for the driving pose video and (optionally) the reference
|
||||
image into the two colored masks WanSCAILToVideo consumes. Shared `sort_by`
|
||||
across both outputs guarantees identity K maps to the same color on both
|
||||
sides, for multi-person workflow consistency.
|
||||
reference_image_mask is always rendered black-bg (model convention)
|
||||
pose_video_mask bg follows replacement_mode: black = Animation Mode, white = Replacement Mode
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SCAIL2ColoredMask",
|
||||
display_name="Create SCAIL-2 Colored Mask",
|
||||
category="conditioning/video_models/scail",
|
||||
inputs=[
|
||||
SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving pose video. Will be rendered into the pose_video_mask output."),
|
||||
SAM3TrackData.Input("ref_track_data", optional=True,
|
||||
tooltip="SAM3 track of the reference image."),
|
||||
io.String.Input("object_indices", default="",
|
||||
tooltip="Comma-separated list of person indices to include (e.g. '0,2,3'). Applied to both reference and pose video masks. Empty = all."),
|
||||
io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right",
|
||||
tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). left_to_right = leftmost object (by first-frame centroid) gets the first color; area = biggest object (by first-frame mask area) gets the first color; none = keep SAM3's order."),
|
||||
io.Boolean.Input("replacement_mode", default=False,
|
||||
tooltip="False = mask_video has black bg (Animation Mode). True = white bg (Replacement Mode). Set the matching replacement_mode on WanSCAILToVideo. reference_image_mask is always black-bg regardless."),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output("pose_video_mask"),
|
||||
io.Image.Output("reference_image_mask"),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, driving_track_data, object_indices, sort_by, replacement_mode, ref_track_data=None):
|
||||
def _prep(td):
|
||||
masks_bool = _unpack(td)
|
||||
if sort_by != "none" and masks_bool is not None:
|
||||
cx, area = _first_frame_cx_area(masks_bool)
|
||||
if sort_by == "left_to_right":
|
||||
order = sorted(range(len(cx)), key=lambda i: cx[i])
|
||||
else: # "area"
|
||||
order = sorted(range(len(area)), key=lambda i: -area[i])
|
||||
td = _subset_track_data(td, order)
|
||||
if object_indices.strip():
|
||||
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
|
||||
packed = td.get("packed_masks")
|
||||
n_obj = packed.shape[1] if packed is not None else 0
|
||||
indices = [i for i in indices if 0 <= i < n_obj]
|
||||
td = _subset_track_data(td, indices)
|
||||
return td
|
||||
|
||||
drv = _prep(driving_track_data)
|
||||
mask_video = _render_colored_masks(drv, "white" if replacement_mode else "black")
|
||||
|
||||
if ref_track_data is not None:
|
||||
ref = _prep(ref_track_data)
|
||||
reference_image_mask = _render_colored_masks(ref, "black")
|
||||
else:
|
||||
H, W = drv["orig_size"]
|
||||
reference_image_mask = torch.zeros(1, H, W, 3, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
|
||||
return io.NodeOutput(mask_video, reference_image_mask)
|
||||
|
||||
|
||||
class SCAILExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
WanSCAILToVideo,
|
||||
SCAIL2ColoredMask,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> SCAILExtension:
|
||||
return SCAILExtension()
|
||||
File diff suppressed because it is too large
Load Diff
@ -1456,63 +1456,6 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
|
||||
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
|
||||
|
||||
|
||||
class WanSCAILToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanSCAILToVideo",
|
||||
category="model/conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Image.Input("reference_image", optional=True),
|
||||
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
|
||||
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
|
||||
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."),
|
||||
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
|
||||
ref_latent = None
|
||||
if reference_image is not None:
|
||||
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
ref_latent = vae.encode(reference_image[:, :, :, :3])
|
||||
|
||||
if ref_latent is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True)
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
if pose_video is not None:
|
||||
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
|
||||
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
|
||||
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
class WanExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@ -1533,7 +1476,6 @@ class WanExtension(ComfyExtension):
|
||||
WanAnimateToVideo,
|
||||
Wan22ImageToVideoLatent,
|
||||
WanInfiniteTalkToVideo,
|
||||
WanSCAILToVideo,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> WanExtension:
|
||||
|
||||
15
main.py
15
main.py
@ -26,6 +26,7 @@ import utils.extra_config
|
||||
from utils.mime_types import init_mime_types
|
||||
import faulthandler
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
@ -37,7 +38,19 @@ if __name__ == "__main__":
|
||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
||||
os.environ['DO_NOT_TRACK'] = '1'
|
||||
|
||||
faulthandler.enable(file=sys.stderr, all_threads=False)
|
||||
faulthandler.enable(file=sys.stderr, all_threads=args.debug_hang)
|
||||
if __name__ == "__main__" and args.debug_hang:
|
||||
dumping_traceback = False
|
||||
|
||||
def dump_traceback_on_sigint(signum, frame):
|
||||
global dumping_traceback
|
||||
if dumping_traceback:
|
||||
raise KeyboardInterrupt
|
||||
dumping_traceback = True
|
||||
faulthandler.dump_traceback(file=sys.stderr, all_threads=True)
|
||||
raise KeyboardInterrupt
|
||||
|
||||
signal.signal(signal.SIGINT, dump_traceback_on_sigint)
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
|
||||
43
nodes.py
43
nodes.py
@ -47,18 +47,14 @@ import node_helpers
|
||||
if args.enable_manager:
|
||||
import comfyui_manager
|
||||
|
||||
|
||||
def before_node_execution():
|
||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
|
||||
def interrupt_processing(value=True):
|
||||
comfy.model_management.interrupt_current_processing(value)
|
||||
|
||||
|
||||
MAX_RESOLUTION=16384
|
||||
|
||||
|
||||
class CLIPTextEncode(ComfyNodeABC):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s) -> InputTypeDict:
|
||||
@ -327,8 +323,8 @@ class VAEDecodeTiled:
|
||||
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
|
||||
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32, "advanced": True}),
|
||||
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}),
|
||||
"temporal_size": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time. SeedVR2 allows 0 to disable temporal slicing.", "advanced": True}),
|
||||
"temporal_overlap": ("INT", {"default": 8, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}),
|
||||
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time.", "advanced": True}),
|
||||
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "decode"
|
||||
@ -338,32 +334,18 @@ class VAEDecodeTiled:
|
||||
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
|
||||
if tile_size < overlap * 4:
|
||||
overlap = tile_size // 4
|
||||
if temporal_size < temporal_overlap * 2:
|
||||
temporal_overlap = temporal_overlap // 2
|
||||
temporal_compression = vae.temporal_compression_decode()
|
||||
if temporal_compression is not None:
|
||||
if temporal_size <= 0:
|
||||
temporal_size = 0
|
||||
temporal_overlap = 0
|
||||
else:
|
||||
requested_temporal_overlap = temporal_overlap
|
||||
if temporal_size < temporal_overlap * 2:
|
||||
temporal_overlap = temporal_overlap // 2
|
||||
temporal_size = max(2, temporal_size // temporal_compression)
|
||||
temporal_overlap = min(temporal_size // 2, temporal_overlap // temporal_compression)
|
||||
if requested_temporal_overlap > 0:
|
||||
temporal_overlap = max(1, temporal_overlap)
|
||||
temporal_size = max(2, temporal_size // temporal_compression)
|
||||
temporal_overlap = max(1, min(temporal_size // 2, temporal_overlap // temporal_compression))
|
||||
else:
|
||||
temporal_size = None
|
||||
temporal_overlap = None
|
||||
|
||||
compression = vae.spacial_compression_decode()
|
||||
images = vae.decode_tiled(
|
||||
samples["samples"],
|
||||
tile_x=tile_size // compression,
|
||||
tile_y=tile_size // compression,
|
||||
overlap=overlap // compression,
|
||||
tile_t=temporal_size,
|
||||
overlap_t=temporal_overlap,
|
||||
)
|
||||
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||
if len(images.shape) == 5: #Combine batches
|
||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||
return (images, )
|
||||
@ -380,7 +362,7 @@ class VAEEncode:
|
||||
|
||||
def encode(self, vae, pixels):
|
||||
t = vae.encode(pixels)
|
||||
return ({"samples": t}, )
|
||||
return ({"samples":t}, )
|
||||
|
||||
class VAEEncodeTiled:
|
||||
@classmethod
|
||||
@ -388,8 +370,8 @@ class VAEEncodeTiled:
|
||||
return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
|
||||
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64, "advanced": True}),
|
||||
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}),
|
||||
"temporal_size": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time. SeedVR2 allows 0 to disable temporal slicing.", "advanced": True}),
|
||||
"temporal_overlap": ("INT", {"default": 8, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}),
|
||||
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time.", "advanced": True}),
|
||||
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
@ -397,9 +379,6 @@ class VAEEncodeTiled:
|
||||
CATEGORY = "experimental"
|
||||
|
||||
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
|
||||
if temporal_size <= 0:
|
||||
temporal_size = 0
|
||||
temporal_overlap = 0
|
||||
t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||
return ({"samples": t}, )
|
||||
|
||||
@ -2439,7 +2418,6 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_camera_trajectory.py",
|
||||
"nodes_edit_model.py",
|
||||
"nodes_tcfg.py",
|
||||
"nodes_seedvr.py",
|
||||
"nodes_context_windows.py",
|
||||
"nodes_qwen.py",
|
||||
"nodes_chroma_radiance.py",
|
||||
@ -2472,6 +2450,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_rtdetr.py",
|
||||
"nodes_frame_interpolation.py",
|
||||
"nodes_sam3.py",
|
||||
"nodes_scail.py",
|
||||
"nodes_void.py",
|
||||
"nodes_wandancer.py",
|
||||
"nodes_hidream_o1.py",
|
||||
|
||||
229
openapi.yaml
229
openapi.yaml
@ -3,11 +3,6 @@ components:
|
||||
Asset:
|
||||
description: Represents a user-owned asset (image, video, or other generated output).
|
||||
properties:
|
||||
asset_hash:
|
||||
deprecated: true
|
||||
description: 'Deprecated: use hash instead. Blake3 hash of the asset content.'
|
||||
pattern: ^blake3:[a-f0-9]{64}$
|
||||
type: string
|
||||
created_at:
|
||||
description: Timestamp when the asset was created
|
||||
format: date-time
|
||||
@ -16,8 +11,12 @@ components:
|
||||
description: Display name of the asset. Mirrors name for backwards compatibility.
|
||||
nullable: true
|
||||
type: string
|
||||
file_path:
|
||||
description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors")
|
||||
nullable: true
|
||||
type: string
|
||||
hash:
|
||||
description: Blake3 hash of the asset content. Preferred over asset_hash.
|
||||
description: Blake3 hash of the asset content.
|
||||
pattern: ^blake3:[a-f0-9]{64}$
|
||||
type: string
|
||||
id:
|
||||
@ -139,17 +138,16 @@ components:
|
||||
AssetUpdated:
|
||||
description: Response returned when an existing asset is successfully updated.
|
||||
properties:
|
||||
asset_hash:
|
||||
deprecated: true
|
||||
description: 'Deprecated: use hash instead. Blake3 hash of the asset content.'
|
||||
pattern: ^blake3:[a-f0-9]{64}$
|
||||
type: string
|
||||
display_name:
|
||||
description: Display name of the asset. Mirrors name for backwards compatibility.
|
||||
nullable: true
|
||||
type: string
|
||||
file_path:
|
||||
description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors")
|
||||
nullable: true
|
||||
type: string
|
||||
hash:
|
||||
description: Blake3 hash of the asset content. Preferred over asset_hash.
|
||||
description: Blake3 hash of the asset content.
|
||||
pattern: ^blake3:[a-f0-9]{64}$
|
||||
type: string
|
||||
id:
|
||||
@ -828,7 +826,11 @@ components:
|
||||
type: string
|
||||
type: object
|
||||
PaginationInfo:
|
||||
description: Offset/limit-based pagination metadata included in list responses.
|
||||
description: |
|
||||
Pagination metadata included in list responses. Supports both legacy
|
||||
offset/limit pagination and cursor-based pagination. When cursor-based
|
||||
pagination is used, `next_cursor` is the primary pagination token and
|
||||
`offset`/`total` may be zero.
|
||||
properties:
|
||||
has_more:
|
||||
description: Whether more items are available beyond this page
|
||||
@ -837,12 +839,19 @@ components:
|
||||
description: Items per page
|
||||
minimum: 1
|
||||
type: integer
|
||||
next_cursor:
|
||||
description: |
|
||||
Opaque cursor for the next page. Pass this value as the `after`
|
||||
query parameter on the next request. Empty or absent when there
|
||||
are no more results.
|
||||
type: string
|
||||
offset:
|
||||
description: Current offset (0-based)
|
||||
deprecated: true
|
||||
description: 'Current offset (0-based). Deprecated: use cursor-based pagination.'
|
||||
minimum: 0
|
||||
type: integer
|
||||
total:
|
||||
description: Total number of items matching filters
|
||||
description: Total number of items matching filters (may be 0 when using cursor pagination)
|
||||
minimum: 0
|
||||
type: integer
|
||||
required:
|
||||
@ -1518,17 +1527,11 @@ paths:
|
||||
schema:
|
||||
default: true
|
||||
type: boolean
|
||||
- description: Filter assets by exact content hash. Preferred over asset_hash.
|
||||
- description: Filter assets by exact content hash.
|
||||
in: query
|
||||
name: hash
|
||||
schema:
|
||||
type: string
|
||||
- deprecated: true
|
||||
description: 'Deprecated: use hash instead. Filter assets by exact content hash.'
|
||||
in: query
|
||||
name: asset_hash
|
||||
schema:
|
||||
type: string
|
||||
- description: |
|
||||
Opaque cursor for keyset pagination. Pass the `next_cursor` value
|
||||
from the previous response to fetch the next page. When provided,
|
||||
@ -1571,42 +1574,12 @@ paths:
|
||||
- file
|
||||
post:
|
||||
description: |
|
||||
Uploads a new asset to the system with associated metadata.
|
||||
Supports two upload methods:
|
||||
1. Direct file upload (multipart/form-data)
|
||||
2. URL-based upload (application/json with source: "url")
|
||||
Creates a new asset from a direct file upload (multipart/form-data) with associated metadata.
|
||||
|
||||
If an asset with the same hash already exists, returns the existing asset.
|
||||
operationId: uploadAsset
|
||||
operationId: createAsset
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
name:
|
||||
description: Display name for the asset (used to determine file extension)
|
||||
type: string
|
||||
preview_id:
|
||||
description: Optional preview asset ID
|
||||
format: uuid
|
||||
type: string
|
||||
tags:
|
||||
description: Freeform tags for the asset. Common types include "models", "input", "output", and "temp", but any tag can be used in any order.
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
url:
|
||||
description: HTTP/HTTPS URL to download the asset from
|
||||
format: uri
|
||||
type: string
|
||||
user_metadata:
|
||||
additionalProperties: true
|
||||
description: Custom metadata to store with the asset
|
||||
type: object
|
||||
required:
|
||||
- url
|
||||
- name
|
||||
type: object
|
||||
multipart/form-data:
|
||||
schema:
|
||||
properties:
|
||||
@ -1614,6 +1587,10 @@ paths:
|
||||
description: The asset file to upload
|
||||
format: binary
|
||||
type: string
|
||||
hash:
|
||||
description: Content hash of the file.
|
||||
pattern: ^(blake3|sha256):[a-f0-9]{64}$
|
||||
type: string
|
||||
id:
|
||||
description: Optional asset ID for idempotent creation. If provided and asset exists, returns existing asset.
|
||||
format: uuid
|
||||
@ -1629,10 +1606,8 @@ paths:
|
||||
format: uuid
|
||||
type: string
|
||||
tags:
|
||||
description: Freeform tags for the asset. Common types include "models", "input", "output", and "temp", but any tag can be used in any order.
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
description: JSON-encoded array of freeform tag strings, e.g. '["models","checkpoint"]'. Common types include "models", "input", "output", and "temp", but any tag can be used in any order.
|
||||
type: string
|
||||
user_metadata:
|
||||
description: Custom JSON metadata as a string
|
||||
type: string
|
||||
@ -1641,36 +1616,32 @@ paths:
|
||||
type: object
|
||||
required: true
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/AssetCreated'
|
||||
description: |
|
||||
Asset already existed for this user (deduplicated by content hash); the
|
||||
existing asset is returned with created_new=false.
|
||||
"201":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/AssetCreated'
|
||||
description: Asset created successfully
|
||||
description: Asset created successfully (created_new=true)
|
||||
"400":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Invalid request (bad file, invalid URL, invalid content type, etc.)
|
||||
description: Invalid request (bad file, invalid content type, etc.)
|
||||
"401":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Unauthorized
|
||||
"403":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Source URL requires authentication or access denied
|
||||
"404":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Source URL not found
|
||||
"413":
|
||||
content:
|
||||
application/json:
|
||||
@ -1683,19 +1654,13 @@ paths:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Unsupported media type
|
||||
"422":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Download failed due to network error or timeout
|
||||
"500":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Internal server error
|
||||
summary: Upload a new asset
|
||||
summary: Create a new asset
|
||||
tags:
|
||||
- file
|
||||
/api/assets/{id}:
|
||||
@ -1730,7 +1695,7 @@ paths:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Asset cannot be deleted because it is referenced by another resource (e.g., workflow version)
|
||||
description: 'Asset cannot be deleted because it is referenced by another resource, e.g. a workflow version (error code: ASSET_IN_USE)'
|
||||
"500":
|
||||
content:
|
||||
application/json:
|
||||
@ -1783,7 +1748,7 @@ paths:
|
||||
description: |
|
||||
Updates an asset's metadata. At least one field must be provided.
|
||||
Only name, mime_type, preview_id, and user_metadata can be updated.
|
||||
For tag management, use the dedicated PUT /api/assets/{id}/tags endpoint.
|
||||
For tag management, use POST (add) and DELETE (remove) /api/assets/{id}/tags.
|
||||
operationId: updateAsset
|
||||
parameters:
|
||||
- description: Asset ID
|
||||
@ -1982,76 +1947,6 @@ paths:
|
||||
summary: Add tags to asset
|
||||
tags:
|
||||
- file
|
||||
put:
|
||||
description: Adds and removes tags from an asset in a single operation
|
||||
operationId: updateAssetTags
|
||||
parameters:
|
||||
- description: Asset ID
|
||||
in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
description: At least one of add or remove must contain items. Empty arrays are allowed when the other array has items.
|
||||
minProperties: 1
|
||||
properties:
|
||||
add:
|
||||
description: Tags to add to the asset. Can be empty if remove has items.
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
remove:
|
||||
description: Tags to remove from the asset. Can be empty if add has items.
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
type: object
|
||||
required: true
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/TagsModificationResponse'
|
||||
description: Tags updated successfully
|
||||
"400":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Invalid request
|
||||
"401":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Unauthorized
|
||||
"404":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Asset not found
|
||||
"422":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Reserved tag validation error
|
||||
"500":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Internal server error
|
||||
summary: Update asset tags
|
||||
tags:
|
||||
- file
|
||||
/api/assets/from-hash:
|
||||
post:
|
||||
description: |
|
||||
@ -2090,12 +1985,20 @@ paths:
|
||||
type: object
|
||||
required: true
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/AssetCreated'
|
||||
description: |
|
||||
Asset reference already existed for this user (deduplicated by content
|
||||
hash); the existing asset is returned with created_new=false.
|
||||
"201":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/AssetCreated'
|
||||
description: Asset reference created successfully
|
||||
description: Asset reference created successfully (created_new=true)
|
||||
"400":
|
||||
content:
|
||||
application/json:
|
||||
@ -2887,7 +2790,21 @@ paths:
|
||||
- asc
|
||||
- desc
|
||||
type: string
|
||||
- description: Pagination offset (0-based)
|
||||
- description: |
|
||||
Opaque cursor for keyset pagination. Pass the `next_cursor` value
|
||||
from a previous response to fetch the next page.
|
||||
Cursor pagination is supported only when `sort_by=create_time`
|
||||
(default). If `sort_by=execution_time`, `after` is ignored and
|
||||
offset/limit pagination is used.
|
||||
Cursors are opaque base64url payloads — clients should treat them
|
||||
as strings and not parse the contents.
|
||||
example: eyJzIjoiY3JlYXRlX3RpbWUiLCJ2IjoiMTcxNjIwMDAwMDAwMDAwMCIsImlkIjoiYTFiMmMzZDQtZTVmNi03YTg5LWIwYzEtZDJlM2Y0YTViNmM3In0
|
||||
in: query
|
||||
name: after
|
||||
schema:
|
||||
type: string
|
||||
- deprecated: true
|
||||
description: 'Pagination offset (0-based). Deprecated: prefer cursor-based pagination via `after`.'
|
||||
in: query
|
||||
name: offset
|
||||
schema:
|
||||
@ -2909,6 +2826,12 @@ paths:
|
||||
schema:
|
||||
$ref: '#/components/schemas/JobsListResponse'
|
||||
description: Success - Jobs retrieved
|
||||
"400":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Bad request (e.g. malformed pagination cursor).
|
||||
"401":
|
||||
content:
|
||||
application/json:
|
||||
|
||||
@ -1253,6 +1253,15 @@ class PromptServer():
|
||||
|
||||
if verbose:
|
||||
logging.info("Starting server\n")
|
||||
if args.debug_hang:
|
||||
logging.info(
|
||||
f"{'-' * 80}\n"
|
||||
"ComfyUI has been started in debug-hang mode. Run your workflow as normal up to\n"
|
||||
"the point of the hang or freeze, then use ctrl-C in the cmd or controlling\n"
|
||||
"terminal to dump the python backtraces for debugging. Please attach the extra\n"
|
||||
"debug info to your bug report.\n"
|
||||
f"{'-' * 80}"
|
||||
)
|
||||
for addr in addresses:
|
||||
address = addr[0]
|
||||
port = addr[1]
|
||||
|
||||
86
tests-unit/assets_test/services/test_image_dimensions.py
Normal file
86
tests-unit/assets_test/services/test_image_dimensions.py
Normal file
@ -0,0 +1,86 @@
|
||||
"""Tests for the image_dimensions service."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from app.assets.services.image_dimensions import extract_image_dimensions
|
||||
|
||||
|
||||
def _make_png(path: Path, size: tuple[int, int]) -> Path:
|
||||
img = Image.new("RGB", size, color=(123, 45, 67))
|
||||
img.save(path, format="PNG")
|
||||
return path
|
||||
|
||||
|
||||
def _make_jpeg(path: Path, size: tuple[int, int]) -> Path:
|
||||
img = Image.new("RGB", size, color=(10, 20, 30))
|
||||
img.save(path, format="JPEG", quality=80)
|
||||
return path
|
||||
|
||||
|
||||
class TestExtractImageDimensions:
|
||||
def test_extracts_png_dimensions(self, tmp_path: Path):
|
||||
f = _make_png(tmp_path / "rect.png", (320, 240))
|
||||
|
||||
result = extract_image_dimensions(str(f), mime_type="image/png")
|
||||
|
||||
assert result == {"kind": "image", "width": 320, "height": 240}
|
||||
|
||||
def test_extracts_jpeg_dimensions(self, tmp_path: Path):
|
||||
f = _make_jpeg(tmp_path / "shot.jpg", (1920, 1080))
|
||||
|
||||
result = extract_image_dimensions(str(f), mime_type="image/jpeg")
|
||||
|
||||
assert result == {"kind": "image", "width": 1920, "height": 1080}
|
||||
|
||||
def test_works_when_mime_type_is_none(self, tmp_path: Path):
|
||||
f = _make_png(tmp_path / "no_mime.png", (50, 100))
|
||||
|
||||
result = extract_image_dimensions(str(f), mime_type=None)
|
||||
|
||||
assert result == {"kind": "image", "width": 50, "height": 100}
|
||||
|
||||
def test_skips_non_image_mime_without_touching_file(self, tmp_path: Path):
|
||||
# Path doesn't need to exist — non-image MIME short-circuits.
|
||||
result = extract_image_dimensions(
|
||||
str(tmp_path / "model.safetensors"),
|
||||
mime_type="application/octet-stream",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mime",
|
||||
["application/json", "text/plain", "video/mp4", "audio/mpeg"],
|
||||
)
|
||||
def test_skips_all_non_image_mime_types(self, tmp_path: Path, mime: str):
|
||||
f = tmp_path / "file.bin"
|
||||
f.write_bytes(b"\x00\x01\x02")
|
||||
|
||||
assert extract_image_dimensions(str(f), mime_type=mime) is None
|
||||
|
||||
def test_returns_none_for_missing_file(self, tmp_path: Path):
|
||||
result = extract_image_dimensions(
|
||||
str(tmp_path / "does_not_exist.png"), mime_type="image/png"
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_corrupt_image(self, tmp_path: Path):
|
||||
f = tmp_path / "corrupt.png"
|
||||
f.write_bytes(b"not actually a png file")
|
||||
|
||||
result = extract_image_dimensions(str(f), mime_type="image/png")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_empty_file(self, tmp_path: Path):
|
||||
f = tmp_path / "empty.png"
|
||||
f.write_bytes(b"")
|
||||
|
||||
result = extract_image_dimensions(str(f), mime_type="image/png")
|
||||
|
||||
assert result is None
|
||||
@ -4,10 +4,12 @@ from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from sqlalchemy.orm import Session as SASession, Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag
|
||||
from app.assets.database.queries import get_reference_tags
|
||||
from app.assets.helpers import get_utc_now
|
||||
from app.assets.services.ingest import (
|
||||
_ingest_file_from_path,
|
||||
_register_existing_asset,
|
||||
@ -15,6 +17,11 @@ from app.assets.services.ingest import (
|
||||
)
|
||||
|
||||
|
||||
def _make_png(path: Path, size: tuple[int, int]) -> Path:
|
||||
Image.new("RGB", size, color=(80, 120, 200)).save(path, format="PNG")
|
||||
return path
|
||||
|
||||
|
||||
class TestIngestFileFromPath:
|
||||
def test_creates_asset_and_reference(self, mock_create_session, temp_dir: Path, session: Session):
|
||||
file_path = temp_dir / "test_file.bin"
|
||||
@ -279,4 +286,203 @@ class TestIngestExistingFileTagFK:
|
||||
ref_tags = sess.query(AssetReferenceTag).all()
|
||||
ref_tag_names = {rt.tag_name for rt in ref_tags}
|
||||
assert "output" in ref_tag_names
|
||||
assert "my-job" in ref_tag_names
|
||||
|
||||
|
||||
class TestIngestImageDimensions:
|
||||
"""system_metadata should carry {kind, width, height} for image assets."""
|
||||
|
||||
def test_image_asset_emits_dimensions(
|
||||
self, mock_create_session, temp_dir: Path, session: Session
|
||||
):
|
||||
f = _make_png(temp_dir / "shot.png", (640, 480))
|
||||
|
||||
result = _ingest_file_from_path(
|
||||
abs_path=str(f),
|
||||
asset_hash="blake3:img1",
|
||||
size_bytes=f.stat().st_size,
|
||||
mtime_ns=1234567890000000000,
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
ref = session.query(AssetReference).filter_by(id=result.reference_id).first()
|
||||
assert ref.system_metadata == {
|
||||
"kind": "image",
|
||||
"width": 640,
|
||||
"height": 480,
|
||||
}
|
||||
|
||||
def test_non_image_asset_leaves_system_metadata_empty(
|
||||
self, mock_create_session, temp_dir: Path, session: Session
|
||||
):
|
||||
f = temp_dir / "model.safetensors"
|
||||
f.write_bytes(b"not an image")
|
||||
|
||||
result = _ingest_file_from_path(
|
||||
abs_path=str(f),
|
||||
asset_hash="blake3:safetensors1",
|
||||
size_bytes=f.stat().st_size,
|
||||
mtime_ns=1234567890000000000,
|
||||
mime_type="application/octet-stream",
|
||||
)
|
||||
|
||||
ref = session.query(AssetReference).filter_by(id=result.reference_id).first()
|
||||
assert ref.system_metadata in (None, {})
|
||||
|
||||
def test_preserves_existing_system_metadata_keys(
|
||||
self, mock_create_session, temp_dir: Path, session: Session
|
||||
):
|
||||
f = _make_png(temp_dir / "annotated.png", (100, 200))
|
||||
|
||||
# First pass populates a sentinel system_metadata key (simulating prior
|
||||
# enricher write).
|
||||
result = _ingest_file_from_path(
|
||||
abs_path=str(f),
|
||||
asset_hash="blake3:img-merge",
|
||||
size_bytes=f.stat().st_size,
|
||||
mtime_ns=1234567890000000000,
|
||||
mime_type="image/png",
|
||||
)
|
||||
ref = session.query(AssetReference).filter_by(id=result.reference_id).first()
|
||||
ref.system_metadata = {**(ref.system_metadata or {}), "source_url": "https://example/x.png"}
|
||||
session.commit()
|
||||
|
||||
# Second pass with the same path triggers the merge code path again.
|
||||
_ingest_file_from_path(
|
||||
abs_path=str(f),
|
||||
asset_hash="blake3:img-merge",
|
||||
size_bytes=f.stat().st_size,
|
||||
mtime_ns=1234567890000000001,
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
session.refresh(ref)
|
||||
assert ref.system_metadata["kind"] == "image"
|
||||
assert ref.system_metadata["width"] == 100
|
||||
assert ref.system_metadata["height"] == 200
|
||||
assert ref.system_metadata["source_url"] == "https://example/x.png"
|
||||
|
||||
|
||||
class TestRegisterExistingAssetBackfill:
|
||||
"""The from-hash path back-fills dimensions from a sibling reference."""
|
||||
|
||||
def _add_reference(
|
||||
self,
|
||||
session: Session,
|
||||
asset: Asset,
|
||||
name: str,
|
||||
system_metadata: dict | None = None,
|
||||
) -> AssetReference:
|
||||
now = get_utc_now()
|
||||
ref = AssetReference(
|
||||
asset_id=asset.id,
|
||||
name=name,
|
||||
owner_id="",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
system_metadata=system_metadata or {},
|
||||
)
|
||||
session.add(ref)
|
||||
session.flush()
|
||||
return ref
|
||||
|
||||
def test_backfills_dimensions_from_sibling_image_reference(
|
||||
self, mock_create_session, session: Session
|
||||
):
|
||||
asset = Asset(hash="blake3:shared", size_bytes=2048, mime_type="image/png")
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
self._add_reference(
|
||||
session,
|
||||
asset,
|
||||
name="original.png",
|
||||
system_metadata={"kind": "image", "width": 800, "height": 600},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
result = _register_existing_asset(
|
||||
asset_hash="blake3:shared",
|
||||
name="from_hash.png",
|
||||
owner_id="user-x",
|
||||
)
|
||||
|
||||
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
|
||||
assert ref.system_metadata.get("kind") == "image"
|
||||
assert ref.system_metadata.get("width") == 800
|
||||
assert ref.system_metadata.get("height") == 600
|
||||
|
||||
def test_no_backfill_when_sibling_has_no_image_metadata(
|
||||
self, mock_create_session, session: Session
|
||||
):
|
||||
asset = Asset(hash="blake3:nodims", size_bytes=2048, mime_type="image/png")
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
self._add_reference(
|
||||
session,
|
||||
asset,
|
||||
name="original.png",
|
||||
system_metadata={"base_model": "flux"}, # no kind=image
|
||||
)
|
||||
session.commit()
|
||||
|
||||
result = _register_existing_asset(
|
||||
asset_hash="blake3:nodims",
|
||||
name="from_hash.png",
|
||||
owner_id="user-x",
|
||||
)
|
||||
|
||||
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
|
||||
meta = ref.system_metadata or {}
|
||||
assert "kind" not in meta
|
||||
assert "width" not in meta
|
||||
assert "height" not in meta
|
||||
|
||||
def test_no_backfill_when_no_sibling_exists(
|
||||
self, mock_create_session, session: Session
|
||||
):
|
||||
asset = Asset(hash="blake3:lonely", size_bytes=1024, mime_type="image/png")
|
||||
session.add(asset)
|
||||
session.commit()
|
||||
|
||||
result = _register_existing_asset(
|
||||
asset_hash="blake3:lonely",
|
||||
name="solo.png",
|
||||
owner_id="user-x",
|
||||
)
|
||||
|
||||
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
|
||||
assert ref.system_metadata in (None, {})
|
||||
|
||||
def test_backfill_preserves_caller_supplied_keys(
|
||||
self, mock_create_session, session: Session
|
||||
):
|
||||
asset = Asset(hash="blake3:preserve", size_bytes=2048, mime_type="image/png")
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
self._add_reference(
|
||||
session,
|
||||
asset,
|
||||
name="original.png",
|
||||
system_metadata={"kind": "image", "width": 1024, "height": 768},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Simulate a from-hash path where the new reference already carries
|
||||
# some system_metadata (e.g. a download-provenance source_url written
|
||||
# by an earlier step). The back-fill must merge dim keys without
|
||||
# clobbering existing keys.
|
||||
result = _register_existing_asset(
|
||||
asset_hash="blake3:preserve",
|
||||
name="from_hash.png",
|
||||
owner_id="user-x",
|
||||
)
|
||||
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
|
||||
# Seed a sentinel key and re-run back-fill via a second register call
|
||||
# to exercise the merge path with pre-existing data.
|
||||
ref.system_metadata = {**(ref.system_metadata or {}), "source_url": "https://example/p"}
|
||||
session.commit()
|
||||
|
||||
assert ref.system_metadata.get("source_url") == "https://example/p"
|
||||
assert ref.system_metadata.get("kind") == "image"
|
||||
assert ref.system_metadata.get("width") == 1024
|
||||
assert ref.system_metadata.get("height") == 768
|
||||
|
||||
@ -1,213 +0,0 @@
|
||||
"""Consolidated SeedVR2 conditioning and refactor regression tests.
|
||||
|
||||
Merges the prior test_seedvr2_refactor_nodes.py and
|
||||
test_seedvr_conditioning_hardening.py modules. Refactor tests use the
|
||||
top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests
|
||||
use _import_nodes_seedvr_isolated() for sys.modules isolation when
|
||||
mocking comfy.model_management.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
|
||||
_SENTINEL = object()
|
||||
_TARGETS = (
|
||||
("comfy.model_management", "comfy"),
|
||||
("comfy_extras.nodes_seedvr", "comfy_extras"),
|
||||
)
|
||||
|
||||
|
||||
def _import_nodes_seedvr_isolated():
|
||||
"""Import comfy_extras.nodes_seedvr with comfy.model_management mocked."""
|
||||
priors = []
|
||||
for mod_name, parent_name in _TARGETS:
|
||||
prior_mod = sys.modules.get(mod_name, _SENTINEL)
|
||||
parent = sys.modules.get(parent_name)
|
||||
attr = mod_name.split(".")[-1]
|
||||
prior_attr = (
|
||||
getattr(parent, attr, _SENTINEL) if parent is not None else _SENTINEL
|
||||
)
|
||||
priors.append((mod_name, parent_name, attr, prior_mod, prior_attr))
|
||||
|
||||
mock_mm = MagicMock()
|
||||
for fn in (
|
||||
"xformers_enabled", "xformers_enabled_vae",
|
||||
"pytorch_attention_enabled", "pytorch_attention_enabled_vae",
|
||||
"sage_attention_enabled", "flash_attention_enabled",
|
||||
"is_intel_xpu",
|
||||
):
|
||||
getattr(mock_mm, fn).return_value = False
|
||||
tv = torch.version.__version__.split(".")
|
||||
mock_mm.torch_version_numeric = (int(tv[0]), int(tv[1]))
|
||||
mock_mm.WINDOWS = False
|
||||
sys.modules["comfy.model_management"] = mock_mm
|
||||
if sys.modules.get("comfy") is None:
|
||||
import comfy as _comfy_pkg # noqa: F401
|
||||
comfy_pkg = sys.modules.get("comfy")
|
||||
if comfy_pkg is not None:
|
||||
setattr(comfy_pkg, "model_management", mock_mm)
|
||||
nodes_seedvr = sys.modules.get("comfy_extras.nodes_seedvr") or (
|
||||
importlib.import_module("comfy_extras.nodes_seedvr")
|
||||
)
|
||||
|
||||
def _restore():
|
||||
for mod_name, parent_name, attr, prior_mod, prior_attr in priors:
|
||||
if prior_mod is _SENTINEL:
|
||||
sys.modules.pop(mod_name, None)
|
||||
else:
|
||||
sys.modules[mod_name] = prior_mod
|
||||
parent = sys.modules.get(parent_name)
|
||||
if parent is None:
|
||||
continue
|
||||
if prior_attr is _SENTINEL:
|
||||
if hasattr(parent, attr):
|
||||
delattr(parent, attr)
|
||||
else:
|
||||
setattr(parent, attr, prior_attr)
|
||||
|
||||
return nodes_seedvr, _restore
|
||||
|
||||
|
||||
class _Rope(nn.Module):
|
||||
"""Minimal RoPE stub exposing a `freqs` parameter."""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.freqs = nn.Parameter(torch.zeros(4))
|
||||
|
||||
|
||||
class _Block(nn.Module):
|
||||
"""Minimal transformer block stub holding a `_Rope`."""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.rope = _Rope()
|
||||
|
||||
|
||||
class _DiffusionModel(nn.Module):
|
||||
"""Stub diffusion model with N blocks and pos/neg conditioning buffers."""
|
||||
def __init__(self, n_blocks=3, zero_conditioning=False, conditioning_dtype=torch.float32):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)])
|
||||
pos = torch.zeros if zero_conditioning else torch.ones
|
||||
self.register_buffer("positive_conditioning", pos((2, 4), dtype=conditioning_dtype))
|
||||
self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype))
|
||||
|
||||
|
||||
class _ModelInner:
|
||||
"""Inner model wrapper exposing `.diffusion_model`."""
|
||||
def __init__(self, diffusion_model):
|
||||
self.diffusion_model = diffusion_model
|
||||
|
||||
|
||||
class _ModelPatcher:
|
||||
"""ModelPatcher stub exposing `.model._ModelInner`."""
|
||||
def __init__(self, diffusion_model):
|
||||
self.model = _ModelInner(diffusion_model)
|
||||
|
||||
|
||||
def test_seedvr2_conditioning_schema_exposes_model_passthrough_output():
|
||||
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||
try:
|
||||
schema = nodes_seedvr.SeedVR2Conditioning.define_schema()
|
||||
assert [input_item.id for input_item in schema.inputs] == [
|
||||
"model",
|
||||
"vae_conditioning",
|
||||
]
|
||||
assert schema.inputs[1].display_name == "latent"
|
||||
assert [output.display_name for output in schema.outputs] == [
|
||||
"model",
|
||||
"positive",
|
||||
"negative",
|
||||
"latent",
|
||||
]
|
||||
finally:
|
||||
restore()
|
||||
|
||||
|
||||
def test_seedvr2_conditioning_returns_packed_input_latent_deterministically():
|
||||
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||
try:
|
||||
diffusion_model = _DiffusionModel()
|
||||
patcher = _ModelPatcher(diffusion_model)
|
||||
samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2)
|
||||
vae_conditioning = {"samples": samples}
|
||||
|
||||
_, first_positive, first_negative, first_latent = (
|
||||
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||
patcher,
|
||||
vae_conditioning,
|
||||
)
|
||||
)
|
||||
_, second_positive, second_negative, second_latent = (
|
||||
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||
patcher,
|
||||
vae_conditioning,
|
||||
)
|
||||
)
|
||||
|
||||
expected_latent = samples.reshape(1, 6, 2, 2)
|
||||
channel_last = samples.movedim(1, -1).contiguous()
|
||||
expected_condition = torch.cat(
|
||||
[
|
||||
channel_last,
|
||||
torch.ones((*channel_last.shape[:-1], 1)),
|
||||
],
|
||||
dim=-1,
|
||||
).movedim(-1, 1).reshape(1, 9, 2, 2)
|
||||
|
||||
assert torch.equal(first_latent["samples"], expected_latent)
|
||||
assert torch.equal(second_latent["samples"], expected_latent)
|
||||
assert torch.equal(
|
||||
first_positive[0][1]["condition"],
|
||||
expected_condition,
|
||||
)
|
||||
assert torch.equal(
|
||||
second_positive[0][1]["condition"],
|
||||
expected_condition,
|
||||
)
|
||||
assert torch.equal(
|
||||
first_negative[0][1]["condition"],
|
||||
expected_condition,
|
||||
)
|
||||
assert torch.equal(
|
||||
second_negative[0][1]["condition"],
|
||||
expected_condition,
|
||||
)
|
||||
finally:
|
||||
restore()
|
||||
|
||||
|
||||
def test_seedvr2_conditioning_fails_loud_on_zero_buffers():
|
||||
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||
try:
|
||||
diffusion_model = _DiffusionModel(zero_conditioning=True)
|
||||
patcher = _ModelPatcher(diffusion_model)
|
||||
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
|
||||
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||
patcher, vae_conditioning,
|
||||
)
|
||||
|
||||
message = str(excinfo.value)
|
||||
assert message.startswith(
|
||||
nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX
|
||||
), (
|
||||
"Fail-loud message must use the standard "
|
||||
"_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers "
|
||||
f"can match it. Got: {message!r}"
|
||||
)
|
||||
assert "positive_conditioning" in message
|
||||
assert "negative_conditioning" in message
|
||||
finally:
|
||||
restore()
|
||||
@ -1,55 +0,0 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
|
||||
def test_seedvr_node_signature_matches_schema():
|
||||
mock_mm = MagicMock()
|
||||
mock_mm.xformers_enabled.return_value = False
|
||||
mock_mm.xformers_enabled_vae.return_value = False
|
||||
mock_mm.sage_attention_enabled.return_value = False
|
||||
mock_mm.flash_attention_enabled.return_value = False
|
||||
|
||||
sentinel = object()
|
||||
prior_cpu = cli_args.cpu
|
||||
cli_args.cpu = True
|
||||
prior_module = sys.modules.get("comfy_extras.nodes_seedvr", sentinel)
|
||||
comfy_pkg = sys.modules.get("comfy")
|
||||
prior_mm_attr = getattr(comfy_pkg, "model_management", sentinel) if comfy_pkg else sentinel
|
||||
|
||||
with patch.dict(sys.modules, {"comfy.model_management": mock_mm}):
|
||||
if comfy_pkg is not None:
|
||||
setattr(comfy_pkg, "model_management", mock_mm)
|
||||
sys.modules.pop("comfy_extras.nodes_seedvr", None)
|
||||
try:
|
||||
nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr")
|
||||
for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning, nodes_seedvr.SeedVR2ProgressiveSampler):
|
||||
schema_ids = [i.id for i in node_cls.define_schema().inputs]
|
||||
exec_params = [
|
||||
p for p in inspect.signature(node_cls.execute).parameters.keys()
|
||||
if p != "cls"
|
||||
]
|
||||
assert schema_ids == exec_params, (
|
||||
f"{node_cls.__name__} schema/execute drift: "
|
||||
f"schema_ids={schema_ids}, exec_params={exec_params}"
|
||||
)
|
||||
finally:
|
||||
cli_args.cpu = prior_cpu
|
||||
if prior_module is sentinel:
|
||||
sys.modules.pop("comfy_extras.nodes_seedvr", None)
|
||||
else:
|
||||
sys.modules["comfy_extras.nodes_seedvr"] = prior_module
|
||||
if comfy_pkg is not None:
|
||||
if prior_mm_attr is sentinel:
|
||||
if hasattr(comfy_pkg, "model_management"):
|
||||
delattr(comfy_pkg, "model_management")
|
||||
else:
|
||||
setattr(comfy_pkg, "model_management", prior_mm_attr)
|
||||
@ -1,57 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
from comfy_extras import nodes_seedvr # noqa: E402
|
||||
|
||||
|
||||
def _schema_ids(items):
|
||||
return [item.id for item in items]
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_schema():
|
||||
schema = nodes_seedvr.SeedVR2PostProcessing.define_schema()
|
||||
|
||||
assert _schema_ids(schema.inputs) == ["images", "original_resized_images", "color_correction_method"]
|
||||
assert schema.inputs[2].options == ["lab", "wavelet", "adain", "none"]
|
||||
assert schema.inputs[2].default == "lab"
|
||||
assert schema.outputs[0].get_io_type() == "IMAGE"
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch):
|
||||
decoded = torch.full((1, 3, 4, 4), 0.25)
|
||||
reference = torch.full((1, 3, 4, 4), 0.75)
|
||||
|
||||
def _lab(content, style):
|
||||
raise torch.cuda.OutOfMemoryError("CUDA out of memory")
|
||||
|
||||
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
||||
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
|
||||
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None)
|
||||
|
||||
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||
try:
|
||||
nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked(
|
||||
decoded, reference, torch.device("cpu"), "lab",
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
assert "color_correction_method=lab" in str(exc)
|
||||
assert " method=lab" not in str(exc)
|
||||
else:
|
||||
raise AssertionError("expected RuntimeError for one-frame LAB OOM")
|
||||
|
||||
|
||||
def test_seedvr2_post_processing_unknown_color_correction_method_raises():
|
||||
decoded = torch.zeros(1, 2, 4, 4, 3)
|
||||
original = torch.zeros(1, 2, 4, 4, 3)
|
||||
try:
|
||||
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus")
|
||||
except ValueError as exc:
|
||||
assert "color_correction_method" in str(exc)
|
||||
else:
|
||||
raise AssertionError("expected ValueError for unknown color_correction_method")
|
||||
@ -73,24 +73,6 @@ def _make_flux_schnell_comfyui_sd():
|
||||
return sd
|
||||
|
||||
|
||||
def _make_seedvr2_7b_separate_mm_sd():
|
||||
return {
|
||||
"blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072),
|
||||
}
|
||||
|
||||
|
||||
def _make_seedvr2_7b_shared_mm_sd():
|
||||
return {
|
||||
"blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
|
||||
}
|
||||
|
||||
|
||||
def _make_seedvr2_3b_shared_mm_sd():
|
||||
return {
|
||||
"blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
|
||||
}
|
||||
|
||||
|
||||
class TestModelDetection:
|
||||
"""Verify that first-match model detection selects the correct model
|
||||
based on list ordering and unet_config specificity."""
|
||||
@ -143,48 +125,6 @@ class TestModelDetection:
|
||||
assert model_config is not None
|
||||
assert type(model_config).__name__ == "FluxSchnell"
|
||||
|
||||
def test_seedvr2_7b_separate_mm_detection_config(self):
|
||||
sd = _make_seedvr2_7b_separate_mm_sd()
|
||||
unet_config = detect_unet_config(sd, "")
|
||||
|
||||
assert unet_config is not None
|
||||
assert unet_config["image_model"] == "seedvr2"
|
||||
assert unet_config["vid_dim"] == 3072
|
||||
assert unet_config["heads"] == 24
|
||||
assert unet_config["num_layers"] == 36
|
||||
assert unet_config["mm_layers"] == 36
|
||||
assert unet_config["mlp_type"] == "normal"
|
||||
assert unet_config["qk_rope"] is True
|
||||
assert unet_config["rope_type"] == "rope3d"
|
||||
assert unet_config["rope_dim"] == 64
|
||||
|
||||
def test_seedvr2_7b_shared_mm_detection_config(self):
|
||||
sd = _make_seedvr2_7b_shared_mm_sd()
|
||||
unet_config = detect_unet_config(sd, "")
|
||||
|
||||
assert unet_config is not None
|
||||
assert unet_config["image_model"] == "seedvr2"
|
||||
assert unet_config["vid_dim"] == 3072
|
||||
assert unet_config["heads"] == 24
|
||||
assert unet_config["num_layers"] == 36
|
||||
assert unet_config["mm_layers"] == 10
|
||||
assert unet_config["mlp_type"] == "swiglu"
|
||||
assert unet_config["qk_rope"] is True
|
||||
assert unet_config["rope_type"] == "rope3d"
|
||||
assert unet_config["rope_dim"] == 64
|
||||
|
||||
def test_seedvr2_3b_shared_mm_detection_config(self):
|
||||
sd = _make_seedvr2_3b_shared_mm_sd()
|
||||
unet_config = detect_unet_config(sd, "")
|
||||
|
||||
assert unet_config is not None
|
||||
assert unet_config["image_model"] == "seedvr2"
|
||||
assert unet_config["vid_dim"] == 2560
|
||||
assert unet_config["heads"] == 20
|
||||
assert unet_config["num_layers"] == 32
|
||||
assert unet_config["mlp_type"] == "swiglu"
|
||||
assert unet_config["qk_rope"] is None
|
||||
|
||||
def test_unet_config_and_required_keys_combination_is_unique(self):
|
||||
"""Each model in the registry must have a unique combination of
|
||||
``unet_config`` and ``required_keys``. If two models share the same
|
||||
|
||||
@ -1,90 +0,0 @@
|
||||
"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must
|
||||
honor the actual tensor/tuple return contract of ``encode()`` and
|
||||
``decode_()`` and must NOT dereference diffusers-style ``.latent_dist``
|
||||
or ``.sample`` attributes on those returns.
|
||||
|
||||
The pre-fix body raised ``AttributeError: 'Tensor' object has no
|
||||
attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and
|
||||
``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'``
|
||||
for ``mode == "decode"`` (the class only defines ``decode_`` with a
|
||||
trailing underscore). The post-fix body unwraps the optional one-element
|
||||
tuple shape that ``return_dict=False`` produces and returns the tensor
|
||||
directly.
|
||||
|
||||
Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses
|
||||
the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and
|
||||
overrides ``encode``/``decode_`` with known tensors so the contract can
|
||||
be probed without loading any real VAE weights.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402
|
||||
|
||||
|
||||
_LATENT_SHAPE = (1, 16, 2, 2, 2)
|
||||
_DECODED_SHAPE = (1, 3, 5, 16, 16)
|
||||
_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16)
|
||||
_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2)
|
||||
|
||||
|
||||
class _StubVAE(VideoAutoencoderKL):
|
||||
def __init__(self):
|
||||
nn.Module.__init__(self)
|
||||
self._encode_out = torch.zeros(*_LATENT_SHAPE)
|
||||
self._decode_out = torch.zeros(*_DECODED_SHAPE)
|
||||
|
||||
def encode(self, x, return_dict=True):
|
||||
return self._encode_out
|
||||
|
||||
def decode_(self, z, return_dict=True):
|
||||
return self._decode_out
|
||||
|
||||
|
||||
def test_forward_encode_returns_tensor():
|
||||
vae = _StubVAE()
|
||||
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
|
||||
result = vae.forward(x, mode="encode")
|
||||
assert type(result) is torch.Tensor
|
||||
assert result.shape == torch.Size(_LATENT_SHAPE)
|
||||
|
||||
|
||||
def test_forward_decode_returns_tensor():
|
||||
vae = _StubVAE()
|
||||
z = torch.zeros(*_INPUT_DECODE_SHAPE)
|
||||
result = vae.forward(z, mode="decode")
|
||||
assert type(result) is torch.Tensor
|
||||
assert result.shape == torch.Size(_DECODED_SHAPE)
|
||||
|
||||
|
||||
class _TupleReturningStubVAE(VideoAutoencoderKL):
|
||||
"""Stub variant whose ``encode``/``decode_`` return the
|
||||
``(tensor,)`` one-element tuple shape ``return_dict=False`` produces
|
||||
in the parent class. Exercises the unwrap branch of
|
||||
``VideoAutoencoderKL.forward``.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
nn.Module.__init__(self)
|
||||
self._encode_tensor = torch.zeros(*_LATENT_SHAPE)
|
||||
self._decode_tensor = torch.zeros(*_DECODED_SHAPE)
|
||||
|
||||
def encode(self, x, return_dict=True):
|
||||
return (self._encode_tensor,)
|
||||
|
||||
def decode_(self, z, return_dict=True):
|
||||
return (self._decode_tensor,)
|
||||
|
||||
|
||||
def test_forward_all_unwraps_one_tuple_at_each_step():
|
||||
vae = _TupleReturningStubVAE()
|
||||
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
|
||||
result = vae.forward(x, mode="all")
|
||||
assert type(result) is torch.Tensor
|
||||
assert result.shape == torch.Size(_DECODED_SHAPE)
|
||||
@ -1,47 +0,0 @@
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
import comfy.sd
|
||||
import comfy.supported_models
|
||||
import comfy.ldm.seedvr.model as seedvr_model
|
||||
|
||||
|
||||
def test_seedvr2_fp16_manual_cast_only_for_bf16_device(monkeypatch):
|
||||
bf16_device = object()
|
||||
fp16_device = object()
|
||||
|
||||
monkeypatch.setattr(
|
||||
comfy.supported_models.comfy.model_management,
|
||||
"should_use_bf16",
|
||||
lambda device=None: device is bf16_device,
|
||||
)
|
||||
|
||||
bf16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"})
|
||||
bf16_config.set_inference_dtype(torch.float16, None, device=bf16_device)
|
||||
assert bf16_config.manual_cast_dtype is torch.bfloat16
|
||||
|
||||
fp16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"})
|
||||
fp16_config.set_inference_dtype(torch.float16, None, device=fp16_device)
|
||||
assert fp16_config.manual_cast_dtype is None
|
||||
|
||||
|
||||
def test_seedvr2_text_conditioning_accepts_cfg1_single_branch():
|
||||
context = torch.arange(6, dtype=torch.float32).reshape(1, 3, 2)
|
||||
|
||||
txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0])
|
||||
|
||||
torch.testing.assert_close(txt, context.squeeze(0))
|
||||
torch.testing.assert_close(txt_shape, torch.tensor([[3]], device=context.device))
|
||||
|
||||
|
||||
def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer():
|
||||
estimate = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 26, 120, 160))
|
||||
old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2
|
||||
|
||||
assert estimate == 101 * 960 * 1280 * 160
|
||||
assert estimate > 15 * 1024 ** 3
|
||||
assert estimate > old_estimate * 100
|
||||
@ -1,341 +0,0 @@
|
||||
"""Consolidated SeedVR2 internals regression tests.
|
||||
|
||||
Sources (all merged verbatim, helper names disambiguated where colliding):
|
||||
|
||||
* RoPE rewrite — NaMMRotaryEmbedding3d.forward must match the legacy
|
||||
apply_rotary_emb wrapper oracle at fp32.
|
||||
* GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare
|
||||
memory_occupy against get_norm_limit(), not float('inf').
|
||||
* SeedVR2 variable-length attention split-loop contract.
|
||||
|
||||
Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and
|
||||
comfy.ldm.modules.attention transitively pull in comfy.model_management,
|
||||
which probes torch.cuda.current_device() at import time unless args.cpu is
|
||||
set first.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
args.cpu = True
|
||||
|
||||
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
|
||||
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
||||
import comfy.ldm.modules.attention as attention # noqa: E402
|
||||
import comfy.ops as comfy_ops # noqa: E402
|
||||
from comfy.ldm.seedvr.model import ( # noqa: E402
|
||||
Cache,
|
||||
NaMMRotaryEmbedding3d,
|
||||
)
|
||||
from comfy.ldm.seedvr.vae import ( # noqa: E402
|
||||
causal_norm_wrapper,
|
||||
set_norm_limit,
|
||||
)
|
||||
from comfy.ldm.modules.attention import var_attention_optimized_split # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RoPE rewrite tests (test_seedvr_rope_rewrite.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Test rig dimensions. dim=192 → per-axis rope dim = 64 (even, lucidrains
|
||||
# requirement). vid_shape=(2,4,4) → L_vid = 32. txt_shape=(8,) → L_txt = 8.
|
||||
_DIM = 192
|
||||
_HEADS = 4
|
||||
_VID_T, _VID_H, _VID_W = 2, 4, 4
|
||||
_TXT_L = 8
|
||||
_L_VID = _VID_T * _VID_H * _VID_W
|
||||
_SEED = 0
|
||||
|
||||
|
||||
def _make_inputs(dtype=torch.float32, device="cpu"):
|
||||
"""Construct the 6 forward inputs + cache. Deterministic via local
|
||||
Generator so global RNG state is not mutated.
|
||||
"""
|
||||
g = torch.Generator(device=device).manual_seed(_SEED)
|
||||
vid_q = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
|
||||
vid_k = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
|
||||
txt_q = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
|
||||
txt_k = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g)
|
||||
vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long, device=device)
|
||||
txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long, device=device)
|
||||
cache = Cache(disable=True)
|
||||
return vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache
|
||||
|
||||
|
||||
def _legacy_get_freqs(rope: NaMMRotaryEmbedding3d, vid_shape, txt_shape):
|
||||
"""Reproduce the pre-rewrite ``get_freqs`` body verbatim against
|
||||
``self.get_axial_freqs`` (parent ``RotaryEmbeddingBase`` method,
|
||||
unchanged by the rewrite).
|
||||
"""
|
||||
max_temporal = 0
|
||||
max_height = 0
|
||||
max_width = 0
|
||||
max_txt_len = 0
|
||||
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
|
||||
max_temporal = max(max_temporal, l + f)
|
||||
max_height = max(max_height, h)
|
||||
max_width = max(max_width, w)
|
||||
max_txt_len = max(max_txt_len, l)
|
||||
with torch.amp.autocast(device_type="cuda", enabled=False):
|
||||
vid_freqs_full = rope.get_axial_freqs(
|
||||
min(max_temporal + 16, 1024),
|
||||
min(max_height + 4, 128),
|
||||
min(max_width + 4, 128),
|
||||
).float()
|
||||
txt_freqs_full = rope.get_axial_freqs(min(max_txt_len + 16, 1024))
|
||||
vid_freq_list, txt_freq_list = [], []
|
||||
for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()):
|
||||
vid_freq = vid_freqs_full[l : l + f, :h, :w].reshape(-1, vid_freqs_full.size(-1))
|
||||
txt_freq = txt_freqs_full[:l].repeat(1, 3).reshape(-1, vid_freqs_full.size(-1))
|
||||
vid_freq_list.append(vid_freq)
|
||||
txt_freq_list.append(txt_freq)
|
||||
return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0)
|
||||
|
||||
|
||||
def _legacy_forward(rope: NaMMRotaryEmbedding3d, vid_q, vid_k, vid_shape,
|
||||
txt_q, txt_k, txt_shape):
|
||||
"""Compute expected forward output via the unchanged
|
||||
``apply_rotary_emb`` wrapper fed with legacy-shape freqs. This is the
|
||||
oracle. The wrapper itself is out of scope for the rewrite (Shape B).
|
||||
"""
|
||||
vid_freqs, txt_freqs = _legacy_get_freqs(rope, vid_shape, txt_shape)
|
||||
vid_freqs = vid_freqs.to(vid_q.device)
|
||||
txt_freqs = txt_freqs.to(txt_q.device)
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
vid_q = rearrange(vid_q, "L h d -> h L d")
|
||||
vid_k = rearrange(vid_k, "L h d -> h L d")
|
||||
vid_q_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype)
|
||||
vid_k_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype)
|
||||
vid_q_out = rearrange(vid_q_out, "h L d -> L h d")
|
||||
vid_k_out = rearrange(vid_k_out, "h L d -> L h d")
|
||||
|
||||
txt_q = rearrange(txt_q, "L h d -> h L d")
|
||||
txt_k = rearrange(txt_k, "L h d -> h L d")
|
||||
txt_q_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype)
|
||||
txt_k_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype)
|
||||
txt_q_out = rearrange(txt_q_out, "h L d -> L h d")
|
||||
txt_k_out = rearrange(txt_k_out, "h L d -> L h d")
|
||||
return vid_q_out, vid_k_out, txt_q_out, txt_k_out
|
||||
|
||||
|
||||
def test_namm_forward_output_tensor_equal_against_legacy_oracle():
|
||||
rope = NaMMRotaryEmbedding3d(dim=_DIM)
|
||||
vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache = _make_inputs()
|
||||
|
||||
expected_vid_q, expected_vid_k, expected_txt_q, expected_txt_k = _legacy_forward(
|
||||
rope,
|
||||
vid_q.clone(), vid_k.clone(), vid_shape,
|
||||
txt_q.clone(), txt_k.clone(), txt_shape,
|
||||
)
|
||||
|
||||
actual_vid_q, actual_vid_k, actual_txt_q, actual_txt_k = rope.forward(
|
||||
vid_q.clone(), vid_k.clone(), vid_shape,
|
||||
txt_q.clone(), txt_k.clone(), txt_shape, cache,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(actual_vid_q, expected_vid_q, rtol=0, atol=0,
|
||||
msg="vid_q output diverges from wrapper oracle")
|
||||
torch.testing.assert_close(actual_vid_k, expected_vid_k, rtol=0, atol=0,
|
||||
msg="vid_k output diverges from wrapper oracle")
|
||||
torch.testing.assert_close(actual_txt_q, expected_txt_q, rtol=0, atol=0,
|
||||
msg="txt_q output diverges from wrapper oracle")
|
||||
torch.testing.assert_close(actual_txt_k, expected_txt_k, rtol=0, atol=0,
|
||||
msg="txt_k output diverges from wrapper oracle")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GroupNorm limit tests (test_seedvr_groupnorm_limit.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_NUM_CHANNELS = 8
|
||||
_NUM_GROUPS = 4
|
||||
_TENSOR_SHAPE = (1, 8, 2, 4, 4)
|
||||
|
||||
_GROUPNORM_SUBCLASSES = [
|
||||
pytest.param(comfy_ops.disable_weight_init.GroupNorm, id="disable_weight_init"),
|
||||
pytest.param(comfy_ops.manual_cast.GroupNorm, id="manual_cast"),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES)
|
||||
def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls):
|
||||
real_group_norm = vae_mod.F.group_norm
|
||||
set_norm_limit(1e-9)
|
||||
try:
|
||||
gn = groupnorm_cls(num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS)
|
||||
gn.eval()
|
||||
|
||||
forward_hook_calls = []
|
||||
|
||||
def _hook(module, inputs, output):
|
||||
forward_hook_calls.append(tuple(inputs[0].shape))
|
||||
|
||||
spy_calls = []
|
||||
|
||||
def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs):
|
||||
spy_calls.append({"num_groups": int(num_groups_arg)})
|
||||
return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs)
|
||||
|
||||
handle = gn.register_forward_hook(_hook)
|
||||
try:
|
||||
with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy):
|
||||
out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE))
|
||||
finally:
|
||||
handle.remove()
|
||||
|
||||
full_calls = len(forward_hook_calls)
|
||||
chunked_calls = sum(1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS)
|
||||
|
||||
assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE
|
||||
assert full_calls == 0, (
|
||||
f"low-limit GroupNorm gate must NOT take the full-forward path; got full_calls={full_calls}"
|
||||
)
|
||||
assert chunked_calls > 0, (
|
||||
f"low-limit GroupNorm gate must take the chunked path; got chunked_calls={chunked_calls}"
|
||||
)
|
||||
finally:
|
||||
set_norm_limit(None)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SeedVR2 var_attention split-loop tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_var_attention_registry_contains_always_available_entries():
|
||||
assert (
|
||||
attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_optimized_split"]
|
||||
is attention.var_attention_optimized_split
|
||||
)
|
||||
|
||||
|
||||
def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch):
|
||||
dim = 8
|
||||
heads = 2
|
||||
head_dim = 4
|
||||
attn = seedvr_model.NaSwinAttention(
|
||||
vid_dim=dim,
|
||||
txt_dim=dim,
|
||||
heads=heads,
|
||||
head_dim=head_dim,
|
||||
qk_bias=False,
|
||||
qk_norm=seedvr_model.CustomRMSNorm,
|
||||
qk_norm_eps=1e-6,
|
||||
rope_type=None,
|
||||
rope_dim=head_dim,
|
||||
shared_weights=False,
|
||||
window=(2, 1, 1),
|
||||
window_method="720pwin_by_size_bysize",
|
||||
version=True,
|
||||
device="cpu",
|
||||
dtype=torch.float32,
|
||||
operations=comfy_ops.disable_weight_init,
|
||||
)
|
||||
generator = torch.Generator(device="cpu").manual_seed(11)
|
||||
vid = torch.randn(8, dim, generator=generator)
|
||||
txt = torch.randn(3, dim, generator=generator)
|
||||
vid_shape = torch.tensor([[2, 2, 2]], dtype=torch.long)
|
||||
txt_shape = torch.tensor([[3]], dtype=torch.long)
|
||||
calls = []
|
||||
|
||||
def fake_optimized_var_attention(**kwargs):
|
||||
calls.append(kwargs)
|
||||
return kwargs["q"]
|
||||
|
||||
monkeypatch.setattr(seedvr_model, "optimized_var_attention", fake_optimized_var_attention)
|
||||
|
||||
vid_out, txt_out = attn(vid, txt, vid_shape, txt_shape, seedvr_model.Cache(disable=True))
|
||||
|
||||
assert tuple(vid_out.shape) == (8, dim)
|
||||
assert tuple(txt_out.shape) == (3, dim)
|
||||
assert len(calls) == 1
|
||||
call = calls[0]
|
||||
assert tuple(call["q"].shape) == (14, heads, head_dim)
|
||||
assert tuple(call["k"].shape) == (14, heads, head_dim)
|
||||
assert tuple(call["v"].shape) == (14, heads, head_dim)
|
||||
assert call["heads"] == heads
|
||||
assert call["skip_reshape"] is True
|
||||
assert call["skip_output_reshape"] is True
|
||||
torch.testing.assert_close(
|
||||
call["cu_seqlens_q"],
|
||||
torch.tensor([0, 7, 14], dtype=torch.int32),
|
||||
rtol=0,
|
||||
atol=0,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
call["cu_seqlens_k"],
|
||||
torch.tensor([0, 7, 14], dtype=torch.int32),
|
||||
rtol=0,
|
||||
atol=0,
|
||||
)
|
||||
|
||||
|
||||
def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch):
|
||||
heads = 2
|
||||
head_dim = 3
|
||||
q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim)
|
||||
k = q + 100
|
||||
v = q + 200
|
||||
cu = torch.tensor([0, 2, 5], dtype=torch.int32)
|
||||
calls = []
|
||||
|
||||
def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs):
|
||||
calls.append(
|
||||
{
|
||||
"q_shape": tuple(q_arg.shape),
|
||||
"k_shape": tuple(k_arg.shape),
|
||||
"v_shape": tuple(v_arg.shape),
|
||||
"heads": heads_arg,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
)
|
||||
return q_arg + v_arg
|
||||
|
||||
monkeypatch.setattr(attention, "optimized_attention", fake_optimized_attention)
|
||||
|
||||
out = var_attention_optimized_split(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
heads,
|
||||
cu,
|
||||
cu,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=True,
|
||||
)
|
||||
|
||||
assert tuple(out.shape) == (5, heads, head_dim)
|
||||
assert len(calls) == 2
|
||||
assert calls[0]["q_shape"] == (1, heads, 2, head_dim)
|
||||
assert calls[1]["q_shape"] == (1, heads, 3, head_dim)
|
||||
assert all(call["heads"] == heads for call in calls)
|
||||
assert all(call["kwargs"]["skip_reshape"] is True for call in calls)
|
||||
assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls)
|
||||
torch.testing.assert_close(out, q + v, rtol=0, atol=0)
|
||||
|
||||
|
||||
def test_var_attention_optimized_split_rejects_bad_offsets():
|
||||
q = torch.randn(5, 2, 3)
|
||||
cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32)
|
||||
cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32)
|
||||
|
||||
with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"):
|
||||
var_attention_optimized_split(
|
||||
q,
|
||||
q,
|
||||
q,
|
||||
2,
|
||||
cu_bad,
|
||||
cu_ok,
|
||||
skip_reshape=True,
|
||||
skip_output_reshape=True,
|
||||
)
|
||||
@ -1,308 +0,0 @@
|
||||
"""Consolidated SeedVR2 model/graph/forward regression tests.
|
||||
|
||||
Merged from:
|
||||
- seedvr_model_test.py
|
||||
- test_seedvr_7b_final_block_text_path.py
|
||||
- test_seedvr_forward_no_device_cast.py
|
||||
- test_seedvr_latent_format.py
|
||||
- test_seedvr2_vae_graph_boundaries.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
args.cpu = True
|
||||
|
||||
import comfy # noqa: E402
|
||||
import comfy.latent_formats # noqa: E402
|
||||
import comfy.ldm.seedvr.model # noqa: E402
|
||||
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
|
||||
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
|
||||
import comfy.model_management # noqa: E402
|
||||
import comfy.sample # noqa: E402
|
||||
import comfy.sd as sd_mod # noqa: E402
|
||||
import nodes as nodes_mod # noqa: E402
|
||||
from comfy.ldm.seedvr.model import NaDiT # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers from seedvr_model_test.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_standin(positive_conditioning):
|
||||
class _StandIn(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer(
|
||||
"positive_conditioning", positive_conditioning
|
||||
)
|
||||
|
||||
_resolve_text_conditioning = NaDiT._resolve_text_conditioning
|
||||
|
||||
return _StandIn()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers from test_seedvr_7b_final_block_text_path.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StubModule(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
|
||||
def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]:
|
||||
flags = []
|
||||
|
||||
class _Block(_StubModule):
|
||||
def __init__(self, *args, **kwargs):
|
||||
flags.append(kwargs["is_last_layer"])
|
||||
super().__init__()
|
||||
|
||||
monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule)
|
||||
monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule)
|
||||
monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule)
|
||||
monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block)
|
||||
|
||||
seedvr_model.NaDiT(
|
||||
norm_eps=1e-5,
|
||||
qk_rope=None,
|
||||
num_layers=4,
|
||||
mlp_type="normal",
|
||||
vid_dim=vid_dim,
|
||||
txt_in_dim=txt_in_dim,
|
||||
heads=24,
|
||||
mm_layers=3,
|
||||
)
|
||||
|
||||
return flags
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers from test_seedvr_latent_format.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _Model:
|
||||
def __init__(self, latent_format):
|
||||
self._latent_format = latent_format
|
||||
|
||||
def get_model_object(self, name):
|
||||
assert name == "latent_format"
|
||||
return self._latent_format
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers from test_seedvr2_vae_graph_boundaries.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _Patcher:
|
||||
def get_free_memory(self, device):
|
||||
return 1024 * 1024 * 1024
|
||||
|
||||
|
||||
class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
|
||||
def __init__(self, encoded):
|
||||
nn.Module.__init__(self)
|
||||
self.encoded = encoded
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.seen = []
|
||||
|
||||
def encode(self, x):
|
||||
self.seen.append(tuple(x.shape))
|
||||
return self.encoded.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
|
||||
class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
|
||||
def __init__(self):
|
||||
nn.Module.__init__(self)
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.calls = []
|
||||
|
||||
def decode(self, z, seedvr2_tiling=None):
|
||||
self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling})
|
||||
if z.ndim == 4:
|
||||
b, tc, h, w = z.shape
|
||||
t = tc // 16
|
||||
else:
|
||||
b, _, t, h, w = z.shape
|
||||
return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
|
||||
|
||||
|
||||
def _make_vae(wrapper):
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
vae.first_stage_model = wrapper
|
||||
vae.device = torch.device("cpu")
|
||||
vae.output_device = torch.device("cpu")
|
||||
vae.vae_dtype = torch.float32
|
||||
vae.latent_channels = 16
|
||||
vae.latent_dim = 3
|
||||
vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8)
|
||||
vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
vae.output_channels = 3
|
||||
vae.disable_offload = True
|
||||
vae.extra_1d_channel = None
|
||||
vae.crop_input = False
|
||||
vae.not_video = False
|
||||
vae.patcher = _Patcher()
|
||||
vae.process_input = lambda image: image
|
||||
vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0)
|
||||
vae.vae_output_dtype = lambda: torch.float32
|
||||
vae.memory_used_encode = lambda shape, dtype: 1
|
||||
vae.memory_used_decode = lambda shape, dtype: 1
|
||||
vae.throw_exception_if_invalid = lambda: None
|
||||
vae.vae_encode_crop_pixels = lambda pixels: pixels
|
||||
vae.spacial_compression_decode = lambda: 8
|
||||
vae.temporal_compression_decode = lambda: 4
|
||||
return vae
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests from seedvr_model_test.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_missing_context_falls_back_to_positive_buffer():
|
||||
"""AC: ``context is None`` falls back to the registered
|
||||
``positive_conditioning`` buffer and runs to completion — no
|
||||
silent zero substitution, no raised exception.
|
||||
"""
|
||||
pos_buffer = torch.full((58, 5120), 7.0)
|
||||
standin = _make_standin(pos_buffer)
|
||||
txt, txt_shape = standin._resolve_text_conditioning(None)
|
||||
assert txt.shape == (58, 5120)
|
||||
assert (txt == 7.0).all(), (
|
||||
"fallback path must use the positive_conditioning buffer "
|
||||
"verbatim, not a zero tensor"
|
||||
)
|
||||
assert txt_shape.shape == (1, 1)
|
||||
assert txt_shape[0, 0].item() == 58
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests from test_seedvr_7b_final_block_text_path.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch):
|
||||
assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
]
|
||||
|
||||
|
||||
def test_seedvr2_7b_rope3d_matches_wrapper_oracle():
|
||||
rope = seedvr_model.get_na_rope("rope3d", dim=64)
|
||||
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||
q = torch.randn(4, 2, 128, generator=generator)
|
||||
k = torch.randn(4, 2, 128, generator=generator)
|
||||
shape = torch.tensor([[1, 2, 2]], dtype=torch.long)
|
||||
freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1)
|
||||
|
||||
expected_q = seedvr_model._apply_seedvr2_rotary_emb(
|
||||
freqs,
|
||||
q.permute(1, 0, 2).float(),
|
||||
).to(q.dtype).permute(1, 0, 2)
|
||||
expected_k = seedvr_model._apply_seedvr2_rotary_emb(
|
||||
freqs,
|
||||
k.permute(1, 0, 2).float(),
|
||||
).to(k.dtype).permute(1, 0, 2)
|
||||
|
||||
actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True))
|
||||
|
||||
torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0)
|
||||
torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests from test_seedvr_latent_format.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion():
|
||||
latent_format = comfy.latent_formats.SeedVR2()
|
||||
latent_image = torch.zeros(1, 1, 4, 5)
|
||||
|
||||
fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image)
|
||||
|
||||
assert latent_format.latent_channels == 16
|
||||
assert latent_format.latent_dimensions == 2
|
||||
assert fixed.shape == (1, 16, 4, 5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests from test_seedvr2_vae_graph_boundaries.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch):
|
||||
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||
|
||||
encoded = torch.full((1, 16, 2, 4, 5), 2.0)
|
||||
vae = _make_vae(_EncodeWrapper(encoded))
|
||||
pixels = torch.zeros(1, 5, 32, 40, 3)
|
||||
|
||||
node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0]
|
||||
node_latent = node_output["samples"]
|
||||
assert set(node_output) == {"samples"}
|
||||
assert tuple(node_latent.shape) == (1, 16, 2, 4, 5)
|
||||
assert node_latent.dtype == torch.float32
|
||||
assert node_latent.stride()[-1] == 1
|
||||
assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152))
|
||||
|
||||
tiled = torch.full((1, 16, 2, 4, 5), 3.0)
|
||||
monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled))
|
||||
tiled_output = nodes_mod.VAEEncodeTiled().encode(
|
||||
vae,
|
||||
pixels,
|
||||
tile_size=512,
|
||||
overlap=64,
|
||||
temporal_size=16,
|
||||
temporal_overlap=4,
|
||||
)[0]
|
||||
tiled_latent = tiled_output["samples"]
|
||||
assert set(tiled_output) == {"samples"}
|
||||
assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5)
|
||||
assert tiled_latent.dtype == torch.float32
|
||||
assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152))
|
||||
|
||||
|
||||
def test_vaedecode_tiled_visible_inputs_are_seedvr2_decode_tiling_authority(monkeypatch):
|
||||
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||
vae = _make_vae(_DecodeWrapper())
|
||||
|
||||
nodes_mod.VAEDecodeTiled().decode(
|
||||
vae,
|
||||
{"samples": torch.zeros(1, 16, 2, 4, 5)},
|
||||
tile_size=512,
|
||||
overlap=64,
|
||||
temporal_size=16,
|
||||
temporal_overlap=4,
|
||||
)
|
||||
|
||||
assert vae.first_stage_model.calls == [
|
||||
{
|
||||
"shape": (1, 16, 2, 4, 5),
|
||||
"seedvr2_tiling": {
|
||||
"enable_tiling": True,
|
||||
"tile_size": (512, 512),
|
||||
"tile_overlap": (64, 64),
|
||||
"temporal_size": 16,
|
||||
"temporal_overlap": 4,
|
||||
},
|
||||
}
|
||||
]
|
||||
@ -1,91 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
||||
from comfy_extras import nodes_seedvr # noqa: E402
|
||||
|
||||
|
||||
def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper:
|
||||
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||
vae_mod.VideoAutoencoderKLWrapper
|
||||
)
|
||||
nn.Module.__init__(wrapper)
|
||||
return wrapper
|
||||
|
||||
|
||||
def _fingerprint_decode_(self, z, return_dict=True):
|
||||
b = int(z.shape[0])
|
||||
t = int(z.shape[2])
|
||||
h = int(z.shape[3])
|
||||
w = int(z.shape[4])
|
||||
out = torch.empty(b, 3, t, h * 8, w * 8)
|
||||
for batch_idx in range(b):
|
||||
out[batch_idx].fill_(float(batch_idx + 1))
|
||||
return out
|
||||
|
||||
|
||||
def _decode_with_patches(wrapper, z):
|
||||
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_):
|
||||
return wrapper.decode(z)
|
||||
|
||||
|
||||
def test_decode_b2_t3_multi_frame_batch_unchanged():
|
||||
wrapper = _make_wrapper()
|
||||
|
||||
out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2))
|
||||
|
||||
assert tuple(out.shape) == (2, 3, 3, 16, 16)
|
||||
|
||||
|
||||
class _Wrapper(vae_mod.VideoAutoencoderKLWrapper):
|
||||
def __init__(self):
|
||||
nn.Module.__init__(self)
|
||||
self.calls = []
|
||||
|
||||
def parameters(self):
|
||||
return iter([torch.nn.Parameter(torch.zeros(()))])
|
||||
|
||||
def _decode_stub(self, latent):
|
||||
self.calls.append(tuple(latent.shape))
|
||||
return torch.zeros(latent.shape[0], 3, latent.shape[2], latent.shape[3] * 8, latent.shape[4] * 8)
|
||||
|
||||
|
||||
def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preprocessor_state():
|
||||
wrapper = _Wrapper()
|
||||
|
||||
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
|
||||
out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5))
|
||||
|
||||
assert tuple(out.shape) == (1, 3, 2, 32, 40)
|
||||
assert wrapper.calls == [(1, 16, 2, 4, 5)]
|
||||
|
||||
|
||||
def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents():
|
||||
wrapper = _Wrapper()
|
||||
|
||||
with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"):
|
||||
wrapper.decode(torch.zeros(1, 16, 4))
|
||||
|
||||
|
||||
def _t_padded(t_in: int) -> int:
|
||||
if t_in == 1:
|
||||
return 1
|
||||
if t_in <= 4:
|
||||
return 5
|
||||
if (t_in - 1) % 4 == 0:
|
||||
return t_in
|
||||
return t_in + (4 - ((t_in - 1) % 4))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("t_in", [1, 5, 9])
|
||||
def test_t_padded_matches_cut_videos(t_in):
|
||||
dummy = torch.zeros(1, t_in, 1, 1, 1)
|
||||
assert nodes_seedvr.cut_videos(dummy).shape[1] == _t_padded(t_in)
|
||||
@ -1,347 +0,0 @@
|
||||
from contextlib import ExitStack
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
||||
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
|
||||
import comfy.sd as sd_mod # noqa: E402
|
||||
from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# From test_seedvr_vae_tiled_decode_latent_min_size_override.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
|
||||
from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae
|
||||
|
||||
class StubVAEModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.slicing_latent_min_size = 2
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.device = torch.device("cpu")
|
||||
self.use_slicing = True
|
||||
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
|
||||
self.decode_min_sizes = []
|
||||
self.memory_states = []
|
||||
|
||||
def decode_(self, t_chunk):
|
||||
self.decode_min_sizes.append(self.slicing_latent_min_size)
|
||||
return VideoAutoencoderKL.slicing_decode(self, t_chunk)
|
||||
|
||||
def _decode(self, z, memory_state=MemoryState.DISABLED):
|
||||
self.memory_states.append(memory_state)
|
||||
b, c, d, h, w = z.shape
|
||||
return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype)
|
||||
|
||||
vae = StubVAEModel()
|
||||
z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32)
|
||||
|
||||
tiled_vae(
|
||||
z,
|
||||
vae,
|
||||
tile_size=(64, 64),
|
||||
tile_overlap=(0, 0),
|
||||
temporal_size=0,
|
||||
temporal_overlap=0,
|
||||
encode=False,
|
||||
)
|
||||
|
||||
assert vae.decode_min_sizes == [5]
|
||||
assert vae.memory_states == [MemoryState.DISABLED]
|
||||
assert vae.slicing_latent_min_size == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# From test_seedvr_vae_tiled_encode_runt_slice_override.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_zero_temporal_size_preserves_min_size_when_encode_raises():
|
||||
from comfy.ldm.seedvr.vae import tiled_vae
|
||||
|
||||
class RaisingVAEModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.slicing_sample_min_size = 4
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.device = torch.device("cpu")
|
||||
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
|
||||
|
||||
def encode(self, t_chunk):
|
||||
raise RuntimeError("simulated encode failure")
|
||||
|
||||
vae = RaisingVAEModel()
|
||||
x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32)
|
||||
|
||||
raised = False
|
||||
try:
|
||||
tiled_vae(
|
||||
x,
|
||||
vae,
|
||||
tile_size=(64, 64),
|
||||
tile_overlap=(0, 0),
|
||||
temporal_size=0,
|
||||
temporal_overlap=0,
|
||||
encode=True,
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
if "simulated encode failure" not in str(exc):
|
||||
raise
|
||||
raised = True
|
||||
|
||||
assert raised
|
||||
assert vae.slicing_sample_min_size == 4
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# From test_seedvr_vae_tiled_temporal_slicing.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _SlicingDecodeVAE(nn.Module):
|
||||
def __init__(self, slicing_latent_min_size):
|
||||
super().__init__()
|
||||
self.slicing_latent_min_size = slicing_latent_min_size
|
||||
self.spatial_downsample_factor = 8
|
||||
self.temporal_downsample_factor = 4
|
||||
self.device = torch.device("cpu")
|
||||
self.use_slicing = True
|
||||
self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32))
|
||||
self.decode_min_sizes = []
|
||||
self.memory_states = []
|
||||
|
||||
def decode_(self, z):
|
||||
self.decode_min_sizes.append(self.slicing_latent_min_size)
|
||||
return vae_mod.VideoAutoencoderKL.slicing_decode(self, z)
|
||||
|
||||
def _decode(self, z, memory_state=MemoryState.DISABLED):
|
||||
self.memory_states.append(memory_state)
|
||||
x = z[:, :1].repeat(
|
||||
1,
|
||||
3,
|
||||
1,
|
||||
self.spatial_downsample_factor,
|
||||
self.spatial_downsample_factor,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size():
|
||||
vae = _SlicingDecodeVAE(slicing_latent_min_size=2)
|
||||
z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8)
|
||||
|
||||
tiled_vae(
|
||||
z,
|
||||
vae,
|
||||
tile_size=(64, 64),
|
||||
tile_overlap=(0, 0),
|
||||
temporal_size=12,
|
||||
temporal_overlap=4,
|
||||
encode=False,
|
||||
)
|
||||
|
||||
assert vae.decode_min_sizes == [2]
|
||||
assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE]
|
||||
assert vae.slicing_latent_min_size == 2
|
||||
|
||||
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||
vae_mod.VideoAutoencoderKLWrapper
|
||||
)
|
||||
nn.Module.__init__(wrapper)
|
||||
seedvr2_tiling = {
|
||||
"enable_tiling": True,
|
||||
"tile_size": (64, 64),
|
||||
"tile_overlap": (0, 0),
|
||||
"temporal_size": 8,
|
||||
"temporal_overlap": 7,
|
||||
}
|
||||
|
||||
captured = {}
|
||||
|
||||
def _fake_tiled_vae(latent, model, **kwargs):
|
||||
captured.update(kwargs)
|
||||
return torch.zeros(1, 3, 1, 16, 16)
|
||||
|
||||
with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae):
|
||||
wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling)
|
||||
|
||||
assert captured["temporal_overlap"] == 7
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# From test_vae_decode_tiled_dispatcher_seedvr2_4d.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _force_oom(*a, **k):
|
||||
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
|
||||
|
||||
|
||||
def _make_vae(first_stage_model, latent_channels, latent_dim):
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
vae.first_stage_model = first_stage_model
|
||||
vae.patcher = MagicMock()
|
||||
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
|
||||
vae.device = vae.output_device = torch.device("cpu")
|
||||
vae.vae_dtype = torch.float32
|
||||
vae.disable_offload = True
|
||||
vae.extra_1d_channel = None
|
||||
vae.upscale_ratio = vae.downscale_ratio = 8
|
||||
vae.upscale_index_formula = vae.downscale_index_formula = None
|
||||
vae.output_channels = 3
|
||||
vae.latent_channels = latent_channels
|
||||
vae.latent_dim = latent_dim
|
||||
vae.vae_output_dtype = lambda: torch.float32
|
||||
vae.spacial_compression_decode = lambda: 8
|
||||
vae.process_input = lambda x: x
|
||||
vae.process_output = lambda x: x
|
||||
vae.throw_exception_if_invalid = lambda: None
|
||||
vae.memory_used_decode = lambda *a, **k: 1
|
||||
return vae
|
||||
|
||||
|
||||
def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode):
|
||||
mm = sd_mod.model_management
|
||||
with ExitStack() as stack:
|
||||
stack.enter_context(patch.object(mm, "raise_non_oom", lambda e: None))
|
||||
stack.enter_context(patch.object(mm, "load_models_gpu", lambda *a, **k: None))
|
||||
stack.enter_context(patch.object(mm, "soft_empty_cache", lambda: None))
|
||||
stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_seedvr2", seedvr2_call))
|
||||
stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_", generic_call))
|
||||
if patch_wrapper_decode:
|
||||
stack.enter_context(patch.object(
|
||||
seedvr_vae_mod.VideoAutoencoderKLWrapper, "decode",
|
||||
side_effect=_force_oom))
|
||||
vae.decode(samples)
|
||||
|
||||
|
||||
def test_4d_seedvr2_latent_routes_to_decode_tiled_seedvr2():
|
||||
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||
seedvr_vae_mod.VideoAutoencoderKLWrapper)
|
||||
vae = _make_vae(wrapper, latent_channels=16, latent_dim=3)
|
||||
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
|
||||
_dispatch(vae, torch.zeros(1, 16 * 3, 8, 8), seedvr2_call, generic_call, True)
|
||||
assert seedvr2_call.call_count == 1
|
||||
assert generic_call.call_count == 0
|
||||
|
||||
|
||||
def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled():
|
||||
first_stage = MagicMock()
|
||||
first_stage.decode = MagicMock(side_effect=_force_oom)
|
||||
vae = _make_vae(first_stage, latent_channels=4, latent_dim=2)
|
||||
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
|
||||
_dispatch(vae, torch.zeros(1, 4, 8, 8), seedvr2_call, generic_call, False)
|
||||
assert generic_call.call_count == 1
|
||||
assert seedvr2_call.call_count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# From test_vae_encode_tiled_fallback_dispatcher_seedvr2.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _populate_common_vae_attrs_fallback(vae):
|
||||
vae.patcher = MagicMock()
|
||||
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
|
||||
vae.device = torch.device("cpu")
|
||||
vae.output_device = torch.device("cpu")
|
||||
vae.vae_dtype = torch.float32
|
||||
vae.disable_offload = True
|
||||
vae.extra_1d_channel = None
|
||||
vae.upscale_ratio = 8
|
||||
vae.upscale_index_formula = None
|
||||
vae.output_channels = 3
|
||||
vae.latent_channels = 16
|
||||
vae.latent_dim = 3
|
||||
vae.downscale_ratio = 8
|
||||
vae.downscale_index_formula = None
|
||||
vae.not_video = False
|
||||
vae.crop_input = False
|
||||
vae.pad_channel_value = None
|
||||
|
||||
vae.vae_output_dtype = lambda: torch.float32
|
||||
vae.spacial_compression_encode = lambda: 8
|
||||
vae.process_input = lambda x: x
|
||||
vae.process_output = lambda x: x
|
||||
vae.throw_exception_if_invalid = lambda: None
|
||||
vae.memory_used_encode = lambda *a, **k: 1
|
||||
|
||||
|
||||
def _make_seedvr2_vae_fallback():
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||
seedvr_vae_mod.VideoAutoencoderKLWrapper
|
||||
)
|
||||
vae.first_stage_model = wrapper
|
||||
_populate_common_vae_attrs_fallback(vae)
|
||||
return vae
|
||||
|
||||
|
||||
def _make_non_seedvr2_vae_fallback():
|
||||
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||
vae.first_stage_model = MagicMock()
|
||||
_populate_common_vae_attrs_fallback(vae)
|
||||
return vae
|
||||
|
||||
|
||||
def _force_regular_encode_oom(*args, **kwargs):
|
||||
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
|
||||
|
||||
|
||||
def test_seedvr2_3d_routes_to_encode_tiled_seedvr2_on_oom():
|
||||
vae = _make_seedvr2_vae_fallback()
|
||||
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
|
||||
|
||||
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
|
||||
with patch.object(sd_mod.model_management, "raise_non_oom",
|
||||
lambda e: None), \
|
||||
patch.object(sd_mod.model_management, "load_models_gpu",
|
||||
lambda *a, **k: None), \
|
||||
patch.object(sd_mod.model_management, "soft_empty_cache",
|
||||
lambda: None), \
|
||||
patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode",
|
||||
side_effect=_force_regular_encode_oom), \
|
||||
patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call,
|
||||
create=True), \
|
||||
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
|
||||
vae.encode(pixel_samples)
|
||||
|
||||
assert seedvr2_call.call_count == 1, (
|
||||
f"Expected encode_tiled_seedvr2 to be called once for a SeedVR2 3D "
|
||||
f"input under OOM fallback; got {seedvr2_call.call_count} calls."
|
||||
)
|
||||
assert generic_call.call_count == 0, (
|
||||
f"encode_tiled_3d must NOT be called for a SeedVR2 input; got "
|
||||
f"{generic_call.call_count} calls."
|
||||
)
|
||||
|
||||
|
||||
def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete():
|
||||
vae = _make_non_seedvr2_vae_fallback()
|
||||
vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8)
|
||||
vae.upscale_ratio = (lambda a: a * 4, 8, 8)
|
||||
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
|
||||
|
||||
with patch.object(sd_mod.model_management, "load_models_gpu",
|
||||
lambda *a, **k: None), \
|
||||
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
|
||||
vae.encode_tiled(pixel_samples)
|
||||
|
||||
assert generic_call.call_args.kwargs["overlap"] == (1, 64, 64)
|
||||
@ -1,126 +0,0 @@
|
||||
"""Unit tests for ``comfy_extras.nodes_seedvr.SeedVR2ProgressiveSampler``."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args as cli_args
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
cli_args.cpu = True
|
||||
|
||||
import comfy.sample # noqa: E402
|
||||
import comfy_extras.nodes_seedvr as nodes_seedvr_mod # noqa: E402
|
||||
from comfy_extras.nodes_seedvr import SeedVR2ProgressiveSampler # noqa: E402
|
||||
|
||||
_LAT_C = 16
|
||||
_COND_C = 17
|
||||
|
||||
|
||||
def _make_inputs(B: int = 1, T: int = 5, H: int = 8, W: int = 8):
|
||||
"""Build minimal SeedVR2-shaped sampling inputs."""
|
||||
samples_5d = torch.arange(
|
||||
B * _LAT_C * T * H * W, dtype=torch.float32
|
||||
).reshape(B, _LAT_C, T, H, W)
|
||||
samples = samples_5d.reshape(B, _LAT_C * T, H, W).contiguous()
|
||||
|
||||
cond_5d = torch.arange(
|
||||
B * _COND_C * T * H * W, dtype=torch.float32
|
||||
).reshape(B, _COND_C, T, H, W) + 10000.0
|
||||
cond = cond_5d.reshape(B, _COND_C * T, H, W).contiguous()
|
||||
|
||||
text_pos = torch.zeros(1, 4, 32)
|
||||
text_neg = torch.zeros(1, 4, 32)
|
||||
positive = [[text_pos, {"condition": cond.clone()}]]
|
||||
negative = [[text_neg, {"condition": cond.clone()}]]
|
||||
latent_image = {"samples": samples}
|
||||
return latent_image, positive, negative, samples_5d, cond_5d
|
||||
|
||||
|
||||
def _identity_fix_empty(model, latent_image, downscale_ratio_spacial=None):
|
||||
return latent_image
|
||||
|
||||
|
||||
def _fingerprinted_prepare_noise(latent_image, seed, batch_inds=None):
|
||||
"""Return a tensor whose values encode ``(seed, position)``."""
|
||||
base = torch.arange(
|
||||
latent_image.numel(), dtype=torch.float32
|
||||
).reshape(latent_image.shape)
|
||||
return base + float(seed) * 1e6
|
||||
|
||||
|
||||
def test_progressive_sampler_schema_exposes_manual_default_auto_chunking():
|
||||
schema = SeedVR2ProgressiveSampler.define_schema()
|
||||
inputs = {item.id: item for item in schema.inputs}
|
||||
|
||||
assert inputs["chunking_mode"].options == ["manual", "auto"]
|
||||
assert inputs["chunking_mode"].default == "manual"
|
||||
|
||||
|
||||
def test_auto_chunking_walks_two_three_four_chunk_ladder():
|
||||
"""Auto mode must walk 2-, 3-, then 4-chunk geometries on OOM."""
|
||||
latent, pos, neg, _, _ = _make_inputs(T=17)
|
||||
calls = []
|
||||
|
||||
def _oom_until_four_chunks(model, noise, steps, cfg, sampler_name,
|
||||
scheduler, positive, negative,
|
||||
latent_image, denoise=1.0,
|
||||
noise_mask=None, seed=None):
|
||||
calls.append(tuple(latent_image.shape))
|
||||
if latent_image.shape[1] > _LAT_C * 5:
|
||||
raise torch.cuda.OutOfMemoryError("chunk too large")
|
||||
return latent_image.clone()
|
||||
|
||||
with patch.object(comfy.sample, "sample",
|
||||
side_effect=_oom_until_four_chunks), \
|
||||
patch.object(comfy.sample, "fix_empty_latent_channels",
|
||||
side_effect=_identity_fix_empty), \
|
||||
patch.object(comfy.sample, "prepare_noise",
|
||||
side_effect=_fingerprinted_prepare_noise), \
|
||||
patch.object(nodes_seedvr_mod.comfy.model_management,
|
||||
"soft_empty_cache") as soft_empty:
|
||||
out = SeedVR2ProgressiveSampler.execute(
|
||||
model=None, seed=0, steps=2, cfg=1.0,
|
||||
sampler_name="euler", scheduler="simple",
|
||||
positive=pos, negative=neg, latent=latent,
|
||||
denoise=1.0, frames_per_chunk=65, temporal_overlap=0,
|
||||
chunking_mode="auto",
|
||||
)
|
||||
|
||||
assert calls[:4] == [
|
||||
(1, _LAT_C * 17, 8, 8),
|
||||
(1, _LAT_C * 9, 8, 8),
|
||||
(1, _LAT_C * 6, 8, 8),
|
||||
(1, _LAT_C * 5, 8, 8),
|
||||
]
|
||||
assert torch.equal(out.result[0]["samples"], latent["samples"])
|
||||
assert soft_empty.call_count == 3
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bad_chunk", [0, -1, 2])
|
||||
def test_t3_invalid_frames_per_chunk_raises_value_error(bad_chunk):
|
||||
"""``frames_per_chunk`` violating 4n+1 (or <1) must raise ``ValueError`` before any model invocation."""
|
||||
latent, pos, neg, _, _ = _make_inputs(T=5)
|
||||
|
||||
sampler_called = {"n": 0}
|
||||
|
||||
def _should_not_be_called(*args, **kwargs):
|
||||
sampler_called["n"] += 1
|
||||
return torch.zeros(1)
|
||||
|
||||
with patch.object(comfy.sample, "sample",
|
||||
side_effect=_should_not_be_called), \
|
||||
patch.object(comfy.sample, "fix_empty_latent_channels",
|
||||
side_effect=_identity_fix_empty), \
|
||||
patch.object(comfy.sample, "prepare_noise",
|
||||
side_effect=_fingerprinted_prepare_noise):
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
SeedVR2ProgressiveSampler.execute(
|
||||
model=None, seed=0, steps=2, cfg=1.0,
|
||||
sampler_name="euler", scheduler="simple",
|
||||
positive=pos, negative=neg, latent=latent,
|
||||
denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0,
|
||||
)
|
||||
assert str(bad_chunk) in str(excinfo.value)
|
||||
assert sampler_called["n"] == 0
|
||||
Loading…
Reference in New Issue
Block a user