mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 08:19:32 +08:00
Merge ad04a6199e into 6978a466b8
This commit is contained in:
commit
6b0a3b4b6e
@ -779,6 +779,9 @@ class ACEAudio(LatentFormat):
|
|||||||
latent_channels = 8
|
latent_channels = 8
|
||||||
latent_dimensions = 2
|
latent_dimensions = 2
|
||||||
|
|
||||||
|
class SeedVR2(LatentFormat):
|
||||||
|
latent_channels = 16
|
||||||
|
|
||||||
class ACEAudio15(LatentFormat):
|
class ACEAudio15(LatentFormat):
|
||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
|
|||||||
@ -22,7 +22,7 @@ def torch_cat_if_needed(xl, dim):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_timestep_embedding(timesteps, embedding_dim):
|
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1):
|
||||||
"""
|
"""
|
||||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||||
From Fairseq.
|
From Fairseq.
|
||||||
@ -33,11 +33,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
|||||||
assert len(timesteps.shape) == 1
|
assert len(timesteps.shape) == 1
|
||||||
|
|
||||||
half_dim = embedding_dim // 2
|
half_dim = embedding_dim // 2
|
||||||
emb = math.log(10000) / (half_dim - 1)
|
emb = math.log(10000) / (half_dim - downscale_freq_shift)
|
||||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||||
emb = emb.to(device=timesteps.device)
|
emb = emb.to(device=timesteps.device)
|
||||||
emb = timesteps.float()[:, None] * emb[None, :]
|
emb = timesteps.float()[:, None] * emb[None, :]
|
||||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||||
|
if flip_sin_to_cos:
|
||||||
|
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||||
if embedding_dim % 2 == 1: # zero pad
|
if embedding_dim % 2 == 1: # zero pad
|
||||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||||
return emb
|
return emb
|
||||||
|
|||||||
77
comfy/ldm/seedvr/attention.py
Normal file
77
comfy/ldm/seedvr/attention.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.ldm.modules import attention as _attention
|
||||||
|
|
||||||
|
|
||||||
|
def _var_attention_qkv(q, k, v, heads, skip_reshape):
|
||||||
|
if skip_reshape:
|
||||||
|
return q, k, v, q.shape[-1]
|
||||||
|
total_tokens, embed_dim = q.shape
|
||||||
|
head_dim = embed_dim // heads
|
||||||
|
return (
|
||||||
|
q.view(total_tokens, heads, head_dim),
|
||||||
|
k.view(k.shape[0], heads, head_dim),
|
||||||
|
v.view(v.shape[0], heads, head_dim),
|
||||||
|
head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _var_attention_output(out, heads, head_dim, skip_output_reshape):
|
||||||
|
if skip_output_reshape:
|
||||||
|
return out
|
||||||
|
return out.reshape(-1, heads * head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
|
||||||
|
if cu_seqlens.dtype not in (torch.int32, torch.int64):
|
||||||
|
raise ValueError(f"{name} must use an integer dtype")
|
||||||
|
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
|
||||||
|
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
|
||||||
|
if cu_seqlens[0].item() != 0:
|
||||||
|
raise ValueError(f"{name} must start at 0")
|
||||||
|
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
|
||||||
|
raise ValueError(f"{name} must be strictly increasing")
|
||||||
|
if cu_seqlens[-1].item() != token_count:
|
||||||
|
raise ValueError(f"{name} does not match token count")
|
||||||
|
|
||||||
|
|
||||||
|
def _split_indices(cu_seqlens):
|
||||||
|
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
|
||||||
|
|
||||||
|
|
||||||
|
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||||
|
|
||||||
|
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
|
||||||
|
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0])
|
||||||
|
if cu_seqlens_k[-1].item() != v.shape[0]:
|
||||||
|
raise ValueError("cu_seqlens_k does not match v token count")
|
||||||
|
|
||||||
|
q_split_indices = _split_indices(cu_seqlens_q)
|
||||||
|
k_split_indices = _split_indices(cu_seqlens_k)
|
||||||
|
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
|
||||||
|
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
|
||||||
|
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
|
||||||
|
if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits):
|
||||||
|
raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count")
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits):
|
||||||
|
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
|
||||||
|
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
|
||||||
|
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
|
||||||
|
out_dtype = q_i.dtype
|
||||||
|
if _attention.optimized_attention is _attention.attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16):
|
||||||
|
q_i = q_i.to(torch.bfloat16)
|
||||||
|
k_i = k_i.to(torch.bfloat16)
|
||||||
|
v_i = v_i.to(torch.bfloat16)
|
||||||
|
out_i = _attention.optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
|
||||||
|
if out_i.dtype != out_dtype:
|
||||||
|
out_i = out_i.to(out_dtype)
|
||||||
|
out.append(out_i.squeeze(0).permute(1, 0, 2))
|
||||||
|
|
||||||
|
out = torch.cat(out, dim=0)
|
||||||
|
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||||
|
|
||||||
|
|
||||||
|
optimized_var_attention = var_attention_optimized_split
|
||||||
340
comfy/ldm/seedvr/color_fix.py
Normal file
340
comfy/ldm/seedvr/color_fix.py
Normal file
@ -0,0 +1,340 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||||
|
from comfy.ldm.seedvr.vae import safe_interpolate_operation
|
||||||
|
from comfy.ldm.seedvr.constants import (
|
||||||
|
CIELAB_DELTA,
|
||||||
|
CIELAB_KAPPA,
|
||||||
|
D65_WHITE_X,
|
||||||
|
D65_WHITE_Z,
|
||||||
|
WAVELET_DECOMP_LEVELS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def wavelet_blur(image: Tensor, radius):
|
||||||
|
max_safe_radius = max(1, min(image.shape[-2:]) // 8)
|
||||||
|
if radius > max_safe_radius:
|
||||||
|
radius = max_safe_radius
|
||||||
|
|
||||||
|
num_channels = image.shape[1]
|
||||||
|
|
||||||
|
kernel_vals = [
|
||||||
|
[0.0625, 0.125, 0.0625],
|
||||||
|
[0.125, 0.25, 0.125],
|
||||||
|
[0.0625, 0.125, 0.0625],
|
||||||
|
]
|
||||||
|
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
||||||
|
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
|
||||||
|
|
||||||
|
image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate')
|
||||||
|
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS):
|
||||||
|
high_freq = torch.zeros_like(image)
|
||||||
|
|
||||||
|
for i in range(levels):
|
||||||
|
radius = 2 ** i
|
||||||
|
low_freq = wavelet_blur(image, radius)
|
||||||
|
high_freq.add_(image).sub_(low_freq)
|
||||||
|
image = low_freq
|
||||||
|
|
||||||
|
return high_freq, low_freq
|
||||||
|
|
||||||
|
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||||
|
|
||||||
|
if content_feat.shape != style_feat.shape:
|
||||||
|
# Resize style to match content spatial dimensions
|
||||||
|
if len(content_feat.shape) >= 3:
|
||||||
|
# safe_interpolate_operation handles FP16 conversion automatically
|
||||||
|
style_feat = safe_interpolate_operation(
|
||||||
|
style_feat,
|
||||||
|
size=content_feat.shape[-2:],
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# Decompose both features into frequency components
|
||||||
|
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
||||||
|
del content_low_freq # Free memory immediately
|
||||||
|
|
||||||
|
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
||||||
|
del style_high_freq # Free memory immediately
|
||||||
|
|
||||||
|
if content_high_freq.shape != style_low_freq.shape:
|
||||||
|
style_low_freq = safe_interpolate_operation(
|
||||||
|
style_low_freq,
|
||||||
|
size=content_high_freq.shape[-2:],
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
content_high_freq.add_(style_low_freq)
|
||||||
|
|
||||||
|
return content_high_freq.clamp_(-1.0, 1.0)
|
||||||
|
|
||||||
|
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
|
||||||
|
original_shape = source.shape
|
||||||
|
|
||||||
|
# Flatten
|
||||||
|
source_flat = source.flatten()
|
||||||
|
reference_flat = reference.flatten()
|
||||||
|
|
||||||
|
# Sort both arrays
|
||||||
|
source_sorted, source_indices = torch.sort(source_flat)
|
||||||
|
reference_sorted, _ = torch.sort(reference_flat)
|
||||||
|
del reference_flat
|
||||||
|
|
||||||
|
# Quantile mapping
|
||||||
|
n_source = len(source_sorted)
|
||||||
|
n_reference = len(reference_sorted)
|
||||||
|
|
||||||
|
if n_source == n_reference:
|
||||||
|
matched_sorted = reference_sorted
|
||||||
|
else:
|
||||||
|
# Interpolate reference to match source quantiles
|
||||||
|
source_quantiles = torch.linspace(0, 1, n_source, device=device)
|
||||||
|
ref_indices = (source_quantiles * (n_reference - 1)).long()
|
||||||
|
ref_indices.clamp_(0, n_reference - 1)
|
||||||
|
matched_sorted = reference_sorted[ref_indices]
|
||||||
|
del source_quantiles, ref_indices, reference_sorted
|
||||||
|
|
||||||
|
del source_sorted, source_flat
|
||||||
|
|
||||||
|
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
|
||||||
|
inverse_indices = torch.argsort(source_indices)
|
||||||
|
del source_indices
|
||||||
|
matched_flat = matched_sorted[inverse_indices]
|
||||||
|
del matched_sorted, inverse_indices
|
||||||
|
|
||||||
|
return matched_flat.reshape(original_shape)
|
||||||
|
|
||||||
|
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||||
|
"""Convert batch of CIELAB images to RGB color space."""
|
||||||
|
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
|
||||||
|
|
||||||
|
# LAB to XYZ
|
||||||
|
fy = (L + 16.0) / 116.0
|
||||||
|
fx = a.div(500.0).add_(fy)
|
||||||
|
fz = fy - b / 200.0
|
||||||
|
del L, a, b
|
||||||
|
|
||||||
|
# XYZ transformation
|
||||||
|
x = torch.where(
|
||||||
|
fx > epsilon,
|
||||||
|
torch.pow(fx, 3.0),
|
||||||
|
fx.mul(116.0).sub_(16.0).div_(kappa)
|
||||||
|
)
|
||||||
|
y = torch.where(
|
||||||
|
fy > epsilon,
|
||||||
|
torch.pow(fy, 3.0),
|
||||||
|
fy.mul(116.0).sub_(16.0).div_(kappa)
|
||||||
|
)
|
||||||
|
z = torch.where(
|
||||||
|
fz > epsilon,
|
||||||
|
torch.pow(fz, 3.0),
|
||||||
|
fz.mul(116.0).sub_(16.0).div_(kappa)
|
||||||
|
)
|
||||||
|
del fx, fy, fz
|
||||||
|
|
||||||
|
# Apply D65 white point (in-place)
|
||||||
|
x.mul_(D65_WHITE_X)
|
||||||
|
# y *= 1.00000 # (no-op, skip)
|
||||||
|
z.mul_(D65_WHITE_Z)
|
||||||
|
|
||||||
|
xyz = torch.stack([x, y, z], dim=1)
|
||||||
|
del x, y, z
|
||||||
|
|
||||||
|
# Matrix multiplication: XYZ -> RGB
|
||||||
|
B, C, H, W = xyz.shape
|
||||||
|
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||||
|
del xyz
|
||||||
|
|
||||||
|
# Ensure dtype consistency for matrix multiplication
|
||||||
|
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
|
||||||
|
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
|
||||||
|
del xyz_flat
|
||||||
|
|
||||||
|
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||||
|
del rgb_linear_flat
|
||||||
|
|
||||||
|
# Apply inverse gamma correction (delinearize)
|
||||||
|
mask = rgb_linear > 0.0031308
|
||||||
|
rgb = torch.where(
|
||||||
|
mask,
|
||||||
|
torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055),
|
||||||
|
rgb_linear * 12.92
|
||||||
|
)
|
||||||
|
del mask, rgb_linear
|
||||||
|
|
||||||
|
return torch.clamp(rgb, 0.0, 1.0)
|
||||||
|
|
||||||
|
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||||
|
"""Convert batch of RGB images to CIELAB color space using D65 illuminant."""
|
||||||
|
# Apply sRGB gamma correction (linearize)
|
||||||
|
mask = rgb > 0.04045
|
||||||
|
rgb_linear = torch.where(
|
||||||
|
mask,
|
||||||
|
torch.pow((rgb + 0.055) / 1.055, 2.4),
|
||||||
|
rgb / 12.92
|
||||||
|
)
|
||||||
|
del mask
|
||||||
|
|
||||||
|
# Matrix multiplication: RGB -> XYZ
|
||||||
|
B, C, H, W = rgb_linear.shape
|
||||||
|
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||||
|
del rgb_linear
|
||||||
|
|
||||||
|
# Ensure dtype consistency for matrix multiplication
|
||||||
|
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
|
||||||
|
xyz_flat = torch.matmul(rgb_flat, matrix.T)
|
||||||
|
del rgb_flat
|
||||||
|
|
||||||
|
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||||
|
del xyz_flat
|
||||||
|
|
||||||
|
# Normalize by D65 white point (in-place)
|
||||||
|
xyz[:, 0].div_(D65_WHITE_X) # X
|
||||||
|
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
|
||||||
|
xyz[:, 2].div_(D65_WHITE_Z) # Z
|
||||||
|
|
||||||
|
# XYZ to LAB transformation
|
||||||
|
epsilon_cubed = epsilon ** 3
|
||||||
|
mask = xyz > epsilon_cubed
|
||||||
|
f_xyz = torch.where(
|
||||||
|
mask,
|
||||||
|
torch.pow(xyz, 1.0 / 3.0),
|
||||||
|
xyz.mul(kappa).add_(16.0).div_(116.0)
|
||||||
|
)
|
||||||
|
del xyz, mask
|
||||||
|
|
||||||
|
# Extract channels and compute LAB
|
||||||
|
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
|
||||||
|
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
|
||||||
|
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
|
||||||
|
del f_xyz
|
||||||
|
|
||||||
|
return torch.stack([L, a, b], dim=1)
|
||||||
|
|
||||||
|
def lab_color_transfer(
|
||||||
|
content_feat: Tensor,
|
||||||
|
style_feat: Tensor,
|
||||||
|
luminance_weight: float = 0.8
|
||||||
|
) -> Tensor:
|
||||||
|
content_feat = wavelet_reconstruction(content_feat, style_feat)
|
||||||
|
|
||||||
|
if content_feat.shape != style_feat.shape:
|
||||||
|
style_feat = safe_interpolate_operation(
|
||||||
|
style_feat,
|
||||||
|
size=content_feat.shape[-2:],
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False
|
||||||
|
)
|
||||||
|
|
||||||
|
device = content_feat.device
|
||||||
|
|
||||||
|
def ensure_float32_precision(c):
|
||||||
|
orig_dtype = c.dtype
|
||||||
|
c = c.float()
|
||||||
|
return c, orig_dtype
|
||||||
|
content_feat, original_dtype = ensure_float32_precision(content_feat)
|
||||||
|
style_feat, _ = ensure_float32_precision(style_feat)
|
||||||
|
|
||||||
|
rgb_to_xyz_matrix = torch.tensor([
|
||||||
|
[0.4124564, 0.3575761, 0.1804375],
|
||||||
|
[0.2126729, 0.7151522, 0.0721750],
|
||||||
|
[0.0193339, 0.1191920, 0.9503041]
|
||||||
|
], dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
xyz_to_rgb_matrix = torch.tensor([
|
||||||
|
[ 3.2404542, -1.5371385, -0.4985314],
|
||||||
|
[-0.9692660, 1.8760108, 0.0415560],
|
||||||
|
[ 0.0556434, -0.2040259, 1.0572252]
|
||||||
|
], dtype=torch.float32, device=device)
|
||||||
|
|
||||||
|
epsilon = CIELAB_DELTA
|
||||||
|
kappa = CIELAB_KAPPA
|
||||||
|
|
||||||
|
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||||
|
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||||
|
|
||||||
|
# Convert to LAB color space
|
||||||
|
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||||
|
del content_feat
|
||||||
|
|
||||||
|
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||||
|
del style_feat, rgb_to_xyz_matrix
|
||||||
|
|
||||||
|
# Match chrominance channels (a*, b*) for accurate color transfer
|
||||||
|
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
|
||||||
|
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
|
||||||
|
|
||||||
|
# Handle luminance with weighted blending
|
||||||
|
if luminance_weight < 1.0:
|
||||||
|
# Partially match luminance for better overall color accuracy
|
||||||
|
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device)
|
||||||
|
# Blend: preserve some content L* for detail, adopt some style L* for color
|
||||||
|
result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight))
|
||||||
|
del matched_L
|
||||||
|
else:
|
||||||
|
# Fully preserve content luminance
|
||||||
|
result_L = content_lab[:, 0]
|
||||||
|
|
||||||
|
del content_lab, style_lab
|
||||||
|
|
||||||
|
# Reconstruct LAB with corrected channels
|
||||||
|
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
|
||||||
|
del result_L, matched_a, matched_b
|
||||||
|
|
||||||
|
# Convert back to RGB
|
||||||
|
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
|
||||||
|
del result_lab, xyz_to_rgb_matrix
|
||||||
|
|
||||||
|
# Convert back to [-1, 1] range (in-place)
|
||||||
|
result = result_rgb.mul_(2.0).sub_(1.0)
|
||||||
|
del result_rgb
|
||||||
|
|
||||||
|
result = result.to(original_dtype)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||||
|
return wavelet_reconstruction(content_feat, style_feat)
|
||||||
|
|
||||||
|
|
||||||
|
def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor:
|
||||||
|
if content_feat.shape != style_feat.shape:
|
||||||
|
style_feat = safe_interpolate_operation(
|
||||||
|
style_feat,
|
||||||
|
size=content_feat.shape[-2:],
|
||||||
|
mode='bilinear',
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
original_dtype = content_feat.dtype
|
||||||
|
content_feat = content_feat.float()
|
||||||
|
style_feat = style_feat.float()
|
||||||
|
|
||||||
|
b, c = content_feat.shape[:2]
|
||||||
|
content_flat = content_feat.reshape(b, c, -1)
|
||||||
|
style_flat = style_feat.reshape(b, c, -1)
|
||||||
|
|
||||||
|
content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1)
|
||||||
|
content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
|
||||||
|
style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1)
|
||||||
|
style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
|
||||||
|
del content_flat, style_flat
|
||||||
|
|
||||||
|
normalized = (content_feat - content_mean) / content_std
|
||||||
|
del content_mean, content_std
|
||||||
|
result = normalized * style_std + style_mean
|
||||||
|
del normalized, style_mean, style_std
|
||||||
|
|
||||||
|
result = result.clamp_(-1.0, 1.0)
|
||||||
|
if result.dtype != original_dtype:
|
||||||
|
result = result.to(original_dtype)
|
||||||
|
return result
|
||||||
72
comfy/ldm/seedvr/constants.py
Normal file
72
comfy/ldm/seedvr/constants.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
"""Named constants for the SeedVR2 integration, grouped by provenance.
|
||||||
|
|
||||||
|
Provenance prefixes:
|
||||||
|
- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline.
|
||||||
|
- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites
|
||||||
|
the upstream config/source path it was lifted from.
|
||||||
|
- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature /
|
||||||
|
ISO / CIE values; cite the standard.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment)
|
||||||
|
# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN)
|
||||||
|
# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070
|
||||||
|
# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT).
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB)
|
||||||
|
SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
# B. Fork heuristics (SEEDVR2 - this integration)
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim.
|
||||||
|
# (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.)
|
||||||
|
SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry.
|
||||||
|
SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case).
|
||||||
|
SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM.
|
||||||
|
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk.
|
||||||
|
SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels).
|
||||||
|
SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16).
|
||||||
|
|
||||||
|
# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing)
|
||||||
|
SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk.
|
||||||
|
SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path.
|
||||||
|
SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path.
|
||||||
|
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path.
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR)
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm.
|
||||||
|
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift.
|
||||||
|
BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem).
|
||||||
|
BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem).
|
||||||
|
BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28.
|
||||||
|
BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28.
|
||||||
|
BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16).
|
||||||
|
BYTEDANCE_GN_CHUNKS_FP32 = 2 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp32).
|
||||||
|
BYTEDANCE_BLOCK_OUT_CHANNELS = (128, 256, 512, 512) # s8_c16_t4_inflation_sd3.yaml:7-11.
|
||||||
|
BYTEDANCE_SLICING_SAMPLE_MIN = 4 # s8_c16_t4_inflation_sd3.yaml:22 (slicing_sample_min_size).
|
||||||
|
BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE = 4 # infer.py:230 (temporal_downsample_factor); the 4n+1 factor.
|
||||||
|
BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE = 8 # infer.py:231 (spatial_downsample_factor).
|
||||||
|
BYTEDANCE_720P_REF_AREA = 45 * 80 # dit_v2/window.py:32 (720p reference area for window scaling).
|
||||||
|
BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal window frames).
|
||||||
|
BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency).
|
||||||
|
BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim).
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
# D. Published standards (cite the literature)
|
||||||
|
# --------------------------------------------------------------------------------------
|
||||||
|
ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864.
|
||||||
|
|
||||||
|
# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65).
|
||||||
|
CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta).
|
||||||
|
CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa).
|
||||||
|
D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1).
|
||||||
|
D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn.
|
||||||
|
WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR).
|
||||||
|
|
||||||
|
# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and
|
||||||
|
# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the
|
||||||
|
# exact existing coefficients move verbatim rather than being retyped here.
|
||||||
1487
comfy/ldm/seedvr/model.py
Normal file
1487
comfy/ldm/seedvr/model.py
Normal file
File diff suppressed because it is too large
Load Diff
1807
comfy/ldm/seedvr/vae.py
Normal file
1807
comfy/ldm/seedvr/vae.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -55,6 +55,7 @@ import comfy.ldm.pixeldit.model
|
|||||||
import comfy.ldm.pixeldit.pid
|
import comfy.ldm.pixeldit.pid
|
||||||
import comfy.ldm.ace.model
|
import comfy.ldm.ace.model
|
||||||
import comfy.ldm.omnigen.omnigen2
|
import comfy.ldm.omnigen.omnigen2
|
||||||
|
import comfy.ldm.seedvr.model
|
||||||
import comfy.ldm.boogu.model
|
import comfy.ldm.boogu.model
|
||||||
import comfy.ldm.qwen_image.model
|
import comfy.ldm.qwen_image.model
|
||||||
import comfy.ldm.ideogram4.model
|
import comfy.ldm.ideogram4.model
|
||||||
@ -931,6 +932,16 @@ class HunyuanDiT(BaseModel):
|
|||||||
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class SeedVR2(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
|
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
condition = kwargs.get("condition", None)
|
||||||
|
if condition is not None:
|
||||||
|
out["condition"] = comfy.conds.CONDRegular(condition)
|
||||||
|
return out
|
||||||
|
|
||||||
class PixArt(BaseModel):
|
class PixArt(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
|
||||||
|
|||||||
@ -598,6 +598,53 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
|
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "seedvr2"
|
||||||
|
dit_config["vid_dim"] = 3072
|
||||||
|
dit_config["heads"] = 24
|
||||||
|
dit_config["num_layers"] = 36
|
||||||
|
# 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.``
|
||||||
|
# submodules) at EVERY block — verified by inspecting the 7B
|
||||||
|
# state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means
|
||||||
|
# ``MMModule.shared_weights=False``). Native NaDiT computes
|
||||||
|
# per-block ``shared_weights = not (i < mm_layers)``, so to keep
|
||||||
|
# every block non-shared we set ``mm_layers = num_layers``.
|
||||||
|
# Without this, blocks at index >= mm_layers (default 10) try to
|
||||||
|
# load ``blocks.N.*.all.*`` keys that don't exist in the file,
|
||||||
|
# silently miss-load → all-black output.
|
||||||
|
dit_config["mm_layers"] = 36
|
||||||
|
dit_config["norm_eps"] = 1e-5
|
||||||
|
dit_config["rope_type"] = "rope3d"
|
||||||
|
dit_config["rope_dim"] = 64
|
||||||
|
dit_config["mlp_type"] = "normal"
|
||||||
|
return dit_config
|
||||||
|
elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "seedvr2"
|
||||||
|
dit_config["vid_dim"] = 3072
|
||||||
|
dit_config["heads"] = 24
|
||||||
|
dit_config["num_layers"] = 36
|
||||||
|
# This checkpoint layout carries shared ``all.`` MMModule keys.
|
||||||
|
# Preserve the historical split: the initial blocks use separate
|
||||||
|
# vid/txt modules, later blocks use shared modules.
|
||||||
|
dit_config["mm_layers"] = 10
|
||||||
|
dit_config["norm_eps"] = 1e-5
|
||||||
|
dit_config["rope_type"] = "rope3d"
|
||||||
|
dit_config["rope_dim"] = 64
|
||||||
|
dit_config["mlp_type"] = "swiglu"
|
||||||
|
return dit_config
|
||||||
|
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["image_model"] = "seedvr2"
|
||||||
|
dit_config["vid_dim"] = 2560
|
||||||
|
dit_config["heads"] = 20
|
||||||
|
dit_config["num_layers"] = 32
|
||||||
|
dit_config["norm_eps"] = 1.0e-05
|
||||||
|
dit_config["mlp_type"] = "swiglu"
|
||||||
|
dit_config["vid_out_norm"] = True
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "wan2.1"
|
dit_config["image_model"] = "wan2.1"
|
||||||
|
|||||||
125
comfy/sd.py
125
comfy/sd.py
@ -1,3 +1,4 @@
|
|||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import torch
|
import torch
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -16,6 +17,7 @@ import comfy.ldm.cosmos.vae
|
|||||||
import comfy.ldm.wan.vae
|
import comfy.ldm.wan.vae
|
||||||
import comfy.ldm.wan.vae2_2
|
import comfy.ldm.wan.vae2_2
|
||||||
import comfy.ldm.hunyuan3d.vae
|
import comfy.ldm.hunyuan3d.vae
|
||||||
|
import comfy.ldm.seedvr.vae
|
||||||
import comfy.ldm.triposplat.vae
|
import comfy.ldm.triposplat.vae
|
||||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||||
import comfy.ldm.cogvideo.vae
|
import comfy.ldm.cogvideo.vae
|
||||||
@ -469,8 +471,10 @@ class CLIP:
|
|||||||
|
|
||||||
class VAE:
|
class VAE:
|
||||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
||||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd
|
||||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||||
|
if metadata is None or metadata.get("keep_diffusers_format") != "true":
|
||||||
|
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||||
|
|
||||||
if model_management.is_amd():
|
if model_management.is_amd():
|
||||||
VAE_KL_MEM_RATIO = 2.73
|
VAE_KL_MEM_RATIO = 2.73
|
||||||
@ -542,6 +546,20 @@ class VAE:
|
|||||||
self.first_stage_model = StageC_coder()
|
self.first_stage_model = StageC_coder()
|
||||||
self.downscale_ratio = 32
|
self.downscale_ratio = 32
|
||||||
self.latent_channels = 16
|
self.latent_channels = 16
|
||||||
|
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
|
||||||
|
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
|
||||||
|
self.latent_channels = 16
|
||||||
|
self.latent_dim = 3
|
||||||
|
self.disable_offload = True
|
||||||
|
self.memory_used_decode = lambda shape, dtype: self.first_stage_model.comfy_memory_used_decode(shape)
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype)
|
||||||
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||||
|
self.downscale_index_formula = (4, 8, 8)
|
||||||
|
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||||
|
self.upscale_index_formula = (4, 8, 8)
|
||||||
|
self.process_input = lambda image: image * 2.0 - 1.0
|
||||||
|
self.crop_input = False
|
||||||
elif "decoder.conv_in.weight" in sd:
|
elif "decoder.conv_in.weight" in sd:
|
||||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||||
@ -1008,6 +1026,10 @@ class VAE:
|
|||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||||
|
|
||||||
|
def _decode_tiled_owned(self, samples, **kwargs):
|
||||||
|
out = self.first_stage_model.decode_tiled(samples.to(self.vae_dtype).to(self.device), **kwargs)
|
||||||
|
return self.process_output(out.to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
|
||||||
|
|
||||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||||
@ -1044,6 +1066,11 @@ class VAE:
|
|||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||||
|
|
||||||
|
def _encode_tiled_owned(self, pixel_samples, **kwargs):
|
||||||
|
x = self.process_input(pixel_samples).to(self.vae_dtype).to(self.device)
|
||||||
|
out = self.first_stage_model.encode_tiled(x, **kwargs)
|
||||||
|
return out.to(device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
|
|
||||||
def decode(self, samples_in, vae_options={}):
|
def decode(self, samples_in, vae_options={}):
|
||||||
self.throw_exception_if_invalid()
|
self.throw_exception_if_invalid()
|
||||||
pixel_samples = None
|
pixel_samples = None
|
||||||
@ -1091,11 +1118,19 @@ class VAE:
|
|||||||
if dims == 1 or self.extra_1d_channel is not None:
|
if dims == 1 or self.extra_1d_channel is not None:
|
||||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||||
elif dims == 2:
|
elif dims == 2:
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
|
||||||
|
tile = 256 // self.spacial_compression_decode()
|
||||||
|
overlap = tile // 4
|
||||||
|
pixel_samples = self._decode_tiled_owned(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
|
||||||
|
else:
|
||||||
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
tile = 256 // self.spacial_compression_decode()
|
tile = 256 // self.spacial_compression_decode()
|
||||||
overlap = tile // 4
|
overlap = tile // 4
|
||||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
|
||||||
|
pixel_samples = self._decode_tiled_owned(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
|
||||||
|
else:
|
||||||
|
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||||
|
|
||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
@ -1114,7 +1149,20 @@ class VAE:
|
|||||||
args["overlap"] = overlap
|
args["overlap"] = overlap
|
||||||
|
|
||||||
with model_management.cuda_device_context(self.device):
|
with model_management.cuda_device_context(self.device):
|
||||||
if dims == 1 or self.extra_1d_channel is not None:
|
if getattr(self.first_stage_model, "comfy_handles_tiling", False) and dims in (2, 3):
|
||||||
|
tiled_args = {}
|
||||||
|
if tile_x is not None:
|
||||||
|
tiled_args["tile_x"] = tile_x
|
||||||
|
if tile_y is not None:
|
||||||
|
tiled_args["tile_y"] = tile_y
|
||||||
|
if overlap is not None:
|
||||||
|
tiled_args["overlap"] = overlap
|
||||||
|
if tile_t is not None:
|
||||||
|
tiled_args["tile_t"] = tile_t
|
||||||
|
if overlap_t is not None:
|
||||||
|
tiled_args["overlap_t"] = overlap_t
|
||||||
|
output = self._decode_tiled_owned(samples, **tiled_args)
|
||||||
|
elif dims == 1 or self.extra_1d_channel is not None:
|
||||||
args.pop("tile_y")
|
args.pop("tile_y")
|
||||||
output = self.decode_tiled_1d(samples, **args)
|
output = self.decode_tiled_1d(samples, **args)
|
||||||
elif dims == 2:
|
elif dims == 2:
|
||||||
@ -1156,6 +1204,8 @@ class VAE:
|
|||||||
else:
|
else:
|
||||||
pixels_in = pixels_in.to(self.device)
|
pixels_in = pixels_in.to(self.device)
|
||||||
out = self.first_stage_model.encode(pixels_in)
|
out = self.first_stage_model.encode(pixels_in)
|
||||||
|
if isinstance(out, tuple):
|
||||||
|
out = out[0]
|
||||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||||
if samples is None:
|
if samples is None:
|
||||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
@ -1175,12 +1225,18 @@ class VAE:
|
|||||||
if self.latent_dim == 3:
|
if self.latent_dim == 3:
|
||||||
tile = 256
|
tile = 256
|
||||||
overlap = tile // 4
|
overlap = tile // 4
|
||||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
|
||||||
|
samples = self._encode_tiled_owned(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap)
|
||||||
|
else:
|
||||||
|
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||||
samples = self.encode_tiled_1d(pixel_samples)
|
samples = self.encode_tiled_1d(pixel_samples)
|
||||||
else:
|
else:
|
||||||
samples = self.encode_tiled_(pixel_samples)
|
samples = self.encode_tiled_(pixel_samples)
|
||||||
|
|
||||||
|
formatter = getattr(self.first_stage_model, "comfy_format_encoded", None)
|
||||||
|
if formatter is not None:
|
||||||
|
samples = formatter(samples)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||||
@ -1188,7 +1244,7 @@ class VAE:
|
|||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
dims = self.latent_dim
|
dims = self.latent_dim
|
||||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
if dims == 3:
|
if dims == 3 and pixel_samples.ndim < 5:
|
||||||
if not self.not_video:
|
if not self.not_video:
|
||||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||||
else:
|
else:
|
||||||
@ -1212,21 +1268,39 @@ class VAE:
|
|||||||
elif dims == 2:
|
elif dims == 2:
|
||||||
samples = self.encode_tiled_(pixel_samples, **args)
|
samples = self.encode_tiled_(pixel_samples, **args)
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
if tile_t is not None:
|
if getattr(self.first_stage_model, "comfy_handles_tiling", False):
|
||||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
tiled_args = {}
|
||||||
|
if tile_x is not None:
|
||||||
|
tiled_args["tile_x"] = tile_x
|
||||||
|
if tile_y is not None:
|
||||||
|
tiled_args["tile_y"] = tile_y
|
||||||
|
if overlap is not None:
|
||||||
|
tiled_args["overlap"] = overlap
|
||||||
|
if tile_t is not None:
|
||||||
|
tiled_args["tile_t"] = tile_t
|
||||||
|
if overlap_t is not None:
|
||||||
|
tiled_args["overlap_t"] = overlap_t
|
||||||
|
samples = self._encode_tiled_owned(pixel_samples, **tiled_args)
|
||||||
else:
|
else:
|
||||||
tile_t_latent = 9999
|
if tile_t is not None:
|
||||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||||
|
else:
|
||||||
|
tile_t_latent = 9999
|
||||||
|
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||||
|
|
||||||
if overlap_t is None:
|
spatial_overlap = overlap if overlap is not None else 64
|
||||||
args["overlap"] = (1, overlap, overlap)
|
if overlap_t is None:
|
||||||
else:
|
args["overlap"] = (1, spatial_overlap, spatial_overlap)
|
||||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
else:
|
||||||
maximum = pixel_samples.shape[2]
|
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), spatial_overlap, spatial_overlap)
|
||||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
maximum = pixel_samples.shape[2]
|
||||||
|
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||||
|
|
||||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||||
|
|
||||||
|
formatter = getattr(self.first_stage_model, "comfy_format_encoded", None)
|
||||||
|
if formatter is not None:
|
||||||
|
samples = formatter(samples)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
@ -1777,6 +1851,17 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
|
|
||||||
return (model, clip, vae)
|
return (model, clip, vae)
|
||||||
|
|
||||||
|
|
||||||
|
def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device):
|
||||||
|
set_dtype = model_config.set_inference_dtype
|
||||||
|
parameters = inspect.signature(set_dtype).parameters
|
||||||
|
supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values())
|
||||||
|
if supports_device:
|
||||||
|
set_dtype(dtype, manual_cast_dtype, device=device)
|
||||||
|
else:
|
||||||
|
set_dtype(dtype, manual_cast_dtype)
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||||
@ -1884,7 +1969,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
|||||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||||
else:
|
else:
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
|
||||||
|
|
||||||
if model_config.clip_vision_prefix is not None:
|
if model_config.clip_vision_prefix is not None:
|
||||||
if output_clipvision:
|
if output_clipvision:
|
||||||
@ -2025,7 +2110,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
|||||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||||
else:
|
else:
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
|
||||||
|
|
||||||
if custom_operations is not None:
|
if custom_operations is not None:
|
||||||
model_config.custom_operations = custom_operations
|
model_config.custom_operations = custom_operations
|
||||||
|
|||||||
@ -1684,6 +1684,35 @@ class Chroma(supported_models_base.BASE):
|
|||||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||||
|
|
||||||
|
class SeedVR2(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "seedvr2"
|
||||||
|
}
|
||||||
|
latent_format = comfy.latent_formats.SeedVR2
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 1.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
|
||||||
|
if (
|
||||||
|
dtype == torch.float16
|
||||||
|
and manual_cast_dtype is None
|
||||||
|
and comfy.model_management.should_use_bf16(device)
|
||||||
|
):
|
||||||
|
manual_cast_dtype = torch.bfloat16
|
||||||
|
super().set_inference_dtype(dtype, manual_cast_dtype, device=device)
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.SeedVR2(self, device=device)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
return None
|
||||||
|
|
||||||
class ChromaRadiance(Chroma):
|
class ChromaRadiance(Chroma):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "chroma_radiance",
|
"image_model": "chroma_radiance",
|
||||||
@ -2318,6 +2347,7 @@ models = [
|
|||||||
HiDream,
|
HiDream,
|
||||||
HiDreamO1,
|
HiDreamO1,
|
||||||
Chroma,
|
Chroma,
|
||||||
|
SeedVR2,
|
||||||
ChromaRadiance,
|
ChromaRadiance,
|
||||||
ACEStep,
|
ACEStep,
|
||||||
ACEStep15,
|
ACEStep15,
|
||||||
|
|||||||
@ -115,7 +115,7 @@ class BASE:
|
|||||||
replace_prefix = {"": self.vae_key_prefix[0]}
|
replace_prefix = {"": self.vae_key_prefix[0]}
|
||||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||||
|
|
||||||
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
|
||||||
self.unet_config['dtype'] = dtype
|
self.unet_config['dtype'] = dtype
|
||||||
self.manual_cast_dtype = manual_cast_dtype
|
self.manual_cast_dtype = manual_cast_dtype
|
||||||
|
|
||||||
|
|||||||
997
comfy_extras/nodes_seedvr.py
Normal file
997
comfy_extras/nodes_seedvr.py
Normal file
@ -0,0 +1,997 @@
|
|||||||
|
from typing_extensions import override
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import logging
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.sample
|
||||||
|
import comfy.samplers
|
||||||
|
from comfy.ldm.seedvr.color_fix import (
|
||||||
|
adain_color_transfer,
|
||||||
|
lab_color_transfer,
|
||||||
|
wavelet_color_transfer,
|
||||||
|
)
|
||||||
|
from comfy.ldm.seedvr.constants import (
|
||||||
|
SEEDVR2_ADAIN_SCALE_MULTIPLIER,
|
||||||
|
SEEDVR2_CHUNK_FRAMES_PER_GB,
|
||||||
|
SEEDVR2_CHUNK_GB_MARGIN,
|
||||||
|
SEEDVR2_COLOR_MEM_HEADROOM,
|
||||||
|
SEEDVR2_COND_CHANNELS,
|
||||||
|
SEEDVR2_DTYPE_BYTES_FLOOR,
|
||||||
|
SEEDVR2_LAB_SCALE_MULTIPLIER,
|
||||||
|
SEEDVR2_LATENT_CHANNELS,
|
||||||
|
SEEDVR2_OOM_BACKOFF_DIVISOR,
|
||||||
|
SEEDVR2_WAVELET_SCALE_MULTIPLIER,
|
||||||
|
)
|
||||||
|
|
||||||
|
from torchvision.transforms import functional as TVF
|
||||||
|
from torchvision.transforms import Lambda
|
||||||
|
from torchvision.transforms.functional import InterpolationMode
|
||||||
|
|
||||||
|
|
||||||
|
_SEEDVR2_INVALID_MODEL_MSG_PREFIX = (
|
||||||
|
"SeedVR2Conditioning: model object does not match expected SeedVR2 structure"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Private sentinel for getattr default: distinguishes "attribute missing"
|
||||||
|
# from "attribute present but None" so the failure message is accurate.
|
||||||
|
_ATTR_MISSING = object()
|
||||||
|
|
||||||
|
|
||||||
|
def _seedvr2_vram_seed_frames_per_chunk(free_bytes, t_pixel):
|
||||||
|
"""Predict the largest 4n+1 pixel-frame chunk that fits in free_bytes."""
|
||||||
|
free_gb = free_bytes / (1024 ** 3)
|
||||||
|
predicted = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_gb - SEEDVR2_CHUNK_GB_MARGIN)
|
||||||
|
# round (not floor) to 4n+1: the fit's central prediction lands on measured n_max
|
||||||
|
n = round((predicted - 1) / 4)
|
||||||
|
seed = 4 * int(n) + 1
|
||||||
|
seed = max(1, min(seed, t_pixel))
|
||||||
|
return seed
|
||||||
|
|
||||||
|
|
||||||
|
def _seedvr2_auto_chunk_attempts(t_latent, t_pixel, frames_per_chunk):
|
||||||
|
"""Return stricter 4n+1 frame chunk sizes for auto OOM retries."""
|
||||||
|
attempts = [frames_per_chunk]
|
||||||
|
current_chunk_latent = (
|
||||||
|
t_latent if t_pixel <= frames_per_chunk
|
||||||
|
else (frames_per_chunk - 1) // 4 + 1
|
||||||
|
)
|
||||||
|
current_chunk_count = max(1, math.ceil(t_latent / current_chunk_latent))
|
||||||
|
seen = {frames_per_chunk}
|
||||||
|
|
||||||
|
for target_chunks in range(max(2, current_chunk_count + 1), t_latent + 1):
|
||||||
|
chunk_latent = max(1, math.ceil(t_latent / target_chunks))
|
||||||
|
candidate = 4 * (chunk_latent - 1) + 1
|
||||||
|
if candidate in seen:
|
||||||
|
continue
|
||||||
|
if candidate >= attempts[-1]:
|
||||||
|
continue
|
||||||
|
attempts.append(candidate)
|
||||||
|
seen.add(candidate)
|
||||||
|
|
||||||
|
return attempts
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_seedvr2_diffusion_model(model):
|
||||||
|
"""Resolve ``model.model.diffusion_model``, failing loud via the ``_ATTR_MISSING`` sentinel so each of the four modes (model/diffusion_model missing vs None) gives an accurate message."""
|
||||||
|
inner = getattr(model, "model", _ATTR_MISSING)
|
||||||
|
if inner is _ATTR_MISSING:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input has no 'model' attribute "
|
||||||
|
f"(got type {type(model).__name__})."
|
||||||
|
)
|
||||||
|
if inner is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input.model is None "
|
||||||
|
f"(input type {type(model).__name__})."
|
||||||
|
)
|
||||||
|
diffusion_model = getattr(inner, "diffusion_model", _ATTR_MISSING)
|
||||||
|
if diffusion_model is _ATTR_MISSING:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model' has no "
|
||||||
|
f"'diffusion_model' attribute (got type {type(inner).__name__})."
|
||||||
|
)
|
||||||
|
if diffusion_model is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model.diffusion_model' "
|
||||||
|
f"is None (model.model type {type(inner).__name__})."
|
||||||
|
)
|
||||||
|
return diffusion_model
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_rope_freqs_float32_cast(diffusion_model):
|
||||||
|
"""Cast every module's ``rope.freqs`` to float32; the per-tensor dtype check (not a sentinel attr) self-corrects across Comfy's unload/reload, which would otherwise restore the archived fp16/bf16 dtype."""
|
||||||
|
for module in diffusion_model.modules():
|
||||||
|
if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'):
|
||||||
|
if module.rope.freqs.data.dtype != torch.float32:
|
||||||
|
module.rope.freqs.data = module.rope.freqs.data.to(torch.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def get_conditions(latent, latent_blur):
|
||||||
|
t, h, w, c = latent.shape
|
||||||
|
cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
|
||||||
|
cond[:, ..., :-1] = latent_blur[:]
|
||||||
|
cond[:, ..., -1:] = 1.0
|
||||||
|
return cond
|
||||||
|
|
||||||
|
def div_pad(image, factor):
|
||||||
|
|
||||||
|
height_factor, width_factor = factor
|
||||||
|
height, width = image.shape[-2:]
|
||||||
|
|
||||||
|
pad_height = (height_factor - (height % height_factor)) % height_factor
|
||||||
|
pad_width = (width_factor - (width % width_factor)) % width_factor
|
||||||
|
|
||||||
|
if pad_height == 0 and pad_width == 0:
|
||||||
|
return image
|
||||||
|
|
||||||
|
if isinstance(image, torch.Tensor):
|
||||||
|
padding = (0, pad_width, 0, pad_height)
|
||||||
|
image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
def cut_videos(videos):
|
||||||
|
t = videos.size(1)
|
||||||
|
if t == 1:
|
||||||
|
return videos
|
||||||
|
if t <= 4 :
|
||||||
|
padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1)
|
||||||
|
padding = torch.cat(padding, dim=1)
|
||||||
|
videos = torch.cat([videos, padding], dim=1)
|
||||||
|
return videos
|
||||||
|
if (t - 1) % (4) == 0:
|
||||||
|
return videos
|
||||||
|
else:
|
||||||
|
padding = [videos[:, -1].unsqueeze(1)] * (
|
||||||
|
4 - ((t - 1) % (4))
|
||||||
|
)
|
||||||
|
padding = torch.cat(padding, dim=1)
|
||||||
|
videos = torch.cat([videos, padding], dim=1)
|
||||||
|
assert (videos.size(1) - 1) % (4) == 0
|
||||||
|
return videos
|
||||||
|
|
||||||
|
def _seedvr2_input_shorter_edge(images, node_name):
|
||||||
|
if images.dim() == 4:
|
||||||
|
return min(images.shape[1], images.shape[2])
|
||||||
|
if images.dim() == 5:
|
||||||
|
return min(images.shape[2], images.shape[3])
|
||||||
|
raise ValueError(
|
||||||
|
f"{node_name}: expected 4-D or 5-D IMAGE tensor, "
|
||||||
|
f"got shape {tuple(images.shape)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _seedvr2_pad(images, upscaled_shorter_edge, node_name):
|
||||||
|
if upscaled_shorter_edge < 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"{node_name}: input shorter edge must be at least 2 pixels; "
|
||||||
|
f"got {upscaled_shorter_edge}."
|
||||||
|
)
|
||||||
|
if images.shape[-1] > 3:
|
||||||
|
images = images[..., :3]
|
||||||
|
if images.dim() == 4:
|
||||||
|
# Comfy video components arrive as a 4-D IMAGE frame sequence:
|
||||||
|
# (frames, H, W, C). SeedVR2 consumes that as one video.
|
||||||
|
images = images.unsqueeze(0)
|
||||||
|
elif images.dim() != 5:
|
||||||
|
raise ValueError(
|
||||||
|
f"{node_name}: expected 4-D or 5-D IMAGE tensor, "
|
||||||
|
f"got shape {tuple(images.shape)}"
|
||||||
|
)
|
||||||
|
images = images.permute(0, 1, 4, 2, 3)
|
||||||
|
|
||||||
|
b, t, c, h, w = images.shape
|
||||||
|
images = images.reshape(b * t, c, h, w)
|
||||||
|
|
||||||
|
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
|
||||||
|
images = clip(images)
|
||||||
|
images = div_pad(images, (16, 16))
|
||||||
|
_, _, new_h, new_w = images.shape
|
||||||
|
|
||||||
|
images = images.reshape(b, t, c, new_h, new_w)
|
||||||
|
images = cut_videos(images)
|
||||||
|
images_bthwc = rearrange(images, "b t c h w -> b t h w c")
|
||||||
|
|
||||||
|
return io.NodeOutput(images_bthwc)
|
||||||
|
|
||||||
|
|
||||||
|
class SeedVR2Preprocess(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SeedVR2Preprocess",
|
||||||
|
display_name="Pre-Process SeedVR2 Input",
|
||||||
|
category="image/upscaling",
|
||||||
|
description="Pad a resized image for SeedVR2 model. Alpha channel is dropped. The node Post-Process SeedVR2 Output re-applies it from the original resized image.",
|
||||||
|
search_aliases=["seedvr2", "upscale", "video upscale", "pad", "preprocess"],
|
||||||
|
inputs=[
|
||||||
|
io.Image.Input("resized_images", tooltip="The resized image to process."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Image.Output("images", tooltip="The padded image for VAE encoding."),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, resized_images):
|
||||||
|
upscaled_shorter_edge = _seedvr2_input_shorter_edge(resized_images, "SeedVR2Preprocess")
|
||||||
|
return _seedvr2_pad(
|
||||||
|
resized_images, upscaled_shorter_edge, "SeedVR2Preprocess",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SeedVR2PostProcessing(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SeedVR2PostProcessing",
|
||||||
|
display_name="Post-Process SeedVR2 Output",
|
||||||
|
category="image/upscaling",
|
||||||
|
description="Align the generated image with the original resized image and apply color correction.",
|
||||||
|
search_aliases=["seedvr2", "upscale", "color correction", "color match", "postprocess"],
|
||||||
|
inputs=[
|
||||||
|
io.Image.Input("images", tooltip="The generated image to process."),
|
||||||
|
io.Image.Input("original_resized_images", tooltip="The original resized image before pre-processing, used as reference."),
|
||||||
|
io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab", tooltip="Method to match the generated image colors to the original image. lab: transfer color in CIELAB space, preserving detail (most faithful). wavelet: transfer low-frequency color, keeping upscaled high-frequency detail. adain: match per-channel mean/std (fastest, global tint). none: skip color transfer (geometry alignment only)."),
|
||||||
|
],
|
||||||
|
outputs=[io.Image.Output(display_name="images", tooltip="The aligned, color-corrected image.")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, images, original_resized_images, color_correction_method):
|
||||||
|
alpha_input = None
|
||||||
|
if original_resized_images.shape[-1] == 4:
|
||||||
|
alpha_input = original_resized_images[..., 3:4]
|
||||||
|
original_resized_images = original_resized_images[..., :3]
|
||||||
|
decoded_5d, decoded_was_4d = cls._as_bthwc(images)
|
||||||
|
reference_full, _ = cls._as_bthwc(original_resized_images)
|
||||||
|
decoded_5d = cls._restore_reference_batch_time(decoded_5d, reference_full)
|
||||||
|
|
||||||
|
b = min(decoded_5d.shape[0], reference_full.shape[0])
|
||||||
|
t = min(decoded_5d.shape[1], reference_full.shape[1])
|
||||||
|
reference_h = reference_full.shape[2]
|
||||||
|
reference_w = reference_full.shape[3]
|
||||||
|
|
||||||
|
decoded_5d = decoded_5d[:b, :t, :, :, :]
|
||||||
|
target_h = min(decoded_5d.shape[2], reference_h)
|
||||||
|
target_w = min(decoded_5d.shape[3], reference_w)
|
||||||
|
decoded_5d = decoded_5d[:, :, :target_h, :target_w, :]
|
||||||
|
if color_correction_method in ("lab", "wavelet", "adain"):
|
||||||
|
reference_5d = reference_full[:b, :t, :, :, :]
|
||||||
|
reference_5d = cls._resize_reference(reference_5d, target_h, target_w)
|
||||||
|
output_device = decoded_5d.device
|
||||||
|
decoded_raw = cls._to_seedvr2_raw(decoded_5d)
|
||||||
|
reference_raw = cls._to_seedvr2_raw(reference_5d)
|
||||||
|
decoded_flat = rearrange(decoded_raw, "b t h w c -> (b t) c h w")
|
||||||
|
reference_flat = rearrange(reference_raw, "b t h w c -> (b t) c h w")
|
||||||
|
output = cls._color_transfer_chunked(
|
||||||
|
decoded_flat, reference_flat, output_device, color_correction_method,
|
||||||
|
)
|
||||||
|
output = rearrange(output, "(b t) c h w -> b t h w c", b=b, t=t)
|
||||||
|
output = output.add(1.0).div(2.0).clamp(0.0, 1.0)
|
||||||
|
elif color_correction_method == "none":
|
||||||
|
output = decoded_5d
|
||||||
|
else:
|
||||||
|
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
|
||||||
|
|
||||||
|
if alpha_input is not None:
|
||||||
|
alpha_5d, _ = cls._as_bthwc(alpha_input)
|
||||||
|
alpha_5d = alpha_5d[:output.shape[0], :output.shape[1], :output.shape[2], :output.shape[3], :]
|
||||||
|
output = torch.cat([output, alpha_5d.to(dtype=output.dtype, device=output.device)], dim=-1)
|
||||||
|
h2 = output.shape[-3] - (output.shape[-3] % 2)
|
||||||
|
w2 = output.shape[-2] - (output.shape[-2] % 2)
|
||||||
|
output = output[:, :, :h2, :w2, :]
|
||||||
|
if decoded_was_4d:
|
||||||
|
output = output.reshape(-1, output.shape[-3], output.shape[-2], output.shape[-1])
|
||||||
|
return io.NodeOutput(output)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _as_bthwc(images):
|
||||||
|
if images.ndim == 4:
|
||||||
|
return images.unsqueeze(0), True
|
||||||
|
if images.ndim == 5:
|
||||||
|
return images, False
|
||||||
|
raise ValueError(
|
||||||
|
f"SeedVR2PostProcessing: expected 4-D or 5-D IMAGE tensor, got shape {tuple(images.shape)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _restore_reference_batch_time(decoded, reference):
|
||||||
|
if decoded.shape[0] != 1:
|
||||||
|
return decoded
|
||||||
|
ref_b, ref_t = reference.shape[:2]
|
||||||
|
if ref_b < 1 or decoded.shape[1] % ref_b != 0:
|
||||||
|
return decoded
|
||||||
|
decoded_t = decoded.shape[1] // ref_b
|
||||||
|
if decoded_t < ref_t:
|
||||||
|
return decoded
|
||||||
|
return decoded.reshape(ref_b, decoded_t, decoded.shape[2], decoded.shape[3], decoded.shape[4])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_seedvr2_raw(images):
|
||||||
|
return images.mul(2.0).sub(1.0)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _color_transfer_on_vae_device(decoded_flat, reference_flat, output_device, transfer_fn):
|
||||||
|
color_device = comfy.model_management.vae_device()
|
||||||
|
decoded_flat = decoded_flat.to(device=color_device)
|
||||||
|
reference_flat = reference_flat.to(device=color_device)
|
||||||
|
output = transfer_fn(decoded_flat, reference_flat)
|
||||||
|
return output.to(device=output_device)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _lab_color_transfer_on_vae_device(decoded_flat, reference_flat, output_device):
|
||||||
|
color_device = comfy.model_management.vae_device()
|
||||||
|
result = None
|
||||||
|
for start in range(decoded_flat.shape[0]):
|
||||||
|
decoded_frame = decoded_flat[start:start + 1].to(device=color_device).clone()
|
||||||
|
reference_frame = reference_flat[start:start + 1].to(device=color_device).clone()
|
||||||
|
output = lab_color_transfer(decoded_frame, reference_frame).to(device=output_device)
|
||||||
|
if result is None:
|
||||||
|
result = torch.empty(
|
||||||
|
(decoded_flat.shape[0],) + tuple(output.shape[1:]),
|
||||||
|
device=output_device,
|
||||||
|
dtype=output.dtype,
|
||||||
|
)
|
||||||
|
result[start:start + 1].copy_(output)
|
||||||
|
if result is None:
|
||||||
|
raise ValueError("SeedVR2PostProcessing: LAB color correction requires at least one frame.")
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _color_transfer_chunked(cls, decoded_flat, reference_flat, output_device, color_correction_method):
|
||||||
|
chunk_size = cls._estimate_color_correction_chunk_size(decoded_flat, color_correction_method)
|
||||||
|
while True:
|
||||||
|
next_chunk_size = None
|
||||||
|
try:
|
||||||
|
return cls._run_color_transfer_chunks(
|
||||||
|
decoded_flat, reference_flat, output_device, color_correction_method, chunk_size,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
comfy.model_management.raise_non_oom(e)
|
||||||
|
if chunk_size <= 1:
|
||||||
|
raise RuntimeError(
|
||||||
|
"SeedVR2PostProcessing: color correction OOM at one frame; "
|
||||||
|
f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}."
|
||||||
|
) from e
|
||||||
|
next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR)
|
||||||
|
|
||||||
|
comfy.model_management.soft_empty_cache()
|
||||||
|
chunk_size = next_chunk_size
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _run_color_transfer_chunks(cls, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size):
|
||||||
|
result = None
|
||||||
|
for start in range(0, decoded_flat.shape[0], chunk_size):
|
||||||
|
end = min(start + chunk_size, decoded_flat.shape[0])
|
||||||
|
decoded_chunk = decoded_flat[start:end]
|
||||||
|
reference_chunk = reference_flat[start:end]
|
||||||
|
if color_correction_method == "lab":
|
||||||
|
output = cls._lab_color_transfer_on_vae_device(decoded_chunk, reference_chunk, output_device)
|
||||||
|
elif color_correction_method == "wavelet":
|
||||||
|
output = cls._color_transfer_on_vae_device(
|
||||||
|
decoded_chunk, reference_chunk, output_device, wavelet_color_transfer,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
output = cls._color_transfer_on_vae_device(
|
||||||
|
decoded_chunk, reference_chunk, output_device, adain_color_transfer,
|
||||||
|
)
|
||||||
|
if result is None:
|
||||||
|
result = torch.empty(
|
||||||
|
(decoded_flat.shape[0],) + tuple(output.shape[1:]),
|
||||||
|
device=output_device,
|
||||||
|
dtype=output.dtype,
|
||||||
|
)
|
||||||
|
result[start:end].copy_(output)
|
||||||
|
if result is None:
|
||||||
|
raise ValueError("SeedVR2PostProcessing: color correction requires at least one frame.")
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _estimate_color_correction_chunk_size(cls, decoded_flat, color_correction_method):
|
||||||
|
multiplier = cls._color_correction_memory_multiplier(color_correction_method)
|
||||||
|
frames = decoded_flat.shape[0]
|
||||||
|
_, channels, height, width = decoded_flat.shape
|
||||||
|
dtype_bytes = max(decoded_flat.element_size(), SEEDVR2_DTYPE_BYTES_FLOOR)
|
||||||
|
bytes_per_frame = height * width * channels * dtype_bytes * multiplier
|
||||||
|
if bytes_per_frame <= 0:
|
||||||
|
return frames
|
||||||
|
color_device = comfy.model_management.vae_device()
|
||||||
|
free_memory = comfy.model_management.get_free_memory(color_device)
|
||||||
|
chunk_size = int((free_memory * SEEDVR2_COLOR_MEM_HEADROOM) // bytes_per_frame)
|
||||||
|
return max(1, min(frames, chunk_size))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _color_correction_memory_multiplier(color_correction_method):
|
||||||
|
if color_correction_method == "lab":
|
||||||
|
return SEEDVR2_LAB_SCALE_MULTIPLIER
|
||||||
|
if color_correction_method == "wavelet":
|
||||||
|
return SEEDVR2_WAVELET_SCALE_MULTIPLIER
|
||||||
|
if color_correction_method == "adain":
|
||||||
|
return SEEDVR2_ADAIN_SCALE_MULTIPLIER
|
||||||
|
raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resize_reference(reference, height, width):
|
||||||
|
if reference.shape[2] == height and reference.shape[3] == width:
|
||||||
|
return reference
|
||||||
|
b, t = reference.shape[:2]
|
||||||
|
reference_flat = rearrange(reference, "b t h w c -> (b t) c h w")
|
||||||
|
resized = TVF.resize(
|
||||||
|
reference_flat,
|
||||||
|
size=(height, width),
|
||||||
|
interpolation=InterpolationMode.BICUBIC,
|
||||||
|
antialias=not (isinstance(reference_flat, torch.Tensor) and reference_flat.device.type == "mps"),
|
||||||
|
)
|
||||||
|
return rearrange(resized, "(b t) c h w -> b t h w c", b=b, t=t)
|
||||||
|
|
||||||
|
|
||||||
|
class SeedVR2Conditioning(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SeedVR2Conditioning",
|
||||||
|
display_name="Apply SeedVR2 Conditioning",
|
||||||
|
category="conditioning",
|
||||||
|
description="Build SeedVR2 positive/negative conditioning from a VAE latent.",
|
||||||
|
search_aliases=["seedvr2", "upscale", "conditioning"],
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model", tooltip="The SeedVR2 model."),
|
||||||
|
io.Latent.Input("vae_conditioning", display_name="latent"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(display_name="model", tooltip="The SeedVR2 model, passed through."),
|
||||||
|
io.Conditioning.Output(display_name="positive", tooltip="The positive conditioning for sampling."),
|
||||||
|
io.Conditioning.Output(display_name="negative", tooltip="The negative conditioning for sampling."),
|
||||||
|
io.Latent.Output(display_name="latent", tooltip="The latent to denoise."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, vae_conditioning) -> io.NodeOutput:
|
||||||
|
|
||||||
|
vae_conditioning = vae_conditioning["samples"]
|
||||||
|
if vae_conditioning.ndim != 5:
|
||||||
|
raise ValueError(
|
||||||
|
"SeedVR2Conditioning expects a 5-D VAE latent in Comfy "
|
||||||
|
f"channel-first layout; got shape {tuple(vae_conditioning.shape)}."
|
||||||
|
)
|
||||||
|
if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS:
|
||||||
|
raise ValueError(
|
||||||
|
"SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy "
|
||||||
|
f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); "
|
||||||
|
f"got channel-last shape {tuple(vae_conditioning.shape)}."
|
||||||
|
)
|
||||||
|
vae_conditioning = vae_conditioning.movedim(1, -1).contiguous()
|
||||||
|
model_patcher = model
|
||||||
|
model = _resolve_seedvr2_diffusion_model(model_patcher)
|
||||||
|
pos_cond = model.positive_conditioning
|
||||||
|
neg_cond = model.negative_conditioning
|
||||||
|
|
||||||
|
# Fail-loud guard against silently-wrong output when a
|
||||||
|
# DiT-only ``.safetensors`` (no ``positive_conditioning`` /
|
||||||
|
# ``negative_conditioning`` keys) is loaded via ``UNETLoader``.
|
||||||
|
# ``NaDiT.__init__`` zero-fills the buffers via ``torch.zeros`` (see
|
||||||
|
# ``comfy/ldm/seedvr/model.py``); ``load_state_dict(strict=False)``
|
||||||
|
# leaves them at zero when the keys are absent. Detect that state
|
||||||
|
# here rather than at ``BaseModel.extra_conds`` (per sampling step,
|
||||||
|
# wasteful) or at the resolver helper (mixes structural shape with
|
||||||
|
# semantic content). Both buffers must be checked together — partial
|
||||||
|
# bake regressions could populate one but not the other.
|
||||||
|
if (
|
||||||
|
pos_cond.float().abs().sum().item() == 0
|
||||||
|
and neg_cond.float().abs().sum().item() == 0
|
||||||
|
):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: positive_conditioning "
|
||||||
|
f"and negative_conditioning buffers are zero-valued — model "
|
||||||
|
f"file appears to be a DiT-only export missing "
|
||||||
|
f"the SeedVR2 conditioning tensors. "
|
||||||
|
f"Re-bake the file with ``positive_conditioning`` (58, 5120) "
|
||||||
|
f"and ``negative_conditioning`` (64, 5120) keys at top level, "
|
||||||
|
f"or load via CheckpointLoaderSimple from a bundled "
|
||||||
|
f"checkpoint."
|
||||||
|
)
|
||||||
|
|
||||||
|
_apply_rope_freqs_float32_cast(model)
|
||||||
|
|
||||||
|
condition = torch.stack([get_conditions(c, c) for c in vae_conditioning])
|
||||||
|
condition = condition.movedim(-1, 1)
|
||||||
|
latent = vae_conditioning.movedim(-1, 1)
|
||||||
|
|
||||||
|
latent = rearrange(latent, "b c t h w -> b (c t) h w")
|
||||||
|
condition = rearrange(condition, "b c t h w -> b (c t) h w")
|
||||||
|
|
||||||
|
negative = [[neg_cond.unsqueeze(0), {"condition": condition}]]
|
||||||
|
positive = [[pos_cond.unsqueeze(0), {"condition": condition}]]
|
||||||
|
|
||||||
|
return io.NodeOutput(model_patcher, positive, negative, {"samples": latent})
|
||||||
|
|
||||||
|
def _slice_collapsed_4d_along_t(tensor_4d: torch.Tensor, t_start: int,
|
||||||
|
t_end: int, channels: int) -> torch.Tensor:
|
||||||
|
"""Slice collapsed ``(B, channels*T, H, W)`` along latent T: reshape (accepts non-contiguous inputs), slice, ``.contiguous()`` (T-slice of 5D is a non-contiguous view; re-collapse needs contiguous), re-collapse."""
|
||||||
|
B, CT, H, W = tensor_4d.shape
|
||||||
|
if CT % channels != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"_slice_collapsed_4d_along_t: collapsed channel dim {CT} is not "
|
||||||
|
f"divisible by channels={channels}; tensor shape {tuple(tensor_4d.shape)}."
|
||||||
|
)
|
||||||
|
T = CT // channels
|
||||||
|
if not (0 <= t_start < t_end <= T):
|
||||||
|
raise ValueError(
|
||||||
|
f"_slice_collapsed_4d_along_t: slice [{t_start}:{t_end}] out of "
|
||||||
|
f"range for T={T}."
|
||||||
|
)
|
||||||
|
new_T = t_end - t_start
|
||||||
|
sliced = tensor_4d.reshape(B, channels, T, H, W)[:, :, t_start:t_end, :, :].contiguous()
|
||||||
|
return sliced.reshape(B, channels * new_T, H, W)
|
||||||
|
|
||||||
|
|
||||||
|
def _slice_seedvr2_cond_along_t(cond_list, t_start: int, t_end: int):
|
||||||
|
"""Return a new conditioning list with each entry's ``options["condition"]`` (collapsed ``(B, 17*T, H, W)``) sliced along latent T; text tensors, other option keys, and condition-less entries pass through unchanged and inputs are not mutated."""
|
||||||
|
new_list = []
|
||||||
|
for entry in cond_list:
|
||||||
|
text_cond, options = entry[0], entry[1]
|
||||||
|
if "condition" not in options:
|
||||||
|
new_list.append(entry)
|
||||||
|
continue
|
||||||
|
new_options = options.copy()
|
||||||
|
new_options["condition"] = _slice_collapsed_4d_along_t(
|
||||||
|
new_options["condition"], t_start, t_end,
|
||||||
|
SEEDVR2_COND_CHANNELS,
|
||||||
|
)
|
||||||
|
new_list.append([text_cond, new_options])
|
||||||
|
return new_list
|
||||||
|
|
||||||
|
|
||||||
|
def _slice_seedvr2_noise_mask_along_t(noise_mask: torch.Tensor,
|
||||||
|
samples_4d: torch.Tensor,
|
||||||
|
t_start: int,
|
||||||
|
t_end: int):
|
||||||
|
"""Slice only masks already expanded to collapsed ``(B, 16*T, H, W)``; pass standard ``(B, 1, H, W)`` ``SetLatentNoiseMask`` outputs through for KSampler to expand."""
|
||||||
|
if noise_mask.ndim == samples_4d.ndim and noise_mask.shape[1] == samples_4d.shape[1]:
|
||||||
|
return _slice_collapsed_4d_along_t(
|
||||||
|
noise_mask, t_start, t_end, SEEDVR2_LATENT_CHANNELS,
|
||||||
|
)
|
||||||
|
return noise_mask
|
||||||
|
|
||||||
|
|
||||||
|
def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor:
|
||||||
|
"""Concatenate collapsed ``(B, channels*T_i, H, W)`` chunks along latent T: un-collapse to 5D, cat on ``dim=2``, re-collapse to 4D."""
|
||||||
|
if len(chunks_4d) == 0:
|
||||||
|
raise ValueError("_concat_chunks_along_t: empty chunk list.")
|
||||||
|
fives = []
|
||||||
|
for ch in chunks_4d:
|
||||||
|
B, CT, H, W = ch.shape
|
||||||
|
if CT % channels != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"_concat_chunks_along_t: chunk shape {tuple(ch.shape)} "
|
||||||
|
f"channel dim {CT} not divisible by channels={channels}."
|
||||||
|
)
|
||||||
|
T = CT // channels
|
||||||
|
fives.append(ch.reshape(B, channels, T, H, W))
|
||||||
|
cat = torch.cat(fives, dim=2).contiguous()
|
||||||
|
B, C, T_total, H, W = cat.shape
|
||||||
|
return cat.reshape(B, C * T_total, H, W)
|
||||||
|
|
||||||
|
|
||||||
|
def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor:
|
||||||
|
"""1D length-``overlap`` crossfade weights for the previous chunk (current = ``1 - w_prev``):
|
||||||
|
Hann window with a ``[1/3, 2/3]`` dead-band for ``overlap >= 3``, linear ramp for ``overlap < 3``
|
||||||
|
(dead-band would collapse a tiny transition). Window shape matched to the reference
|
||||||
|
overlapping-frame blend for parity; caller broadcasts across ``(B, C, T_overlap, H, W)``.
|
||||||
|
"""
|
||||||
|
if overlap < 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"_hann_blend_weights_1d: overlap must be >= 1; got {overlap}."
|
||||||
|
)
|
||||||
|
if overlap >= 3:
|
||||||
|
t = torch.linspace(0.0, 1.0, steps=overlap, device=device, dtype=dtype)
|
||||||
|
blend_start = 1.0 / 3.0
|
||||||
|
blend_end = 2.0 / 3.0
|
||||||
|
u = ((t - blend_start) / (blend_end - blend_start)).clamp(0.0, 1.0)
|
||||||
|
return 0.5 + 0.5 * torch.cos(torch.pi * u)
|
||||||
|
return torch.linspace(1.0, 0.0, steps=overlap, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def _blend_overlap_region(prev_tail_5d: torch.Tensor,
|
||||||
|
cur_head_5d: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Blend two equal-shape 5D ``(B, C, T_overlap, H, W)`` tensors with a 1D Hann/linear T-ramp: ``prev_tail_5d`` takes the descending weight, ``cur_head_5d`` takes ``1 - w_prev`` (caller ensures matching shape/dtype/device)."""
|
||||||
|
if prev_tail_5d.shape != cur_head_5d.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"_blend_overlap_region: shape mismatch "
|
||||||
|
f"prev {tuple(prev_tail_5d.shape)} vs "
|
||||||
|
f"cur {tuple(cur_head_5d.shape)}."
|
||||||
|
)
|
||||||
|
overlap = int(prev_tail_5d.shape[2])
|
||||||
|
w_prev_1d = _hann_blend_weights_1d(
|
||||||
|
overlap, prev_tail_5d.device, prev_tail_5d.dtype,
|
||||||
|
)
|
||||||
|
# Reshape to (1, 1, overlap, 1, 1) for broadcast across B, C, H, W.
|
||||||
|
w_prev = w_prev_1d.view(1, 1, overlap, 1, 1)
|
||||||
|
w_cur = 1.0 - w_prev
|
||||||
|
return prev_tail_5d * w_prev + cur_head_5d * w_cur
|
||||||
|
|
||||||
|
|
||||||
|
def _concat_chunks_with_overlap_blend(chunk_specs, channels: int,
|
||||||
|
overlap_latent: int) -> torch.Tensor:
|
||||||
|
"""Concatenate overlapping ``(t_start, t_end, chunk_4d)`` specs (source-latent T coords) into one collapsed 4D tensor, Hann/linear-blending overlaps; ``overlap_latent == 0`` fast-paths to plain concat (bit-identical to ``_concat_chunks_along_t``). Each blend uses the actual width ``min(prev_end - cur_start, chunk length)``, smaller than ``overlap_latent`` for a runt final chunk."""
|
||||||
|
if len(chunk_specs) == 0:
|
||||||
|
raise ValueError("_concat_chunks_with_overlap_blend: empty chunk list.")
|
||||||
|
if overlap_latent < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"_concat_chunks_with_overlap_blend: overlap_latent must be "
|
||||||
|
f">= 0; got {overlap_latent}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate channel divisibility once and capture per-chunk T.
|
||||||
|
chunk_5d = []
|
||||||
|
for t_start, t_end, ch in chunk_specs:
|
||||||
|
B, CT, H, W = ch.shape
|
||||||
|
if CT % channels != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"_concat_chunks_with_overlap_blend: chunk shape "
|
||||||
|
f"{tuple(ch.shape)} channel dim {CT} not divisible "
|
||||||
|
f"by channels={channels}."
|
||||||
|
)
|
||||||
|
T = CT // channels
|
||||||
|
if t_end - t_start != T:
|
||||||
|
raise ValueError(
|
||||||
|
f"_concat_chunks_with_overlap_blend: chunk T={T} mismatches "
|
||||||
|
f"declared range [{t_start}:{t_end}]."
|
||||||
|
)
|
||||||
|
chunk_5d.append((t_start, t_end, ch.reshape(B, channels, T, H, W)))
|
||||||
|
|
||||||
|
if overlap_latent == 0:
|
||||||
|
# Fast path: pure concat in the caller-provided chunk order.
|
||||||
|
return _concat_chunks_along_t(
|
||||||
|
[c.reshape(c.shape[0], channels * c.shape[2], c.shape[3], c.shape[4])
|
||||||
|
for _, _, c in chunk_5d],
|
||||||
|
channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
T_total = max(t_end for _, t_end, _ in chunk_5d)
|
||||||
|
first_5d = chunk_5d[0][2]
|
||||||
|
B = first_5d.shape[0]
|
||||||
|
H = first_5d.shape[3]
|
||||||
|
W = first_5d.shape[4]
|
||||||
|
result = torch.empty(
|
||||||
|
(B, channels, T_total, H, W),
|
||||||
|
device=first_5d.device, dtype=first_5d.dtype,
|
||||||
|
)
|
||||||
|
filled_until = 0
|
||||||
|
for i, (cs, ce, ct_5d) in enumerate(chunk_5d):
|
||||||
|
chunk_T = int(ct_5d.shape[2])
|
||||||
|
if i == 0:
|
||||||
|
result[:, :, cs:ce, :, :] = ct_5d
|
||||||
|
filled_until = ce
|
||||||
|
continue
|
||||||
|
# Overlap region width is bounded by both the previous fill
|
||||||
|
# frontier and the current chunk's actual length (for runt
|
||||||
|
# final chunks shorter than the configured overlap).
|
||||||
|
overlap_len = min(filled_until - cs, chunk_T)
|
||||||
|
if overlap_len > 0:
|
||||||
|
prev_tail = result[:, :, cs:cs + overlap_len, :, :].contiguous()
|
||||||
|
cur_head = ct_5d[:, :, :overlap_len, :, :].contiguous()
|
||||||
|
blended = _blend_overlap_region(prev_tail, cur_head)
|
||||||
|
result[:, :, cs:cs + overlap_len, :, :] = blended
|
||||||
|
tail_start = cs + overlap_len
|
||||||
|
tail_end = ce
|
||||||
|
if tail_end > tail_start:
|
||||||
|
result[:, :, tail_start:tail_end, :, :] = (
|
||||||
|
ct_5d[:, :, overlap_len:, :, :]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Disjoint chunks (overlap_latent set but this pair did not
|
||||||
|
# actually overlap, e.g. step_latent equal to chunk_latent
|
||||||
|
# in a degenerate config). Treat as concat.
|
||||||
|
result[:, :, cs:ce, :, :] = ct_5d
|
||||||
|
filled_until = ce
|
||||||
|
|
||||||
|
return result.contiguous().reshape(B, channels * T_total, H, W)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_standard_sample(model, seed: int, steps: int, cfg: float,
|
||||||
|
sampler_name: str, scheduler: str,
|
||||||
|
positive, negative, latent: dict,
|
||||||
|
denoise: float) -> dict:
|
||||||
|
"""Single-shot mirror of ``nodes.py:common_ksampler`` (seed -> noise, ``comfy.sample.sample``, latent dict); used by the ProgressiveSampler short-circuit when the whole sequence fits one chunk."""
|
||||||
|
samples_in = latent["samples"]
|
||||||
|
samples_in = comfy.sample.fix_empty_latent_channels(
|
||||||
|
model, samples_in, latent.get("downscale_ratio_spacial", None),
|
||||||
|
)
|
||||||
|
batch_inds = latent.get("batch_index", None)
|
||||||
|
noise = comfy.sample.prepare_noise(samples_in, seed, batch_inds)
|
||||||
|
noise_mask = latent.get("noise_mask", None)
|
||||||
|
samples = comfy.sample.sample(
|
||||||
|
model, noise, steps, cfg, sampler_name, scheduler,
|
||||||
|
positive, negative, samples_in,
|
||||||
|
denoise=denoise, noise_mask=noise_mask, seed=seed,
|
||||||
|
)
|
||||||
|
out = latent.copy()
|
||||||
|
out.pop("downscale_ratio_spacial", None)
|
||||||
|
out["samples"] = samples
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SeedVR2ProgressiveSampler(io.ComfyNode):
|
||||||
|
"""Sequential temporal chunking sampler for SeedVR2 native.
|
||||||
|
|
||||||
|
Drop-in replacement for ``KSampler`` in SeedVR2 native workflows that
|
||||||
|
OOM on long sequences. The latent enters the sampler in SeedVR2's
|
||||||
|
collapsed form ``(B, 16*T, H, W)`` (collapsed by ``SeedVR2Conditioning``
|
||||||
|
at ``rearrange(b c t h w -> b (c t) h w)``); this node slices that
|
||||||
|
tensor along the temporal axis, runs the configured inner sampler
|
||||||
|
sequentially per chunk against the standard ``comfy.sample.sample``
|
||||||
|
entry point, and concatenates per-chunk outputs back into a single
|
||||||
|
``(B, 16*T_total, H, W)`` latent.
|
||||||
|
|
||||||
|
``frames_per_chunk`` is expressed in pixel-frame units to match the
|
||||||
|
SeedVR2 4n+1 constraint enforced upstream by ``cut_videos`` and the
|
||||||
|
VAE's ``temporal_downsample_factor=4``. A pixel chunk size ``F``
|
||||||
|
maps to ``(F - 1) // 4 + 1`` latent-frame chunks.
|
||||||
|
|
||||||
|
Determinism contract: a single noise tensor is generated once from
|
||||||
|
the user seed and sliced per chunk (rather than re-seeding each
|
||||||
|
chunk), so a workflow that fits in a single chunk produces output
|
||||||
|
identical to a workflow that fits in N chunks at the same seed,
|
||||||
|
modulo the inherent T-axis chunk-boundary independence of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="SeedVR2ProgressiveSampler",
|
||||||
|
display_name="Sample SeedVR2 (Progressive)",
|
||||||
|
category="sampling",
|
||||||
|
description="Sample a SeedVR2 latent in sequential temporal chunks to allow longer videos to fit into VRAM via frame blending the resulting upscaled latents.",
|
||||||
|
search_aliases=["seedvr2", "upscale", "video upscale", "sampler", "chunk"],
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model", tooltip="The model used for denoising the input latent."),
|
||||||
|
io.Int.Input("seed", default=0, min=0,
|
||||||
|
max=0xffffffffffffffff,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="The random seed used for creating the noise."),
|
||||||
|
io.Int.Input("steps", default=20, min=1, max=10000,
|
||||||
|
tooltip="The number of steps used in the denoising process."),
|
||||||
|
io.Float.Input("cfg", default=1.0, min=0.0, max=100.0,
|
||||||
|
step=0.1, round=0.01,
|
||||||
|
tooltip="The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."),
|
||||||
|
io.Combo.Input("sampler_name",
|
||||||
|
options=comfy.samplers.SAMPLER_NAMES,
|
||||||
|
tooltip="The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."),
|
||||||
|
io.Combo.Input("scheduler",
|
||||||
|
options=comfy.samplers.SCHEDULER_NAMES,
|
||||||
|
tooltip="The scheduler controls how noise is gradually removed to form the image."),
|
||||||
|
io.Conditioning.Input("positive",
|
||||||
|
tooltip="The conditioning describing the attributes you want to include in the image."),
|
||||||
|
io.Conditioning.Input("negative",
|
||||||
|
tooltip="The conditioning describing the attributes you want to exclude from the image."),
|
||||||
|
io.Latent.Input("latent",
|
||||||
|
tooltip="The latent image to denoise."),
|
||||||
|
io.Float.Input("denoise", default=1.0, min=0.0, max=1.0,
|
||||||
|
step=0.01,
|
||||||
|
tooltip="The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."),
|
||||||
|
io.Int.Input("frames_per_chunk", default=21, min=1,
|
||||||
|
max=16384, step=4,
|
||||||
|
tooltip="Pixel frames per temporal chunk (4n+1: 1, 5, 9, 13, ...)."),
|
||||||
|
io.Int.Input("temporal_overlap", default=0, min=0,
|
||||||
|
max=16384,
|
||||||
|
tooltip="Latent frames blended between adjacent chunks to hide the seam; 0 = no blend."),
|
||||||
|
io.Combo.Input("chunking_mode",
|
||||||
|
options=["manual", "auto"],
|
||||||
|
default="manual",
|
||||||
|
tooltip="manual = use frames_per_chunk exactly; auto = shrink the chunk until it fits in VRAM."),
|
||||||
|
],
|
||||||
|
outputs=[io.Latent.Output(display_name="latent", tooltip="The upscaled latent.")],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, seed, steps, cfg, sampler_name, scheduler,
|
||||||
|
positive, negative, latent, denoise,
|
||||||
|
frames_per_chunk, temporal_overlap,
|
||||||
|
chunking_mode="manual") -> io.NodeOutput:
|
||||||
|
# 4n+1 validation in pixel-frame domain. The SeedVR2 native pipeline
|
||||||
|
# requires pixel-frame counts of the form 4n+1 (1, 5, 9, 13, ...),
|
||||||
|
# imposed at ``cut_videos`` upstream and propagated through the VAE's
|
||||||
|
# temporal_downsample_factor=4. Reject violations explicitly before
|
||||||
|
# any model invocation; a silent rounding would mis-align chunk
|
||||||
|
# boundaries with the 4n+1 lattice.
|
||||||
|
if frames_per_chunk < 1 or (frames_per_chunk - 1) % 4 != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"SeedVR2ProgressiveSampler: frames_per_chunk must be a "
|
||||||
|
f"4n+1 pixel-frame count (1, 5, 9, 13, 17, 21, ...); "
|
||||||
|
f"got {frames_per_chunk}."
|
||||||
|
)
|
||||||
|
|
||||||
|
samples_4d = latent["samples"]
|
||||||
|
if torch.count_nonzero(samples_4d) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"SeedVR2ProgressiveSampler: input latent is empty (all zeros). "
|
||||||
|
"SeedVR2 is an upscaler; connect an encoded latent from "
|
||||||
|
"'Apply SeedVR2 conditioning' rather than an empty latent."
|
||||||
|
)
|
||||||
|
samples_4d = comfy.sample.fix_empty_latent_channels(
|
||||||
|
model, samples_4d,
|
||||||
|
latent.get("downscale_ratio_spacial", None),
|
||||||
|
)
|
||||||
|
if samples_4d.ndim != 4:
|
||||||
|
raise ValueError(
|
||||||
|
f"SeedVR2ProgressiveSampler: expected 4D collapsed latent "
|
||||||
|
f"(B, 16*T, H, W); got shape {tuple(samples_4d.shape)}."
|
||||||
|
)
|
||||||
|
B, CT, H, W = samples_4d.shape
|
||||||
|
if CT % SEEDVR2_LATENT_CHANNELS != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"SeedVR2ProgressiveSampler: collapsed channel dim {CT} is "
|
||||||
|
f"not divisible by SeedVR2 latent channels "
|
||||||
|
f"{SEEDVR2_LATENT_CHANNELS}; latent does not appear to be "
|
||||||
|
f"SeedVR2-shaped."
|
||||||
|
)
|
||||||
|
T_latent = CT // SEEDVR2_LATENT_CHANNELS
|
||||||
|
T_pixel = 4 * (T_latent - 1) + 1
|
||||||
|
|
||||||
|
if chunking_mode not in ("manual", "auto"):
|
||||||
|
raise ValueError(
|
||||||
|
f"SeedVR2ProgressiveSampler: chunking_mode must be "
|
||||||
|
f"'manual' or 'auto'; got {chunking_mode!r}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if chunking_mode == "auto":
|
||||||
|
free_memory = comfy.model_management.get_free_memory(model.load_device)
|
||||||
|
seed_frames_per_chunk = _seedvr2_vram_seed_frames_per_chunk(
|
||||||
|
free_memory, T_pixel,
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"SeedVR2ProgressiveSampler auto: free=%.2fGB -> seeding "
|
||||||
|
"frames_per_chunk=%s (4n+1; T_pixel=%s).",
|
||||||
|
free_memory / (1024 ** 3), seed_frames_per_chunk, T_pixel,
|
||||||
|
)
|
||||||
|
attempts = _seedvr2_auto_chunk_attempts(
|
||||||
|
T_latent, T_pixel, seed_frames_per_chunk,
|
||||||
|
)
|
||||||
|
for i, attempt_frames_per_chunk in enumerate(attempts):
|
||||||
|
retry = False
|
||||||
|
try:
|
||||||
|
return cls.execute(
|
||||||
|
model=model, seed=seed, steps=steps, cfg=cfg,
|
||||||
|
sampler_name=sampler_name, scheduler=scheduler,
|
||||||
|
positive=positive, negative=negative,
|
||||||
|
latent=latent, denoise=denoise,
|
||||||
|
frames_per_chunk=attempt_frames_per_chunk,
|
||||||
|
temporal_overlap=temporal_overlap,
|
||||||
|
chunking_mode="manual",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
comfy.model_management.raise_non_oom(e)
|
||||||
|
if i == len(attempts) - 1:
|
||||||
|
raise RuntimeError(
|
||||||
|
"SeedVR2ProgressiveSampler: exhausted auto "
|
||||||
|
"chunking attempts after OOM. Tried "
|
||||||
|
f"frames_per_chunk values {attempts}."
|
||||||
|
) from e
|
||||||
|
retry = True
|
||||||
|
|
||||||
|
if retry:
|
||||||
|
logging.warning(
|
||||||
|
"SeedVR2ProgressiveSampler auto chunking OOM at "
|
||||||
|
"frames_per_chunk=%s; retrying with "
|
||||||
|
"frames_per_chunk=%s.",
|
||||||
|
attempt_frames_per_chunk, attempts[i + 1],
|
||||||
|
)
|
||||||
|
comfy.model_management.soft_empty_cache()
|
||||||
|
|
||||||
|
# Short-circuit: total fits in one chunk -> standard path with no
|
||||||
|
# chunking overhead. Output of this branch is byte-identical to the
|
||||||
|
# built-in KSampler given the same (model, seed, steps, cfg,
|
||||||
|
# sampler_name, scheduler, positive, negative, latent,
|
||||||
|
# denoise) tuple.
|
||||||
|
if T_pixel <= frames_per_chunk:
|
||||||
|
return io.NodeOutput(_run_standard_sample(
|
||||||
|
model, seed, steps, cfg, sampler_name, scheduler,
|
||||||
|
positive, negative, latent, denoise,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Map pixel chunk -> latent chunk. Each chunk's latent length is
|
||||||
|
# at most ``chunk_latent``; the final chunk may be a runt that
|
||||||
|
# is automatically 4n+1-aligned in the pixel domain by the
|
||||||
|
# T_pixel = 4*(T_latent-1) + 1 mapping (every positive integer
|
||||||
|
# T_latent corresponds to a valid 4n+1 pixel count).
|
||||||
|
chunk_latent = (frames_per_chunk - 1) // 4 + 1
|
||||||
|
|
||||||
|
# ``temporal_overlap`` is exposed in latent-frame units, but users
|
||||||
|
# do not know the derived latent chunk length. Treat oversized
|
||||||
|
# values as "maximum valid overlap" while preserving a strictly
|
||||||
|
# positive chunk-loop stride.
|
||||||
|
if temporal_overlap < 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"SeedVR2ProgressiveSampler: temporal_overlap must be >= 0; "
|
||||||
|
f"got {temporal_overlap}."
|
||||||
|
)
|
||||||
|
temporal_overlap = min(temporal_overlap, chunk_latent - 1)
|
||||||
|
step_latent = chunk_latent - temporal_overlap
|
||||||
|
|
||||||
|
# Generate full noise once from the user seed, then slice along T
|
||||||
|
# per chunk. Using one global noise tensor (rather than re-seeding
|
||||||
|
# per chunk) preserves seed-determinism across chunk-count
|
||||||
|
# variations: the same (seed, total T_latent) always produces the
|
||||||
|
# same noise samples regardless of how the work is partitioned.
|
||||||
|
batch_inds = latent.get("batch_index", None)
|
||||||
|
noise_full = comfy.sample.prepare_noise(samples_4d, seed, batch_inds)
|
||||||
|
|
||||||
|
noise_mask = latent.get("noise_mask", None)
|
||||||
|
|
||||||
|
# Build the flat list of chunk ranges first so the chunking
|
||||||
|
# geometry is fully known before any sample call.
|
||||||
|
chunk_ranges = []
|
||||||
|
for chunk_start in range(0, T_latent, step_latent):
|
||||||
|
chunk_end = min(chunk_start + chunk_latent, T_latent)
|
||||||
|
if chunk_start >= chunk_end:
|
||||||
|
# The final iteration of a stride that lands exactly on
|
||||||
|
# T_latent produces a zero-length chunk; skip it.
|
||||||
|
break
|
||||||
|
chunk_ranges.append((chunk_start, chunk_end))
|
||||||
|
if chunk_end >= T_latent:
|
||||||
|
break
|
||||||
|
|
||||||
|
def _sample_one_chunk(chunk_start, chunk_end):
|
||||||
|
samples_chunk = _slice_collapsed_4d_along_t(
|
||||||
|
samples_4d, chunk_start, chunk_end,
|
||||||
|
SEEDVR2_LATENT_CHANNELS,
|
||||||
|
)
|
||||||
|
noise_chunk = _slice_collapsed_4d_along_t(
|
||||||
|
noise_full, chunk_start, chunk_end,
|
||||||
|
SEEDVR2_LATENT_CHANNELS,
|
||||||
|
)
|
||||||
|
positive_chunk = _slice_seedvr2_cond_along_t(
|
||||||
|
positive, chunk_start, chunk_end,
|
||||||
|
)
|
||||||
|
negative_chunk = _slice_seedvr2_cond_along_t(
|
||||||
|
negative, chunk_start, chunk_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Per-chunk noise_mask handling: standard masks are passed
|
||||||
|
# through for KSampler expansion; pre-expanded collapsed
|
||||||
|
# masks are sliced.
|
||||||
|
chunk_noise_mask = None
|
||||||
|
if noise_mask is not None:
|
||||||
|
chunk_noise_mask = _slice_seedvr2_noise_mask_along_t(
|
||||||
|
noise_mask, samples_4d, chunk_start, chunk_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
return comfy.sample.sample(
|
||||||
|
model, noise_chunk, steps, cfg, sampler_name, scheduler,
|
||||||
|
positive_chunk, negative_chunk, samples_chunk,
|
||||||
|
denoise=denoise, noise_mask=chunk_noise_mask, seed=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_specs = []
|
||||||
|
for chunk_start, chunk_end in chunk_ranges:
|
||||||
|
chunk_samples = _sample_one_chunk(chunk_start, chunk_end)
|
||||||
|
chunk_specs.append((chunk_start, chunk_end, chunk_samples))
|
||||||
|
|
||||||
|
final = _concat_chunks_with_overlap_blend(
|
||||||
|
chunk_specs, SEEDVR2_LATENT_CHANNELS, temporal_overlap,
|
||||||
|
)
|
||||||
|
|
||||||
|
out = latent.copy()
|
||||||
|
out.pop("downscale_ratio_spacial", None)
|
||||||
|
out["samples"] = final
|
||||||
|
return io.NodeOutput(out)
|
||||||
|
|
||||||
|
|
||||||
|
class SeedVRExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
SeedVR2Conditioning,
|
||||||
|
SeedVR2Preprocess,
|
||||||
|
SeedVR2PostProcessing,
|
||||||
|
SeedVR2ProgressiveSampler,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> SeedVRExtension:
|
||||||
|
return SeedVRExtension()
|
||||||
1
nodes.py
1
nodes.py
@ -2430,6 +2430,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_camera_trajectory.py",
|
"nodes_camera_trajectory.py",
|
||||||
"nodes_edit_model.py",
|
"nodes_edit_model.py",
|
||||||
"nodes_tcfg.py",
|
"nodes_tcfg.py",
|
||||||
|
"nodes_seedvr.py",
|
||||||
"nodes_context_windows.py",
|
"nodes_context_windows.py",
|
||||||
"nodes_qwen.py",
|
"nodes_qwen.py",
|
||||||
"nodes_boogu.py",
|
"nodes_boogu.py",
|
||||||
|
|||||||
213
tests-unit/comfy_extras_test/test_seedvr2_conditioning.py
Normal file
213
tests-unit/comfy_extras_test/test_seedvr2_conditioning.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
"""Consolidated SeedVR2 conditioning and refactor regression tests.
|
||||||
|
|
||||||
|
Merges the prior test_seedvr2_refactor_nodes.py and
|
||||||
|
test_seedvr_conditioning_hardening.py modules. Refactor tests use the
|
||||||
|
top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests
|
||||||
|
use _import_nodes_seedvr_isolated() for sys.modules isolation when
|
||||||
|
mocking comfy.model_management.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
|
||||||
|
_SENTINEL = object()
|
||||||
|
_TARGETS = (
|
||||||
|
("comfy.model_management", "comfy"),
|
||||||
|
("comfy_extras.nodes_seedvr", "comfy_extras"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _import_nodes_seedvr_isolated():
|
||||||
|
"""Import comfy_extras.nodes_seedvr with comfy.model_management mocked."""
|
||||||
|
priors = []
|
||||||
|
for mod_name, parent_name in _TARGETS:
|
||||||
|
prior_mod = sys.modules.get(mod_name, _SENTINEL)
|
||||||
|
parent = sys.modules.get(parent_name)
|
||||||
|
attr = mod_name.split(".")[-1]
|
||||||
|
prior_attr = (
|
||||||
|
getattr(parent, attr, _SENTINEL) if parent is not None else _SENTINEL
|
||||||
|
)
|
||||||
|
priors.append((mod_name, parent_name, attr, prior_mod, prior_attr))
|
||||||
|
|
||||||
|
mock_mm = MagicMock()
|
||||||
|
for fn in (
|
||||||
|
"xformers_enabled", "xformers_enabled_vae",
|
||||||
|
"pytorch_attention_enabled", "pytorch_attention_enabled_vae",
|
||||||
|
"sage_attention_enabled", "flash_attention_enabled",
|
||||||
|
"is_intel_xpu",
|
||||||
|
):
|
||||||
|
getattr(mock_mm, fn).return_value = False
|
||||||
|
tv = torch.version.__version__.split(".")
|
||||||
|
mock_mm.torch_version_numeric = (int(tv[0]), int(tv[1]))
|
||||||
|
mock_mm.WINDOWS = False
|
||||||
|
sys.modules["comfy.model_management"] = mock_mm
|
||||||
|
if sys.modules.get("comfy") is None:
|
||||||
|
import comfy as _comfy_pkg # noqa: F401
|
||||||
|
comfy_pkg = sys.modules.get("comfy")
|
||||||
|
if comfy_pkg is not None:
|
||||||
|
setattr(comfy_pkg, "model_management", mock_mm)
|
||||||
|
nodes_seedvr = sys.modules.get("comfy_extras.nodes_seedvr") or (
|
||||||
|
importlib.import_module("comfy_extras.nodes_seedvr")
|
||||||
|
)
|
||||||
|
|
||||||
|
def _restore():
|
||||||
|
for mod_name, parent_name, attr, prior_mod, prior_attr in priors:
|
||||||
|
if prior_mod is _SENTINEL:
|
||||||
|
sys.modules.pop(mod_name, None)
|
||||||
|
else:
|
||||||
|
sys.modules[mod_name] = prior_mod
|
||||||
|
parent = sys.modules.get(parent_name)
|
||||||
|
if parent is None:
|
||||||
|
continue
|
||||||
|
if prior_attr is _SENTINEL:
|
||||||
|
if hasattr(parent, attr):
|
||||||
|
delattr(parent, attr)
|
||||||
|
else:
|
||||||
|
setattr(parent, attr, prior_attr)
|
||||||
|
|
||||||
|
return nodes_seedvr, _restore
|
||||||
|
|
||||||
|
|
||||||
|
class _Rope(nn.Module):
|
||||||
|
"""Minimal RoPE stub exposing a `freqs` parameter."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.freqs = nn.Parameter(torch.zeros(4))
|
||||||
|
|
||||||
|
|
||||||
|
class _Block(nn.Module):
|
||||||
|
"""Minimal transformer block stub holding a `_Rope`."""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.rope = _Rope()
|
||||||
|
|
||||||
|
|
||||||
|
class _DiffusionModel(nn.Module):
|
||||||
|
"""Stub diffusion model with N blocks and pos/neg conditioning buffers."""
|
||||||
|
def __init__(self, n_blocks=3, zero_conditioning=False, conditioning_dtype=torch.float32):
|
||||||
|
super().__init__()
|
||||||
|
self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)])
|
||||||
|
pos = torch.zeros if zero_conditioning else torch.ones
|
||||||
|
self.register_buffer("positive_conditioning", pos((2, 4), dtype=conditioning_dtype))
|
||||||
|
self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype))
|
||||||
|
|
||||||
|
|
||||||
|
class _ModelInner:
|
||||||
|
"""Inner model wrapper exposing `.diffusion_model`."""
|
||||||
|
def __init__(self, diffusion_model):
|
||||||
|
self.diffusion_model = diffusion_model
|
||||||
|
|
||||||
|
|
||||||
|
class _ModelPatcher:
|
||||||
|
"""ModelPatcher stub exposing `.model._ModelInner`."""
|
||||||
|
def __init__(self, diffusion_model):
|
||||||
|
self.model = _ModelInner(diffusion_model)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_conditioning_schema_exposes_model_passthrough_output():
|
||||||
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||||
|
try:
|
||||||
|
schema = nodes_seedvr.SeedVR2Conditioning.define_schema()
|
||||||
|
assert [input_item.id for input_item in schema.inputs] == [
|
||||||
|
"model",
|
||||||
|
"vae_conditioning",
|
||||||
|
]
|
||||||
|
assert schema.inputs[1].display_name == "latent"
|
||||||
|
assert [output.display_name for output in schema.outputs] == [
|
||||||
|
"model",
|
||||||
|
"positive",
|
||||||
|
"negative",
|
||||||
|
"latent",
|
||||||
|
]
|
||||||
|
finally:
|
||||||
|
restore()
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_conditioning_returns_packed_input_latent_deterministically():
|
||||||
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||||
|
try:
|
||||||
|
diffusion_model = _DiffusionModel()
|
||||||
|
patcher = _ModelPatcher(diffusion_model)
|
||||||
|
samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2)
|
||||||
|
vae_conditioning = {"samples": samples}
|
||||||
|
|
||||||
|
_, first_positive, first_negative, first_latent = (
|
||||||
|
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||||
|
patcher,
|
||||||
|
vae_conditioning,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_, second_positive, second_negative, second_latent = (
|
||||||
|
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||||
|
patcher,
|
||||||
|
vae_conditioning,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_latent = samples.reshape(1, 6, 2, 2)
|
||||||
|
channel_last = samples.movedim(1, -1).contiguous()
|
||||||
|
expected_condition = torch.cat(
|
||||||
|
[
|
||||||
|
channel_last,
|
||||||
|
torch.ones((*channel_last.shape[:-1], 1)),
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
).movedim(-1, 1).reshape(1, 9, 2, 2)
|
||||||
|
|
||||||
|
assert torch.equal(first_latent["samples"], expected_latent)
|
||||||
|
assert torch.equal(second_latent["samples"], expected_latent)
|
||||||
|
assert torch.equal(
|
||||||
|
first_positive[0][1]["condition"],
|
||||||
|
expected_condition,
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
second_positive[0][1]["condition"],
|
||||||
|
expected_condition,
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
first_negative[0][1]["condition"],
|
||||||
|
expected_condition,
|
||||||
|
)
|
||||||
|
assert torch.equal(
|
||||||
|
second_negative[0][1]["condition"],
|
||||||
|
expected_condition,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
restore()
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_conditioning_fails_loud_on_zero_buffers():
|
||||||
|
nodes_seedvr, restore = _import_nodes_seedvr_isolated()
|
||||||
|
try:
|
||||||
|
diffusion_model = _DiffusionModel(zero_conditioning=True)
|
||||||
|
patcher = _ModelPatcher(diffusion_model)
|
||||||
|
vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))}
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
nodes_seedvr.SeedVR2Conditioning.execute(
|
||||||
|
patcher, vae_conditioning,
|
||||||
|
)
|
||||||
|
|
||||||
|
message = str(excinfo.value)
|
||||||
|
assert message.startswith(
|
||||||
|
nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX
|
||||||
|
), (
|
||||||
|
"Fail-loud message must use the standard "
|
||||||
|
"_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers "
|
||||||
|
f"can match it. Got: {message!r}"
|
||||||
|
)
|
||||||
|
assert "positive_conditioning" in message
|
||||||
|
assert "negative_conditioning" in message
|
||||||
|
finally:
|
||||||
|
restore()
|
||||||
55
tests-unit/comfy_extras_test/test_seedvr2_nodes.py
Normal file
55
tests-unit/comfy_extras_test/test_seedvr2_nodes.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import sys
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr_node_signature_matches_schema():
|
||||||
|
mock_mm = MagicMock()
|
||||||
|
mock_mm.xformers_enabled.return_value = False
|
||||||
|
mock_mm.xformers_enabled_vae.return_value = False
|
||||||
|
mock_mm.sage_attention_enabled.return_value = False
|
||||||
|
mock_mm.flash_attention_enabled.return_value = False
|
||||||
|
|
||||||
|
sentinel = object()
|
||||||
|
prior_cpu = cli_args.cpu
|
||||||
|
cli_args.cpu = True
|
||||||
|
prior_module = sys.modules.get("comfy_extras.nodes_seedvr", sentinel)
|
||||||
|
comfy_pkg = sys.modules.get("comfy")
|
||||||
|
prior_mm_attr = getattr(comfy_pkg, "model_management", sentinel) if comfy_pkg else sentinel
|
||||||
|
|
||||||
|
with patch.dict(sys.modules, {"comfy.model_management": mock_mm}):
|
||||||
|
if comfy_pkg is not None:
|
||||||
|
setattr(comfy_pkg, "model_management", mock_mm)
|
||||||
|
sys.modules.pop("comfy_extras.nodes_seedvr", None)
|
||||||
|
try:
|
||||||
|
nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr")
|
||||||
|
for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning, nodes_seedvr.SeedVR2ProgressiveSampler):
|
||||||
|
schema_ids = [i.id for i in node_cls.define_schema().inputs]
|
||||||
|
exec_params = [
|
||||||
|
p for p in inspect.signature(node_cls.execute).parameters.keys()
|
||||||
|
if p != "cls"
|
||||||
|
]
|
||||||
|
assert schema_ids == exec_params, (
|
||||||
|
f"{node_cls.__name__} schema/execute drift: "
|
||||||
|
f"schema_ids={schema_ids}, exec_params={exec_params}"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
cli_args.cpu = prior_cpu
|
||||||
|
if prior_module is sentinel:
|
||||||
|
sys.modules.pop("comfy_extras.nodes_seedvr", None)
|
||||||
|
else:
|
||||||
|
sys.modules["comfy_extras.nodes_seedvr"] = prior_module
|
||||||
|
if comfy_pkg is not None:
|
||||||
|
if prior_mm_attr is sentinel:
|
||||||
|
if hasattr(comfy_pkg, "model_management"):
|
||||||
|
delattr(comfy_pkg, "model_management")
|
||||||
|
else:
|
||||||
|
setattr(comfy_pkg, "model_management", prior_mm_attr)
|
||||||
57
tests-unit/comfy_extras_test/test_seedvr2_post_processing.py
Normal file
57
tests-unit/comfy_extras_test/test_seedvr2_post_processing.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
from comfy_extras import nodes_seedvr # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def _schema_ids(items):
|
||||||
|
return [item.id for item in items]
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_schema():
|
||||||
|
schema = nodes_seedvr.SeedVR2PostProcessing.define_schema()
|
||||||
|
|
||||||
|
assert _schema_ids(schema.inputs) == ["images", "original_resized_images", "color_correction_method"]
|
||||||
|
assert schema.inputs[2].options == ["lab", "wavelet", "adain", "none"]
|
||||||
|
assert schema.inputs[2].default == "lab"
|
||||||
|
assert schema.outputs[0].get_io_type() == "IMAGE"
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch):
|
||||||
|
decoded = torch.full((1, 3, 4, 4), 0.25)
|
||||||
|
reference = torch.full((1, 3, 4, 4), 0.75)
|
||||||
|
|
||||||
|
def _lab(content, style):
|
||||||
|
raise torch.cuda.OutOfMemoryError("CUDA out of memory")
|
||||||
|
|
||||||
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu"))
|
||||||
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000)
|
||||||
|
monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None)
|
||||||
|
|
||||||
|
with patch.object(nodes_seedvr, "lab_color_transfer", _lab):
|
||||||
|
try:
|
||||||
|
nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked(
|
||||||
|
decoded, reference, torch.device("cpu"), "lab",
|
||||||
|
)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
assert "color_correction_method=lab" in str(exc)
|
||||||
|
assert " method=lab" not in str(exc)
|
||||||
|
else:
|
||||||
|
raise AssertionError("expected RuntimeError for one-frame LAB OOM")
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_post_processing_unknown_color_correction_method_raises():
|
||||||
|
decoded = torch.zeros(1, 2, 4, 4, 3)
|
||||||
|
original = torch.zeros(1, 2, 4, 4, 3)
|
||||||
|
try:
|
||||||
|
nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus")
|
||||||
|
except ValueError as exc:
|
||||||
|
assert "color_correction_method" in str(exc)
|
||||||
|
else:
|
||||||
|
raise AssertionError("expected ValueError for unknown color_correction_method")
|
||||||
@ -73,6 +73,24 @@ def _make_flux_schnell_comfyui_sd():
|
|||||||
return sd
|
return sd
|
||||||
|
|
||||||
|
|
||||||
|
def _make_seedvr2_7b_separate_mm_sd():
|
||||||
|
return {
|
||||||
|
"blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_seedvr2_7b_shared_mm_sd():
|
||||||
|
return {
|
||||||
|
"blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_seedvr2_3b_shared_mm_sd():
|
||||||
|
return {
|
||||||
|
"blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class TestModelDetection:
|
class TestModelDetection:
|
||||||
"""Verify that first-match model detection selects the correct model
|
"""Verify that first-match model detection selects the correct model
|
||||||
based on list ordering and unet_config specificity."""
|
based on list ordering and unet_config specificity."""
|
||||||
@ -125,6 +143,45 @@ class TestModelDetection:
|
|||||||
assert model_config is not None
|
assert model_config is not None
|
||||||
assert type(model_config).__name__ == "FluxSchnell"
|
assert type(model_config).__name__ == "FluxSchnell"
|
||||||
|
|
||||||
|
def test_seedvr2_7b_separate_mm_detection_config(self):
|
||||||
|
sd = _make_seedvr2_7b_separate_mm_sd()
|
||||||
|
unet_config = detect_unet_config(sd, "")
|
||||||
|
|
||||||
|
assert unet_config is not None
|
||||||
|
assert unet_config["image_model"] == "seedvr2"
|
||||||
|
assert unet_config["vid_dim"] == 3072
|
||||||
|
assert unet_config["heads"] == 24
|
||||||
|
assert unet_config["num_layers"] == 36
|
||||||
|
assert unet_config["mm_layers"] == 36
|
||||||
|
assert unet_config["mlp_type"] == "normal"
|
||||||
|
assert unet_config["rope_type"] == "rope3d"
|
||||||
|
assert unet_config["rope_dim"] == 64
|
||||||
|
|
||||||
|
def test_seedvr2_7b_shared_mm_detection_config(self):
|
||||||
|
sd = _make_seedvr2_7b_shared_mm_sd()
|
||||||
|
unet_config = detect_unet_config(sd, "")
|
||||||
|
|
||||||
|
assert unet_config is not None
|
||||||
|
assert unet_config["image_model"] == "seedvr2"
|
||||||
|
assert unet_config["vid_dim"] == 3072
|
||||||
|
assert unet_config["heads"] == 24
|
||||||
|
assert unet_config["num_layers"] == 36
|
||||||
|
assert unet_config["mm_layers"] == 10
|
||||||
|
assert unet_config["mlp_type"] == "swiglu"
|
||||||
|
assert unet_config["rope_type"] == "rope3d"
|
||||||
|
assert unet_config["rope_dim"] == 64
|
||||||
|
|
||||||
|
def test_seedvr2_3b_shared_mm_detection_config(self):
|
||||||
|
sd = _make_seedvr2_3b_shared_mm_sd()
|
||||||
|
unet_config = detect_unet_config(sd, "")
|
||||||
|
|
||||||
|
assert unet_config is not None
|
||||||
|
assert unet_config["image_model"] == "seedvr2"
|
||||||
|
assert unet_config["vid_dim"] == 2560
|
||||||
|
assert unet_config["heads"] == 20
|
||||||
|
assert unet_config["num_layers"] == 32
|
||||||
|
assert unet_config["mlp_type"] == "swiglu"
|
||||||
|
|
||||||
def test_unet_config_and_required_keys_combination_is_unique(self):
|
def test_unet_config_and_required_keys_combination_is_unique(self):
|
||||||
"""Each model in the registry must have a unique combination of
|
"""Each model in the registry must have a unique combination of
|
||||||
``unet_config`` and ``required_keys``. If two models share the same
|
``unet_config`` and ``required_keys``. If two models share the same
|
||||||
|
|||||||
86
tests-unit/comfy_test/seedvr_vae_forward_test.py
Normal file
86
tests-unit/comfy_test/seedvr_vae_forward_test.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must
|
||||||
|
honor the actual tensor/tuple return contract of ``encode()`` and
|
||||||
|
``decode_()`` and must NOT dereference diffusers-style ``.latent_dist``
|
||||||
|
or ``.sample`` attributes on those returns.
|
||||||
|
|
||||||
|
The pre-fix body raised ``AttributeError: 'Tensor' object has no
|
||||||
|
attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and
|
||||||
|
``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'``
|
||||||
|
for ``mode == "decode"`` (the class only defines ``decode_`` with a
|
||||||
|
trailing underscore). The post-fix body unwraps the optional one-element
|
||||||
|
tuple shape that ``return_dict=False`` produces and returns the tensor
|
||||||
|
directly.
|
||||||
|
|
||||||
|
Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses
|
||||||
|
the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and
|
||||||
|
overrides ``encode``/``decode_`` with known tensors so the contract can
|
||||||
|
be probed without loading any real VAE weights.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
_LATENT_SHAPE = (1, 16, 2, 2, 2)
|
||||||
|
_DECODED_SHAPE = (1, 3, 5, 16, 16)
|
||||||
|
_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16)
|
||||||
|
_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class _StubVAE(VideoAutoencoderKL):
|
||||||
|
def __init__(self):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self._encode_out = torch.zeros(*_LATENT_SHAPE)
|
||||||
|
self._decode_out = torch.zeros(*_DECODED_SHAPE)
|
||||||
|
|
||||||
|
def encode(self, x, return_dict=True):
|
||||||
|
return self._encode_out
|
||||||
|
|
||||||
|
def decode_(self, z, return_dict=True):
|
||||||
|
return self._decode_out
|
||||||
|
|
||||||
|
|
||||||
|
def test_forward_encode_returns_tensor():
|
||||||
|
vae = _StubVAE()
|
||||||
|
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
|
||||||
|
result = vae.forward(x, mode="encode")
|
||||||
|
assert type(result) is torch.Tensor
|
||||||
|
assert result.shape == torch.Size(_LATENT_SHAPE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_forward_decode_returns_tensor():
|
||||||
|
vae = _StubVAE()
|
||||||
|
z = torch.zeros(*_INPUT_DECODE_SHAPE)
|
||||||
|
result = vae.forward(z, mode="decode")
|
||||||
|
assert type(result) is torch.Tensor
|
||||||
|
assert result.shape == torch.Size(_DECODED_SHAPE)
|
||||||
|
|
||||||
|
|
||||||
|
class _TupleReturningStubVAE(VideoAutoencoderKL):
|
||||||
|
"""Stub whose ``encode``/``decode_`` return the ``(tensor,)`` tuple of ``return_dict=False``, exercising the unwrap branch of ``VideoAutoencoderKL.forward``."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self._encode_tensor = torch.zeros(*_LATENT_SHAPE)
|
||||||
|
self._decode_tensor = torch.zeros(*_DECODED_SHAPE)
|
||||||
|
|
||||||
|
def encode(self, x, return_dict=True):
|
||||||
|
return (self._encode_tensor,)
|
||||||
|
|
||||||
|
def decode_(self, z, return_dict=True):
|
||||||
|
return (self._decode_tensor,)
|
||||||
|
|
||||||
|
|
||||||
|
def test_forward_all_unwraps_one_tuple_at_each_step():
|
||||||
|
vae = _TupleReturningStubVAE()
|
||||||
|
x = torch.zeros(*_INPUT_ENCODE_SHAPE)
|
||||||
|
result = vae.forward(x, mode="all")
|
||||||
|
assert type(result) is torch.Tensor
|
||||||
|
assert result.shape == torch.Size(_DECODED_SHAPE)
|
||||||
49
tests-unit/comfy_test/test_seedvr2_dtype.py
Normal file
49
tests-unit/comfy_test/test_seedvr2_dtype.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
import comfy.sd
|
||||||
|
import comfy.supported_models
|
||||||
|
import comfy.ldm.seedvr.model as seedvr_model
|
||||||
|
import comfy.ldm.seedvr.vae as seedvr_vae
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_fp16_manual_cast_only_for_bf16_device(monkeypatch):
|
||||||
|
bf16_device = object()
|
||||||
|
fp16_device = object()
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
comfy.supported_models.comfy.model_management,
|
||||||
|
"should_use_bf16",
|
||||||
|
lambda device=None: device is bf16_device,
|
||||||
|
)
|
||||||
|
|
||||||
|
bf16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"})
|
||||||
|
bf16_config.set_inference_dtype(torch.float16, None, device=bf16_device)
|
||||||
|
assert bf16_config.manual_cast_dtype is torch.bfloat16
|
||||||
|
|
||||||
|
fp16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"})
|
||||||
|
fp16_config.set_inference_dtype(torch.float16, None, device=fp16_device)
|
||||||
|
assert fp16_config.manual_cast_dtype is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_text_conditioning_accepts_cfg1_single_branch():
|
||||||
|
context = torch.arange(6, dtype=torch.float32).reshape(1, 3, 2)
|
||||||
|
|
||||||
|
txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0])
|
||||||
|
|
||||||
|
torch.testing.assert_close(txt, context.squeeze(0))
|
||||||
|
torch.testing.assert_close(txt_shape, torch.tensor([[3]], device=context.device))
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer():
|
||||||
|
wrapper = seedvr_vae.VideoAutoencoderKLWrapper.__new__(seedvr_vae.VideoAutoencoderKLWrapper)
|
||||||
|
estimate = wrapper.comfy_memory_used_decode((1, 16, 26, 120, 160))
|
||||||
|
old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2
|
||||||
|
|
||||||
|
assert estimate == 101 * 960 * 1280 * 160
|
||||||
|
assert estimate > 15 * 1024 ** 3
|
||||||
|
assert estimate > old_estimate * 100
|
||||||
216
tests-unit/comfy_test/test_seedvr2_internals.py
Normal file
216
tests-unit/comfy_test/test_seedvr2_internals.py
Normal file
@ -0,0 +1,216 @@
|
|||||||
|
"""Consolidated SeedVR2 internals regression tests.
|
||||||
|
|
||||||
|
Sources (all merged verbatim, helper names disambiguated where colliding):
|
||||||
|
|
||||||
|
* GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare
|
||||||
|
memory_occupy against get_norm_limit(), not float('inf').
|
||||||
|
* SeedVR2 variable-length attention split-loop contract.
|
||||||
|
|
||||||
|
Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and
|
||||||
|
comfy.ldm.modules.attention transitively pull in comfy.model_management,
|
||||||
|
which probes torch.cuda.current_device() at import time unless args.cpu is
|
||||||
|
set first.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
args.cpu = True
|
||||||
|
|
||||||
|
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
|
||||||
|
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
||||||
|
import comfy.ldm.modules.attention as attention # noqa: E402
|
||||||
|
import comfy.ops as comfy_ops # noqa: E402
|
||||||
|
from comfy.ldm.seedvr.vae import ( # noqa: E402
|
||||||
|
causal_norm_wrapper,
|
||||||
|
set_norm_limit,
|
||||||
|
)
|
||||||
|
from comfy.ldm.seedvr.attention import var_attention_optimized_split # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GroupNorm limit tests (test_seedvr_groupnorm_limit.py)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_NUM_CHANNELS = 8
|
||||||
|
_NUM_GROUPS = 4
|
||||||
|
_TENSOR_SHAPE = (1, 8, 2, 4, 4)
|
||||||
|
|
||||||
|
_GROUPNORM_SUBCLASSES = [
|
||||||
|
pytest.param(comfy_ops.disable_weight_init.GroupNorm, id="disable_weight_init"),
|
||||||
|
pytest.param(comfy_ops.manual_cast.GroupNorm, id="manual_cast"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES)
|
||||||
|
def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls):
|
||||||
|
real_group_norm = vae_mod.F.group_norm
|
||||||
|
set_norm_limit(1e-9)
|
||||||
|
try:
|
||||||
|
gn = groupnorm_cls(num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS)
|
||||||
|
gn.eval()
|
||||||
|
|
||||||
|
forward_hook_calls = []
|
||||||
|
|
||||||
|
def _hook(module, inputs, output):
|
||||||
|
forward_hook_calls.append(tuple(inputs[0].shape))
|
||||||
|
|
||||||
|
spy_calls = []
|
||||||
|
|
||||||
|
def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs):
|
||||||
|
spy_calls.append({"num_groups": int(num_groups_arg)})
|
||||||
|
return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs)
|
||||||
|
|
||||||
|
handle = gn.register_forward_hook(_hook)
|
||||||
|
try:
|
||||||
|
with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy):
|
||||||
|
out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE))
|
||||||
|
finally:
|
||||||
|
handle.remove()
|
||||||
|
|
||||||
|
full_calls = len(forward_hook_calls)
|
||||||
|
chunked_calls = sum(1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS)
|
||||||
|
|
||||||
|
assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE
|
||||||
|
assert full_calls == 0, (
|
||||||
|
f"low-limit GroupNorm gate must NOT take the full-forward path; got full_calls={full_calls}"
|
||||||
|
)
|
||||||
|
assert chunked_calls > 0, (
|
||||||
|
f"low-limit GroupNorm gate must take the chunked path; got chunked_calls={chunked_calls}"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
set_norm_limit(None)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SeedVR2 var_attention split-loop tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch):
|
||||||
|
dim = 8
|
||||||
|
heads = 2
|
||||||
|
head_dim = 4
|
||||||
|
attn = seedvr_model.NaSwinAttention(
|
||||||
|
vid_dim=dim,
|
||||||
|
txt_dim=dim,
|
||||||
|
heads=heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
qk_bias=False,
|
||||||
|
qk_norm=seedvr_model.CustomRMSNorm,
|
||||||
|
qk_norm_eps=1e-6,
|
||||||
|
rope_type=None,
|
||||||
|
rope_dim=head_dim,
|
||||||
|
shared_weights=False,
|
||||||
|
window=(2, 1, 1),
|
||||||
|
window_method="720pwin_by_size_bysize",
|
||||||
|
version=True,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.float32,
|
||||||
|
operations=comfy_ops.disable_weight_init,
|
||||||
|
)
|
||||||
|
generator = torch.Generator(device="cpu").manual_seed(11)
|
||||||
|
vid = torch.randn(8, dim, generator=generator)
|
||||||
|
txt = torch.randn(3, dim, generator=generator)
|
||||||
|
vid_shape = torch.tensor([[2, 2, 2]], dtype=torch.long)
|
||||||
|
txt_shape = torch.tensor([[3]], dtype=torch.long)
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def fake_optimized_var_attention(**kwargs):
|
||||||
|
calls.append(kwargs)
|
||||||
|
return kwargs["q"]
|
||||||
|
|
||||||
|
monkeypatch.setattr(seedvr_model, "optimized_var_attention", fake_optimized_var_attention)
|
||||||
|
|
||||||
|
vid_out, txt_out = attn(vid, txt, vid_shape, txt_shape, seedvr_model.Cache(disable=True))
|
||||||
|
|
||||||
|
assert tuple(vid_out.shape) == (8, dim)
|
||||||
|
assert tuple(txt_out.shape) == (3, dim)
|
||||||
|
assert len(calls) == 1
|
||||||
|
call = calls[0]
|
||||||
|
assert tuple(call["q"].shape) == (14, heads, head_dim)
|
||||||
|
assert tuple(call["k"].shape) == (14, heads, head_dim)
|
||||||
|
assert tuple(call["v"].shape) == (14, heads, head_dim)
|
||||||
|
assert call["heads"] == heads
|
||||||
|
assert call["skip_reshape"] is True
|
||||||
|
assert call["skip_output_reshape"] is True
|
||||||
|
torch.testing.assert_close(
|
||||||
|
call["cu_seqlens_q"],
|
||||||
|
torch.tensor([0, 7, 14], dtype=torch.int32),
|
||||||
|
rtol=0,
|
||||||
|
atol=0,
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
call["cu_seqlens_k"],
|
||||||
|
torch.tensor([0, 7, 14], dtype=torch.int32),
|
||||||
|
rtol=0,
|
||||||
|
atol=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch):
|
||||||
|
heads = 2
|
||||||
|
head_dim = 3
|
||||||
|
q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim)
|
||||||
|
k = q + 100
|
||||||
|
v = q + 200
|
||||||
|
cu = torch.tensor([0, 2, 5], dtype=torch.int32)
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs):
|
||||||
|
calls.append(
|
||||||
|
{
|
||||||
|
"q_shape": tuple(q_arg.shape),
|
||||||
|
"k_shape": tuple(k_arg.shape),
|
||||||
|
"v_shape": tuple(v_arg.shape),
|
||||||
|
"heads": heads_arg,
|
||||||
|
"kwargs": kwargs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return q_arg + v_arg
|
||||||
|
|
||||||
|
monkeypatch.setattr(attention, "optimized_attention", fake_optimized_attention)
|
||||||
|
|
||||||
|
out = var_attention_optimized_split(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
heads,
|
||||||
|
cu,
|
||||||
|
cu,
|
||||||
|
skip_reshape=True,
|
||||||
|
skip_output_reshape=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tuple(out.shape) == (5, heads, head_dim)
|
||||||
|
assert len(calls) == 2
|
||||||
|
assert calls[0]["q_shape"] == (1, heads, 2, head_dim)
|
||||||
|
assert calls[1]["q_shape"] == (1, heads, 3, head_dim)
|
||||||
|
assert all(call["heads"] == heads for call in calls)
|
||||||
|
assert all(call["kwargs"]["skip_reshape"] is True for call in calls)
|
||||||
|
assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls)
|
||||||
|
torch.testing.assert_close(out, q + v, rtol=0, atol=0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_var_attention_optimized_split_rejects_bad_offsets():
|
||||||
|
q = torch.randn(5, 2, 3)
|
||||||
|
cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32)
|
||||||
|
cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"):
|
||||||
|
var_attention_optimized_split(
|
||||||
|
q,
|
||||||
|
q,
|
||||||
|
q,
|
||||||
|
2,
|
||||||
|
cu_bad,
|
||||||
|
cu_ok,
|
||||||
|
skip_reshape=True,
|
||||||
|
skip_output_reshape=True,
|
||||||
|
)
|
||||||
307
tests-unit/comfy_test/test_seedvr2_model.py
Normal file
307
tests-unit/comfy_test/test_seedvr2_model.py
Normal file
@ -0,0 +1,307 @@
|
|||||||
|
"""Consolidated SeedVR2 model/graph/forward regression tests.
|
||||||
|
|
||||||
|
Merged from:
|
||||||
|
- seedvr_model_test.py
|
||||||
|
- test_seedvr_7b_final_block_text_path.py
|
||||||
|
- test_seedvr_forward_no_device_cast.py
|
||||||
|
- test_seedvr_latent_format.py
|
||||||
|
- test_seedvr2_vae_graph_boundaries.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
args.cpu = True
|
||||||
|
|
||||||
|
import comfy # noqa: E402
|
||||||
|
import comfy.latent_formats # noqa: E402
|
||||||
|
import comfy.ldm.seedvr.model # noqa: E402
|
||||||
|
import comfy.ldm.seedvr.model as seedvr_model # noqa: E402
|
||||||
|
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
|
||||||
|
import comfy.model_management # noqa: E402
|
||||||
|
import comfy.sample # noqa: E402
|
||||||
|
import comfy.sd as sd_mod # noqa: E402
|
||||||
|
import nodes as nodes_mod # noqa: E402
|
||||||
|
from comfy.ldm.seedvr.model import NaDiT # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers from seedvr_model_test.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_standin(positive_conditioning):
|
||||||
|
class _StandIn(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.register_buffer(
|
||||||
|
"positive_conditioning", positive_conditioning
|
||||||
|
)
|
||||||
|
|
||||||
|
_resolve_text_conditioning = NaDiT._resolve_text_conditioning
|
||||||
|
|
||||||
|
return _StandIn()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers from test_seedvr_7b_final_block_text_path.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _StubModule(nn.Module):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]:
|
||||||
|
flags = []
|
||||||
|
|
||||||
|
class _Block(_StubModule):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
flags.append(kwargs["is_last_layer"])
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule)
|
||||||
|
monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule)
|
||||||
|
monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule)
|
||||||
|
monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block)
|
||||||
|
|
||||||
|
seedvr_model.NaDiT(
|
||||||
|
norm_eps=1e-5,
|
||||||
|
num_layers=4,
|
||||||
|
mlp_type="normal",
|
||||||
|
vid_dim=vid_dim,
|
||||||
|
txt_in_dim=txt_in_dim,
|
||||||
|
heads=24,
|
||||||
|
mm_layers=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
return flags
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers from test_seedvr_latent_format.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _Model:
|
||||||
|
def __init__(self, latent_format):
|
||||||
|
self._latent_format = latent_format
|
||||||
|
|
||||||
|
def get_model_object(self, name):
|
||||||
|
assert name == "latent_format"
|
||||||
|
return self._latent_format
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers from test_seedvr2_vae_graph_boundaries.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _Patcher:
|
||||||
|
def get_free_memory(self, device):
|
||||||
|
return 1024 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
|
class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
|
||||||
|
def __init__(self, encoded):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.encoded = encoded
|
||||||
|
self.spatial_downsample_factor = 8
|
||||||
|
self.temporal_downsample_factor = 4
|
||||||
|
self.seen = []
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
self.seen.append(tuple(x.shape))
|
||||||
|
return self.encoded.to(device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper):
|
||||||
|
def __init__(self):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.spatial_downsample_factor = 8
|
||||||
|
self.temporal_downsample_factor = 4
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
def decode(self, z, seedvr2_tiling=None):
|
||||||
|
self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling})
|
||||||
|
if z.ndim == 4:
|
||||||
|
b, tc, h, w = z.shape
|
||||||
|
t = tc // 16
|
||||||
|
else:
|
||||||
|
b, _, t, h, w = z.shape
|
||||||
|
return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_vae(wrapper):
|
||||||
|
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||||
|
vae.first_stage_model = wrapper
|
||||||
|
vae.device = torch.device("cpu")
|
||||||
|
vae.output_device = torch.device("cpu")
|
||||||
|
vae.vae_dtype = torch.float32
|
||||||
|
vae.latent_channels = 16
|
||||||
|
vae.latent_dim = 3
|
||||||
|
vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8)
|
||||||
|
vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||||
|
vae.output_channels = 3
|
||||||
|
vae.disable_offload = True
|
||||||
|
vae.extra_1d_channel = None
|
||||||
|
vae.crop_input = False
|
||||||
|
vae.not_video = False
|
||||||
|
vae.patcher = _Patcher()
|
||||||
|
vae.process_input = lambda image: image
|
||||||
|
vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0)
|
||||||
|
vae.vae_output_dtype = lambda: torch.float32
|
||||||
|
vae.memory_used_encode = lambda shape, dtype: 1
|
||||||
|
vae.memory_used_decode = lambda shape, dtype: 1
|
||||||
|
vae.throw_exception_if_invalid = lambda: None
|
||||||
|
vae.vae_encode_crop_pixels = lambda pixels: pixels
|
||||||
|
vae.spacial_compression_decode = lambda: 8
|
||||||
|
vae.temporal_compression_decode = lambda: 4
|
||||||
|
return vae
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests from seedvr_model_test.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_context_falls_back_to_positive_buffer():
|
||||||
|
"""``context is None`` falls back to the registered ``positive_conditioning`` buffer and runs to completion."""
|
||||||
|
pos_buffer = torch.full((58, 5120), 7.0)
|
||||||
|
standin = _make_standin(pos_buffer)
|
||||||
|
txt, txt_shape = standin._resolve_text_conditioning(None)
|
||||||
|
assert txt.shape == (58, 5120)
|
||||||
|
assert (txt == 7.0).all(), (
|
||||||
|
"fallback path must use the positive_conditioning buffer "
|
||||||
|
"verbatim, not a zero tensor"
|
||||||
|
)
|
||||||
|
assert txt_shape.shape == (1, 1)
|
||||||
|
assert txt_shape[0, 0].item() == 58
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests from test_seedvr_7b_final_block_text_path.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch):
|
||||||
|
assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_7b_rope3d_matches_wrapper_oracle():
|
||||||
|
rope = seedvr_model.get_na_rope("rope3d", dim=64)
|
||||||
|
generator = torch.Generator(device="cpu").manual_seed(0)
|
||||||
|
q = torch.randn(4, 2, 128, generator=generator)
|
||||||
|
k = torch.randn(4, 2, 128, generator=generator)
|
||||||
|
shape = torch.tensor([[1, 2, 2]], dtype=torch.long)
|
||||||
|
freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1)
|
||||||
|
|
||||||
|
expected_q = seedvr_model._apply_seedvr2_rotary_emb(
|
||||||
|
freqs,
|
||||||
|
q.permute(1, 0, 2).float(),
|
||||||
|
).to(q.dtype).permute(1, 0, 2)
|
||||||
|
expected_k = seedvr_model._apply_seedvr2_rotary_emb(
|
||||||
|
freqs,
|
||||||
|
k.permute(1, 0, 2).float(),
|
||||||
|
).to(k.dtype).permute(1, 0, 2)
|
||||||
|
|
||||||
|
actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True))
|
||||||
|
|
||||||
|
torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0)
|
||||||
|
torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests from test_seedvr_latent_format.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion():
|
||||||
|
latent_format = comfy.latent_formats.SeedVR2()
|
||||||
|
latent_image = torch.zeros(1, 1, 4, 5)
|
||||||
|
|
||||||
|
fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image)
|
||||||
|
|
||||||
|
assert latent_format.latent_channels == 16
|
||||||
|
assert latent_format.latent_dimensions == 2
|
||||||
|
assert fixed.shape == (1, 16, 4, 5)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests from test_seedvr2_vae_graph_boundaries.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch):
|
||||||
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||||
|
|
||||||
|
encoded = torch.full((1, 16, 2, 4, 5), 2.0)
|
||||||
|
vae = _make_vae(_EncodeWrapper(encoded))
|
||||||
|
pixels = torch.zeros(1, 5, 32, 40, 3)
|
||||||
|
|
||||||
|
node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0]
|
||||||
|
node_latent = node_output["samples"]
|
||||||
|
assert set(node_output) == {"samples"}
|
||||||
|
assert tuple(node_latent.shape) == (1, 16, 2, 4, 5)
|
||||||
|
assert node_latent.dtype == torch.float32
|
||||||
|
assert node_latent.stride()[-1] == 1
|
||||||
|
assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152))
|
||||||
|
|
||||||
|
tiled = torch.full((1, 16, 2, 4, 5), 3.0)
|
||||||
|
monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled))
|
||||||
|
tiled_output = nodes_mod.VAEEncodeTiled().encode(
|
||||||
|
vae,
|
||||||
|
pixels,
|
||||||
|
tile_size=512,
|
||||||
|
overlap=64,
|
||||||
|
temporal_size=16,
|
||||||
|
temporal_overlap=4,
|
||||||
|
)[0]
|
||||||
|
tiled_latent = tiled_output["samples"]
|
||||||
|
assert set(tiled_output) == {"samples"}
|
||||||
|
assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5)
|
||||||
|
assert tiled_latent.dtype == torch.float32
|
||||||
|
assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152))
|
||||||
|
|
||||||
|
|
||||||
|
def test_vaedecode_tiled_spatial_applies_temporal_discarded(monkeypatch):
|
||||||
|
monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None)
|
||||||
|
vae = _make_vae(_DecodeWrapper())
|
||||||
|
|
||||||
|
nodes_mod.VAEDecodeTiled().decode(
|
||||||
|
vae,
|
||||||
|
{"samples": torch.zeros(1, 16, 2, 4, 5)},
|
||||||
|
tile_size=512,
|
||||||
|
overlap=64,
|
||||||
|
temporal_size=16,
|
||||||
|
temporal_overlap=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Spatial inputs flow through; temporal inputs are discarded — SeedVR2 owns
|
||||||
|
# temporal via the MemoryState causal cache, so VAEDecodeTiled's temporal
|
||||||
|
# knobs are no-ops at the wrapper.
|
||||||
|
assert vae.first_stage_model.calls == [
|
||||||
|
{
|
||||||
|
"shape": (1, 16, 2, 4, 5),
|
||||||
|
"seedvr2_tiling": {
|
||||||
|
"enable_tiling": True,
|
||||||
|
"tile_size": (512, 512),
|
||||||
|
"tile_overlap": (64, 64),
|
||||||
|
"temporal_size": 0,
|
||||||
|
"temporal_overlap": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
91
tests-unit/comfy_test/test_seedvr2_vae_decode.py
Normal file
91
tests-unit/comfy_test/test_seedvr2_vae_decode.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
||||||
|
from comfy_extras import nodes_seedvr # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper:
|
||||||
|
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||||
|
vae_mod.VideoAutoencoderKLWrapper
|
||||||
|
)
|
||||||
|
nn.Module.__init__(wrapper)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def _fingerprint_decode_(self, z, return_dict=True):
|
||||||
|
b = int(z.shape[0])
|
||||||
|
t = int(z.shape[2])
|
||||||
|
h = int(z.shape[3])
|
||||||
|
w = int(z.shape[4])
|
||||||
|
out = torch.empty(b, 3, t, h * 8, w * 8)
|
||||||
|
for batch_idx in range(b):
|
||||||
|
out[batch_idx].fill_(float(batch_idx + 1))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_with_patches(wrapper, z):
|
||||||
|
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_):
|
||||||
|
return wrapper.decode(z)
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_b2_t3_multi_frame_batch_unchanged():
|
||||||
|
wrapper = _make_wrapper()
|
||||||
|
|
||||||
|
out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2))
|
||||||
|
|
||||||
|
assert tuple(out.shape) == (2, 3, 3, 16, 16)
|
||||||
|
|
||||||
|
|
||||||
|
class _Wrapper(vae_mod.VideoAutoencoderKLWrapper):
|
||||||
|
def __init__(self):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.calls = []
|
||||||
|
|
||||||
|
def parameters(self):
|
||||||
|
return iter([torch.nn.Parameter(torch.zeros(()))])
|
||||||
|
|
||||||
|
def _decode_stub(self, latent):
|
||||||
|
self.calls.append(tuple(latent.shape))
|
||||||
|
return torch.zeros(latent.shape[0], 3, latent.shape[2], latent.shape[3] * 8, latent.shape[4] * 8)
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preprocessor_state():
|
||||||
|
wrapper = _Wrapper()
|
||||||
|
|
||||||
|
with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub):
|
||||||
|
out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5))
|
||||||
|
|
||||||
|
assert tuple(out.shape) == (1, 3, 2, 32, 40)
|
||||||
|
assert wrapper.calls == [(1, 16, 2, 4, 5)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents():
|
||||||
|
wrapper = _Wrapper()
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"):
|
||||||
|
wrapper.decode(torch.zeros(1, 16, 4))
|
||||||
|
|
||||||
|
|
||||||
|
def _t_padded(t_in: int) -> int:
|
||||||
|
if t_in == 1:
|
||||||
|
return 1
|
||||||
|
if t_in <= 4:
|
||||||
|
return 5
|
||||||
|
if (t_in - 1) % 4 == 0:
|
||||||
|
return t_in
|
||||||
|
return t_in + (4 - ((t_in - 1) % 4))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("t_in", [1, 5, 9])
|
||||||
|
def test_t_padded_matches_cut_videos(t_in):
|
||||||
|
dummy = torch.zeros(1, t_in, 1, 1, 1)
|
||||||
|
assert nodes_seedvr.cut_videos(dummy).shape[1] == _t_padded(t_in)
|
||||||
348
tests-unit/comfy_test/test_seedvr2_vae_tiled.py
Normal file
348
tests-unit/comfy_test/test_seedvr2_vae_tiled.py
Normal file
@ -0,0 +1,348 @@
|
|||||||
|
from contextlib import ExitStack
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
import comfy.ldm.seedvr.vae as vae_mod # noqa: E402
|
||||||
|
import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402
|
||||||
|
import comfy.sd as sd_mod # noqa: E402
|
||||||
|
from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# From test_seedvr_vae_tiled_decode_latent_min_size_override.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_decode_zero_temporal_size_disables_slicing_for_call():
|
||||||
|
from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae
|
||||||
|
|
||||||
|
class StubVAEModel(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.slicing_latent_min_size = 2
|
||||||
|
self.spatial_downsample_factor = 8
|
||||||
|
self.temporal_downsample_factor = 4
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
self.use_slicing = True
|
||||||
|
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
|
||||||
|
self.decode_min_sizes = []
|
||||||
|
self.memory_states = []
|
||||||
|
|
||||||
|
def decode_(self, t_chunk):
|
||||||
|
self.decode_min_sizes.append(self.slicing_latent_min_size)
|
||||||
|
return VideoAutoencoderKL.slicing_decode(self, t_chunk)
|
||||||
|
|
||||||
|
def _decode(self, z, memory_state=MemoryState.DISABLED):
|
||||||
|
self.memory_states.append(memory_state)
|
||||||
|
b, c, d, h, w = z.shape
|
||||||
|
return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype)
|
||||||
|
|
||||||
|
vae = StubVAEModel()
|
||||||
|
z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32)
|
||||||
|
|
||||||
|
tiled_vae(
|
||||||
|
z,
|
||||||
|
vae,
|
||||||
|
tile_size=(64, 64),
|
||||||
|
tile_overlap=(0, 0),
|
||||||
|
temporal_size=0,
|
||||||
|
temporal_overlap=0,
|
||||||
|
encode=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert vae.decode_min_sizes == [5]
|
||||||
|
assert vae.memory_states == [MemoryState.DISABLED]
|
||||||
|
assert vae.slicing_latent_min_size == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# From test_seedvr_vae_tiled_encode_runt_slice_override.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_zero_temporal_size_preserves_min_size_when_encode_raises():
|
||||||
|
from comfy.ldm.seedvr.vae import tiled_vae
|
||||||
|
|
||||||
|
class RaisingVAEModel(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.slicing_sample_min_size = 4
|
||||||
|
self.spatial_downsample_factor = 8
|
||||||
|
self.temporal_downsample_factor = 4
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32))
|
||||||
|
|
||||||
|
def encode(self, t_chunk):
|
||||||
|
raise RuntimeError("simulated encode failure")
|
||||||
|
|
||||||
|
vae = RaisingVAEModel()
|
||||||
|
x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32)
|
||||||
|
|
||||||
|
raised = False
|
||||||
|
try:
|
||||||
|
tiled_vae(
|
||||||
|
x,
|
||||||
|
vae,
|
||||||
|
tile_size=(64, 64),
|
||||||
|
tile_overlap=(0, 0),
|
||||||
|
temporal_size=0,
|
||||||
|
temporal_overlap=0,
|
||||||
|
encode=True,
|
||||||
|
)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
if "simulated encode failure" not in str(exc):
|
||||||
|
raise
|
||||||
|
raised = True
|
||||||
|
|
||||||
|
assert raised
|
||||||
|
assert vae.slicing_sample_min_size == 4
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# From test_seedvr_vae_tiled_temporal_slicing.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _SlicingDecodeVAE(nn.Module):
|
||||||
|
def __init__(self, slicing_latent_min_size):
|
||||||
|
super().__init__()
|
||||||
|
self.slicing_latent_min_size = slicing_latent_min_size
|
||||||
|
self.spatial_downsample_factor = 8
|
||||||
|
self.temporal_downsample_factor = 4
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
self.use_slicing = True
|
||||||
|
self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32))
|
||||||
|
self.decode_min_sizes = []
|
||||||
|
self.memory_states = []
|
||||||
|
|
||||||
|
def decode_(self, z):
|
||||||
|
self.decode_min_sizes.append(self.slicing_latent_min_size)
|
||||||
|
return vae_mod.VideoAutoencoderKL.slicing_decode(self, z)
|
||||||
|
|
||||||
|
def _decode(self, z, memory_state=MemoryState.DISABLED):
|
||||||
|
self.memory_states.append(memory_state)
|
||||||
|
x = z[:, :1].repeat(
|
||||||
|
1,
|
||||||
|
3,
|
||||||
|
1,
|
||||||
|
self.spatial_downsample_factor,
|
||||||
|
self.spatial_downsample_factor,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size():
|
||||||
|
vae = _SlicingDecodeVAE(slicing_latent_min_size=2)
|
||||||
|
z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8)
|
||||||
|
|
||||||
|
tiled_vae(
|
||||||
|
z,
|
||||||
|
vae,
|
||||||
|
tile_size=(64, 64),
|
||||||
|
tile_overlap=(0, 0),
|
||||||
|
temporal_size=12,
|
||||||
|
temporal_overlap=4,
|
||||||
|
encode=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert vae.decode_min_sizes == [2]
|
||||||
|
assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE]
|
||||||
|
assert vae.slicing_latent_min_size == 2
|
||||||
|
|
||||||
|
wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||||
|
vae_mod.VideoAutoencoderKLWrapper
|
||||||
|
)
|
||||||
|
nn.Module.__init__(wrapper)
|
||||||
|
seedvr2_tiling = {
|
||||||
|
"enable_tiling": True,
|
||||||
|
"tile_size": (64, 64),
|
||||||
|
"tile_overlap": (0, 0),
|
||||||
|
"temporal_size": 8,
|
||||||
|
"temporal_overlap": 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def _fake_tiled_vae(latent, model, **kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
return torch.zeros(1, 3, 1, 16, 16)
|
||||||
|
|
||||||
|
with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae):
|
||||||
|
wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling)
|
||||||
|
|
||||||
|
assert captured["temporal_overlap"] == 7
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# From test_vae_decode_tiled_dispatcher_seedvr2_4d.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _force_oom(*a, **k):
|
||||||
|
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_vae(first_stage_model, latent_channels, latent_dim):
|
||||||
|
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||||
|
vae.first_stage_model = first_stage_model
|
||||||
|
vae.patcher = MagicMock()
|
||||||
|
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
|
||||||
|
vae.device = vae.output_device = torch.device("cpu")
|
||||||
|
vae.vae_dtype = torch.float32
|
||||||
|
vae.disable_offload = True
|
||||||
|
vae.extra_1d_channel = None
|
||||||
|
vae.upscale_ratio = vae.downscale_ratio = 8
|
||||||
|
vae.upscale_index_formula = vae.downscale_index_formula = None
|
||||||
|
vae.output_channels = 3
|
||||||
|
vae.latent_channels = latent_channels
|
||||||
|
vae.latent_dim = latent_dim
|
||||||
|
vae.vae_output_dtype = lambda: torch.float32
|
||||||
|
vae.spacial_compression_decode = lambda: 8
|
||||||
|
vae.process_input = lambda x: x
|
||||||
|
vae.process_output = lambda x: x
|
||||||
|
vae.throw_exception_if_invalid = lambda: None
|
||||||
|
vae.memory_used_decode = lambda *a, **k: 1
|
||||||
|
return vae
|
||||||
|
|
||||||
|
|
||||||
|
def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode):
|
||||||
|
mm = sd_mod.model_management
|
||||||
|
with ExitStack() as stack:
|
||||||
|
stack.enter_context(patch.object(mm, "raise_non_oom", lambda e: None))
|
||||||
|
stack.enter_context(patch.object(mm, "load_models_gpu", lambda *a, **k: None))
|
||||||
|
stack.enter_context(patch.object(mm, "soft_empty_cache", lambda: None))
|
||||||
|
stack.enter_context(patch.object(sd_mod.VAE, "_decode_tiled_owned", seedvr2_call))
|
||||||
|
stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_", generic_call))
|
||||||
|
if patch_wrapper_decode:
|
||||||
|
stack.enter_context(patch.object(
|
||||||
|
seedvr_vae_mod.VideoAutoencoderKLWrapper, "decode",
|
||||||
|
side_effect=_force_oom))
|
||||||
|
vae.decode(samples)
|
||||||
|
|
||||||
|
|
||||||
|
def test_4d_seedvr2_latent_routes_to_owned_decode_tiled():
|
||||||
|
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||||
|
seedvr_vae_mod.VideoAutoencoderKLWrapper)
|
||||||
|
vae = _make_vae(wrapper, latent_channels=16, latent_dim=3)
|
||||||
|
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
|
||||||
|
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
|
||||||
|
_dispatch(vae, torch.zeros(1, 16 * 3, 8, 8), seedvr2_call, generic_call, True)
|
||||||
|
assert seedvr2_call.call_count == 1
|
||||||
|
assert generic_call.call_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled():
|
||||||
|
first_stage = MagicMock()
|
||||||
|
first_stage.comfy_handles_tiling = False
|
||||||
|
first_stage.decode = MagicMock(side_effect=_force_oom)
|
||||||
|
vae = _make_vae(first_stage, latent_channels=4, latent_dim=2)
|
||||||
|
seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64))
|
||||||
|
generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64))
|
||||||
|
_dispatch(vae, torch.zeros(1, 4, 8, 8), seedvr2_call, generic_call, False)
|
||||||
|
assert generic_call.call_count == 1
|
||||||
|
assert seedvr2_call.call_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# From test_vae_encode_tiled_fallback_dispatcher_seedvr2.py
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _populate_common_vae_attrs_fallback(vae):
|
||||||
|
vae.patcher = MagicMock()
|
||||||
|
vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024)
|
||||||
|
vae.device = torch.device("cpu")
|
||||||
|
vae.output_device = torch.device("cpu")
|
||||||
|
vae.vae_dtype = torch.float32
|
||||||
|
vae.disable_offload = True
|
||||||
|
vae.extra_1d_channel = None
|
||||||
|
vae.upscale_ratio = 8
|
||||||
|
vae.upscale_index_formula = None
|
||||||
|
vae.output_channels = 3
|
||||||
|
vae.latent_channels = 16
|
||||||
|
vae.latent_dim = 3
|
||||||
|
vae.downscale_ratio = 8
|
||||||
|
vae.downscale_index_formula = None
|
||||||
|
vae.not_video = False
|
||||||
|
vae.crop_input = False
|
||||||
|
vae.pad_channel_value = None
|
||||||
|
|
||||||
|
vae.vae_output_dtype = lambda: torch.float32
|
||||||
|
vae.spacial_compression_encode = lambda: 8
|
||||||
|
vae.process_input = lambda x: x
|
||||||
|
vae.process_output = lambda x: x
|
||||||
|
vae.throw_exception_if_invalid = lambda: None
|
||||||
|
vae.memory_used_encode = lambda *a, **k: 1
|
||||||
|
|
||||||
|
|
||||||
|
def _make_seedvr2_vae_fallback():
|
||||||
|
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||||
|
wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__(
|
||||||
|
seedvr_vae_mod.VideoAutoencoderKLWrapper
|
||||||
|
)
|
||||||
|
vae.first_stage_model = wrapper
|
||||||
|
_populate_common_vae_attrs_fallback(vae)
|
||||||
|
return vae
|
||||||
|
|
||||||
|
|
||||||
|
def _make_non_seedvr2_vae_fallback():
|
||||||
|
vae = sd_mod.VAE.__new__(sd_mod.VAE)
|
||||||
|
vae.first_stage_model = MagicMock()
|
||||||
|
vae.first_stage_model.comfy_handles_tiling = False
|
||||||
|
_populate_common_vae_attrs_fallback(vae)
|
||||||
|
return vae
|
||||||
|
|
||||||
|
|
||||||
|
def _force_regular_encode_oom(*args, **kwargs):
|
||||||
|
raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test")
|
||||||
|
|
||||||
|
|
||||||
|
def test_seedvr2_3d_routes_to_owned_encode_tiled_on_oom():
|
||||||
|
vae = _make_seedvr2_vae_fallback()
|
||||||
|
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
|
||||||
|
|
||||||
|
seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||||
|
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||||
|
|
||||||
|
with patch.object(sd_mod.model_management, "raise_non_oom",
|
||||||
|
lambda e: None), \
|
||||||
|
patch.object(sd_mod.model_management, "load_models_gpu",
|
||||||
|
lambda *a, **k: None), \
|
||||||
|
patch.object(sd_mod.model_management, "soft_empty_cache",
|
||||||
|
lambda: None), \
|
||||||
|
patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode",
|
||||||
|
side_effect=_force_regular_encode_oom), \
|
||||||
|
patch.object(sd_mod.VAE, "_encode_tiled_owned", seedvr2_call), \
|
||||||
|
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
|
||||||
|
vae.encode(pixel_samples)
|
||||||
|
|
||||||
|
assert seedvr2_call.call_count == 1, (
|
||||||
|
f"Expected _encode_tiled_owned to be called once for a SeedVR2 3D "
|
||||||
|
f"input under OOM fallback; got {seedvr2_call.call_count} calls."
|
||||||
|
)
|
||||||
|
assert generic_call.call_count == 0, (
|
||||||
|
f"encode_tiled_3d must NOT be called for a SeedVR2 input; got "
|
||||||
|
f"{generic_call.call_count} calls."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete():
|
||||||
|
vae = _make_non_seedvr2_vae_fallback()
|
||||||
|
vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8)
|
||||||
|
vae.upscale_ratio = (lambda a: a * 4, 8, 8)
|
||||||
|
generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8))
|
||||||
|
pixel_samples = torch.zeros((1, 8, 64, 64, 3))
|
||||||
|
|
||||||
|
with patch.object(sd_mod.model_management, "load_models_gpu",
|
||||||
|
lambda *a, **k: None), \
|
||||||
|
patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call):
|
||||||
|
vae.encode_tiled(pixel_samples)
|
||||||
|
|
||||||
|
assert generic_call.call_args.kwargs["overlap"] == (1, 64, 64)
|
||||||
95
tests-unit/comfy_test/test_seedvr_progressive_sampler.py
Normal file
95
tests-unit/comfy_test/test_seedvr_progressive_sampler.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
"""Unit tests for ``comfy_extras.nodes_seedvr.SeedVR2ProgressiveSampler``."""
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args as cli_args
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
cli_args.cpu = True
|
||||||
|
|
||||||
|
import comfy.sample # noqa: E402
|
||||||
|
import comfy_extras.nodes_seedvr as nodes_seedvr_mod # noqa: E402
|
||||||
|
from comfy_extras.nodes_seedvr import SeedVR2ProgressiveSampler # noqa: E402
|
||||||
|
|
||||||
|
_LAT_C = 16
|
||||||
|
_COND_C = 17
|
||||||
|
|
||||||
|
|
||||||
|
def _make_inputs(B: int = 1, T: int = 5, H: int = 8, W: int = 8):
|
||||||
|
"""Build minimal SeedVR2-shaped sampling inputs."""
|
||||||
|
samples_5d = torch.arange(
|
||||||
|
B * _LAT_C * T * H * W, dtype=torch.float32
|
||||||
|
).reshape(B, _LAT_C, T, H, W)
|
||||||
|
samples = samples_5d.reshape(B, _LAT_C * T, H, W).contiguous()
|
||||||
|
|
||||||
|
cond_5d = torch.arange(
|
||||||
|
B * _COND_C * T * H * W, dtype=torch.float32
|
||||||
|
).reshape(B, _COND_C, T, H, W) + 10000.0
|
||||||
|
cond = cond_5d.reshape(B, _COND_C * T, H, W).contiguous()
|
||||||
|
|
||||||
|
text_pos = torch.zeros(1, 4, 32)
|
||||||
|
text_neg = torch.zeros(1, 4, 32)
|
||||||
|
positive = [[text_pos, {"condition": cond.clone()}]]
|
||||||
|
negative = [[text_neg, {"condition": cond.clone()}]]
|
||||||
|
latent_image = {"samples": samples}
|
||||||
|
return latent_image, positive, negative, samples_5d, cond_5d
|
||||||
|
|
||||||
|
|
||||||
|
def _identity_fix_empty(model, latent_image, downscale_ratio_spacial=None):
|
||||||
|
return latent_image
|
||||||
|
|
||||||
|
|
||||||
|
def _fingerprinted_prepare_noise(latent_image, seed, batch_inds=None):
|
||||||
|
"""Return a tensor whose values encode ``(seed, position)``."""
|
||||||
|
base = torch.arange(
|
||||||
|
latent_image.numel(), dtype=torch.float32
|
||||||
|
).reshape(latent_image.shape)
|
||||||
|
return base + float(seed) * 1e6
|
||||||
|
|
||||||
|
|
||||||
|
def test_progressive_sampler_schema_exposes_manual_default_auto_chunking():
|
||||||
|
schema = SeedVR2ProgressiveSampler.define_schema()
|
||||||
|
inputs = {item.id: item for item in schema.inputs}
|
||||||
|
|
||||||
|
assert inputs["chunking_mode"].options == ["manual", "auto"]
|
||||||
|
assert inputs["chunking_mode"].default == "manual"
|
||||||
|
|
||||||
|
|
||||||
|
def test_vram_seed_frames_per_chunk_predicts_4n1_clamped_to_t_pixel():
|
||||||
|
"""VRAM chunk-size law: seed = nearest 4n+1 to 4*(free_GB - 3), clamped to [1, t_pixel]."""
|
||||||
|
gib = 1024 ** 3
|
||||||
|
seed = nodes_seedvr_mod._seedvr2_vram_seed_frames_per_chunk
|
||||||
|
assert seed(20 * gib, 65) == 65 # 4*(20-3)=68 -> 4n+1 69 -> clamp to t_pixel 65
|
||||||
|
assert seed(6 * gib, 97) == 13 # 4*(6-3)=12 -> nearest 4n+1 13
|
||||||
|
assert seed(2 * gib, 97) == 1 # below margin -> floor at 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("bad_chunk", [0, -1, 2])
|
||||||
|
def test_t3_invalid_frames_per_chunk_raises_value_error(bad_chunk):
|
||||||
|
"""``frames_per_chunk`` violating 4n+1 (or <1) must raise ``ValueError`` before any model invocation."""
|
||||||
|
latent, pos, neg, _, _ = _make_inputs(T=5)
|
||||||
|
|
||||||
|
sampler_called = {"n": 0}
|
||||||
|
|
||||||
|
def _should_not_be_called(*args, **kwargs):
|
||||||
|
sampler_called["n"] += 1
|
||||||
|
return torch.zeros(1)
|
||||||
|
|
||||||
|
with patch.object(comfy.sample, "sample",
|
||||||
|
side_effect=_should_not_be_called), \
|
||||||
|
patch.object(comfy.sample, "fix_empty_latent_channels",
|
||||||
|
side_effect=_identity_fix_empty), \
|
||||||
|
patch.object(comfy.sample, "prepare_noise",
|
||||||
|
side_effect=_fingerprinted_prepare_noise):
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
SeedVR2ProgressiveSampler.execute(
|
||||||
|
model=None, seed=0, steps=2, cfg=1.0,
|
||||||
|
sampler_name="euler", scheduler="simple",
|
||||||
|
positive=pos, negative=neg, latent=latent,
|
||||||
|
denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0,
|
||||||
|
)
|
||||||
|
assert str(bad_chunk) in str(excinfo.value)
|
||||||
|
assert sampler_called["n"] == 0
|
||||||
Loading…
Reference in New Issue
Block a user